2次元版、ニュートン法(関数の勾配が0である解を求める)
#loss function1 x=seq(-2,2,length=50) y=seq(-2,2,length=50) lossfunc1=function(x,y){ return(x*y*exp(-x^2-y^2)) } z=matrix(0,50,50) for(i in 1:50){ for(j in 1:50){ z[i,j]=lossfunc1(x[i],y[j]) } } persp(x,y,z,ticktype="detailed",theta = 30, phi = 20, expand = 0.5,col = "lightblue",xlim=c(-2,2),ylim=c(-2,2),zlim=c(-0.5,0.5)) dffunc1=function(vecw){ x=vecw[1];y=vecw[2] return(matrix(c(exp(-x^2-y^2)*y*(1-2*x^2),exp(-x^2-y^2)*x*(1-2*y^2)),2,1)) } ddffunc1=function(vecw){ x=vecw[1];y=vecw[2] return(matrix(c(2*exp(-x^2-y^2)*x*(-3+2*x^2)*y,exp(-x^2-y^2)*(-1+2*x^2)*(-1+2*y^2),exp(-x^2-y^2)*(-1+2*x^2)*(-1+2*y^2),2*exp(-x^2-y^2)*x*(-3+2*y^2)*y),2,2)) } newton202=function(vecw,times){ cnt=1 for(i in 1:times){ vecw_before=vecw vecw=vecw-solve(ddffunc1(vecw))%*%dffunc1(vecw) printf("cnt:%d;vecw=(%.8f,%.8f),error=%.8f\n",cnt,vecw[1],vecw[2],(dffunc1(vecw)[1])^2+(dffunc1(vecw)[2])^2) if((dffunc1(vecw)[1])^2+(dffunc1(vecw)[2])^2<=1.0e-8){ break } cnt=cnt+1 } }
> vecw=matrix(c(0.5,0.5),2,1) > newton202(vecw,1e+3) cnt:1;vecw=(0.75000000,0.75000000),error=0.00185272 cnt:2;vecw=(0.70522388,0.70522388),error=0.00000385 cnt:3;vecw=(0.70710432,0.70710432),error=0.00000000 > lossfunc1(0.70710432,0.70710432) [1] 0.1839397 > > vecw=matrix(c(-0.5,0.5),2,1) > newton202(vecw,1e+3) cnt:1;vecw=(-0.75000000,0.75000000),error=0.00185272 cnt:2;vecw=(-0.70522388,0.70522388),error=0.00000385 cnt:3;vecw=(-0.70710432,0.70710432),error=0.00000000 > lossfunc1(-0.70710432,0.70710432) [1] -0.1839397 > > vecw=matrix(c(0.5,-0.5),2,1) > newton202(vecw,1e+3) cnt:1;vecw=(0.75000000,-0.75000000),error=0.00185272 cnt:2;vecw=(0.70522388,-0.70522388),error=0.00000385 cnt:3;vecw=(0.70710432,-0.70710432),error=0.00000000 > lossfunc1(-0.70710432,0.70710432) [1] -0.1839397 > > vecw=matrix(c(-0.5,-0.5),2,1) > newton202(vecw,1e+3) cnt:1;vecw=(-0.75000000,-0.75000000),error=0.00185272 cnt:2;vecw=(-0.70522388,-0.70522388),error=0.00000385 cnt:3;vecw=(-0.70710432,-0.70710432),error=0.00000000 > lossfunc1(-0.70710432,-0.70710432) [1] 0.1839397
> vecw=matrix(c(-0.41,0.41),2,1) > newton202(vecw,1e+3) cnt:1;vecw=(-1.00822949,1.00822949),error=0.03719689 cnt:2;vecw=(0.15078630,-0.15078630),error=0.03782969 cnt:3;vecw=(-0.03450600,0.03450600),error=0.00235874 cnt:4;vecw=(0.00033224,-0.00033224),error=0.00000022 cnt:5;vecw=(-0.00000000,0.00000000),error=0.00000000 > > vecw=matrix(c(-0.42,0.42),2,1) > newton202(vecw,1e+3) cnt:1;vecw=(-0.94774768,0.94774768),error=0.03135856 cnt:2;vecw=(-0.45367526,0.45367526),error=0.06255394 cnt:3;vecw=(-0.82475590,0.82475590),error=0.01163297 cnt:4;vecw=(-0.68323527,0.68323527),error=0.00063577 cnt:5;vecw=(-0.70679739,0.70679739),error=0.00000010 cnt:6;vecw=(-0.70710671,0.70710671),error=0.00000000 > > vecw=matrix(c(-0.43,0.43),2,1) > newton202(vecw,1e+3) cnt:1;vecw=(-0.90087389,0.90087389),error=0.02452967 cnt:2;vecw=(-0.59685511,0.59685511),error=0.01416716 cnt:3;vecw=(-0.70777844,0.70777844),error=0.00000049 cnt:4;vecw=(-0.70710646,0.70710646),error=0.00000000 > > vecw=matrix(c(-0.44,0.44),2,1) > newton202(vecw,1e+3) cnt:1;vecw=(-0.86384818,0.86384818),error=0.01829444 cnt:2;vecw=(-0.65192581,0.65192581),error=0.00349310 cnt:3;vecw=(-0.70609653,0.70609653),error=0.00000111 cnt:4;vecw=(-0.70710607,0.70710607),error=0.00000000 > > vecw=matrix(c(-0.45,0.45),2,1) > newton202(vecw,1e+3) cnt:1;vecw=(-0.83417390,0.83417390),error=0.01320154 cnt:2;vecw=(-0.67745283,0.67745283),error=0.00098713 cnt:3;vecw=(-0.70666376,0.70666376),error=0.00000021 cnt:4;vecw=(-0.70710664,0.70710664),error=0.00000000 >
vecw=matrix(c(-0.41,0.41),2,1)の場合はダメでした。