【tensorflow2.x 入門】@tf.function で学習速度が上がるか試してみる

tf2,x 系に最近入門してみて、subclassing API まわりの書き方を勉強しています。

学習ループの中で、逆伝播を行う tf.GradientTape() 以下の箇所が計算負荷が大きいが、@tf.function を追加することで、関数を計算グラフに変換してより早く実行してくれる。ということだそうなので、実際に追加してみて速度がどの程度変わるのかを試してみました。

小さなデータセットでは速度の差をあまり実感できなさそうなので、今回は food101 データセットを利用します。

データセットについてはこちらに書いてあります。512×512サイズの画像で、101クラス分類です。

 

まずは、tensorflow_datasets から food101 をダウンロードして、224×224 にリサイズします。メモリの問題で、バッチサイズは32にして、データセットを作成します。

 

tensorflow-hub から、mobilenet_v2 の学習済みモデルの特徴抽出器の部分のみを取り出します。重みは固定せず学習させます。

 

まずは、@tf.function なしで学習とテストの関数を作成します。

スタートタイムとエンドタイムの差で、1 エポックで何秒かかったか測ります。

約 855 秒ですので、14 分ちょっとですね。

 

では train 関数に @tf.function をつけて再度 1 エポック回します。

596 秒!およそ 10 分なので、1/3 ほど学習時間が短縮されました。

 

▼公式ドキュメントの解説

https://www.tensorflow.org/tutorials/customization/performance?hl=ja

おわり