【初心者向け】基礎&実践プログラミング

初心者がつまづきやすいところ、最短で実力が身につく方法をお伝えします。

【PyTorch】サンプル⑥ 〜 nn パッケージ 〜

f:id:AIProgrammer:20200413020244p:plain

目的

  • PyTorch: nnを参考にPyTorchのnnパッケージを扱う。
  • nnパッケージの便利さを感じる。

前準備

PyTorchのインストールはこちらから。

初めて、Google Colaboratoryを使いたい方は、こちらをご覧ください。

コマンドラインの「>>>」の行がPythonで実行するコマンドです。 それ以外の行は、コマンドの実行結果です。

>>> print("PyTorch")    # コマンドラインに打ち込む
PyTorch    # 出力結果

これまで、勉強してきた内容は以下。

チュートリアル

サンプル

nn パッケージ

nnは、ニューラルネットワークの構築に用いる。

PyTorchの自動微分autogradによって、計算グラフやパラメータの勾配を簡単に計算することができます。ですが、自動微分だけで複雑なニューラルネットワークを定義するのは困難です。そこで活躍するのがnnパッケージです。

nnパッケージを用いて、ニューラルネットワークを一つのモジュールとして定義することができます。

PyTorchのインポート

import torch

使用するデータ

バッチサイズNを64、入力の次元D_inを1000、隠れ層の次元Hを100、出力の次元D_outを10とします。

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

入力(x)と予測したい(y)を乱数で定義します。

# Create random input and output data
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

ニューラルネットワークのモデルを定義

ニューラルネットワークのモデルをnnパッケージを用いて定義します。

定義の仕方は、大きく2つありますがここでは一番簡単なtorch.nn.Sequentialを使います。

作り方は簡単で、任意の層を積み重ねていくだけです。この例では、input > Linear(線型結合) > ReLU(活性化関数) > Linear(線型結合) > outputの順に層が積み重なっています。

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

損失(loss)の定義

二乗誤差もnnパッケージを用いて計算することができます。

reductionのデフォルトはmeanですので、何も指定しなければtorch.nn.MSELossは平均二乗誤差を返します。

reduction=sumとした場合は、累積二乗誤差を算出します。

loss_fn = torch.nn.MSELoss(reduction='sum')

(参考) nnパッケージを使う前は以下のように記述していた。

    loss = (y_pred - y).pow(2).sum()

学習パラメータ

学習率を1e-4として、学習回数を500回とします。

learning_rate = 1e-4
for t in range(500):

モデルへのデータ入出力 (Forward pass)

定義したニューラルネットワークモデルへデータxを入力し、予測値y_predを取得します。

y_pred = model(x)

(参考) nnパッケージを使う前は以下のように記述していた。

    y_pred = x.mm(w1).clamp(min=0).mm(w2)

損失(loss)の計算

定義した損失関数で予測値y_predと真値yとの間の損失を計算します。

    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

勾配の初期化

逆伝播(backward)させる前に、モデルのパラメータが持つ勾配を0(ゼロ)で初期化します。

    model.zero_grad()

勾配の計算

backwardメソッドでモデルパラメータ(Weight)の勾配を算出します。

    loss.backward()

パラメータ(Weight)の更新

確率勾配降下法(SGD: stochastic gradient descent)で、Weightを更新する。

    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

(参考) nnパッケージを使う前は以下のように記述していた。

    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

実行

以下のコードを6_nn_package.pyとして保存します。

6_nn_package.py

import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. 
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. 
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    loss.backward()

    # Update the weights using gradient descent.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

保存ができたら実行しましょう。

左の数字が学習回数、右の数値がパーセプトロンの推定値と実際の答えと二乗誤差です。

学習を重ねるごとに、二乗誤差が小さくなることがわかります。

$ python3 6_nn_package.py 
99 2.5003600120544434
199 0.06977272033691406
299 0.003307548351585865
399 0.00018405442824587226
499 1.1299152902211063e-05

終わりに

多層のニューラルネットワークには、膨大な量のパラメータが存在しています。

nnパッケージを用いる前は、各パラメータごとに勾配計算やパラメータの更新などをしていましたが、それでは記述が困難です。

nnパッケージを用いることで、楽に勾配計算やパラメータの更新等が実行できることが感じてもらえたら嬉しいです(^^)!。



頑張れ!喝!!の代わりにB!ブックマークを押していただけるとただただうれしいです(^^)! ↓