【Python実装】ノンパラベイズ3次元無限関係モデル(3D-IRM)をギブスサンプリング(MCMC)で推論
今回は、書籍「続・わかりやすいパターン認識」の13章で紹介されている無限関係モデル(Infinite Relational Model)のギブズサンプリング(MCMC)による推論を、3次元にカスタマイズした3D-IRM(勝手に名前)をPythonで実装します。
モデルと推論方法に関しては、書籍「続・わかりやすいパターン認識」の13章を参考にしています。詳しくはこちらをご参照ください。
- 作者: 石井健一郎,上田修功
- 出版社/メーカー: オーム社
- 発売日: 2014/08/26
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (2件) を見る
今回のコードを全てgithubに載せています。遊べるようにnotebookもつけてます。githubはこちら
Twitterフォローよろしくお願いいたします! twitterはこちら
無限関係モデルとは?
詳しくは書籍「続・わかりやすいパターン認識」の13章を参考にしてください。ここでは簡単に説明します。
無限関係モデル(Infinite Relational Model:IRM)とは、異種オブジェクトを同時にクラスタリングする共クラスタリング手法の一つで、ノンパラメトリックベイズに基づきます。
ざっくり言うと、データの値は「1」か「0」で、軸毎にカチャカチャ入れ替えて、「1」が比較的多いグループと「0」が比較的多いグループにクラスタリングするイメージです。
下図が3D-IRMによるクラスタ結果の例です。「1」のみをプロットしています。ルービックキューブみたいです。
ポイントは「X軸を動かすときはY, Z軸は固定」「Y軸を動かすときはX, Z軸は固定」と動かし方に制限があるところです。
元のトイデータを4×4×4のキューブで全て統一していますが、データ次第ではクラスタ毎にキューブの大きさがバラバラになります。むしろ大半がそうです。(同軸上の長さは固定)
ノンパラメトリックベイズなので、クラスタ数は予め決めず、データから自動で推論してくれます。書籍では、中華料理店過程(Chinese Restaurant Process、略してCRP)に従って独立に各軸がクラスタリングされると仮定しています。
仮定するデータの生成過程としては数式で表すと(各定義は書籍)
$${\bf s}^{1}| \alpha \sim C R P(\alpha)$$$${\bf s}^{2}| \alpha \sim C R P(\alpha) $$
$${\bf s}^{3}| \alpha \sim C R P(\alpha) $$$$ \theta({s_x}^{1}, {s_y}^{2}, {s_z}^{3})| a,b \sim Be(a,b)$$
$$R_{xyz}|{s_x}^{1} = {w_i}^{1}, {s_y}^{2} = {w_j}^{2}, {s_z}^{3} = {w_k}^{3} \Theta \sim Bern(R_{xyz} ; \theta_{ijk})$$
CRPによってクラスタ番号を表す\({\bf s}^{1}, {\bf s}^{2}, {\bf s}^{3}\)が割り当てられ、3軸の各々のクラスタの直積からなる各クラスタ\((i, j, k)\)毎にパラメータ\(\theta_{ijk}\)がベータ分布から別々に決まり、1セルずつ\(R_{xyz}\)(「1」か「0」)がベルヌーイ分布から生成されるといった具合です。
何に使えるのか?
例えば、(x, y, z)軸に(店舗、顧客、商品)をセットして、購入したを「1」、購入しないを「0」とすると、各クラスターから「この(地域A)では、(顧客層B)に対して、(商品カテゴリーC)が売れやすい、売れにくい」などが分かります。(x, y, z)に何をセットするかは自由です。応用幅は広いと思います。
モデルに関しては以上で、推論方法について書籍では、崩壊型ギブスサンプリングで数式と手続きが書かれているのですが、今回は自分の勉強のために普通のギブスサンプリングに直して実装しました。
ギブズサンプリングとは?
MCMCの一手法で、各パラメーター毎にそれ以外のパラメーターで条件付けした分布から、順にサンプリングしていく手法です。
パラメーター毎に条件付き分布を手計算する必要があり、stanでよく使われる自動微分すればOKなHMC(Hamiltonian Monte Carlo)とは違い、自動化が難しい手法になっています。
ギブズサンプリングは、推移核を工夫したメトロポリス・ヘイスティング法(MH法)とも解釈ができ、MH法で言うところの採択率が1になり、サンプルが棄却されないことで知られています。1パラメーターずつサンプルしていくため、各軸平行に直角にサンプル点が動いていくイメージです。そのためパラメーター間の相関が強い分布に対しては、サンプリング数が多く必要になります。また、パラメーターの条件付き分布がいつも簡単にサンプリングできる分布になるとは限らないので要注意です。
各クラスタ番号\({\bf s}\)の条件付き分布が中華料理店過程にちょうど対応します。
ギブズサンプリングによってサンプリングした\({\bf s}^{1}, {\bf s}^{2}, \Theta\)の内、対数事後分布\(logP({\bf s}^{1}, {\bf s}^{2}, \Theta\|R)\)を最も大きくする\({\bf s}^{1}, {\bf s}^{2}, \Theta\)を最終結果として採用します。
実装
ほとんどNumpyとScipyで実装しました。データも自分で作成しています。出来るだけfor文を使わずにNumpyで書くようにしたので可読性は低いです(Numpyのせいにする)。それでも次元を増やしすぎると極端に遅くなります。
# gibbs sampling def predict_S(R, alpha,a,b, iter_num=500, reset_iter_num=100): X, Y, Z = R.shape # set first values ########################## sx = CRP(alpha=alpha, sample_num=X) sy = CRP(alpha=alpha, sample_num=Y) sz = CRP(alpha=alpha, sample_num=Z) theta = posterier_theta(sx, sy, sz, R, a, b) ############################################## max_v = -np.inf # to recycle 'def s_update' R_transpose_y = R.transpose((1,2,0)) R_transpose_z = R.transpose((2,0,1)) # gibbs sampling for t in range(iter_num): print("\r calculating... t={}".format(t), end="") sx, theta = s_update(sx, sy, sz, theta, R, a, b, alpha, axis=0) sy, theta = s_update(sy, sz, sx, theta, R_transpose_y, a, b, alpha, axis=1) sz, theta = s_update(sz, sx, sy, theta, R_transpose_z, a, b, alpha, axis=2) log_p_sx = np.log(Ewens_sampling_formula(sx, alpha)) log_p_sy = np.log(Ewens_sampling_formula(sy, alpha)) log_p_sz = np.log(Ewens_sampling_formula(sz, alpha)) log_p_theta = np.sum(st.beta.logpdf(theta, a,b)) log_p_R_theta = log_R_theta_probability(sx, sy, sz, R, theta) #log_p_R_ijk = log_R_probability(sx, sy, sz, R, a, b) # logP(sx, sy, sz, theta| R) v = log_p_sx + log_p_sy + log_p_sz + log_p_theta + log_p_R_theta # update if over max if v > max_v: max_v = v max_sx = sx max_sy = sy max_sz = sz max_theta = theta print(" update S and theta : logP(sx, sy, sz, theta| R) = ", v) # to prevent getting stuck local minima, reset S and theta if t%reset_iter_num==0: sx = CRP(alpha=alpha, sample_num=X) sy = CRP(alpha=alpha, sample_num=Y) sz = CRP(alpha=alpha, sample_num=Z) theta = posterier_theta(sx, sy, sz, R, a, b) return max_sx, max_sy, max_sz, max_theta
早い段階で更新されなくなり局所解にはまっている印象を受けたので数十サンプル毎にパラメーターを初期化して、広範囲探索するようにしました。
具体的なギブズサンプリングの部分です。
# update s def s_update(s1, s2, s3, theta, R, a, b, alpha, axis): if axis==1: theta = theta.transpose((1,2,0)) elif axis==2: theta = theta.transpose((2,0,1))
# sort orderby s2,s3 for easy calculation sorted_s2_index = s2.argsort() sorted_s3_index = s3.argsort() sorted_s2 = s2[sorted_s2_index] sorted_s3 = s3[sorted_s3_index] R_sorted = R[:,sorted_s2_index,:][:,:,sorted_s3_index] for idx in range(len(s1)): # remove s1_x for gibbs sampling s1_delete = s1[idx] s1_left = np.delete(s1, idx) theta_left = theta[np.unique(s1_left),:, :] # if category_num is decreased, fill empty category number s1_left = reset_s_number(s1_left) # count n_ij num_n_ijk_left = count_n_ijk(s1_left, s2, s3) # log_p(s1_k | s1_left) by Dirichlet Process n_i = np.add.reduce(num_n_ijk_left, axis=(1,2)) n_i = np.append(n_i, alpha) ln_p_s1_idx_s1_left = np.log(n_i/(np.add.reduce(num_n_ijk_left,axis=(0,1,2)) + alpha)) # log_p(R| s1_left, s2, s3) R_idx_sorted =R_sorted[idx,:,:] R_ijk_1, R_ijk_0 = count_one_zero_2D(R_idx_sorted, sorted_s2, sorted_s3) ln_p_R_xyz_new = np.sum(betaln(R_ijk_1+a, R_ijk_0 +b) - betaln(a,b)) ln_p_R_xyz_exist= np.sum(R_ijk_1 * np.log(theta_left), axis=(1,2))+ np.sum(R_ijk_0 * np.log(1-theta_left),axis=(1,2)) # Ratio for choosing new s1_x '+100' is for preventing underflow p_s1_idx = np.exp(ln_p_s1_idx_s1_left + np.append(ln_p_R_xyz_exist, ln_p_R_xyz_new)+100) p_s1_idx/=np.sum(p_s1_idx) s_new = np.argmax(np.random.multinomial(n=1, pvals=p_s1_idx)) # new s1 updated s1 = np.insert(s1_left, idx, s_new) # update theta theta = posterier_theta(s1, s2, s3, R, a,b) if axis==1: theta = theta.transpose((2,0,1)) elif axis==2: theta = theta.transpose((1,2,0)) return s1, theta
今回、ギブズサンプリングを使用していますが、ベルヌーイ分布のパラメーターだけはサンプリングせず、条件付き分布の平均値をサンプルとみなして使っています。特に理由はないです。(この手法もなんか名前があるらしい)サンプリングしてもOKです。また軸毎に\(s\)をサンプルする毎にベルヌーイ分布のパラメーターもサンプル。\(s_x\)→\(\Theta\)→\(s_y\)→\(\Theta\)→\(s_z\)→\(\Theta\)と更新しています。
今回のコードを全てgithubに載せています。遊べるようにnotebookもつけてます。githubはこちら
動かしてみた
まずは、見やすさのためベルヌーイ分布のパラメーターを1か0にした超理想的なデータに適用しました。クラスタリング手法なのでデータによって、局所解が多く存在します。
次にノイジーなデータにも試しました。 パラメーターが0.8と0.2と0の3種類のベルヌーイ分布で生成しました。
クラスタ数(4,4,4)で推定できました。X軸で切ってクラスタ毎のパラメータ\(\Theta\)を可視化しました。0.2や0.8付近ででまとまっています。
今回のコードを全てgithubに載せています。遊べるようにnotebookもつけてます。githubはこちら
Twitterフォローよろしくお願いいたします! twitterはこちら