2次元版、勾配降下法1(最急降下法)
#loss function x=seq(-3,3,length=50) y=seq(-3,3,length=50) lossfunc1=function(x,y){ return(x^2+y^2) } z=matrix(0,50,50) for(i in 1:50){ for(j in 1:50){ z[i,j]=pdf(x[i],y[j]) } } persp(x,y,z,theta = 0, phi = 20, expand = 0.5,col = "lightblue",xlim=c(-3,3),ylim=c(-3,3),zlim=c(-0.5,3))
dffunc1=function(vecw){ x=vecw[1];y=vecw[2] return(matrix(c(2*x,2*y),2,1)) } vecw_list=c() cntlist=c() errorlist=c() grad201=function(vecw,alpha,times){ cnt=1 for(i in 1:times){ vecw_before=vecw vecw=vecw-alpha*dffunc1(vecw) #if(cnt==1||0==cnt%%1000000){ vecw_list<<-c(vecw_list,vecw) cntlist<<-c(cntlist,cnt) errorlist<<-c(errorlist,abs(vecw-vecw_before)) printf("cnt:%d;vecw=(%.8f,%.8f),error=%.8f\n",cnt,vecw[1],vecw[2],c(vecw-vecw_before)%*%c(vecw-vecw_before)) #} if(c(vecw-vecw_before)%*%c(vecw-vecw_before)<=1.0e-8){ break } cnt=cnt+1 } }
実行結果:
> vecw=matrix(c(1,2),2,1) > grad201(vecw,0.1,1e+3) cnt:1;vecw=(0.80000000,1.60000000),error=0.20000000 cnt:2;vecw=(0.64000000,1.28000000),error=0.12800000 cnt:3;vecw=(0.51200000,1.02400000),error=0.08192000 cnt:4;vecw=(0.40960000,0.81920000),error=0.05242880 cnt:5;vecw=(0.32768000,0.65536000),error=0.03355443 cnt:6;vecw=(0.26214400,0.52428800),error=0.02147484 cnt:7;vecw=(0.20971520,0.41943040),error=0.01374390 cnt:8;vecw=(0.16777216,0.33554432),error=0.00879609 cnt:9;vecw=(0.13421773,0.26843546),error=0.00562950 cnt:10;vecw=(0.10737418,0.21474836),error=0.00360288 cnt:11;vecw=(0.08589935,0.17179869),error=0.00230584 cnt:12;vecw=(0.06871948,0.13743895),error=0.00147574 cnt:13;vecw=(0.05497558,0.10995116),error=0.00094447 cnt:14;vecw=(0.04398047,0.08796093),error=0.00060446 cnt:15;vecw=(0.03518437,0.07036874),error=0.00038686 cnt:16;vecw=(0.02814750,0.05629500),error=0.00024759 cnt:17;vecw=(0.02251800,0.04503600),error=0.00015846 cnt:18;vecw=(0.01801440,0.03602880),error=0.00010141 cnt:19;vecw=(0.01441152,0.02882304),error=0.00006490 cnt:20;vecw=(0.01152922,0.02305843),error=0.00004154 cnt:21;vecw=(0.00922337,0.01844674),error=0.00002658 cnt:22;vecw=(0.00737870,0.01475740),error=0.00001701 cnt:23;vecw=(0.00590296,0.01180592),error=0.00001089 cnt:24;vecw=(0.00472237,0.00944473),error=0.00000697 cnt:25;vecw=(0.00377789,0.00755579),error=0.00000446 cnt:26;vecw=(0.00302231,0.00604463),error=0.00000285 cnt:27;vecw=(0.00241785,0.00483570),error=0.00000183 cnt:28;vecw=(0.00193428,0.00386856),error=0.00000117 cnt:29;vecw=(0.00154743,0.00309485),error=0.00000075 cnt:30;vecw=(0.00123794,0.00247588),error=0.00000048 cnt:31;vecw=(0.00099035,0.00198070),error=0.00000031 cnt:32;vecw=(0.00079228,0.00158456),error=0.00000020 cnt:33;vecw=(0.00063383,0.00126765),error=0.00000013 cnt:34;vecw=(0.00050706,0.00101412),error=0.00000008 cnt:35;vecw=(0.00040565,0.00081130),error=0.00000005 cnt:36;vecw=(0.00032452,0.00064904),error=0.00000003 cnt:37;vecw=(0.00025961,0.00051923),error=0.00000002 cnt:38;vecw=(0.00020769,0.00041538),error=0.00000001 cnt:39;vecw=(0.00016615,0.00033231),error=0.00000001 >