PyTorch Lightningでwandbを使って実験管理をしたい!しかしドキュメント見てもよくわからなかった…という方に向けて、こちらに残しておこうと思います。
今回も、前に書いたときと同様に公式ドキュメントにあったCoolSystemの実装コードをもとに改修していきます。
※PyTorch Lightningのgithubの内容が最近変わったようで、この実装コードが現在は乗っていませんでした…。
もとのスクリプト
デフォルトではTensorboardLoggerとのことで、公式Githubにはこのようなスクリプトとなっていました。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
class CoolSystem(pl.LightningModule): def __init__(self): super(CoolSystem, self).__init__() # not the best model... self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) def training_step(self, batch, batch_idx): # REQUIRED x, y = batch y_hat = self.forward(x) loss = F.cross_entropy(y_hat, y) tensorboard_logs = {'train_loss': loss} return {'loss': loss, 'log': tensorboard_logs} def validation_step(self, batch, batch_idx): # OPTIONAL x, y = batch y_hat = self.forward(x) return {'val_loss': F.cross_entropy(y_hat, y)} def validation_end(self, outputs): # OPTIONAL avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() tensorboard_logs = {'val_loss': avg_loss} return {'avg_val_loss': avg_loss, 'log': tensorboard_logs} def test_step(self, batch, batch_idx): # OPTIONAL x, y = batch y_hat = self.forward(x) return {'test_loss': F.cross_entropy(y_hat, y)} def test_end(self, outputs): # OPTIONAL avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() tensorboard_logs = {'test_loss': avg_loss} return {'avg_test_loss': avg_loss, 'log': tensorboard_logs} def configure_optimizers(self): # REQUIRED # can return multiple optimizers and learning_rate schedulers # (LBFGS it is automatically supported, no need for closure function) return torch.optim.Adam(self.parameters(), lr=0.02) @pl.data_loader def train_dataloader(self): # REQUIRED return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @pl.data_loader def val_dataloader(self): # OPTIONAL return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @pl.data_loader def test_dataloader(self): # OPTIONAL return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32) |
実行は以下です。
1 2 3 4 |
model = CoolSystem() trainer = Trainer(max_epochs=10) trainer.fit(model) trainer.test() |
こちらをwandbで残せるよう変えていきます。
WandbLogger
の使い方
pytorch lightningではカスタムロガーとしてWandbLoggerが用意されています。
使い方としては最低限以下を押さえておけばよさそうです。(他にもできることありそうですが、私はいったん必要なく試していません)
①loggerの定義
1 2 3 |
from pytorch_lightning.loggers.wandb import WandbLogger logger = WandbLogger(project="test_project") |
project引数で、今回の実験のプロジェクト名を指定します。
他にも、試行ごとにランダムで名前が付けられるものに関して、name引数で指定できたりします。
②ハイパーパラメータの記録
1 2 3 4 5 6 |
params = {"optimizer": "Adam", "lr": 1e-2, "batch_size": 32, "epochs": 10} logger.log_hyperparams(params=params) |
辞書型で、ハイパーパラメータを記述しておき、log_hyperparamsの引数に与えてあげることで記録できます。
③メトリクスの記録
※例として、学習データのロスを残す場合
1 2 3 4 5 6 7 8 |
def training_step(self, batch, batch_idx): # REQUIRED x, y = batch y_hat = self.forward(x) loss = F.cross_entropy(y_hat, y) #tensorboard_logs = {'train_loss': loss} self.logger.log_metrics({'loss': loss}) return {'loss': loss} |
メトリクスの記録に関しては、Class内で記述を行います。
logger.log_metrics()内に、キーバリューペアの形で渡してあげます。
④Trainerのインスタンス化のさいにLoggerを指定
1 |
trainer = Trainer(max_epochs=10, logger=logger) |
logger引数に①で定義したWandbLoggerを指定してあげればOKです。
修正後スクリプト
というわけで、以下が修正後スクリプトです。
もとのスクリプトで不要になった部分(Tensorboardに残す記述)に関してはコメントアウトしています。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
class CoolSystem(pl.LightningModule): def __init__(self): super(CoolSystem, self).__init__() # not the best model... self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) def training_step(self, batch, batch_idx): # REQUIRED x, y = batch y_hat = self.forward(x) loss = F.cross_entropy(y_hat, y) #tensorboard_logs = {'train_loss': loss} self.logger.log_metrics({'loss': loss}) #追加 return {'loss': loss} def validation_step(self, batch, batch_idx): # OPTIONAL x, y = batch y_hat = self.forward(x) return {'val_loss': F.cross_entropy(y_hat, y)} def validation_end(self, outputs): # OPTIONAL avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() #tensorboard_logs = {'val_loss': avg_loss} self.logger.log_metrics({'avg_val_loss': avg_loss}) #追加 return {'avg_val_loss': avg_loss} def test_step(self, batch, batch_idx): # OPTIONAL x, y = batch y_hat = self.forward(x) return {'test_loss': F.cross_entropy(y_hat, y)} def test_end(self, outputs): # OPTIONAL avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() #tensorboard_logs = {'test_loss': avg_loss} self.logger.log_metrics({'avg_test_loss': avg_loss}) #追加 return {'avg_test_loss': avg_loss} def configure_optimizers(self): # REQUIRED # can return multiple optimizers and learning_rate schedulers # (LBFGS it is automatically supported, no need for closure function) return torch.optim.Adam(self.parameters(), lr=0.02) @pl.data_loader def train_dataloader(self): # REQUIRED return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @pl.data_loader def val_dataloader(self): # OPTIONAL return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @pl.data_loader def test_dataloader(self): # OPTIONAL return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32) |
※メトリクスの記録をClassに追加。 #追加 の部分です。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
from pytorch_lightning.loggers.wandb import WandbLogger logger = WandbLogger(project="test_project") params = {"optimizer": "Adam", "lr": 1e-2, "batch_size": 32, "epochs": 10} logger.log_hyperparams(params=params) model = CoolSystem() trainer = Trainer(max_epochs=10, logger=logger) trainer.fit(model) trainer.test() |
Wandbのサイトから結果を確認
test_projectが作成されているはずですので、中を見に行くと無事に以下のように記録が残りました!



というわけで、PyTorch Lightningのログをwandbで残す方法でした。
どなたかのお役に立てば幸いです。
最後まで読んでいただきありがとうございました。