PyTorch の object detection チュートリアルでつまったところを残しておく

torchvision にある Faster RCNN を使わせてもらおうと思って、こちらのチュートリアルを眺めていました。

https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

Finetuning from a pretrained model というところで下記の様に書いてあり、一瞬これはなんだ、と詰まってしまったのでメモを残しておこうと思います。(今みたらきちんとコメントアウトで何してるか書いてありますね、、。)

 

モデルの構成を見てみる

学習済みモデルを読み込んで、list(model.children) としてモデル構成を見てみます。一番最後のブロックだけ表示すると下記のようになります。

RoIHead という名前で、中に座標回帰と分類器が格納されている box_predictor なるものがいることがわかります。

デフォルトでは、COCOデータセットで学習がされているので、分類器の出力は 91, それぞれい対して bbox の 4 座標ということで座標回帰の出力は 4 x 91 = 364 となっています。あれ、91 クラス? 80 クラスだったような、、。

ドキュメントを見に行くと、Faster RCNN と Mask RCNN が一緒くたになっているようで、インスタンスセグメンテーションとして学習させているようなので、91 stuff ということのようです。

http://cocodataset.org/#home

 

それで、チュートリアルではこれを「人物/背景」の 2 クラス分類器にするということで、分類器の出力は 2, 座標の出力は 8 にしたいです。

以下のように書き換えています。

FasterRCNNPredictor の引数に 入力チャネル数と出力チャネル数を指定してあげることで変更できるため、それぞれを取得して書き換えています。

改めてモデル構成の最後のブロックを表示してみるときちんと望み通りになっていることがわかります。

 

おわり