Scalaで誤差逆伝播法を実装(行列計算版)

公開日: : 最終更新日:2016/01/23 scala, 技術

githubにアップしました。
前回実装した誤差逆伝播法を行列で実装しました。
breeze-nlpを使いました。
import scala.util.Random;
import breeze.linalg._;
import breeze.plot._;

object MatrixStyle {

  def sigmoid(v:Double):Double = {
    1.0 / (1.0 + Math.exp(-v));
  }
  
  def main(args: Array[String]): Unit = {
    val signal_num = 2; //バイアス含めた入力次元数
    val hidden_num = 4; //隠れ層の次元数
    val output_num = 1;
    val eta = 0.1;
    
    val N = 100;
    val temp_input = ( -1.0 to 1.0 by 2.0 / N ).toArray;
    val temp_inst = temp_input.map { v => 0.5 * ( Math.sin ( Math.PI * v ) + 1)};
    //val temp_inst = temp_input.map { v => if(v > 0.0) 1.0 else 0.0};
    //val temp_inst = temp_input.map { v => v*v };
    //val temp_inst = temp_input.map { v => sigmoid(v) };
    val input = ((arr:Array[Double]) => {
      	var a = Array.ofDim[Double](arr.length, signal_num-1);
      	for( i <- 0 until arr.length){
      		a(i)(0) = arr(i);
      	}
      	a;
      }:Array[Array[Double]])(temp_input);
    
    val instraction = ((arr:Array[Double]) => {
      var a = Array.ofDim[Double](arr.length, output_num);
      for ( i <- 0 until arr.length){
        a(i)(0) = arr(i);
      }
      a;
    }:Array[Array[Double]])(temp_inst);
    
    println(input.deep);
    println(instraction.deep);

    val rand = new Random;
 		var w1 = DenseMatrix.zeros[Double]( hidden_num, signal_num).map{v => rand.nextDouble()*2 - 1};
    var w2 = DenseMatrix.zeros[Double]( output_num, hidden_num).map{v => rand.nextDouble()*2 - 1};

    for ( loop <- 1 to 10000){
    	for( input_num <- 0 until input.length ){
        
        val in = input(input_num);
        //入力ベクトル
        val x1 = DenseVector(in :+ 1.0); //バイアスベクトルを追加して入力層を生成
    		val x2 = (w1 * x1).map { v => sigmoid(v) }; //順伝播で隠れ層の出力 バイアスベクトルは追加しない
        val x3 = (w2 * x2).map { v => sigmoid(v) }; //順伝播で出力層へ出力
        
        //誤差を計算
        val inst = DenseVector(instraction(input_num)); //教師信号
        val error = (x3 - inst).map{ v => 0.5 * v * v }.sum; 
        //出力層の誤差を計算
        val error_out = (x3-inst) * x3.t * (1.0-x3);        
        //隠れ層の誤差を計算
        val error_hidden = diag(diag(x2*(w2.t*error_out).t) * (1.0 - x2).t);
        /*
        //デバッグ
        println("w2 " + w2);
        println("error_out "+error_out);
        println("w2.t * error_out " + w2.t*error_out);
        println("x2 "+x2);
        println("x2*(w2.t*error_out).t " + x2*(w2.t*error_out).t);
        println("diag(x2*(w2.t*error_out).t) " + diag(x2*(w2.t*error_out).t));
        System.exit(0);
        */
        
        //重みを更新
        w1 -= (x1 * error_hidden.t).t :* eta;
        w2 -= (x2 * error_out.t).t :* eta;
    	} 
    }
    
    //ここから出力
    def output(in:Array[Double]):Double = {
    		val x1 = DenseVector(in :+ 1.0); //バイアスベクトルを追加して入力層を生成
    		val x2 = (w1 * x1).map { v => sigmoid(v) }; //順伝播で隠れ層の出力 バイアスベクトルは追加しない
    		val x3 = (w2 * x2).map { v => sigmoid(v) }; //順伝播で出力層へ出力
        x3(0);
    }
    val o = input.map{v => output(v)};
    val f = Figure();
    val p = f.subplot(0);
    p += plot(temp_input,temp_inst,'-');
    p += plot(temp_input,o,'+');
  }  
}

sin

関連記事

no image

Sparkでhdfsに出力 その2

前回のhdfsに出力するはなしの続き hdfsはnamenode(ネームノード)を経由して出力

記事を読む

亀山ダムのグラフ表示

ソースコードの紛失などの凡ミスを経て、ようやく亀山ダムの釣果情報をグラフで表示できるようにな

記事を読む

no image

scalaで複数ファイルをコンパイルして実行する方法

scalaで複数ファイルを作って、実行する方法がわからなくて2時間ほど時間を無駄にした。(無職なんだ

記事を読む

no image

水中探索ロボット(水中ドローン)の決定版 OpenROV

以前も記事に書きましたが、私は水中探索カメラに興味があります。 先日、フランス人に釣りのことを

記事を読む

no image

Scalaで配列同士の二乗誤差を出力する

配列 Tと配列 Xがあり、その二乗誤差を求める実装。 数式だと以下の様な感じ。誰もが見たことある。

記事を読む

no image

Sparkを最小労力で動かす その3

結局わかったことは、Cloudera manager難しいってことと、 Spark on Ya

記事を読む

no image

cloudera managerを動かすためにメモリを増設した

Cloudera Managerの推奨スペックが 8GBと非常に要求が高かったので、Amazo

記事を読む

no image

ScalaでSparse Codingの実装を試す

Scala言語でスパースコーディング(Sparse Coding)を実装している人が居たので試してみ

記事を読む

no image

scala 多重ループ

関数型言語のscalaで多重ループを表現する方法を教えて下さい。 書きかけ

記事を読む

no image

true positive, false positive, ROCカーブがわからなかった

true positive, false positive, false negative, tru

記事を読む

no image
メタニウムMGLとカシータスMGLの飛距離について

メタニウムMGLとカシータスMGLは飛距離については一緒らしいです。

no image
ゾディアス166m-2とクロノス662mbの比較

どうも2ピースロッドが大好きなヤワイです。 同じスペックの竿を2

no image
エリアトラウトで村田基さんの真似を辞めたら釣れるようになった

半年ぶりの更新です。 今年一年は、仕事が忙しくほとんど釣りに行け

上からクロノス、エアエッジ、ブレイゾン、ゾディアス
クロノス 672MHB 使ってみてのレビュー

クロノスロッドはかなりイケてるロッドの気がして、3月に購入してから2回

趣味にエリアフィッシング(ニジマス釣り)をおすすめする7つの理由

休みの日が雨だと何もやることがなくて嫌ですね。昔は雷さえならなければ釣

大好き片倉ダム(笹川湖)~ボート屋紹介~

亀山ダムの釣果情報を出しているページを運営しているものの、なんだかんだ

ピニオンギアのスペースチューニング
アルテグラチューニング

久しぶりの投稿です。 最近、エリアトラウトにどっぷりはまっており

素人がエリアトラウトで人並みに釣ることができるようになるまで

エリアトラウトに始めて2ヶ月が経ちました。完全にハマってます。 技量

王禅寺フィッシュオンへのアクセス1
初エリアトラウトにフィッシュオン王禅寺に行ってきた(公共交通機関でアクセス編)

前回の投稿以降、更新を忘れていたので書きます。 私は車やバイクを

【最安か?】上州屋渋谷の創業祭でジリオン SV TWが10%オフだった

なかなか今年の初バス釣りの予定がたちませんね。 グラフから釣行日

→もっと見る

PAGE TOP ↑