アプリ開発記録です。前回はどんなアプリを作るか決めました。
今回は、核となる処理が実現できるか見ていこうと思います。
・COCOデータセット学習済みモデルを用いて、画像内の Person をセグメンテーション
・Person の箇所だけ除去
・除去した箇所を、モザイクパターン画像で埋める
Before

After

それではやっていこうと思います。
COCOデータセットでの学習済みモデルを用いて、セグメンテーション
Person クラスが分類ラベルに含まれている代表的なデータセットでいうと、COCOデータセットがあります。(Person は代表的なデータセットにはだいたい含まれているような気もしますが・・)
ですので、COCOで学習された重みを用いれば、わざわざ再学習せずとも人間をセグメンテーションすることが簡単にできるはずです。
torchvision に組み込みのモデルが COCO データセットで Pascal VOC の 21クラス(中に Person 含む)を学習させたものでしたので、今回はそれを使わせてもらうことにしましょう。
まずは画像を読み込みます。
1 2 3 4 5 6 7 8 9 10 |
import torch import torchvision from torchvision import models from torchvision import transforms import numpy as np from PIL import Image # 画像の読み込み sample1 = Image.open("data/sample1.jpg") |
前処理を行います。この画像が 4016 x 6016 の高解像度画像だったので、リサイズして小さくします。(高解像度画像の対応は、ミニマムにいったん一通りできてから考えます。)
1 2 3 4 5 6 7 |
transform = transforms.Compose([ transforms.Resize((800, 1200)), transforms.ToTensor() ]) inputs = transform(sample1) inputs = inputs.unsqueeze(0) |
準備ができたのでモデルを用意して推論。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
#推論 model = models.segmentation.deeplabv3_resnet101(pretrained=True, num_classes=21) model.eval() pred = model(inputs) # 後処理 mask = torch.argmax(pred["out"][0], 0) print(mask.size()) """ torch.Size([800, 1200]) """ print(mask.unique()) """ tensor([ 0, 15]) """こ |
マスク画像を獲得することができました。今回はラベル 15 が Person です。
Person の箇所だけ除去
では、マスク画像の情報をもとに、元画像から Person の箇所だけ除去します。numpy で処理します。
1 2 3 4 5 6 7 8 9 10 |
# 元画像を再度読み込みリサイズ img = Image.open("data/sample1.jpg") img = np.array(img.resize((1200, 800))) mask = np.stack([mask, mask, mask], 2) print(mask.shape) """ (800, 1200, 3) """ |
元画像の中で、mask 画像の ラベルが 15 出ない位置を、輝度 0 で埋めます。
1 2 3 4 5 6 |
import matplotlib.pyplot as plt %matplotlib inline result = np.where(mask != 15, img , 0) plt.imshow(result) |
除去した箇所を、モザイクパターンで埋める
では、最後に先程できた画像で輝度が 0 の箇所を、任意のモザイク画像で埋めていきます。今回は木目のようなパターンでやってみます。
1 2 3 4 5 6 7 8 9 10 11 |
mosaic = Image.open("mosaic.png") mosaic = np.array(mosaic.resize((1200, 800))) print(mosaic.shape) """ (800, 1200, 3) """ output = np.where(result == 0, mosaic, result) plt.imshow(out) |
うまくいきました!いいぞいいぞ。
(続き)