読者です 読者をやめる 読者になる 読者になる

INFINITY -数学とかプログラミングとか-

統計とプログラムを使って役に立たせたい

TeX用コマンド入力を支援するための辞書をご利用ください。
sanctuary's blogは,適当なことが書いてあります。

勾配降下法1(最急降下法)

勾配降下法をやっています。


問題設定:

損失関数J(w)を最小にするようなwを求めることが目的です。


今、J(t)=t^4としておきます。Jのグラフは、

f:id:infinity_th4:20120614234653j:image:w360

のようになります。

t=0のとき最小値は0です。

勾配は\frac{d}{dt}J(t)=4t^3です。グラフは、
f:id:infinity_th4:20120614234652j:image:w360
です。


func2=function(t){
return(t^4)
}

dffunc2=function(t){
return(4*t^3)
}

xlist=c()
cntlist=c()
errorlist=c()

grad2=function(x,alpha,times){
	cnt=1
	for(i in 1:times){
		xbefore=x
		x=x-alpha*dffunc2(x)#更新則
		if(cnt==1||0==cnt%%1000000){
			xlist<<-c(xlist,x)
			cntlist<<-c(cntlist,cnt)
			errorlist<<-c(errorlist,abs(x-xbefore))
			printf("cnt:%d;x=%f,abs(x-xbefore)=%.13f\n",cnt,x,abs(x-xbefore))
		}
		if(abs(x-xbefore)<=1.0e-11){
			printf("cnt:%d;x=%f,abs(x-xbefore)=%.13f\n",cnt,x,abs(x-xbefore))
			break
		}
		cnt=cnt+1
	}

}

alpha=0.1のとき。

> grad2(0.01,0.1,1e+8)
cnt:1;x=0.010000,abs(x-xbefore)=0.0000004000000
cnt:1000000;x=0.001111,abs(x-xbefore)=0.0000000005487
cnt:2000000;x=0.000788,abs(x-xbefore)=0.0000000001958
cnt:3000000;x=0.000644,abs(x-xbefore)=0.0000000001069
cnt:4000000;x=0.000558,abs(x-xbefore)=0.0000000000696
cnt:5000000;x=0.000499,abs(x-xbefore)=0.0000000000498
cnt:6000000;x=0.000456,abs(x-xbefore)=0.0000000000379
cnt:7000000;x=0.000422,abs(x-xbefore)=0.0000000000301
cnt:8000000;x=0.000395,abs(x-xbefore)=0.0000000000246
cnt:9000000;x=0.000372,abs(x-xbefore)=0.0000000000207
cnt:10000000;x=0.000353,abs(x-xbefore)=0.0000000000176
cnt:11000000;x=0.000337,abs(x-xbefore)=0.0000000000153
cnt:12000000;x=0.000323,abs(x-xbefore)=0.0000000000134
cnt:13000000;x=0.000310,abs(x-xbefore)=0.0000000000119
cnt:14000000;x=0.000299,abs(x-xbefore)=0.0000000000107
cnt:14607585;x=0.000292,abs(x-xbefore)=0.0000000000100

更新回数によるxの変化を表したグラフ

plot(cntlist,xlist,type="l")

f:id:infinity_th4:20120614231933j:image:w360


更新回数による誤差|x-xbefore|の変化を表したグラフ

plot(cntlist,errorlist,type="l")

f:id:infinity_th4:20120614231934j:image:w360

どちらのグラフもなめらか(なはず)です。
(数が大きすぎて一部だけを採用しています)


発散しちゃう場合(alpha=1のとき):

xlist=c()
cntlist=c()
errorlist=c()

grad21=function(x,alpha,times){
	cnt=1
	for(i in 1:times){
		xbefore=x
		x=x-alpha*dffunc2(x)
		#if(cnt==1||0==cnt%%1000000){
			xlist<<-c(xlist,x)
			cntlist<<-c(cntlist,cnt)
			errorlist<<-c(errorlist,abs(x-xbefore))
			printf("cnt:%d;x=%f,abs(x-xbefore)=%.13f\n",cnt,x,abs(x-xbefore))
		#}
		if(abs(x-xbefore)<=1.0e-11){
			#printf("cnt:%d;x=%f,abs(x-xbefore)=%.13f\n",cnt,x,abs(x-xbefore))
			break
		}
		cnt=cnt+1
	}

}
> grad21(-2,1,6)
cnt:1;x=30.000000,abs(x-xbefore)=32.0000000000000
cnt:2;x=-107970.000000,abs(x-xbefore)=108000.0000000000000
cnt:3;x=5034650126184030.000000,abs(x-xbefore)=5034650126292000.0000000000000
cnt:4;x=-510467242137979669844282662082828080260080064684.000000,abs(x-xbefore)=510467242137979669844282662082828080260080064684.0000000000000
cnt:5;x=532063692658205331924864608888064082424424628800086842620862222602804040662048448264440204404406488004486208466682268824660688084444820826866848.000000,abs(x-xbefore)=532063692658205331924864608888064082424424628800086842620862222602804040662048448264440204404406488004486208466682268824660688084444820826866848.0000000000000
cnt:6;x=-Inf,abs(x-xbefore)=Inf

最初、これでやってて完全にはまってました。


alphaは、刻み幅や学習係数と呼ばれ、xが収束するかを大きく左右するようですね。
やってみた感じ、alpha>=1ときは、発散しちゃいます。




損失関数をJ(t)=\log(cosh(t))とし、結果を表示する件数を変えるために新しく関数grad3を定義します。


損失関数のグラフは、

x=seq(-5,5,,1000)
plot(x,log(cosh(x)),xlim=c(-5,5),ylim=c(-2,2),type="l")
par(new=T)
abline(h = 0)

f:id:infinity_th4:20120615001131j:image:w360

となります。


損失関数の微分は、tanh(t)です。グラフは以下です。

x=seq(-5,5,,1000)
plot(x,tanh(x),xlim=c(-5,5),ylim=c(-2,2),type="l")
par(new=T)
abline(h = 0)

f:id:infinity_th4:20120615001130j:image:w360

func3=function(t){
return(log(cosh(t)))
}

dffunc3=function(t){
return((sinh(t))/(cosh(t)))
}

xlist=c()
cntlist=c()
errorlist=c()

grad3=function(x,alpha,times){
	cnt=1
	for(i in 1:times){
		xbefore=x
		x=x-alpha*dffunc3(x)
		if(cnt==1||0==cnt%%100){
			xlist<<-c(xlist,x)
			cntlist<<-c(cntlist,cnt)
			errorlist<<-c(errorlist,abs(x-xbefore))
			printf("cnt:%d;x=%f,abs(x-xbefore)=%.13f\n",cnt,x,abs(x-xbefore))
		}
		if(abs(x-xbefore)<=1.0e-8){
			printf("cnt:%d;x=%f,abs(x-xbefore)=%.13f\n",cnt,x,abs(x-xbefore))		
			break
		}
		cnt=cnt+1
	}

}
> grad3(5,0.01,100000)
cnt:1;x=4.990001,abs(x-xbefore)=0.0099990920426
cnt:100;x=4.000287,abs(x-xbefore)=0.0099934294467
cnt:200;x=3.002400,abs(x-xbefore)=0.0099517513467
cnt:300;x=2.017618,abs(x-xbefore)=0.0096590492535
cnt:400;x=1.113095,abs(x-xbefore)=0.0080797706330
cnt:500;x=0.479389,abs(x-xbefore)=0.0044934793548
cnt:600;x=0.181362,abs(x-xbefore)=0.0018115185266
cnt:700;x=0.066704,abs(x-xbefore)=0.0006727509644
cnt:800;x=0.024432,abs(x-xbefore)=0.0002467350759
cnt:900;x=0.008944,abs(x-xbefore)=0.0000903369547
cnt:1000;x=0.003274,abs(x-xbefore)=0.0000330674215
cnt:1100;x=0.001198,abs(x-xbefore)=0.0000121038033
cnt:1200;x=0.000439,abs(x-xbefore)=0.0000044303863
cnt:1300;x=0.000161,abs(x-xbefore)=0.0000016216648
cnt:1400;x=0.000059,abs(x-xbefore)=0.0000005935818
cnt:1500;x=0.000022,abs(x-xbefore)=0.0000002172701
cnt:1600;x=0.000008,abs(x-xbefore)=0.0000000795279
cnt:1700;x=0.000003,abs(x-xbefore)=0.0000000291098
cnt:1800;x=0.000001,abs(x-xbefore)=0.0000000106551
cnt:1807;x=0.000001,abs(x-xbefore)=0.0000000099313

plot(cntlist,xlist,type="l")

f:id:infinity_th4:20120615000503j:image:w360


plot(cntlist,errorlist,type="l")

f:id:infinity_th4:20120615000800j:image:w360

さっきよりずっと早く収束しましたねヾ(=´・∀・`=)