【PyTorch】torchvisionから学習済みモデルを使用する際のTips

torchvisionから学習済みモデルを使用したい場合に、

・任意の層まで、パラメータを固定したい

・学習済み重みののっかった層をとってきたい

なんてことがありました。

前者はファインチューニングの際に必要ですし、後者はSkip Connection的な実装をする際に必要となってきます。

(Encoderで学習済みモデルを利用して、Upsampling時にEncoderの情報を使うようなシーン等)

というわけで、これらのやり方をメモしておこうと思います。

 

任意の層までパラメータを固定したい

まずは、サンプルとして vgg16 を持ってきます。

モデルのパラメータは vgg.parameters() で取り出すことができます。

Generatorオブジェクトですので、いったんリストに格納して中身を見てみます。

VGG16なので、要素数は16個かと思いきや、32個あります。

各パラメータのshape を見てあげると、これは重みとバイアスでそれぞれ別にあることがわかります。

 

各パラメータはデフォルトでrequires_grad=True となっており、再学習されてしまいます。

ですので、例えば12層目まで重みバイアスを固定で、それ以降は再学習可能にしたい場合はfor文を回してあげて、任意の層でrequires_grad=Falseとしてあげればよいことになります。

 

学習済み重みののっかった層をとってきたい

例えば、VGGの5層目を通した後の出力を保存しておいて、Decoderのある層に加算しようだとかいう場合。

モデルの重みは、.parameters()で取得することができましたが、定義された層の情報は、.children()で取得することができます。

さらに、中は大きく3つの要素で構成されていることがわかります。

一つ目の、nn.Sequential() の中身はConv層で、特徴抽出器の部分。

二つ目は AdaptiveAvgpool2d()

三つ目は、nn.Sequential()で中身は全結合層なので、分類器の部分ですね。

というわけですので、学習済み重みがセットされた層の、はじめの5層だけを取り出す際には、

こんな風にすればnn.Sequentialのまとまりで層をとってくることができました。

 

ちなみにtorchvisionのソースを見るとこのようになっており、 ( https://pytorch.org/docs/stable/_modules/torchvision/models/vgg.html )

vggモデルでいえば、同様のことは vgg.features[:5] でも行えますね。

 

最後まで読んでいただきありがとうございました。