勾配降下法1(最急降下法)
勾配降下法をやっています。
問題設定:
損失関数J(w)を最小にするようなwを求めることが目的です。
今、としておきます。Jのグラフは、
のようになります。
のとき最小値は0です。
func2=function(t){ return(t^4) } dffunc2=function(t){ return(4*t^3) } xlist=c() cntlist=c() errorlist=c() grad2=function(x,alpha,times){ cnt=1 for(i in 1:times){ xbefore=x x=x-alpha*dffunc2(x)#更新則 if(cnt==1||0==cnt%%1000000){ xlist<<-c(xlist,x) cntlist<<-c(cntlist,cnt) errorlist<<-c(errorlist,abs(x-xbefore)) printf("cnt:%d;x=%f,abs(x-xbefore)=%.13f\n",cnt,x,abs(x-xbefore)) } if(abs(x-xbefore)<=1.0e-11){ printf("cnt:%d;x=%f,abs(x-xbefore)=%.13f\n",cnt,x,abs(x-xbefore)) break } cnt=cnt+1 } }
alpha=0.1のとき。
> grad2(0.01,0.1,1e+8) cnt:1;x=0.010000,abs(x-xbefore)=0.0000004000000 cnt:1000000;x=0.001111,abs(x-xbefore)=0.0000000005487 cnt:2000000;x=0.000788,abs(x-xbefore)=0.0000000001958 cnt:3000000;x=0.000644,abs(x-xbefore)=0.0000000001069 cnt:4000000;x=0.000558,abs(x-xbefore)=0.0000000000696 cnt:5000000;x=0.000499,abs(x-xbefore)=0.0000000000498 cnt:6000000;x=0.000456,abs(x-xbefore)=0.0000000000379 cnt:7000000;x=0.000422,abs(x-xbefore)=0.0000000000301 cnt:8000000;x=0.000395,abs(x-xbefore)=0.0000000000246 cnt:9000000;x=0.000372,abs(x-xbefore)=0.0000000000207 cnt:10000000;x=0.000353,abs(x-xbefore)=0.0000000000176 cnt:11000000;x=0.000337,abs(x-xbefore)=0.0000000000153 cnt:12000000;x=0.000323,abs(x-xbefore)=0.0000000000134 cnt:13000000;x=0.000310,abs(x-xbefore)=0.0000000000119 cnt:14000000;x=0.000299,abs(x-xbefore)=0.0000000000107 cnt:14607585;x=0.000292,abs(x-xbefore)=0.0000000000100
更新回数によるxの変化を表したグラフ
plot(cntlist,xlist,type="l")
更新回数による誤差|x-xbefore|の変化を表したグラフ
plot(cntlist,errorlist,type="l")
どちらのグラフもなめらか(なはず)です。
(数が大きすぎて一部だけを採用しています)
発散しちゃう場合(alpha=1のとき):
xlist=c() cntlist=c() errorlist=c() grad21=function(x,alpha,times){ cnt=1 for(i in 1:times){ xbefore=x x=x-alpha*dffunc2(x) #if(cnt==1||0==cnt%%1000000){ xlist<<-c(xlist,x) cntlist<<-c(cntlist,cnt) errorlist<<-c(errorlist,abs(x-xbefore)) printf("cnt:%d;x=%f,abs(x-xbefore)=%.13f\n",cnt,x,abs(x-xbefore)) #} if(abs(x-xbefore)<=1.0e-11){ #printf("cnt:%d;x=%f,abs(x-xbefore)=%.13f\n",cnt,x,abs(x-xbefore)) break } cnt=cnt+1 } }
> grad21(-2,1,6) cnt:1;x=30.000000,abs(x-xbefore)=32.0000000000000 cnt:2;x=-107970.000000,abs(x-xbefore)=108000.0000000000000 cnt:3;x=5034650126184030.000000,abs(x-xbefore)=5034650126292000.0000000000000 cnt:4;x=-510467242137979669844282662082828080260080064684.000000,abs(x-xbefore)=510467242137979669844282662082828080260080064684.0000000000000 cnt:5;x=532063692658205331924864608888064082424424628800086842620862222602804040662048448264440204404406488004486208466682268824660688084444820826866848.000000,abs(x-xbefore)=532063692658205331924864608888064082424424628800086842620862222602804040662048448264440204404406488004486208466682268824660688084444820826866848.0000000000000 cnt:6;x=-Inf,abs(x-xbefore)=Inf
最初、これでやってて完全にはまってました。
alphaは、刻み幅や学習係数と呼ばれ、xが収束するかを大きく左右するようですね。
やってみた感じ、alpha>=1ときは、発散しちゃいます。
損失関数をとし、結果を表示する件数を変えるために新しく関数grad3を定義します。
損失関数のグラフは、
x=seq(-5,5,,1000) plot(x,log(cosh(x)),xlim=c(-5,5),ylim=c(-2,2),type="l") par(new=T) abline(h = 0)
となります。
損失関数の微分は、です。グラフは以下です。
x=seq(-5,5,,1000) plot(x,tanh(x),xlim=c(-5,5),ylim=c(-2,2),type="l") par(new=T) abline(h = 0)
func3=function(t){ return(log(cosh(t))) } dffunc3=function(t){ return((sinh(t))/(cosh(t))) } xlist=c() cntlist=c() errorlist=c() grad3=function(x,alpha,times){ cnt=1 for(i in 1:times){ xbefore=x x=x-alpha*dffunc3(x) if(cnt==1||0==cnt%%100){ xlist<<-c(xlist,x) cntlist<<-c(cntlist,cnt) errorlist<<-c(errorlist,abs(x-xbefore)) printf("cnt:%d;x=%f,abs(x-xbefore)=%.13f\n",cnt,x,abs(x-xbefore)) } if(abs(x-xbefore)<=1.0e-8){ printf("cnt:%d;x=%f,abs(x-xbefore)=%.13f\n",cnt,x,abs(x-xbefore)) break } cnt=cnt+1 } }
> grad3(5,0.01,100000) cnt:1;x=4.990001,abs(x-xbefore)=0.0099990920426 cnt:100;x=4.000287,abs(x-xbefore)=0.0099934294467 cnt:200;x=3.002400,abs(x-xbefore)=0.0099517513467 cnt:300;x=2.017618,abs(x-xbefore)=0.0096590492535 cnt:400;x=1.113095,abs(x-xbefore)=0.0080797706330 cnt:500;x=0.479389,abs(x-xbefore)=0.0044934793548 cnt:600;x=0.181362,abs(x-xbefore)=0.0018115185266 cnt:700;x=0.066704,abs(x-xbefore)=0.0006727509644 cnt:800;x=0.024432,abs(x-xbefore)=0.0002467350759 cnt:900;x=0.008944,abs(x-xbefore)=0.0000903369547 cnt:1000;x=0.003274,abs(x-xbefore)=0.0000330674215 cnt:1100;x=0.001198,abs(x-xbefore)=0.0000121038033 cnt:1200;x=0.000439,abs(x-xbefore)=0.0000044303863 cnt:1300;x=0.000161,abs(x-xbefore)=0.0000016216648 cnt:1400;x=0.000059,abs(x-xbefore)=0.0000005935818 cnt:1500;x=0.000022,abs(x-xbefore)=0.0000002172701 cnt:1600;x=0.000008,abs(x-xbefore)=0.0000000795279 cnt:1700;x=0.000003,abs(x-xbefore)=0.0000000291098 cnt:1800;x=0.000001,abs(x-xbefore)=0.0000000106551 cnt:1807;x=0.000001,abs(x-xbefore)=0.0000000099313
plot(cntlist,xlist,type="l")
plot(cntlist,errorlist,type="l")
さっきよりずっと早く収束しましたねヾ(=´・∀・`=)