[Pythonで実装]MCMCを使ったパラメータ推定

未分類
こんな人におすすめ
  • MCMCについて勉強してみたけど、実際のコードはどう書けばいいの?
  • MCMCを用いたパラメータ推定の手法が知りたい!
    • 誤差も評価したい
  • MCMCに何となく興味がある

当記事では、マルコフ連鎖モンテカルロ法(MCMC)を用いて、パラメータを推定する手法を紹介します。Pythonのコードを載せているので、ぜひお手元で動かしてみてください。Jupyter Notebookでコードは書いています。

まずは一変数の場合

イメージを掴むために簡単な線形モデルのパラメータを推定しましょう。

まずはデータを生成します。

import numpy as np
import matplotlib.pyplot as plt

# 再現できるように乱数のseed値を固定
np.random.seed(1)

# データ作成
x = np.arange(10, dtype=float)
y = 9.6*x

# 乱数で誤差を付加する
y += np.random.normal(loc=0, scale=1.0, size=10)

#グラフ作成
plt.plot(x, y, 'o')
plt.xlabel("x")
plt.ylabel("y")

y = 9.6x という関係のデータを作成。誤差付きのデータを想定しているので、yに平均0、標準偏差1.0の正規分布に従う乱数を誤差として加えます。

というわけで、ここまでの過程は早く忘れてしまいましょう。問題はこのようなデータが仮に得られたとき(傾きはまだ知らない)、どうやって傾きを求めるかです。つまりこの場合、傾きがモデルパラメータであり、これを推定しようというわけです。最小二乗法のような最尤法ではデータに最も合う値を求めることができますが、それはパラメータは100%これだ!!と言っているようなものです。もしあなたが大人でしたら、100%と言い切ることはできないはずです。なぜならあなたは手元にあるデータ以外のすべてのデータの確認ができていないからです。実際、データが増えれば(新しい観測が得られれば)答えは変化するでしょう。

このことを踏まえると、可能性まで含めて議論する方が理にかなっているということに納得できるでしょう。

メトロポリス・ヘイスティング法でサンプリング
import numpy as np
import matplotlib.pyplot as plt

def ssr(y, sample_y): #残差二乗和を計算 y:実際のデータ sample_y:候補のデータ
    r = [np.square(yy-s_yy) for yy, s_yy in zip(y, sample_y)]
    return sum(r)

n = 10**5 #反復回数
sigma2 = 1.0 # 標準偏差 とりあえず1.0にしておく。
step_size = 0.1 #ステップ幅

#初期値
a = 0 #モデルパラメータ(傾き)
count = 0
naccept = 0 #受理された回数

list_a = [] #パラメータの履歴

while count < n:
    backup = a #パラメータのバックアップをとっておく
    action_init = ssr(y, a*x) #残差二乗和を計算
    
    #遷移候補
    a += (np.random.rand()-0.50)*step_size*2.0 
    action_fin = ssr(y, a*x)

    #メトロポリステスト
    metropolis = np.random.rand() # 0~1の一様乱数を生成
    action = action_init-action_fin #残差二乗和を計算
    
    alpha = np.exp(action/(2*sigma2))
    if alpha > metropolis: #受理
        list_a.append(a)
        naccept += 1
    else: #棄却
        a = backup
    count += 1

# 図を作成
burn_in = int(n/10)
plt.hist(list_a[burn_in:], bins = 100)
plt.xlabel('a')

print("受理された回数: {}, 棄却された回数: {}, 更新確率: {}".format(naccept, n-naccept, round(naccept/n*100, 1)))
print("中央値: {:.3f}, 平均値: {:.3f}, 標準偏差: {:.3f}".format(np.median(list_a), np.mean(list_a), np.std(list_a)))
結果
受理された回数: 68120, 棄却された回数: 31880, 更新確率: 68.1
中央値: 9.597, 平均値: 9.584, 標準偏差: 0.298

多変数の場合

準備中…

コメント

タイトルとURLをコピーしました