Scalaで誤差逆伝播法を実装(行列計算版)

公開日: : 最終更新日:2016/01/23 scala, 技術

githubにアップしました。
前回実装した誤差逆伝播法を行列で実装しました。
breeze-nlpを使いました。
import scala.util.Random;
import breeze.linalg._;
import breeze.plot._;

object MatrixStyle {

  def sigmoid(v:Double):Double = {
    1.0 / (1.0 + Math.exp(-v));
  }
  
  def main(args: Array[String]): Unit = {
    val signal_num = 2; //バイアス含めた入力次元数
    val hidden_num = 4; //隠れ層の次元数
    val output_num = 1;
    val eta = 0.1;
    
    val N = 100;
    val temp_input = ( -1.0 to 1.0 by 2.0 / N ).toArray;
    val temp_inst = temp_input.map { v => 0.5 * ( Math.sin ( Math.PI * v ) + 1)};
    //val temp_inst = temp_input.map { v => if(v > 0.0) 1.0 else 0.0};
    //val temp_inst = temp_input.map { v => v*v };
    //val temp_inst = temp_input.map { v => sigmoid(v) };
    val input = ((arr:Array[Double]) => {
      	var a = Array.ofDim[Double](arr.length, signal_num-1);
      	for( i <- 0 until arr.length){
      		a(i)(0) = arr(i);
      	}
      	a;
      }:Array[Array[Double]])(temp_input);
    
    val instraction = ((arr:Array[Double]) => {
      var a = Array.ofDim[Double](arr.length, output_num);
      for ( i <- 0 until arr.length){
        a(i)(0) = arr(i);
      }
      a;
    }:Array[Array[Double]])(temp_inst);
    
    println(input.deep);
    println(instraction.deep);

    val rand = new Random;
 		var w1 = DenseMatrix.zeros[Double]( hidden_num, signal_num).map{v => rand.nextDouble()*2 - 1};
    var w2 = DenseMatrix.zeros[Double]( output_num, hidden_num).map{v => rand.nextDouble()*2 - 1};

    for ( loop <- 1 to 10000){
    	for( input_num <- 0 until input.length ){
        
        val in = input(input_num);
        //入力ベクトル
        val x1 = DenseVector(in :+ 1.0); //バイアスベクトルを追加して入力層を生成
    		val x2 = (w1 * x1).map { v => sigmoid(v) }; //順伝播で隠れ層の出力 バイアスベクトルは追加しない
        val x3 = (w2 * x2).map { v => sigmoid(v) }; //順伝播で出力層へ出力
        
        //誤差を計算
        val inst = DenseVector(instraction(input_num)); //教師信号
        val error = (x3 - inst).map{ v => 0.5 * v * v }.sum; 
        //出力層の誤差を計算
        val error_out = (x3-inst) * x3.t * (1.0-x3);        
        //隠れ層の誤差を計算
        val error_hidden = diag(diag(x2*(w2.t*error_out).t) * (1.0 - x2).t);
        /*
        //デバッグ
        println("w2 " + w2);
        println("error_out "+error_out);
        println("w2.t * error_out " + w2.t*error_out);
        println("x2 "+x2);
        println("x2*(w2.t*error_out).t " + x2*(w2.t*error_out).t);
        println("diag(x2*(w2.t*error_out).t) " + diag(x2*(w2.t*error_out).t));
        System.exit(0);
        */
        
        //重みを更新
        w1 -= (x1 * error_hidden.t).t :* eta;
        w2 -= (x2 * error_out.t).t :* eta;
    	} 
    }
    
    //ここから出力
    def output(in:Array[Double]):Double = {
    		val x1 = DenseVector(in :+ 1.0); //バイアスベクトルを追加して入力層を生成
    		val x2 = (w1 * x1).map { v => sigmoid(v) }; //順伝播で隠れ層の出力 バイアスベクトルは追加しない
    		val x3 = (w2 * x2).map { v => sigmoid(v) }; //順伝播で出力層へ出力
        x3(0);
    }
    val o = input.map{v => output(v)};
    val f = Figure();
    val p = f.subplot(0);
    p += plot(temp_input,temp_inst,'-');
    p += plot(temp_input,o,'+');
  }  
}

sin

関連記事

no image

ScalaNLP Breezeでjarファイル化に失敗する場合の対策

エラー ScalaNLPをeclipseで使いたいために、jarファイルにしようと思うのですが、い

記事を読む

亀山ダムのグラフ表示

ソースコードの紛失などの凡ミスを経て、ようやく亀山ダムの釣果情報をグラフで表示できるようにな

記事を読む

no image

macのlsが死んでしまった。 ls : invalid line width: Fが出る問題。

基本的にlinuxで開発をすることが多いのですが、お茶でもしながらノマドワークしようと、macに開発

記事を読む

Github EvolutionalComputationにjDEのScala実装を追加

https://github.com/yawaip/EvolutionalComputation

記事を読む

no image

周波数スペクトルの見方と意味がわからなかった

フーリエ変換 ← わかる 周波数スペクトル ← わからん ヤワイは画像処理をすこしかじっ

記事を読む

no image

オンライン学習の勉強してみる

クラウド関係を触っているものですから、どうしても気になるわけですね。 学会等でお会いしたことがある

記事を読む

no image

fstabがうまくいかなかった

raspberry pi で radikoを録音しようと思って、まずはNASに接続しなきゃと思ったの

記事を読む

no image

水中探索ロボット(水中ドローン)の決定版 OpenROV

以前も記事に書きましたが、私は水中探索カメラに興味があります。 先日、フランス人に釣りのことを

記事を読む

no image

Scalaで配列同士の二乗誤差を出力する

配列 Tと配列 Xがあり、その二乗誤差を求める実装。 数式だと以下の様な感じ。誰もが見たことある。

記事を読む

Github EvolutionalComputationにCMA-ESのScala実装を追加

https://github.com/yawaip/EvolutionalComputation

記事を読む

no image
メタニウムMGLとカシータスMGLの飛距離について

メタニウムMGLとカシータスMGLは飛距離については一緒らしいです。

no image
ゾディアス166m-2とクロノス662mbの比較

どうも2ピースロッドが大好きなヤワイです。 同じスペックの竿を2

no image
エリアトラウトで村田基さんの真似を辞めたら釣れるようになった

半年ぶりの更新です。 今年一年は、仕事が忙しくほとんど釣りに行け

上からクロノス、エアエッジ、ブレイゾン、ゾディアス
クロノス 672MHB 使ってみてのレビュー

クロノスロッドはかなりイケてるロッドの気がして、3月に購入してから2回

趣味にエリアフィッシング(ニジマス釣り)をおすすめする7つの理由

休みの日が雨だと何もやることがなくて嫌ですね。昔は雷さえならなければ釣

大好き片倉ダム(笹川湖)~ボート屋紹介~

亀山ダムの釣果情報を出しているページを運営しているものの、なんだかんだ

ピニオンギアのスペースチューニング
アルテグラチューニング

久しぶりの投稿です。 最近、エリアトラウトにどっぷりはまっており

素人がエリアトラウトで人並みに釣ることができるようになるまで

エリアトラウトに始めて2ヶ月が経ちました。完全にハマってます。 技量

王禅寺フィッシュオンへのアクセス1
初エリアトラウトにフィッシュオン王禅寺に行ってきた(公共交通機関でアクセス編)

前回の投稿以降、更新を忘れていたので書きます。 私は車やバイクを

【最安か?】上州屋渋谷の創業祭でジリオン SV TWが10%オフだった

なかなか今年の初バス釣りの予定がたちませんね。 グラフから釣行日

→もっと見る

PAGE TOP ↑