trsing’s diary

勉強、読んだ本、仕事で調べたこととかのメモ。

PRMLメモ 3章 図3.8、図3.9

図3.8と図3.9の図を描画するpythonのコード
参考(というかコメント以外ほぼ丸写し…)
https://qiita.com/naoya_t/items/80ea108cebc694f5cd63qiita.com
コード見ると理解しやすいですね。

#https://qiita.com/naoya_t/items/80ea108cebc694f5cd63参照
from pylab import *

S=0.1
ALPHA=0.1
BETA=9

#(xs,ts):訓練データ。観測値と目標値。
#訓練データから事後分布の平均(3.53)と分散の逆数(3.54)を求める
#これらから予測分布(3.58)を求める
#予測分布(3.58)の平均と分散から図3.8を描画
#事後分布(3.49)からモデルパラメータwを決定して図3.9を描画
def sub(xs,ts):
    #s:標準偏差、mu:平均
    #\phi_j(x)作成する関数を返す。ガウス基底関数。
    def gaussian_basis_func(s,mu):
        return lambda x:exp(-(x - mu)**2/(2 * s**2))
    
    #s:標準偏差、xs:サイズ
    #基底関数\phi(x)を作成する関数を返す。
    def gaussian_basis_funcs(s,xs):
        return [gaussian_basis_func(s,mu) for mu in xs]
    
    xs9=arange(0,1.01,0.125)#9個
    bases=gaussian_basis_funcs(S,xs9)
    
    N=size(xs)
    M=size(bases)
    #\phi(x)作成
    def Phi(x):
        return array([basis(x) for basis in bases])
    
    PHI=array(list(map(Phi,xs)))#\Phiを作成
    PHI.resize(N,M)

    #重みの事後分布(3.49)の平均、分散、予測分布(3.58)の平均と分散を返す関数を返す
    def predictive_dist_func(alpha, beta):#wの平均m_N、分散S_N^{-1}、推定値の関数を返す
        #(3.49)の平均(m_N)と分散の逆行列(S_N^{-1})を作成
        S_N_inv=alpha*eye(M)+beta*dot(PHI.T,PHI)#(3.54):dotは積。.Tは転値
        m_N=beta*solve(S_N_inv,dot(PHI.T,ts))#(3.53):solve(A,b)はAx=bの解(A^(-1)b)
    
        #入力xに対応する予測分布(3.58)の平均と分散を返す
        def func(x):
            Phi_x=Phi(x)#\phi(x)
            mu=dot(m_N.T, Phi_x)#予測分布の平均m_N^T\phi(x)
            s2_N=1.0/beta+dot(Phi_x.T,solve(S_N_inv,Phi_x))#予測分布の分散1/\beta+\phi(x)^TS_N\Phi(x)
            return (mu,s2_N)
        
        return m_N,S_N_inv,func#(3.53),(3.54),(3.58)のxに対する平均、分散を返す関数
    
    xmin=-0.05
    xmax=1.05
    ymin=-1.5
    ymax=1.5
    
    
    clf()
    axis([xmin,xmax,ymin,ymax])
    title("Fig 3.8 (%d sample%s)" % (N,'s' if N > 1 else ''))
    
    x_=arange(xmin,xmax,0.01)
    plot(x_,sin(x_*pi*2),color='gray')#元のデータ

    m_N,S_N_inv,f=predictive_dist_func(ALPHA,BETA)#ハイパーパラメータに対応する平均、分散、予測式
    
    y_h=[]
    y_m=[]
    y_l=[]
    #yの予測の平均(y_m)、平均±標準偏差(y_h,y_l)作成
    for mu,s2 in map(f,x_):#mu,s2:各x_の要素に対応する予測の平均と分散
        s=sqrt(s2)
        y_m.append(mu)
        y_h.append(mu+s)
        y_l.append(mu-s)
    #図3.8描画
    fill_between(x_,y_h,y_l,color='#cccccc')
    plot(x_,y_m,color='#000000')
    scatter(xs,ts,color='r', marker='o')#入力と目標値のペア
    show()
    clf()
    
    x_=arange(xmin,xmax,0.01)
    #図3.9描画
    plot(x_,sin(x_*pi*2),color='gray')
    #推定したwでグラフを描画
    for i in range(5):
        w=multivariate_normal(m_N,inv(S_N_inv),1).T#wを作成(平均m_N、分散S_Nガウス分布より)
        y=lambda x:dot(w.T,Phi(x))[0]#推定値を算出する関数
        plot(x_,y(x_),color='#cccccc')
    scatter(xs,ts,color='r',marker='o')
    show()
 
    
def main():
    #xと対応するtを作成(t_n=sin(x_n*2*\pi)+ガウスノイズ)
    xs=arange(0,1.01,0.02)#x_n
    ts=sin(xs*pi*2)+normal(loc=0.0,scale=0.1,size=size(xs))#t_n
    
    #n:範囲、k:個数
    #0~n-1からk個選んでソートしたものを返す。抜き出すデータ点
    def randidx(n,k):
        r = list(range(n))
        shuffle(r)
        return sort(r[0:k])
    #それぞれの訓練データ数(1,2,5,20)で処理
    for k in(1,2,5,20):
        indices=randidx(size(xs),k)
        sub(xs[indices],ts[indices])#使用する(x_n,t_n)
        
if __name__=='__main__':
    main()