MLFlowやらComet.mlやら、最近実験管理ツールなるものの存在を知りました。
とあるGithubのレポジトリを解読していたら、そこではwandbというライブラリが使われていました。
これがとても簡単で使いやすくよい実験管理ツールでしたので、メモがてらご紹介しようと思います。
Weights & Biases の使い方
事前にこちらのサイトからサインアップしておきましょう。(無料です)
https://www.wandb.com/
使い方としては、まずはインストール
1 |
!pip install wandb |
Kerasではこれだけ!簡単です。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import wandb from wandb.keras import WandbCallback #プロジェクトを指定(なかったら作成される) wandb.init(project="sample_project") #調整するパラメータを登録 config = {"optimizer": "Adam", "lr" : 1e-3, "epochs": 20} wandb.config.update(config) ~~~~~~~~~~~ #WandCallback()を指定 model.fit(train_images, train_labels, validation_data = (test_images, test_labels), epochs=20, callbacks=[WandbCallback()]) |
TF2.0のサンプルコードで試してみる
こちらのコードでwandb使ってみようと思います。(ファッションMNIST)
https://www.tensorflow.org/tutorials/keras/classification?hl=ja
まずはライブラリのインポートとデータの取得、前処理です。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
from __future__ import absolute_import, division, print_function, unicode_literals # TensorFlow と tf.keras のインポート %tensorflow_version 2.x import tensorflow as tf from tensorflow import keras # ヘルパーライブラリのインポート import numpy as np import matplotlib.pyplot as plt print(tf.__version__) ``` 2.1.0 ``` #データのダウンロード fashion_mnist = keras.datasets.fashion_mnist (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() #正規化 train_images = train_images / 255.0 test_images = test_images / 255.0 |
ここで wandb.init
しておきます
1 2 3 4 5 6 7 8 9 |
import wandb from wandb.keras import WandbCallback wandb.init(project="sample_project") config = {"optimizer": "Adam", "lr" : 1e-3, "epochs": 20} wandb.config.update(config) |
モデルを構築して、学習します。このとき、callback=[WandCallback()]
を忘れずに
1 2 3 4 5 6 7 8 9 10 11 12 13 |
model = keras.Sequential([ keras.layers.Flatten(input_shape=(28, 28)), keras.layers.Dense(128, activation='relu'), keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(train_images, train_labels, validation_data = (test_images, test_labels), epochs=20, callbacks=[WandbCallback()]) |
ハイパーパラメータを変えて何度か学習を回します。(wandb.initから3回やりました。)
これでwandbのサイトに行き、ログインすると・・・

おお!これはテンションが上がります。学習ログを自動的に重ねて可視化してくれます。
さらに諸条件はテーブルで確認できます。

さらにさらに!自動的に学習時の環境や、ベスト時の重みの保存等してくれており、ダウンロードができます!

たった数行付け加えただけなのに。これは神。
もっといろんな機能ついてそうですし、触ってみよう。
(公式ドキュメント:https://docs.wandb.com/)
他の実験管理ツールはもっと便利なのでしょうか。
最後まで読んでいただきありがとうございました。