Pyroのチュートリアルを読んでみた(2)

今回は前回に引き続きPyroのチュートリアルを読み進めてみました!今回のチュートリアルはここです。

このチュートリアルの目標は、簡単な生成モデルを作成することです。生成モデルについては書籍や他Webサイトなどで調べてみてください。

ライブラリのインストール

まずはチュートリアルで必要なライブラリをインポートします。

今回必要なライブラリは以下のようになっています。

import matplotlib.pyplot as plt
import numpy as np
import torch

import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist

pyro.set_rng_seed(101)

何をモデル化するのか

このチュートリアルでモデル化する対象は、ライブラリをインポートしているコードの下に書かれています。主に以下のような内容です。

物体の重さを測るある体重計は、同じ物体に対して毎回違った重さを表示してしまう。この変動性を物体の密度や材料のような前提知識を用いた推測によって補おう。」

これが今回モデル化したい対象です。毎回違う重さを表示する体重計なんて筆者の人生で一度もない経験でありなかなか現実性がなさそうな事象ですが、この問題を考えていきます。

簡単なモデル例

前の節で話したことは、次のプロセスで表すことにします。

$$weight|guess \sim Normal(guess, 1)$$

$$ measurement | guess, weight \sim Normal(weight, 0.75)$$

1行目は、平均がguess、分散が1のガウス分布から値をサンプルし、weightに代入するという操作をしています。

2行目は、1行目の操作によって得られたweightを平均にし、分散を0.75としてガウス分布から値をサンプルしmeasurementに代入しています。

これらの操作をプログラムにすると以下のようになります。

def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

条件付け

次に、観測データが得られた場合の重さを推測することを考えます。つまり、measurementが得られたとき、weightはどれくらいかを推測します(ただし、guessは既知)。これは次のように表されます。

$$ (weight | guess, measurement=9.5) \sim ?$$

これをPyroを使って表すと次のようなプログラムになります。

conditioned_scale = pyro.condition(scale, data={"measurement": 9.5})

pyro.conditionには、第1引数に関数を、第2引数に観測データを辞書型で渡します。

そしてpyro.conditionは、Pythonの関数と同じような振る舞いをするので、lambdaやdefを使ってパラメータ化したりすることができます。(この部分はうまく文章にできなかったので、チュートリアル本文を読むことをオススメします)。

def deferred_conditioned_scale(measurement, guess):
return pyro.condition(scale, data={"measurement": measurement})(guess)

また、pyro.conditionは、pyro.sampleによって書き表すことができます。それが以下のコードになります。

def scale_obs(guess):  # equivalent to conditioned_scale above
weight = pyro.sample("weight", dist.Normal(guess, 1.)) # here we condition on measurement == 9.5 return pyro.sample("measurement", dist.Normal(weight, 0.75), obs=9.5)

Guide関数による推論

Guide関数は、MCMCの提案分布、変分推論の近似分布などで利用することができる関数です。

Guide関数は以下の2点を満たす必要があります。

  1. モデルに現れる全ての未観測sample statementがGuideにも現れること(sample statementが何を指しているのかがわかりません)
  2. モデルとガイドは全く同じ引数をとる

次に、自作でGuide関数を作成してみます。そのためには、まずweightに関する事後分布を計算によって求める必要があります。チュートリアルのリンク先の3.4節にあるように(一応リンク先はこちらとなっています)、weightに関する事後分布はガウス分布となります。あとは、事後分布のパラメータを更新します。事後分布のパラメータは、平均パラメータと分散パラメータで、リンク先に書いてある更新式を利用すると、以下のようなプログラムを書くことができます。

def perfect_guide(guess):
loc =(0.75**2 * guess + 9.5) / (1 + 0.75**2) # 9.14 scale = np.sqrt(0.75**2/(1 + 0.75**2)) # 0.6 return pyro.sample("weight", dist.Normal(loc, scale))

locが平均、scaleが分散です。こちらの計算結果は、Normal(9.14, 0.6)となり、最終的にこれらのパラメータをセットしたガウス分布からサンプルされた値がperfect_guideの返り値になります。

Parametrized Stochastic Functions and Variational Inference

前節では事後分布は解析的に求めることができました。これは、尤度関数と平均パラメータの事前分布が両方ガウス分布であったため、事後分布もガウス分布となります。しかし、必ずしも全ての分布がガウス分布で表されることはなく、むしろこのようなケースは実際は稀なのです。解析的に求めることができない事後分布は、変分推論やMCMCなどの近侍手法を用いることによって分布を近似し、その近似分布を最適化することになります。

Pyroにおいては、Pyro.paramが用意されているので、こちらを用いることにします。

例えば以下のように使うことができます。

def scale_parametrized_guide(guess):
a = pyro.param("a", torch.tensor(guess)) b = pyro.param("b", torch.tensor(1.)) return pyro.sample("weight", dist.Normal(a, torch.abs(b)))

こうすることによって、”a”を呼び出せば、guessの値が参照され、”b”を呼び出せば1.0が取り出されます。

また、ガウス分布の分散は正の値をとりますが、このようなパラメータの制約をconstraintsを使って表現することができます。

from torch.distributions import constraints

def scale_parametrized_guide_constrained(guess):
    a = pyro.param("a", torch.tensor(guess))
    b = pyro.param("b", torch.tensor(1.), constraint=constraints.positive)
    return pyro.sample("weight", dist.Normal(a, b))  # no more torch.abs

このようにすることで一つ上の例で示したtorch.absを用いることなく、正の値しか取れないという制約を加えることができます。

まとめ

今回はパラメータの事後分布の推定をPyroを使って実装してきました。その他にも変分推論などで用いるGuide関数など、新しい機能を紹介しました。

次回は、SVIのチュートリアルについて書きたいと思います!