Pyroのチュートリアルを読んでみた(1)

今回はPyroのチュートリアルを読んで遊んでみたので紹介していきたいと思います!

pyroとは?

pyroとは、確率的プログラミング言語の1種です。確率的プログラミング言語とは、一言で言ってしまえば「統計的モデルを(比較的)お手軽に構築することができるプログラミング言語」です。(あくまで筆者の個人的な考えです。正確に理解したい方はこちらなどが参考になるのではないでしょうか。)

pyroの良いところは、ガウス過程やVAEなどが利用できることです。他にもモンテカルロ法、変分推論など多くの手法が実装されているのでとてもありがたいライブラリとなっています。

筆者はまだDocumentを読んでみたくらいの初心者ですので何かと間違っているところはあると思いますが、この記事を読んでpyroの雰囲気だけでもわかっていただければ嬉しいです!

pytorchを使ったプログラム

では実際にpyroのチュートリアルをざっくりと読んでいきたいと思います。今回読んでいくチュートリアルはここにあります。今回のチュートリアルでは、pytorchでガウス分布やベルヌーイ分布どういった形で利用するかということと、pytorchで実装したプログラムをpyroで書くとどうなるかについて説明しています。

ライブラリのインストール

まず、pytorchとpyroのインストールについてですが、これは各ライブラリのホームページを参照していただくとインストール方法が丁寧に書かれているのでそちらを確認してください。一応どちらもpipでインストールすることができます。そして以下のようにライブラリを読み込みます。

import torch
import pyro 

pyro.set_rng_seed(101)

4行目の”pyro.est_rng_seed(101)”は、おそらくですがランダムシードを固定しているプログラムかと思います。

ガウス分布の作成

次に以下のコードについて見ていきます。

loc = 0. # mean zero
scale = 1. # unit variance
normal = torch.distributions.Normal(loc, scale) # create a normal distribution object
x = normal.rsample() # draw a sample from N(0,1)
print("sample", x)
print("log prob", normal.log_prob(x)) # score the sample from N(0,1)

まず1行目と2行目で平均と分散の値を定めています。そしてその値を3行目のように渡し、オブジェクトを作成します。これで平均0、分散1のガウス分布が作成されました。4行目では、作成したオブジェクトからサンプルを抽出しています。

そして最終行のlog_prob関数についてですが、これは損失関数を構築する際に使用するそうです。(詳しくはDocument参照https://pytorch.org/docs/stable/distributions.html#score-function

簡単なモデルの作成

それでは次に進みます。次は、これらの知識を利用して簡単なモデルを作成します。モデルのプログラムは以下のようなものになります。

def weather():
    cloudy = torch.distributions.Bernoulli(0.3).sample()
    cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'
    mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
    scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
    temp = torch.distributions.Normal(mean_temp, scale_temp).rsample()
    return cloudy, temp.item()

2行目ではベルヌーイ分布からサンプルをとっています。ベルヌーイ分布は0か1を出力する分布でパラメータとして0.3を渡しているところから、1になる確率が0.3、0になる確率が0.7となっています。よって2行目の変数cloudyには0か1のどちらかが代入されます。そして次の行では、前で代入された値が1の場合、文字列”cloudy”、0の場合は文字列”sunny”を代入します。

4、5行目ではガウス分布の平均と分散パラメータを定めています。そして6行目ではパラメータをセットしたガウス分布から値を抽出し、その値をtempに代入しています。

ここまでで分布から値を抽出することから、簡単なモデルを作成するところまでみてきました。しかし、これらのプログラムはpyroを使ったプログラムではありません。ここでは基本的にpytorchに備わっている関数を使っているだけでした。

次からは、今までのプログラムをpyroを使ったプログラムに書き直すとどうなるかについてみていきます。

pyroを使ったプログラム

pyroによるサンプル方法

まずガウス分布からのサンプルを抽出するプログラムをpyroで書くと以下のようになります。

x = pyro.sample("my_sample", pyro.distributions.Normal(loc, scale))
print(x)

変数locと変数scaleは最初に使った値と同じです。”my_sample”というのはおそらくサンプルの名前をつけることができて任意の文字列を渡すことができると思います(たぶん)。

pyroによるモデルの作成

次にweatherモデルをpyroで書くと次のようになります。

def weather():
    cloudy = pyro.sample('cloudy', pyro.distributions.Bernoulli(0.3))
    cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'
    mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
    scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
    temp = pyro.sample('temp', pyro.distributions.Normal(mean_temp, scale_temp))
    return cloudy, temp.item()

変更点は、分布からサンプル抽出する際にpyro.sampleを使うところです。それ以外は変わりありませんね。

チュートリアルのこれより下の部分はここまでやってきたことで全て理解できるので割愛します。

最後にチュートリアルのコードをコメント付きで書いたものを以下に貼ります。

In [1]:
import torch
import pyro
# random_seedの固定?
pyro.set_rng_seed(101)
In [2]:
# 平均0, 分散1のガウス分布からサンプル
loc = 0.   # mean zero
scale = 1. # unit variance
# locとscaleをセット
normal = torch.distributions.Normal(loc, scale) # create a normal distribution object
# 値をセットしたガウス分布から実際にxをサンプルする
x = normal.rsample() # draw a sample from N(0,1)
print("sample", x)
print("log prob", normal.log_prob(x)) # score the sample from N(0,1)
sample tensor(-1.3905)
log prob tensor(-1.8857)
In [11]:
def weather():
    # ベルヌーイ分布からのサンプル(1か0を出力)
    # Bernoulli(0.3)より, cloudyになる確率が0.3, sunnyが0.7になる
    cloudy = torch.distributions.Bernoulli(0.3).sample()
    # 前の行でcloudyに代入された値が1だったら'cloudy', 0だったら'sunny'が代入される
    cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'
    mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
    scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
    temp = torch.distributions.Normal(mean_temp, scale_temp).rsample()
    return cloudy, temp.item()

for _ in range(3):
    print(weather())
('sunny', 58.44780731201172)
('sunny', 80.37315368652344)
('sunny', 101.10397338867188)
In [9]:
# "my_sample"の部分は任意の文字列で良い(?)
x = pyro.sample("my_sample", pyro.distributions.Normal(loc, scale))
print(x)
tensor(0.6033)
In [10]:
def weather():
    cloudy = pyro.sample('cloudy', pyro.distributions.Bernoulli(0.3))
    cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'
    mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
    scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
    temp = pyro.sample('temp', pyro.distributions.Normal(mean_temp, scale_temp))
    return cloudy, temp.item()

for _ in range(3):
    print(weather())
('sunny', 83.63148498535156)
('sunny', 75.91869354248047)
('sunny', 122.68901062011719)
In [12]:
def ice_cream_sales():
    cloudy, temp = weather()
    expected_sales = 200. if cloudy == 'sunny' and temp > 80.0 else 50.
    ice_cream = pyro.sample('ice_cream', pyro.distributions.Normal(expected_sales, 10.0))
    return ice_cream
In [16]:
def geometric(p, t=None):
    if t is None:
        t = 0
    x = pyro.sample("x_{}".format(t), pyro.distributions.Bernoulli(p))
    if x.item() == 1:
        return 0
    else:
        return 1 + geometric(p, t + 1)

print(geometric(0.5))
0
In [18]:
def normal_product(loc, scale):
    z1 = pyro.sample("z1", pyro.distributions.Normal(loc, scale))
    z2 = pyro.sample("z2", pyro.distributions.Normal(loc, scale))
    y = z1 * z2
    return y

def make_normal_normal():
    mu_latent = pyro.sample("mu_latent", pyro.distributions.Normal(0, 1))
    fn = lambda scale: normal_product(mu_latent, scale)
    return fn

print(make_normal_normal()(1.))
tensor(3.1065)