{ "cells": [ { "cell_type": "markdown", "id": "ab23c813", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "\n", "# SSD目标检测\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/tutorials/application/zh_cn/cv/mindspore_ssd.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/tutorials/application/zh_cn/cv/mindspore_ssd.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/tutorials/application/source_zh_cn/cv/ssd.ipynb)" ] }, { "cell_type": "markdown", "id": "9eb3db9c", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 模型简介\n", "\n", "SSD,全称Single Shot MultiBox Detector,是Wei Liu在ECCV 2016上提出的一种目标检测算法。使用Nvidia Titan X在VOC 2007测试集上,SSD对于输入尺寸300x300的网络,达到74.3%mAP(mean Average Precision)以及59FPS;对于512x512的网络,达到了76.9%mAP ,超越当时最强的Faster RCNN(73.2%mAP)。具体可参考论文[1]。\n", "SSD目标检测主流算法分成可以两个类型:\n", "\n", "1. two-stage方法:RCNN系列
\n", "\n", " 通过算法产生候选框,然后再对这些候选框进行分类和回归。
\n", "\n", "2. one-stage方法:YOLO和SSD
\n", "\n", " 直接通过主干网络给出类别位置信息,不需要区域生成。
\n", "\n", "SSD是单阶段的目标检测算法,通过卷积神经网络进行特征提取,取不同的特征层进行检测输出,所以SSD是一种多尺度的检测方法。在需要检测的特征层,直接使用一个3 $\\times$ 3卷积,进行通道的变换。SSD采用了anchor的策略,预设不同长宽比例的anchor,每一个输出特征层基于anchor预测多个检测框(4或者6)。采用了多尺度检测方法,浅层用于检测小目标,深层用于检测大目标。SSD的框架如下图:\n", "\n", "![SSD-1](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_1.png)\n" ] }, { "cell_type": "markdown", "id": "8b982901", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### 模型结构\n", "\n", "SSD采用VGG16作为基础模型,然后在VGG16的基础上新增了卷积层来获得更多的特征图以用于检测。SSD的网络结构如图所示。上面是SSD模型,下面是YOLO模型,可以明显看到SSD利用了多尺度的特征图做检测。\n", "\n", "![SSD-2](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_2.jpg)\n", "
\n", "\n", "两种单阶段目标检测算法的比较:
\n", "SSD先通过卷积不断进行特征提取,在需要检测物体的网络,直接通过一个3 $\\times$ 3卷积得到输出,卷积的通道数由anchor数量和类别数量决定,具体为(anchor数量*(类别数量+4))。 \n", "SSD对比了YOLO系列目标检测方法,不同的是SSD通过卷积得到最后的边界框,而YOLO对最后的输出采用全连接的形式得到一维向量,对向量进行拆解得到最终的检测框。" ] }, { "cell_type": "markdown", "id": "f176a899", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### 模型特点\n", "\n", "- 多尺度检测\n", "\n", " 在SSD的网络结构图中我们可以看到,SSD使用了多个特征层,特征层的尺寸分别是38 $\\times$ 38,19 $\\times$ 19,10 $\\times$ 10,5 $\\times$ 5,3 $\\times$ 3,1 $\\times$ 1,一共6种不同的特征图尺寸。大尺度特征图(较靠前的特征图)可以用来检测小物体,而小尺度特征图(较靠后的特征图)用来检测大物体。多尺度检测的方式,可以使得检测更加充分(SSD属于密集检测),更能检测出小目标。\n", "\n", "- 采用卷积进行检测\n", "\n", " 与YOLO最后采用全连接层不同,SSD直接采用卷积对不同的特征图来进行提取检测结果。对于形状为m $\\times$ n $\\times$ p的特征图,只需要采用3 $\\times$ 3 $\\times$ p这样比较小的卷积核得到检测值。\n", "\n", "- 预设anchor\n", "\n", " 在YOLOv1中,直接由网络预测目标的尺寸,这种方式使得预测框的长宽比和尺寸没有限制,难以训练。在SSD中,采用预设边界框,我们习惯称它为anchor(在SSD论文中叫default bounding boxes),预测框的尺寸在anchor的指导下进行微调。" ] }, { "cell_type": "markdown", "id": "6c469bbf", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 环境准备\n", "\n", "本案例基于MindSpore实现,开始实验前,请确保本地已经安装了mindspore、download、pycocotools、opencv-python。" ] }, { "cell_type": "markdown", "id": "ff194502", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 数据准备与处理\n", "\n", "本案例所使用的数据集为COCO 2017。为了更加方便地保存和加载数据,本案例中在数据读取前首先将COCO数据集转换成MindRecord格式。使用MindSpore Record数据格式可以减少磁盘IO、网络IO开销,从而获得更好的使用体验和性能提升。\n", "首先我们需要下载处理好的MindRecord格式的COCO数据集。\n", "运行以下代码将数据集下载并解压到指定路径。" ] }, { "cell_type": "code", "execution_count": 2, "id": "365ae83b", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/ssd_datasets.zip (16.6 MB)\n", "\n", "file_sizes: 100%|██████████████████████████| 17.4M/17.4M [00:00<00:00, 26.9MB/s]\n", "Extracting zip file...\n", "Successfully downloaded / unzipped to ./\n" ] } ], "source": [ "from download import download\n", "\n", "dataset_url = \"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/ssd_datasets.zip\"\n", "path = \"./\"\n", "path = download(dataset_url, path, kind=\"zip\", replace=True)" ] }, { "cell_type": "markdown", "id": "ee1b5ac8", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "然后我们为数据处理定义一些输入:" ] }, { "cell_type": "code", "execution_count": 3, "id": "54bafd5a", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "coco_root = \"./datasets/\"\n", "anno_json = \"./datasets/annotations/instances_val2017.json\"\n", "\n", "train_cls = ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',\n", " 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',\n", " 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',\n", " 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',\n", " 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',\n", " 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',\n", " 'kite', 'baseball bat', 'baseball glove', 'skateboard',\n", " 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',\n", " 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',\n", " 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',\n", " 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',\n", " 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',\n", " 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',\n", " 'refrigerator', 'book', 'clock', 'vase', 'scissors',\n", " 'teddy bear', 'hair drier', 'toothbrush']\n", "\n", "train_cls_dict = {}\n", "for i, cls in enumerate(train_cls):\n", " train_cls_dict[cls] = i" ] }, { "cell_type": "markdown", "id": "8292a570", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### 数据采样\n", "\n", "为了使模型对于各种输入对象大小和形状更加鲁棒,SSD算法每个训练图像通过以下选项之一随机采样:\n", "\n", "- 使用整个原始输入图像\n", "\n", "- 采样一个区域,使采样区域和原始图片最小的交并比重叠为0.1,0.3,0.5,0.7或0.9\n", "\n", "- 随机采样一个区域\n", "\n", "每个采样区域的大小为原始图像大小的[0.3,1],长宽比在1/2和2之间。如果真实标签框中心在采样区域内,则保留两者重叠部分作为新图片的真实标注框。在上述采样步骤之后,将每个采样区域大小调整为固定大小,并以0.5的概率水平翻转。" ] }, { "cell_type": "code", "execution_count": 4, "id": "cc75d5a8", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import cv2\n", "import numpy as np\n", "\n", "def _rand(a=0., b=1.):\n", " return np.random.rand() * (b - a) + a\n", "\n", "def intersect(box_a, box_b):\n", " \"\"\"Compute the intersect of two sets of boxes.\"\"\"\n", " max_yx = np.minimum(box_a[:, 2:4], box_b[2:4])\n", " min_yx = np.maximum(box_a[:, :2], box_b[:2])\n", " inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf)\n", " return inter[:, 0] * inter[:, 1]\n", "\n", "def jaccard_numpy(box_a, box_b):\n", " \"\"\"Compute the jaccard overlap of two sets of boxes.\"\"\"\n", " inter = intersect(box_a, box_b)\n", " area_a = ((box_a[:, 2] - box_a[:, 0]) *\n", " (box_a[:, 3] - box_a[:, 1]))\n", " area_b = ((box_b[2] - box_b[0]) *\n", " (box_b[3] - box_b[1]))\n", " union = area_a + area_b - inter\n", " return inter / union\n", "\n", "def random_sample_crop(image, boxes):\n", " \"\"\"Crop images and boxes randomly.\"\"\"\n", " height, width, _ = image.shape\n", " min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9])\n", "\n", " if min_iou is None:\n", " return image, boxes\n", "\n", " for _ in range(50):\n", " image_t = image\n", " w = _rand(0.3, 1.0) * width\n", " h = _rand(0.3, 1.0) * height\n", " # aspect ratio constraint b/t .5 & 2\n", " if h / w < 0.5 or h / w > 2:\n", " continue\n", "\n", " left = _rand() * (width - w)\n", " top = _rand() * (height - h)\n", " rect = np.array([int(top), int(left), int(top + h), int(left + w)])\n", " overlap = jaccard_numpy(boxes, rect)\n", "\n", " # dropout some boxes\n", " drop_mask = overlap > 0\n", " if not drop_mask.any():\n", " continue\n", "\n", " if overlap[drop_mask].min() < min_iou and overlap[drop_mask].max() > (min_iou + 0.2):\n", " continue\n", "\n", " image_t = image_t[rect[0]:rect[2], rect[1]:rect[3], :]\n", " centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0\n", " m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])\n", " m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])\n", "\n", " # mask in that both m1 and m2 are true\n", " mask = m1 * m2 * drop_mask\n", "\n", " # have any valid boxes? try again if not\n", " if not mask.any():\n", " continue\n", "\n", " # take only matching gt boxes\n", " boxes_t = boxes[mask, :].copy()\n", " boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2])\n", " boxes_t[:, :2] -= rect[:2]\n", " boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4])\n", " boxes_t[:, 2:4] -= rect[:2]\n", "\n", " return image_t, boxes_t\n", " return image, boxes\n", "\n", "def ssd_bboxes_encode(boxes):\n", " \"\"\"Labels anchors with ground truth inputs.\"\"\"\n", "\n", " def jaccard_with_anchors(bbox):\n", " \"\"\"Compute jaccard score a box and the anchors.\"\"\"\n", " # Intersection bbox and volume.\n", " ymin = np.maximum(y1, bbox[0])\n", " xmin = np.maximum(x1, bbox[1])\n", " ymax = np.minimum(y2, bbox[2])\n", " xmax = np.minimum(x2, bbox[3])\n", " w = np.maximum(xmax - xmin, 0.)\n", " h = np.maximum(ymax - ymin, 0.)\n", "\n", " # Volumes.\n", " inter_vol = h * w\n", " union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol\n", " jaccard = inter_vol / union_vol\n", " return np.squeeze(jaccard)\n", "\n", " pre_scores = np.zeros((8732), dtype=np.float32)\n", " t_boxes = np.zeros((8732, 4), dtype=np.float32)\n", " t_label = np.zeros((8732), dtype=np.int64)\n", " for bbox in boxes:\n", " label = int(bbox[4])\n", " scores = jaccard_with_anchors(bbox)\n", " idx = np.argmax(scores)\n", " scores[idx] = 2.0\n", " mask = (scores > matching_threshold)\n", " mask = mask & (scores > pre_scores)\n", " pre_scores = np.maximum(pre_scores, scores * mask)\n", " t_label = mask * label + (1 - mask) * t_label\n", " for i in range(4):\n", " t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i]\n", "\n", " index = np.nonzero(t_label)\n", "\n", " # Transform to tlbr.\n", " bboxes = np.zeros((8732, 4), dtype=np.float32)\n", " bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2\n", " bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]]\n", "\n", " # Encode features.\n", " bboxes_t = bboxes[index]\n", " default_boxes_t = default_boxes[index]\n", " bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * 0.1)\n", " tmp = np.maximum(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4], 0.000001)\n", " bboxes_t[:, 2:4] = np.log(tmp) / 0.2\n", " bboxes[index] = bboxes_t\n", "\n", " num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)\n", " return bboxes, t_label.astype(np.int32), num_match\n", "\n", "def preprocess_fn(img_id, image, box, is_training):\n", " \"\"\"Preprocess function for dataset.\"\"\"\n", " cv2.setNumThreads(2)\n", "\n", " def _infer_data(image, input_shape):\n", " img_h, img_w, _ = image.shape\n", " input_h, input_w = input_shape\n", "\n", " image = cv2.resize(image, (input_w, input_h))\n", "\n", " # When the channels of image is 1\n", " if len(image.shape) == 2:\n", " image = np.expand_dims(image, axis=-1)\n", " image = np.concatenate([image, image, image], axis=-1)\n", "\n", " return img_id, image, np.array((img_h, img_w), np.float32)\n", "\n", " def _data_aug(image, box, is_training, image_size=(300, 300)):\n", " ih, iw, _ = image.shape\n", " h, w = image_size\n", " if not is_training:\n", " return _infer_data(image, image_size)\n", " # Random crop\n", " box = box.astype(np.float32)\n", " image, box = random_sample_crop(image, box)\n", " ih, iw, _ = image.shape\n", " # Resize image\n", " image = cv2.resize(image, (w, h))\n", " # Flip image or not\n", " flip = _rand() < .5\n", " if flip:\n", " image = cv2.flip(image, 1, dst=None)\n", " # When the channels of image is 1\n", " if len(image.shape) == 2:\n", " image = np.expand_dims(image, axis=-1)\n", " image = np.concatenate([image, image, image], axis=-1)\n", " box[:, [0, 2]] = box[:, [0, 2]] / ih\n", " box[:, [1, 3]] = box[:, [1, 3]] / iw\n", " if flip:\n", " box[:, [1, 3]] = 1 - box[:, [3, 1]]\n", " box, label, num_match = ssd_bboxes_encode(box)\n", " return image, box, label, num_match\n", "\n", " return _data_aug(image, box, is_training, image_size=[300, 300])" ] }, { "cell_type": "markdown", "id": "ac590832", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### 数据集创建" ] }, { "cell_type": "code", "execution_count": 5, "id": "46071849", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from mindspore import Tensor\n", "from mindspore.dataset import MindDataset\n", "from mindspore.dataset.vision import Decode, HWC2CHW, Normalize, RandomColorAdjust\n", "\n", "\n", "def create_ssd_dataset(mindrecord_file, batch_size=32, device_num=1, rank=0,\n", " is_training=True, num_parallel_workers=1, use_multiprocessing=True):\n", " \"\"\"Create SSD dataset with MindDataset.\"\"\"\n", " dataset = MindDataset(mindrecord_file, columns_list=[\"img_id\", \"image\", \"annotation\"], num_shards=device_num,\n", " shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training)\n", "\n", " decode = Decode()\n", " dataset = dataset.map(operations=decode, input_columns=[\"image\"])\n", "\n", " change_swap_op = HWC2CHW()\n", " # Computed from random subset of ImageNet training images\n", " normalize_op = Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],\n", " std=[0.229 * 255, 0.224 * 255, 0.225 * 255])\n", " color_adjust_op = RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)\n", " compose_map_func = (lambda img_id, image, annotation: preprocess_fn(img_id, image, annotation, is_training))\n", "\n", " if is_training:\n", " output_columns = [\"image\", \"box\", \"label\", \"num_match\"]\n", " trans = [color_adjust_op, normalize_op, change_swap_op]\n", " else:\n", " output_columns = [\"img_id\", \"image\", \"image_shape\"]\n", " trans = [normalize_op, change_swap_op]\n", "\n", " dataset = dataset.map(operations=compose_map_func, input_columns=[\"img_id\", \"image\", \"annotation\"],\n", " output_columns=output_columns, python_multiprocessing=use_multiprocessing,\n", " num_parallel_workers=num_parallel_workers)\n", "\n", " dataset = dataset.map(operations=trans, input_columns=[\"image\"], python_multiprocessing=use_multiprocessing,\n", " num_parallel_workers=num_parallel_workers)\n", "\n", " dataset = dataset.batch(batch_size, drop_remainder=True)\n", " return dataset" ] }, { "cell_type": "markdown", "id": "dfb61c0c", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 模型构建\n", "\n", "SSD的网络结构主要分为以下几个部分:\n", "\n", "![SSD-3](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_3.jpg)\n", "\n", "- VGG16 Base Layer\n", "\n", "- Extra Feature Layer\n", "\n", "- Detection Layer\n", "\n", "- NMS\n", "\n", "- Anchor\n", "\n", "### Backbone Layer\n", "\n", "![SSD-4](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_4.png)\n", "\n", "输入图像经过预处理后大小固定为300×300,首先经过backbone,本案例中使用的是VGG16网络的前13个卷积层,然后分别将VGG16的全连接层fc6和fc7转换成3 $\\times$ 3卷积层block6和1 $\\times$ 1卷积层block7,进一步提取特征。 在block6中,使用了空洞数为6的空洞卷积,其padding也为6,这样做同样也是为了增加感受野的同时保持参数量与特征图尺寸的不变。\n", "\n", "### Extra Feature Layer\n", "\n", "在VGG16的基础上,SSD进一步增加了4个深度卷积层,用于提取更高层的语义信息:\n", "\n", "![SSD-5](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_5.png)\n", "\n", "block8-11,用于更高语义信息的提取。block8的通道数为512,而block9、block10与block11的通道数都为256。从block7到block11,这5个卷积后输出特征图的尺寸依次为19×19、10×10、5×5、3×3和1×1。为了降低参数量,使用了1×1卷积先降低通道数为该层输出通道数的一半,再利用3×3卷积进行特征提取。\n", "\n", "### Anchor\n", "\n", "SSD采用了PriorBox来进行区域生成。将固定大小宽高的PriorBox作为先验的感兴趣区域,利用一个阶段完成能够分类与回归。设计大量的密集的PriorBox保证了对整幅图像的每个地方都有检测。PriorBox位置的表示形式是以中心点坐标和框的宽、高(cx,cy,w,h)来表示的,同时都转换成百分比的形式。\n", "PriorBox生成规则:\n", "SSD由6个特征层来检测目标,在不同特征层上,PriorBox的尺寸scale大小是不一样的,最低层的scale=0.1,最高层的scale=0.95,其他层的计算公式如下:\n", "\n", "![SSD-6](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_6.jpg)\n", "\n", "在某个特征层上其scale一定,那么会设置不同长宽比ratio的PriorBox,其长和宽的计算公式如下:\n", "\n", "![SSD-7](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_7.jpg)\n", "\n", "在ratio=1的时候,还会根据该特征层和下一个特征层计算一个特定scale的PriorBox(长宽比ratio=1),计算公式如下:\n", "\n", "![SSD-8](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_8.jpg)\n", "\n", "每个特征层的每个点都会以上述规则生成PriorBox,(cx,cy)由当前点的中心点来确定,由此每个特征层都生成大量密集的PriorBox,如下图:\n", "\n", "![SSD-9](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_9.png)\n", "\n", "SSD使用了第4、7、8、9、10和11这6个卷积层得到的特征图,这6个特征图尺寸越来越小,而其对应的感受野越来越大。6个特征图上的每一个点分别对应4、6、6、6、4、4个PriorBox。某个特征图上的一个点根据下采样率可以得到在原图的坐标,以该坐标为中心生成4个或6个不同大小的PriorBox,然后利用特征图的特征去预测每一个PriorBox对应类别与位置的预测量。例如:第8个卷积层得到的特征图大小为10×10×512,每个点对应6个PriorBox,一共有600个PriorBox。定义MultiBox类,生成多个预测框。\n", "\n", "### Detection Layer\n", "\n", "![SSD-10](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_10.jpg)\n", "\n", "SSD模型一共有6个预测特征图,对于其中一个尺寸为m\\*n,通道为p的预测特征图,假设其每个像素点会产生k个anchor,每个anchor会对应c个类别和4个回归偏移量,使用(4+c)k个尺寸为3x3,通道为p的卷积核对该预测特征图进行卷积操作,得到尺寸为m\\*n,通道为(4+c)m\\*k的输出特征图,它包含了预测特征图上所产生的每个anchor的回归偏移量和各类别概率分数。所以对于尺寸为m\\*n的预测特征图,总共会产生(4+c)k\\*m\\*n个结果。cls分支的输出通道数为k\\*class_num,loc分支的输出通道数为k\\*4。" ] }, { "cell_type": "code", "execution_count": 6, "id": "79967cd2", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from mindspore import nn\n", "\n", "def _make_layer(channels):\n", " in_channels = channels[0]\n", " layers = []\n", " for out_channels in channels[1:]:\n", " layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3))\n", " layers.append(nn.ReLU())\n", " in_channels = out_channels\n", " return nn.SequentialCell(layers)\n", "\n", "class Vgg16(nn.Cell):\n", " \"\"\"VGG16 module.\"\"\"\n", "\n", " def __init__(self):\n", " super(Vgg16, self).__init__()\n", " self.b1 = _make_layer([3, 64, 64])\n", " self.b2 = _make_layer([64, 128, 128])\n", " self.b3 = _make_layer([128, 256, 256, 256])\n", " self.b4 = _make_layer([256, 512, 512, 512])\n", " self.b5 = _make_layer([512, 512, 512, 512])\n", "\n", " self.m1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')\n", " self.m2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')\n", " self.m3 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')\n", " self.m4 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')\n", " self.m5 = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode='SAME')\n", "\n", " def construct(self, x):\n", " # block1\n", " x = self.b1(x)\n", " x = self.m1(x)\n", "\n", " # block2\n", " x = self.b2(x)\n", " x = self.m2(x)\n", "\n", " # block3\n", " x = self.b3(x)\n", " x = self.m3(x)\n", "\n", " # block4\n", " x = self.b4(x)\n", " block4 = x\n", " x = self.m4(x)\n", "\n", " # block5\n", " x = self.b5(x)\n", " x = self.m5(x)\n", "\n", " return block4, x" ] }, { "cell_type": "code", "execution_count": 7, "id": "1ba72ae2", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import mindspore as ms\n", "import mindspore.nn as nn\n", "import mindspore.ops as ops\n", "\n", "def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0):\n", " in_channels = in_channel\n", " out_channels = in_channel\n", " depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same',\n", " padding=pad, group=in_channels)\n", " conv = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, pad_mode='same', has_bias=True)\n", " bn = nn.BatchNorm2d(in_channel, eps=1e-3, momentum=0.97,\n", " gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)\n", "\n", " return nn.SequentialCell([depthwise_conv, bn, nn.ReLU6(), conv])\n", "\n", "class FlattenConcat(nn.Cell):\n", " \"\"\"FlattenConcat module.\"\"\"\n", "\n", " def __init__(self):\n", " super(FlattenConcat, self).__init__()\n", " self.num_ssd_boxes = 8732\n", "\n", " def construct(self, inputs):\n", " output = ()\n", " batch_size = ops.shape(inputs[0])[0]\n", " for x in inputs:\n", " x = ops.transpose(x, (0, 2, 3, 1))\n", " output += (ops.reshape(x, (batch_size, -1)),)\n", " res = ops.concat(output, axis=1)\n", " return ops.reshape(res, (batch_size, self.num_ssd_boxes, -1))\n", "\n", "class MultiBox(nn.Cell):\n", " \"\"\"\n", " Multibox conv layers. Each multibox layer contains class conf scores and localization predictions.\n", " \"\"\"\n", "\n", " def __init__(self):\n", " super(MultiBox, self).__init__()\n", " num_classes = 81\n", " out_channels = [512, 1024, 512, 256, 256, 256]\n", " num_default = [4, 6, 6, 6, 4, 4]\n", "\n", " loc_layers = []\n", " cls_layers = []\n", " for k, out_channel in enumerate(out_channels):\n", " loc_layers += [_last_conv2d(out_channel, 4 * num_default[k],\n", " kernel_size=3, stride=1, pad_mod='same', pad=0)]\n", " cls_layers += [_last_conv2d(out_channel, num_classes * num_default[k],\n", " kernel_size=3, stride=1, pad_mod='same', pad=0)]\n", "\n", " self.multi_loc_layers = nn.CellList(loc_layers)\n", " self.multi_cls_layers = nn.CellList(cls_layers)\n", " self.flatten_concat = FlattenConcat()\n", "\n", " def construct(self, inputs):\n", " loc_outputs = ()\n", " cls_outputs = ()\n", " for i in range(len(self.multi_loc_layers)):\n", " loc_outputs += (self.multi_loc_layers[i](inputs[i]),)\n", " cls_outputs += (self.multi_cls_layers[i](inputs[i]),)\n", " return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs)\n", "\n", "class SSD300Vgg16(nn.Cell):\n", " \"\"\"SSD300Vgg16 module.\"\"\"\n", "\n", " def __init__(self):\n", " super(SSD300Vgg16, self).__init__()\n", "\n", " # VGG16 backbone: block1~5\n", " self.backbone = Vgg16()\n", "\n", " # SSD blocks: block6~7\n", " self.b6_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6, pad_mode='pad')\n", " self.b6_2 = nn.Dropout(p=0.5)\n", "\n", " self.b7_1 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1)\n", " self.b7_2 = nn.Dropout(p=0.5)\n", "\n", " # Extra Feature Layers: block8~11\n", " self.b8_1 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, padding=1, pad_mode='pad')\n", " self.b8_2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, pad_mode='valid')\n", "\n", " self.b9_1 = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1, padding=1, pad_mode='pad')\n", " self.b9_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, pad_mode='valid')\n", "\n", " self.b10_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1)\n", " self.b10_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, pad_mode='valid')\n", "\n", " self.b11_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1)\n", " self.b11_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, pad_mode='valid')\n", "\n", " # boxes\n", " self.multi_box = MultiBox()\n", "\n", " def construct(self, x):\n", " # VGG16 backbone: block1~5\n", " block4, x = self.backbone(x)\n", "\n", " # SSD blocks: block6~7\n", " x = self.b6_1(x) # 1024\n", " x = self.b6_2(x)\n", "\n", " x = self.b7_1(x) # 1024\n", " x = self.b7_2(x)\n", " block7 = x\n", "\n", " # Extra Feature Layers: block8~11\n", " x = self.b8_1(x) # 256\n", " x = self.b8_2(x) # 512\n", " block8 = x\n", "\n", " x = self.b9_1(x) # 128\n", " x = self.b9_2(x) # 256\n", " block9 = x\n", "\n", " x = self.b10_1(x) # 128\n", " x = self.b10_2(x) # 256\n", " block10 = x\n", "\n", " x = self.b11_1(x) # 128\n", " x = self.b11_2(x) # 256\n", " block11 = x\n", "\n", " # boxes\n", " multi_feature = (block4, block7, block8, block9, block10, block11)\n", " pred_loc, pred_label = self.multi_box(multi_feature)\n", " if not self.training:\n", " pred_label = ops.sigmoid(pred_label)\n", " pred_loc = pred_loc.astype(ms.float32)\n", " pred_label = pred_label.astype(ms.float32)\n", " return pred_loc, pred_label" ] }, { "cell_type": "markdown", "id": "d1942062", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 损失函数\n", "\n", "SSD算法的目标函数分为两部分:计算相应的预选框与目标类别的置信度误差(confidence loss, conf)以及相应的位置误差(locatization loss, loc):\n", "\n", "![SSD-11](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_11.jpg)\n", "\n", "其中:
\n", "N 是先验框的正样本数量;
\n", "c 为类别置信度预测值;
\n", "l 为先验框的所对应边界框的位置预测值;
\n", "g 为ground truth的位置参数
\n", "α 用以调整confidence loss和location loss之间的比例,默认为1。\n", "\n", "### 对于位置损失函数\n", "\n", "针对所有的正样本,采用 Smooth L1 Loss, 位置信息都是 encode 之后的位置信息。\n", "\n", "![SSD-12](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_12.jpg)\n", "\n", "### 对于置信度损失函数\n", "\n", "置信度损失是多类置信度(c)上的softmax损失。\n", "\n", "![SSD-13](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_13.jpg)" ] }, { "cell_type": "code", "execution_count": 8, "id": "d9a35c7c", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "def class_loss(logits, label):\n", " \"\"\"Calculate category losses.\"\"\"\n", " label = ops.one_hot(label, ops.shape(logits)[-1], Tensor(1.0, ms.float32), Tensor(0.0, ms.float32))\n", " weight = ops.ones_like(logits)\n", " pos_weight = ops.ones_like(logits)\n", " sigmiod_cross_entropy = ops.binary_cross_entropy_with_logits(logits, label, weight.astype(ms.float32), pos_weight.astype(ms.float32))\n", " sigmoid = ops.sigmoid(logits)\n", " label = label.astype(ms.float32)\n", " p_t = label * sigmoid + (1 - label) * (1 - sigmoid)\n", " modulating_factor = ops.pow(1 - p_t, 2.0)\n", " alpha_weight_factor = label * 0.75 + (1 - label) * (1 - 0.75)\n", " focal_loss = modulating_factor * alpha_weight_factor * sigmiod_cross_entropy\n", " return focal_loss" ] }, { "cell_type": "markdown", "id": "10871aeb", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Metrics\n", "\n", "在SSD中,训练过程是不需要用到非极大值抑制(NMS),但当进行检测时,例如输入一张图片要求输出框的时候,需要用到NMS过滤掉那些重叠度较大的预测框。
\n", "非极大值抑制的流程如下:\n", "\n", "1. 根据置信度得分进行排序\n", "\n", "2. 选择置信度最高的比边界框添加到最终输出列表中,将其从边界框列表中删除
\n", "\n", "3. 计算所有边界框的面积
\n", "\n", "4. 计算置信度最高的边界框与其它候选框的IoU
\n", "\n", "5. 删除IoU大于阈值的边界框
\n", "\n", "6. 重复上述过程,直至边界框列表为空
" ] }, { "cell_type": "code", "execution_count": 9, "id": "cea96244", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import json\n", "from pycocotools.coco import COCO\n", "from pycocotools.cocoeval import COCOeval\n", "\n", "\n", "def apply_eval(eval_param_dict):\n", " net = eval_param_dict[\"net\"]\n", " net.set_train(False)\n", " ds = eval_param_dict[\"dataset\"]\n", " anno_json = eval_param_dict[\"anno_json\"]\n", " coco_metrics = COCOMetrics(anno_json=anno_json,\n", " classes=train_cls,\n", " num_classes=81,\n", " max_boxes=100,\n", " nms_threshold=0.6,\n", " min_score=0.1)\n", " for data in ds.create_dict_iterator(output_numpy=True, num_epochs=1):\n", " img_id = data['img_id']\n", " img_np = data['image']\n", " image_shape = data['image_shape']\n", "\n", " output = net(Tensor(img_np))\n", "\n", " for batch_idx in range(img_np.shape[0]):\n", " pred_batch = {\n", " \"boxes\": output[0].asnumpy()[batch_idx],\n", " \"box_scores\": output[1].asnumpy()[batch_idx],\n", " \"img_id\": int(np.squeeze(img_id[batch_idx])),\n", " \"image_shape\": image_shape[batch_idx]\n", " }\n", " coco_metrics.update(pred_batch)\n", " eval_metrics = coco_metrics.get_metrics()\n", " return eval_metrics\n", "\n", "\n", "def apply_nms(all_boxes, all_scores, thres, max_boxes):\n", " \"\"\"Apply NMS to bboxes.\"\"\"\n", " y1 = all_boxes[:, 0]\n", " x1 = all_boxes[:, 1]\n", " y2 = all_boxes[:, 2]\n", " x2 = all_boxes[:, 3]\n", " areas = (x2 - x1 + 1) * (y2 - y1 + 1)\n", "\n", " order = all_scores.argsort()[::-1]\n", " keep = []\n", "\n", " while order.size > 0:\n", " i = order[0]\n", " keep.append(i)\n", "\n", " if len(keep) >= max_boxes:\n", " break\n", "\n", " xx1 = np.maximum(x1[i], x1[order[1:]])\n", " yy1 = np.maximum(y1[i], y1[order[1:]])\n", " xx2 = np.minimum(x2[i], x2[order[1:]])\n", " yy2 = np.minimum(y2[i], y2[order[1:]])\n", "\n", " w = np.maximum(0.0, xx2 - xx1 + 1)\n", " h = np.maximum(0.0, yy2 - yy1 + 1)\n", " inter = w * h\n", "\n", " ovr = inter / (areas[i] + areas[order[1:]] - inter)\n", "\n", " inds = np.where(ovr <= thres)[0]\n", "\n", " order = order[inds + 1]\n", " return keep\n", "\n", "\n", "class COCOMetrics:\n", " \"\"\"Calculate mAP of predicted bboxes.\"\"\"\n", "\n", " def __init__(self, anno_json, classes, num_classes, min_score, nms_threshold, max_boxes):\n", " self.num_classes = num_classes\n", " self.classes = classes\n", " self.min_score = min_score\n", " self.nms_threshold = nms_threshold\n", " self.max_boxes = max_boxes\n", "\n", " self.val_cls_dict = {i: cls for i, cls in enumerate(classes)}\n", " self.coco_gt = COCO(anno_json)\n", " cat_ids = self.coco_gt.loadCats(self.coco_gt.getCatIds())\n", " self.class_dict = {cat['name']: cat['id'] for cat in cat_ids}\n", "\n", " self.predictions = []\n", " self.img_ids = []\n", "\n", " def update(self, batch):\n", " pred_boxes = batch['boxes']\n", " box_scores = batch['box_scores']\n", " img_id = batch['img_id']\n", " h, w = batch['image_shape']\n", "\n", " final_boxes = []\n", " final_label = []\n", " final_score = []\n", " self.img_ids.append(img_id)\n", "\n", " for c in range(1, self.num_classes):\n", " class_box_scores = box_scores[:, c]\n", " score_mask = class_box_scores > self.min_score\n", " class_box_scores = class_box_scores[score_mask]\n", " class_boxes = pred_boxes[score_mask] * [h, w, h, w]\n", "\n", " if score_mask.any():\n", " nms_index = apply_nms(class_boxes, class_box_scores, self.nms_threshold, self.max_boxes)\n", " class_boxes = class_boxes[nms_index]\n", " class_box_scores = class_box_scores[nms_index]\n", "\n", " final_boxes += class_boxes.tolist()\n", " final_score += class_box_scores.tolist()\n", " final_label += [self.class_dict[self.val_cls_dict[c]]] * len(class_box_scores)\n", "\n", " for loc, label, score in zip(final_boxes, final_label, final_score):\n", " res = {}\n", " res['image_id'] = img_id\n", " res['bbox'] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]]\n", " res['score'] = score\n", " res['category_id'] = label\n", " self.predictions.append(res)\n", "\n", " def get_metrics(self):\n", " with open('predictions.json', 'w') as f:\n", " json.dump(self.predictions, f)\n", "\n", " coco_dt = self.coco_gt.loadRes('predictions.json')\n", " E = COCOeval(self.coco_gt, coco_dt, iouType='bbox')\n", " E.params.imgIds = self.img_ids\n", " E.evaluate()\n", " E.accumulate()\n", " E.summarize()\n", " return E.stats[0]\n", "\n", "\n", "class SsdInferWithDecoder(nn.Cell):\n", " \"\"\"\n", "SSD Infer wrapper to decode the bbox locations.\"\"\"\n", "\n", " def __init__(self, network, default_boxes, ckpt_path):\n", " super(SsdInferWithDecoder, self).__init__()\n", " param_dict = ms.load_checkpoint(ckpt_path)\n", " ms.load_param_into_net(network, param_dict)\n", " self.network = network\n", " self.default_boxes = default_boxes\n", " self.prior_scaling_xy = 0.1\n", " self.prior_scaling_wh = 0.2\n", "\n", " def construct(self, x):\n", " pred_loc, pred_label = self.network(x)\n", "\n", " default_bbox_xy = self.default_boxes[..., :2]\n", " default_bbox_wh = self.default_boxes[..., 2:]\n", " pred_xy = pred_loc[..., :2] * self.prior_scaling_xy * default_bbox_wh + default_bbox_xy\n", " pred_wh = ops.exp(pred_loc[..., 2:] * self.prior_scaling_wh) * default_bbox_wh\n", "\n", " pred_xy_0 = pred_xy - pred_wh / 2.0\n", " pred_xy_1 = pred_xy + pred_wh / 2.0\n", " pred_xy = ops.concat((pred_xy_0, pred_xy_1), -1)\n", " pred_xy = ops.maximum(pred_xy, 0)\n", " pred_xy = ops.minimum(pred_xy, 1)\n", " return pred_xy, pred_label" ] }, { "cell_type": "markdown", "id": "2ce744ef", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 训练过程\n", "\n", "### (1)先验框匹配\n", "\n", "在训练过程中,首先要确定训练图片中的ground truth(真实目标)与哪个先验框来进行匹配,与之匹配的先验框所对应的边界框将负责预测它。\n", "\n", "SSD的先验框与ground truth的匹配原则主要有两点:\n", "\n", "1. 对于图片中每个ground truth,找到与其IOU最大的先验框,该先验框与其匹配,这样可以保证每个ground truth一定与某个先验框匹配。通常称与ground truth匹配的先验框为正样本,反之,若一个先验框没有与任何ground truth进行匹配,那么该先验框只能与背景匹配,就是负样本。\n", "\n", "2. 对于剩余的未匹配先验框,若某个ground truth的IOU大于某个阈值(一般是0.5),那么该先验框也与这个ground truth进行匹配。尽管一个ground truth可以与多个先验框匹配,但是ground truth相对先验框还是太少了,所以负样本相对正样本会很多。为了保证正负样本尽量平衡,SSD采用了hard negative mining,就是对负样本进行抽样,抽样时按照置信度误差(预测背景的置信度越小,误差越大)进行降序排列,选取误差的较大的top-k作为训练的负样本,以保证正负样本比例接近1:3。\n", "\n", "注意点:\n", "\n", "1. 通常称与gt匹配的prior为正样本,反之,若某一个prior没有与任何一个gt匹配,则为负样本。\n", "\n", "2. 某个gt可以和多个prior匹配,而每个prior只能和一个gt进行匹配。\n", "\n", "3. 如果多个gt和某一个prior的IOU均大于阈值,那么prior只与IOU最大的那个进行匹配。\n", "\n", "![SSD-14](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_14.jpg)\n", "\n", "如上图所示,训练过程中的 prior boxes 和 ground truth boxes 的匹配,基本思路是:让每一个 prior box 回归并且到 ground truth box,这个过程的调控我们需要损失层的帮助,他会计算真实值和预测值之间的误差,从而指导学习的走向。\n", "\n", "### (2)损失函数\n", "\n", "损失函数使用的是上文提到的位置损失函数和置信度损失函数的加权和。\n", "\n", "### (3)数据增强\n", "\n", "使用之前定义好的数据增强方式,对创建好的数据增强方式进行数据增强。\n", "\n", "模型训练时,设置模型训练的epoch次数为60,然后通过create_ssd_dataset类创建了训练集和验证集。batch_size大小为5,图像尺寸统一调整为300×300。损失函数使用位置损失函数和置信度损失函数的加权和,优化器使用Momentum,并设置初始学习率为0.001。回调函数方面使用了LossMonitor和TimeMonitor来监控训练过程中每个epoch结束后,损失值Loss的变化情况以及每个epoch、每个step的运行时间。设置每训练10个epoch保存一次模型。" ] }, { "cell_type": "code", "execution_count": 10, "id": "3b4bbd84", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import math\n", "import itertools as it\n", "\n", "from mindspore.common import set_seed\n", "\n", "class GeneratDefaultBoxes():\n", " \"\"\"\n", " Generate Default boxes for SSD, follows the order of (W, H, archor_sizes).\n", " `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w].\n", " `self.default_boxes_tlbr` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2].\n", " \"\"\"\n", "\n", " def __init__(self):\n", " fk = 300 / np.array([8, 16, 32, 64, 100, 300])\n", " scale_rate = (0.95 - 0.1) / (len([4, 6, 6, 6, 4, 4]) - 1)\n", " scales = [0.1 + scale_rate * i for i in range(len([4, 6, 6, 6, 4, 4]))] + [1.0]\n", " self.default_boxes = []\n", " for idex, feature_size in enumerate([38, 19, 10, 5, 3, 1]):\n", " sk1 = scales[idex]\n", " sk2 = scales[idex + 1]\n", " sk3 = math.sqrt(sk1 * sk2)\n", " if idex == 0 and not [[2], [2, 3], [2, 3], [2, 3], [2], [2]][idex]:\n", " w, h = sk1 * math.sqrt(2), sk1 / math.sqrt(2)\n", " all_sizes = [(0.1, 0.1), (w, h), (h, w)]\n", " else:\n", " all_sizes = [(sk1, sk1)]\n", " for aspect_ratio in [[2], [2, 3], [2, 3], [2, 3], [2], [2]][idex]:\n", " w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio)\n", " all_sizes.append((w, h))\n", " all_sizes.append((h, w))\n", " all_sizes.append((sk3, sk3))\n", "\n", " assert len(all_sizes) == [4, 6, 6, 6, 4, 4][idex]\n", "\n", " for i, j in it.product(range(feature_size), repeat=2):\n", " for w, h in all_sizes:\n", " cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex]\n", " self.default_boxes.append([cy, cx, h, w])\n", "\n", " def to_tlbr(cy, cx, h, w):\n", " return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2\n", "\n", " # For IoU calculation\n", " self.default_boxes_tlbr = np.array(tuple(to_tlbr(*i) for i in self.default_boxes), dtype='float32')\n", " self.default_boxes = np.array(self.default_boxes, dtype='float32')\n", "\n", "default_boxes_tlbr = GeneratDefaultBoxes().default_boxes_tlbr\n", "default_boxes = GeneratDefaultBoxes().default_boxes\n", "\n", "y1, x1, y2, x2 = np.split(default_boxes_tlbr[:, :4], 4, axis=-1)\n", "vol_anchors = (x2 - x1) * (y2 - y1)\n", "matching_threshold = 0.5" ] }, { "cell_type": "code", "execution_count": 11, "id": "faae8a3f", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from mindspore.common.initializer import initializer, TruncatedNormal\n", "\n", "\n", "def init_net_param(network, initialize_mode='TruncatedNormal'):\n", " \"\"\"Init the parameters in net.\"\"\"\n", " params = network.trainable_params()\n", " for p in params:\n", " if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:\n", " if initialize_mode == 'TruncatedNormal':\n", " p.set_data(initializer(TruncatedNormal(0.02), p.data.shape, p.data.dtype))\n", " else:\n", " p.set_data(initialize_mode, p.data.shape, p.data.dtype)\n", "\n", "\n", "def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):\n", " \"\"\" generate learning rate array\"\"\"\n", " lr_each_step = []\n", " total_steps = steps_per_epoch * total_epochs\n", " warmup_steps = steps_per_epoch * warmup_epochs\n", " for i in range(total_steps):\n", " if i < warmup_steps:\n", " lr = lr_init + (lr_max - lr_init) * i / warmup_steps\n", " else:\n", " lr = lr_end + (lr_max - lr_end) * (1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.\n", " if lr < 0.0:\n", " lr = 0.0\n", " lr_each_step.append(lr)\n", "\n", " current_step = global_step\n", " lr_each_step = np.array(lr_each_step).astype(np.float32)\n", " learning_rate = lr_each_step[current_step:]\n", "\n", " return learning_rate\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "7c739ddf", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "=================== Starting Training =====================\n", "Epoch:[1/60], loss:1365.3849 , time:42.76231384277344s \n", "Epoch:[2/60], loss:1350.9009 , time:43.63900399208069s \n", "Epoch:[3/60], loss:1325.2102 , time:48.01434779167175s \n", "Epoch:[4/60], loss:1297.8125 , time:40.65014576911926s \n", "Epoch:[5/60], loss:1269.7281 , time:40.627623081207275s \n", "Epoch:[6/60], loss:1240.8068 , time:42.14572191238403s \n", "Epoch:[7/60], loss:1210.52 , time:41.091148853302s \n", "Epoch:[8/60], loss:1178.0127 , time:41.88719820976257s \n", "Epoch:[9/60], loss:1142.2338 , time:41.147764444351196s \n", "Epoch:[10/60], loss:1101.929 , time:42.21702218055725s \n", "Epoch:[11/60], loss:1055.7747 , time:40.66824555397034s \n", "Epoch:[12/60], loss:1002.66125 , time:40.70291781425476s \n", "Epoch:[13/60], loss:942.0149 , time:42.10250663757324s \n", "Epoch:[14/60], loss:874.245 , time:41.27074885368347s \n", "Epoch:[15/60], loss:801.06055 , time:40.62501621246338s \n", "Epoch:[16/60], loss:725.4527 , time:41.78050708770752s \n", "Epoch:[17/60], loss:651.15564 , time:40.619580030441284s \n", "Epoch:[18/60], loss:581.7435 , time:41.07759237289429s \n", "Epoch:[19/60], loss:519.85223 , time:41.74708104133606s \n", "Epoch:[20/60], loss:466.71866 , time:40.79696846008301s \n", "Epoch:[21/60], loss:422.35846 , time:40.40337634086609s \n", "Epoch:[22/60], loss:385.95758 , time:41.0706627368927s \n", "Epoch:[23/60], loss:356.3252 , time:41.02973508834839s \n", "Epoch:[24/60], loss:332.2302 , time:41.101938009262085s \n", "Epoch:[25/60], loss:312.56158 , time:40.12760329246521s \n", "Epoch:[26/60], loss:296.3943 , time:40.62085247039795s \n", "Epoch:[27/60], loss:282.99237 , time:42.20474720001221s \n", "Epoch:[28/60], loss:271.7844 , time:40.27843761444092s \n", "Epoch:[29/60], loss:262.32687 , time:40.6625394821167s \n", "Epoch:[30/60], loss:254.28302 , time:41.42288422584534s \n", "Epoch:[31/60], loss:247.38882 , time:40.49200940132141s \n", "Epoch:[32/60], loss:241.44067 , time:41.48827362060547s \n", "Epoch:[33/60], loss:236.28123 , time:41.1355299949646s \n", "Epoch:[34/60], loss:231.78201 , time:40.45781660079956s \n", "Epoch:[35/60], loss:227.84433 , time:40.92684364318848s \n", "Epoch:[36/60], loss:224.38614 , time:40.89856195449829s \n", "Epoch:[37/60], loss:221.34372 , time:41.585039138793945s \n", "Epoch:[38/60], loss:218.66156 , time:40.8972954750061s \n", "Epoch:[39/60], loss:216.29553 , time:42.22093486785889s \n", "Epoch:[40/60], loss:214.20854 , time:40.75188755989075s \n", "Epoch:[41/60], loss:212.36868 , time:41.51768183708191s \n", "Epoch:[42/60], loss:210.74985 , time:40.3460476398468s \n", "Epoch:[43/60], loss:209.32901 , time:40.65240502357483s \n", "Epoch:[44/60], loss:208.08626 , time:41.250218629837036s \n", "Epoch:[45/60], loss:207.00375 , time:40.334686040878296s \n", "Epoch:[46/60], loss:206.06656 , time:40.822086811065674s \n", "Epoch:[47/60], loss:205.2609 , time:40.492422103881836s \n", "Epoch:[48/60], loss:204.57387 , time:41.39555335044861s \n", "Epoch:[49/60], loss:203.9947 , time:40.29546666145325s \n", "Epoch:[50/60], loss:203.51189 , time:39.61115860939026s \n", "Epoch:[51/60], loss:203.11642 , time:41.232492446899414s \n", "Epoch:[52/60], loss:202.79791 , time:40.896180152893066s \n", "Epoch:[53/60], loss:202.54779 , time:40.62282419204712s \n", "Epoch:[54/60], loss:202.35779 , time:40.751235485076904s \n", "Epoch:[55/60], loss:202.2188 , time:41.790447473526s \n", "Epoch:[56/60], loss:202.12277 , time:41.371476888656616s \n", "Epoch:[57/60], loss:202.05978 , time:41.00389575958252s \n", "Epoch:[58/60], loss:202.02513 , time:40.384965658187866s \n", "Epoch:[59/60], loss:202.00772 , time:40.91265916824341s \n", "Epoch:[60/60], loss:201.9999 , time:41.31216502189636s \n", "=================== Training Success =====================\n" ] } ], "source": [ "import time\n", "\n", "from mindspore.amp import DynamicLossScaler\n", "\n", "set_seed(1)\n", "\n", "# load data\n", "mindrecord_dir = \"./datasets/MindRecord_COCO\"\n", "mindrecord_file = \"./datasets/MindRecord_COCO/ssd.mindrecord0\"\n", "\n", "dataset = create_ssd_dataset(mindrecord_file, batch_size=5, rank=0, use_multiprocessing=True)\n", "dataset_size = dataset.get_dataset_size()\n", "\n", "image, get_loc, gt_label, num_matched_boxes = next(dataset.create_tuple_iterator())\n", "\n", "# Network definition and initialization\n", "network = SSD300Vgg16()\n", "init_net_param(network)\n", "\n", "# Define the learning rate\n", "lr = Tensor(get_lr(global_step=0 * dataset_size,\n", " lr_init=0.001, lr_end=0.001 * 0.05, lr_max=0.05,\n", " warmup_epochs=2, total_epochs=60, steps_per_epoch=dataset_size))\n", "\n", "# Define the optimizer\n", "opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), lr,\n", " 0.9, 0.00015, float(1024))\n", "\n", "# Define the forward procedure\n", "def forward_fn(x, gt_loc, gt_label, num_matched_boxes):\n", " pred_loc, pred_label = network(x)\n", " mask = ops.less(0, gt_label).astype(ms.float32)\n", " num_matched_boxes = ops.sum(num_matched_boxes.astype(ms.float32))\n", "\n", " # Positioning loss\n", " mask_loc = ops.tile(ops.expand_dims(mask, -1), (1, 1, 4))\n", " smooth_l1 = nn.SmoothL1Loss()(pred_loc, gt_loc) * mask_loc\n", " loss_loc = ops.sum(ops.sum(smooth_l1, -1), -1)\n", "\n", " # Category loss\n", " loss_cls = class_loss(pred_label, gt_label)\n", " loss_cls = ops.sum(loss_cls, (1, 2))\n", "\n", " return ops.sum((loss_cls + loss_loc) / num_matched_boxes)\n", "\n", "grad_fn = ms.value_and_grad(forward_fn, None, opt.parameters, has_aux=False)\n", "loss_scaler = DynamicLossScaler(1024, 2, 1000)\n", "\n", "# Gradient updates\n", "def train_step(x, gt_loc, gt_label, num_matched_boxes):\n", " loss, grads = grad_fn(x, gt_loc, gt_label, num_matched_boxes)\n", " opt(grads)\n", " return loss\n", "\n", "print(\"=================== Starting Training =====================\")\n", "for epoch in range(60):\n", " network.set_train(True)\n", " begin_time = time.time()\n", " for step, (image, get_loc, gt_label, num_matched_boxes) in enumerate(dataset.create_tuple_iterator()):\n", " loss = train_step(image, get_loc, gt_label, num_matched_boxes)\n", " end_time = time.time()\n", " times = end_time - begin_time\n", " print(f\"Epoch:[{int(epoch + 1)}/{int(60)}], \"\n", " f\"loss:{loss} , \"\n", " f\"time:{times}s \")\n", "ms.save_checkpoint(network, \"ssd-60_9.ckpt\")\n", "print(\"=================== Training Success =====================\")" ] }, { "cell_type": "markdown", "id": "8a978b0f", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 评估\n", "\n", "自定义eval_net()类对训练好的模型进行评估,调用了上述定义的SsdInferWithDecoder类返回预测的坐标及标签,然后分别计算了在不同的IoU阈值、area和maxDets设置下的Average Precision(AP)和Average Recall(AR)。使用COCOMetrics类计算mAP。模型在测试集上的评估指标如下。\n", "\n", "### 精确率(AP)和召回率(AR)的解释\n", "\n", "- TP:IoU>设定的阈值的检测框数量(同一Ground Truth只计算一次)。\n", "\n", "- FP:IoU<=设定的阈值的检测框,或者是检测到同一个GT的多余检测框的数量。\n", "\n", "- FN:没有检测到的GT的数量。\n", "\n", "### 精确率(AP)和召回率(AR)的公式\n", "\n", "- 精确率(Average Precision,AP):\n", "\n", " ![SSD-15](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_15.jpg)\n", "\n", " 精确率是将正样本预测正确的结果与正样本预测的结果和预测错误的结果的和的比值,主要反映出预测结果错误率。\n", "\n", "- 召回率(Average Recall,AR):\n", "\n", " ![SSD-16](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/SSD_16.jpg)\n", "\n", " 召回率是正样本预测正确的结果与正样本预测正确的结果和正样本预测错误的和的比值,主要反映出来的是预测结果中的漏检率。\n", "\n", "### 关于以下代码运行结果的输出指标\n", "\n", "- 第一个值即为mAP(mean Average Precision), 即各类别AP的平均值。\n", "\n", "- 第二个值是iou取0.5的mAP值,是voc的评判标准。\n", "\n", "- 第三个值是评判较为严格的mAP值,可以反应算法框的位置精准程度;中间几个数为物体大小的mAP值。\n", "\n", "对于AR看一下maxDets=10/100的mAR值,反应检出率,如果两者接近,说明对于这个数据集来说,不用检测出100个框,可以提高性能。" ] }, { "cell_type": "code", "execution_count": 13, "id": "61007796", "metadata": { "pycharm": { "name": "#%%\n" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Start Eval!\n", "Load Checkpoint!\n" ] }, { "name": "stdout", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "========================================\n", "\n", "total images num: 9\n", "loading annotations into memory...\n", "Done (t=0.00s)\n", "creating index...\n", "index created!\n", "Loading and preparing results...\n", "DONE (t=0.47s)\n", "creating index...\n", "index created!\n", "Running per image evaluation...\n", "Evaluate annotation type *bbox*\n", "DONE (t=0.97s).\n", "Accumulating evaluation results...\n", "DONE (t=0.20s).\n", " Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.003\n", " Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.006\n", " Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.000\n", " Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000\n", " Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.052\n", " Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.016\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.005\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.037\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.071\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.057\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.328\n", "\n", "========================================\n", "\n", "mAP: 0.0025924737758294216\n" ] } ], "source": [ "mindrecord_file = \"./datasets/MindRecord_COCO/ssd_eval.mindrecord0\"\n", "\n", "def ssd_eval(dataset_path, ckpt_path, anno_json):\n", " \"\"\"SSD evaluation.\"\"\"\n", " batch_size = 1\n", " ds = create_ssd_dataset(dataset_path, batch_size=batch_size,\n", " is_training=False, use_multiprocessing=False)\n", "\n", " network = SSD300Vgg16()\n", " print(\"Load Checkpoint!\")\n", " net = SsdInferWithDecoder(network, Tensor(default_boxes), ckpt_path)\n", "\n", " net.set_train(False)\n", " total = ds.get_dataset_size() * batch_size\n", " print(\"\\n========================================\\n\")\n", " print(\"total images num: \", total)\n", " eval_param_dict = {\"net\": net, \"dataset\": ds, \"anno_json\": anno_json}\n", " mAP = apply_eval(eval_param_dict)\n", " print(\"\\n========================================\\n\")\n", " print(f\"mAP: {mAP}\")\n", "\n", "def eval_net():\n", " print(\"Start Eval!\")\n", " ssd_eval(mindrecord_file, \"./ssd-60_9.ckpt\", anno_json)\n", "\n", "eval_net()" ] }, { "cell_type": "markdown", "id": "1126555c", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 引用\n", "\n", "[1] Liu W, Anguelov D, Erhan D, et al. Ssd: Single shot multibox detector[C]//European conference on computer vision. Springer, Cham, 2016: 21-37.
" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" }, "toc-autonumbering": false, "toc-showmarkdowntxt": true, "toc-showtags": true }, "nbformat": 4, "nbformat_minor": 5 }