Scalaで誤差逆伝播法を実装(行列計算版)

公開日: : 最終更新日:2016/01/23 scala, 技術

githubにアップしました。
前回実装した誤差逆伝播法を行列で実装しました。
breeze-nlpを使いました。
import scala.util.Random;
import breeze.linalg._;
import breeze.plot._;

object MatrixStyle {

  def sigmoid(v:Double):Double = {
    1.0 / (1.0 + Math.exp(-v));
  }
  
  def main(args: Array[String]): Unit = {
    val signal_num = 2; //バイアス含めた入力次元数
    val hidden_num = 4; //隠れ層の次元数
    val output_num = 1;
    val eta = 0.1;
    
    val N = 100;
    val temp_input = ( -1.0 to 1.0 by 2.0 / N ).toArray;
    val temp_inst = temp_input.map { v => 0.5 * ( Math.sin ( Math.PI * v ) + 1)};
    //val temp_inst = temp_input.map { v => if(v > 0.0) 1.0 else 0.0};
    //val temp_inst = temp_input.map { v => v*v };
    //val temp_inst = temp_input.map { v => sigmoid(v) };
    val input = ((arr:Array[Double]) => {
      	var a = Array.ofDim[Double](arr.length, signal_num-1);
      	for( i <- 0 until arr.length){
      		a(i)(0) = arr(i);
      	}
      	a;
      }:Array[Array[Double]])(temp_input);
    
    val instraction = ((arr:Array[Double]) => {
      var a = Array.ofDim[Double](arr.length, output_num);
      for ( i <- 0 until arr.length){
        a(i)(0) = arr(i);
      }
      a;
    }:Array[Array[Double]])(temp_inst);
    
    println(input.deep);
    println(instraction.deep);

    val rand = new Random;
 		var w1 = DenseMatrix.zeros[Double]( hidden_num, signal_num).map{v => rand.nextDouble()*2 - 1};
    var w2 = DenseMatrix.zeros[Double]( output_num, hidden_num).map{v => rand.nextDouble()*2 - 1};

    for ( loop <- 1 to 10000){
    	for( input_num <- 0 until input.length ){
        
        val in = input(input_num);
        //入力ベクトル
        val x1 = DenseVector(in :+ 1.0); //バイアスベクトルを追加して入力層を生成
    		val x2 = (w1 * x1).map { v => sigmoid(v) }; //順伝播で隠れ層の出力 バイアスベクトルは追加しない
        val x3 = (w2 * x2).map { v => sigmoid(v) }; //順伝播で出力層へ出力
        
        //誤差を計算
        val inst = DenseVector(instraction(input_num)); //教師信号
        val error = (x3 - inst).map{ v => 0.5 * v * v }.sum; 
        //出力層の誤差を計算
        val error_out = (x3-inst) * x3.t * (1.0-x3);        
        //隠れ層の誤差を計算
        val error_hidden = diag(diag(x2*(w2.t*error_out).t) * (1.0 - x2).t);
        /*
        //デバッグ
        println("w2 " + w2);
        println("error_out "+error_out);
        println("w2.t * error_out " + w2.t*error_out);
        println("x2 "+x2);
        println("x2*(w2.t*error_out).t " + x2*(w2.t*error_out).t);
        println("diag(x2*(w2.t*error_out).t) " + diag(x2*(w2.t*error_out).t));
        System.exit(0);
        */
        
        //重みを更新
        w1 -= (x1 * error_hidden.t).t :* eta;
        w2 -= (x2 * error_out.t).t :* eta;
    	} 
    }
    
    //ここから出力
    def output(in:Array[Double]):Double = {
    		val x1 = DenseVector(in :+ 1.0); //バイアスベクトルを追加して入力層を生成
    		val x2 = (w1 * x1).map { v => sigmoid(v) }; //順伝播で隠れ層の出力 バイアスベクトルは追加しない
    		val x3 = (w2 * x2).map { v => sigmoid(v) }; //順伝播で出力層へ出力
        x3(0);
    }
    val o = input.map{v => output(v)};
    val f = Figure();
    val p = f.subplot(0);
    p += plot(temp_input,temp_inst,'-');
    p += plot(temp_input,o,'+');
  }  
}

sin

関連記事

亀山ダムのグラフ表示

ソースコードの紛失などの凡ミスを経て、ようやく亀山ダムの釣果情報をグラフで表示できるようにな

記事を読む

no image

cloudera managerを動かすためにメモリを増設した

Cloudera Managerの推奨スペックが 8GBと非常に要求が高かったので、Amazo

記事を読む

no image

Scala IDE ( Eclipse )でgithubに接続する

何回やっても忘れてしまって、作業に手間取ってしまうのでメモします。 プロジェクトのところで→[

記事を読む

Githubに遺伝的アルゴリズム(Real Coded Genetic Algorithm)のScala実装をアップしました。

私のGithubです。 遺伝的アルゴリズムを用いて実数問題を解くということに、興味を持

記事を読む

Github EvolutionalComputationにwPSOのScala実装を追加

https://github.com/yawaip/EvolutionalComputation

記事を読む

no image

水中探索ロボット(水中ドローン)の決定版 OpenROV

以前も記事に書きましたが、私は水中探索カメラに興味があります。 先日、フランス人に釣りのことを

記事を読む

no image

Virtual Box quickstart vmでSparkを動かす その1

Cloudera ManagerのVirtual Box Sparkを使ってみたのだけど、 以下の

記事を読む

scalaでベクトルの重み計算

以下のような重みの計算を行う場合、数式通りに配列を用いて実装する場合と、行列計算で行う方法をメモする

記事を読む

no image

ScalaNLP Breezeでjarファイル化に失敗する場合の対策

エラー ScalaNLPをeclipseで使いたいために、jarファイルにしようと思うのですが、い

記事を読む

no image

Scala言語でベクトルの単位ベクトルを生成

式はこれ メモ def main(args: Array) { val vec = L

記事を読む

no image
メタニウムMGLとカシータスMGLの飛距離について

メタニウムMGLとカシータスMGLは飛距離については一緒らしいです。

no image
ゾディアス166m-2とクロノス662mbの比較

どうも2ピースロッドが大好きなヤワイです。 同じスペックの竿を2

no image
エリアトラウトで村田基さんの真似を辞めたら釣れるようになった

半年ぶりの更新です。 今年一年は、仕事が忙しくほとんど釣りに行け

上からクロノス、エアエッジ、ブレイゾン、ゾディアス
クロノス 672MHB 使ってみてのレビュー

クロノスロッドはかなりイケてるロッドの気がして、3月に購入してから2回

趣味にエリアフィッシング(ニジマス釣り)をおすすめする7つの理由

休みの日が雨だと何もやることがなくて嫌ですね。昔は雷さえならなければ釣

大好き片倉ダム(笹川湖)~ボート屋紹介~

亀山ダムの釣果情報を出しているページを運営しているものの、なんだかんだ

ピニオンギアのスペースチューニング
アルテグラチューニング

久しぶりの投稿です。 最近、エリアトラウトにどっぷりはまっており

素人がエリアトラウトで人並みに釣ることができるようになるまで

エリアトラウトに始めて2ヶ月が経ちました。完全にハマってます。 技量

王禅寺フィッシュオンへのアクセス1
初エリアトラウトにフィッシュオン王禅寺に行ってきた(公共交通機関でアクセス編)

前回の投稿以降、更新を忘れていたので書きます。 私は車やバイクを

【最安か?】上州屋渋谷の創業祭でジリオン SV TWが10%オフだった

なかなか今年の初バス釣りの予定がたちませんね。 グラフから釣行日

→もっと見る

PAGE TOP ↑