物体検出を改めてお勉強し直しています。
自前のデータセットで再学習をさせる場合には、レポジトリの案内に従って所定の位置にファイル配置したりすれば学習は回ってしまったりしますが、一からカスタムデータセット作成まで隠れた部分をきちんと実装するとどうなるのか、見ていきたいと思います。
Oxford-IIIT Pet Dataset のデータを使って見ていきます。こちらは37種類の犬猫のデータセットです。(物体検出にも対応)
・アノテーションデータのフォーマット
・XMLファイルをパースして必要な情報を取り出す
・カスタムデータセットクラスの作成
アノテーションデータのフォーマット
アノテーションデータは、Pascal VOC形式のものは以下のような xml ファイルで得られます。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
<annotation> <folder>OXIIIT</folder> <filename>Abyssinian_1.jpg</filename> <source> <database>OXFORD-IIIT Pet Dataset</database> <annotation>OXIIIT</annotation> <image>flickr</image> </source> <size> <width>600</width> <height>400</height> <depth>3</depth> </size> <segmented>0</segmented> <object> <name>cat</name> <pose>Frontal</pose> <truncated>0</truncated> <occluded>0</occluded> <bndbox> <xmin>333</xmin> <ymin>72</ymin> <xmax>425</xmax> <ymax>158</ymax> </bndbox> <difficult>0</difficult> </object> </annotation> |
画像のサイズや、バウンディングボックスの座標の位置などが記載されています。
truncated はオブジェクトが部分的に見えている場合は0、完全に見えていてbboxで完全に囲えている場合は1とつく。
difficult はオブジェクトの認識が困難であれば1、そうでなければ0とつくそうです。
このままでは使用できないため、中から使用する情報だけを取り出して上げる必要があります。
XMLファイルをパースして必要な情報を取り出す
今回は、画像のサイズ(width, height)とbboxの座標、クラスの番号を取り出します。
※「作りながら学ぶ!PyTorchによる発展ディープラーニング」のコードをもとに微修正
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import xml.etree.ElementTree as ET class xml2list(object): def __init__(self, classes): self.classes = classes def __call__(self, xml_path): ret = [] xml = ET.parse(xml_path).getroot() for size in xml.iter("size"): width = float(size.find("width").text) height = float(size.find("height").text) for obj in xml.iter("object"): difficult = int(obj.find("difficult").text) if difficult == 1: continue bndbox = [width, height] name = obj.find("name").text.lower().strip() bbox = obj.find("bndbox") pts = ["xmin", "ymin", "xmax", "ymax"] for pt in pts: cur_pixel = float(bbox.find(pt).text) bndbox.append(cur_pixel) label_idx = self.classes.index(name) bndbox.append(label_idx) ret += [bndbox] return np.array(ret) # [width, height, xmin, ymin, xamx, ymax, label_idx] |
パースするクラスを定義し、試しに1つのxmlファイルに対して実行して、のぞみの情報が帰ってくるか確認します。
1 2 3 4 5 6 7 8 9 10 11 |
xml_paths = glob("./annotations/*.xml") classes = ["dog", "cat"] transform_anno = xml2list(classes) # 動作確認 transform_anno(xml_paths[0]) ''' array([[500., 333., 112., 14., 393., 298., 1.]]) ''' |
すべての画像の情報を取り出して、データフレームで見やすくします。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
df = pd.DataFrame(columns=["image_id", "width", "height", "xmin", "ymin", "xmax", "ymax", "class"]) for path in xml_paths: image_id = path.split("/")[-1].split(".")[0] bboxs = transform_anno(path) for bbox in bboxs: tmp = pd.Series(bbox, index=["width", "height", "xmin", "ymin", "xmax", "ymax", "class"]) tmp["image_id"] = image_id df = df.append(tmp, ignore_index=True) df = df.sort_values(by="image_id", ascending=True) df.head() |
使用するモデルに応じて、バウンディングボックスに与える座標の形式が異なります。今回のような [xmin, ymin, xmax, ymax] (あるいは[x1, y1, x2, y2])のような座標は Faster R-CNN 等で使えますが、YOLO を使う際には、[center_x, center_y, width, height] のような形式で座標情報を与える必要があるため、計算をして上げる必要があります。
カスタムデータセットクラスの作成
アノテーションデータに関しては、boxes, area, labels, image_id 等の情報を辞書型で整理して渡してあげる必要があります。
各情報は下記のとおり。
target: a dict containing the following fields
boxes (FloatTensor[N, 4])
: the coordinates of theN
bounding boxes in[x0, y0, x1, y1]
format, ranging from0
toW
and0
toH
labels (Int64Tensor[N])
: the label for each bounding box.0
represents always the background class.image_id (Int64Tensor[1])
: an image identifier. It should be unique between all the images in the dataset, and is used during evaluationarea (Tensor[N])
: The area of the bounding box. This is used during evaluation with the COCO metric, to separate the metric scores between small, medium and large boxes.iscrowd (UInt8Tensor[N])
: instances with iscrowd=True will be ignored during evaluation.- (optionally)
masks (UInt8Tensor[N, H, W])
: The segmentation masks for each one of the objects- (optionally)
keypoints (FloatTensor[N, K, 3])
: For each one of the N objects, it contains the K keypoints in[x, y, visibility]
format, defining the object. visibility=0 means that the keypoint is not visible. Note that for data augmentation, the notion of flipping a keypoint is dependent on the data representation, and you should probably adaptreferences/detection/transforms.py
for your new keypoint representation
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
class MyDataset(torch.utils.data.Dataset): def __init__(self, df, image_dir): super().__init__() self.image_ids = df["image_id"].unique() self.df = df self.image_dir = image_dir def __getitem__(self, index): transform = transforms.Compose([ transforms.ToTensor() ]) # 入力画像の読み込み image_id = self.image_ids[index] image = Image.open(f"{self.image_dir}/{image_id}.jpg") image = transform(image) # アノテーションデータの読み込み records = self.df[self.df["image_id"] == image_id] boxes = torch.tensor(records[["xmin", "ymin", "xmax", "ymax"]].values, dtype=torch.float32) area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) area = torch.as_tensor(area, dtype=torch.float32) labels = torch.tensor(records["class"].values, dtype=torch.int64) iscrowd = torch.zeros((records.shape[0], ), dtype=torch.int64) target = {} target["boxes"] = boxes target["labels"]= labels target["image_id"] = torch.tensor([index]) target["area"] = area target["iscrowd"] = iscrowd return image, target, image_id def __len__(self): return self.image_ids.shape[0] |
1 2 |
image_dir = "./images/" dataset = MyDataset(df, image_dir) |
いったんここまで。
最後まで読んでいただきありがとうございました。
参考
作りながら学ぶ!PyTorchによる発展ディープラーニング
COCO and Pascal VOC data format for Object detection
TORCHVISION OBJECT DETECTION FINETUNING TUTORIAL