[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/tutorials/source_en/cv/ssd.md) # SSD for Object Detection ## Model Introduction Single shot multibox detector (SSD) is an object detection algorithm proposed by Wei Liu at ECCV 2016. On the VOC 2007 test set using NVIDIA Titan X, the SSD reaches 74.3% mAP(mean Average Precision) and 59 FPS for a 300 x 300 input network. For the 512 x 512 network, the SSD reaches 76.9% mAP, surpassing Faster RCNN (73.2% mAP). For details, see the paper[1]. Mainstream SSD object detection algorithms are classified into the following types: 1. Two-stage method: RCNN series
The candidate boxes are generated through the algorithm, and then classified and regressed.
2. One-stage method: YOLO and SSD
Location information is directly provided through the backbone network and does not need to be generated by region.
SSD is the one-stage object detection algorithm. Feature extraction is performed by using a convolutional neural network, and different feature layers are used for detection output. Therefore, the SSD is a multi-scale detection method. At the feature layer to be detected, a 3 $\times$ 3 convolution is directly used to transform the channel. SSD uses the anchor policy, and anchors with different length-width ratios are preset. Each output feature layer predicts a plurality of detection boxes (4 or 6) based on the anchor. A multi-scale detection method is used. The shallow layer is used to detect small objects, and the deep layer is used to detect large objects. The following figure shows the SSD framework. ![SSD-1](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/SSD_1.png) ### Model Structure The SSD uses VGG-16 as a basic model, and then adds a convolutional layer based on VGG-16 to obtain more feature maps for detection. The following figure shows the SSD network structure. The upper part is the SSD model, and the lower part is the YOLO model. It can be seen that the SSD uses a multi-scale feature map for detection. ![SSD-2](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/SSD_2.jpg)
Comparison of two one-stage object detection algorithms:
The SSD first performs feature extraction continuously through convolution. On a network that needs to detect an object, an output is directly obtained through 3 $\times$ 3 convolution. The number of convolution channels is determined by the number of anchors and the number of classes, which is equal to (Number of anchors x (Number of classes + 4)). The SSD compares the YOLO series object detection methods. The difference is that the SSD obtains the final bounding box through convolution, while the YOLO obtains the one-dimensional vector in the form of full connection for the final output, and disassembles the vector to obtain the final detection box. ### Model Features - Multi-scale object detection As shown in the SSD network structure, the SSD uses multiple feature layers. The sizes of the feature layers are 38 $\times$ 38, 19 $\times$ 19, 10 $\times$ 10, 5 $\times$ 5, 3 $\times$ 3, and 1 $\times$ 1. There are six feature map sizes in total. A large-scale feature map (a feature map in the front) can be used to detect small objects, and a small-scale feature map (a feature map in the rear) can be used to detect large objects. Multi-scale features have been proven highly effective for small object detection (the SSD belongs to intensive detection). - Convolution used for detection Different from YOLO that finally uses a fully-connected layer, the SSD directly uses convolution to extract detection results from different feature maps. For a feature map whose shape is m $\times$ n $\times$ p, only a relatively small convolution kernel such as 3 $\times$ 3 $\times$ p needs to be used to obtain a detection value. - Preset anchor In YOLOv1, a size of an object is directly predicted by a network. In this manner, a length-width ratio and a size of a prediction box are not limited, and training is difficult. In the SSD, a preset anchor (also called default bounding boxes in the SSD paper) is used. The size of the prediction box is fine-tuned under the guidance of anchor. ## Environment Preparation This case is based on MindSpore. Before the experiment, ensure that mindspore, download, pycocotools, and opencv-python have been installed on the local host. ## Data Preparation and Processing The dataset used in this case is COCO 2017. To facilitate data saving and loading, the COCO dataset is converted into the MindRecord format before data reading. The MindSpore Record data format reduces disk I/O and network I/O overheads, improving user experience and performance. First, we need to download the processed COCO dataset in MindRecord format. Run the following code to download and decompress the dataset to a specified path. ```python from download import download dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/ssd_datasets.zip" path = "./" path = download(dataset_url, path, kind="zip", replace=True) ``` ```text Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/ssd_datasets.zip (16.6 MB) file_sizes: 100%|██████████████████████████| 17.4M/17.4M [00:00<00:00, 26.9MB/s] Extracting zip file... Successfully downloaded / unzipped to ./ ``` Then We define some inputs for data processing. ```python coco_root = "./datasets/" anno_json = "./datasets/annotations/instances_val2017.json" train_cls = ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] train_cls_dict = {} for i, cls in enumerate(train_cls): train_cls_dict[cls] = i ``` ### Data Sampling To make the model more robust to various input object sizes and shapes, the SSD algorithm randomly samples each training image using one of the following options: - Use the entire original input image. - Sample a region so that the minimum intersection-over-union ratio between the sampled region and the original image is 0.1, 0.3, 0.5, 0.7, or 0.9. - Randomly sample a region. The size of each sampling region is [0.3,1] of the size of the original image, and the aspect ratio is between 1/2 and 2. If the center of the actual label box is within the sampling region, the overlapping part is retained as the real label box of the new image. After the foregoing sampling steps, each sampling region is adjusted to a fixed size, and is flipped horizontally with a probability of 0.5. ```python import cv2 import numpy as np def _rand(a=0., b=1.): return np.random.rand() * (b - a) + a def intersect(box_a, box_b): """Compute the intersect of two sets of boxes.""" max_yx = np.minimum(box_a[:, 2:4], box_b[2:4]) min_yx = np.maximum(box_a[:, :2], box_b[:2]) inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf) return inter[:, 0] * inter[:, 1] def jaccard_numpy(box_a, box_b): """Compute the jaccard overlap of two sets of boxes.""" inter = intersect(box_a, box_b) area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])) area_b = ((box_b[2] - box_b[0]) * (box_b[3] - box_b[1])) union = area_a + area_b - inter return inter / union def random_sample_crop(image, boxes): """Crop images and boxes randomly.""" height, width, _ = image.shape min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9]) if min_iou is None: return image, boxes for _ in range(50): image_t = image w = _rand(0.3, 1.0) * width h = _rand(0.3, 1.0) * height # aspect ratio constraint b/t .5 & 2 if h / w < 0.5 or h / w > 2: continue left = _rand() * (width - w) top = _rand() * (height - h) rect = np.array([int(top), int(left), int(top + h), int(left + w)]) overlap = jaccard_numpy(boxes, rect) # dropout some boxes drop_mask = overlap > 0 if not drop_mask.any(): continue if overlap[drop_mask].min() < min_iou and overlap[drop_mask].max() > (min_iou + 0.2): continue image_t = image_t[rect[0]:rect[2], rect[1]:rect[3], :] centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0 m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) # mask in that both m1 and m2 are true mask = m1 * m2 * drop_mask # have any valid boxes? try again if not if not mask.any(): continue # take only matching gt boxes boxes_t = boxes[mask, :].copy() boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2]) boxes_t[:, :2] -= rect[:2] boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4]) boxes_t[:, 2:4] -= rect[:2] return image_t, boxes_t return image, boxes def ssd_bboxes_encode(boxes): """Labels anchors with ground truth inputs.""" def jaccard_with_anchors(bbox): """Compute jaccard score a box and the anchors.""" # Intersection bbox and volume. ymin = np.maximum(y1, bbox[0]) xmin = np.maximum(x1, bbox[1]) ymax = np.minimum(y2, bbox[2]) xmax = np.minimum(x2, bbox[3]) w = np.maximum(xmax - xmin, 0.) h = np.maximum(ymax - ymin, 0.) # Volumes. inter_vol = h * w union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol jaccard = inter_vol / union_vol return np.squeeze(jaccard) pre_scores = np.zeros((8732), dtype=np.float32) t_boxes = np.zeros((8732, 4), dtype=np.float32) t_label = np.zeros((8732), dtype=np.int64) for bbox in boxes: label = int(bbox[4]) scores = jaccard_with_anchors(bbox) idx = np.argmax(scores) scores[idx] = 2.0 mask = (scores > matching_threshold) mask = mask & (scores > pre_scores) pre_scores = np.maximum(pre_scores, scores * mask) t_label = mask * label + (1 - mask) * t_label for i in range(4): t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i] index = np.nonzero(t_label) # Transform to tlbr. bboxes = np.zeros((8732, 4), dtype=np.float32) bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2 bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]] # Encode features. bboxes_t = bboxes[index] default_boxes_t = default_boxes[index] bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * 0.1) tmp = np.maximum(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4], 0.000001) bboxes_t[:, 2:4] = np.log(tmp) / 0.2 bboxes[index] = bboxes_t num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32) return bboxes, t_label.astype(np.int32), num_match def preprocess_fn(img_id, image, box, is_training): """Preprocess function for dataset.""" cv2.setNumThreads(2) def _infer_data(image, input_shape): img_h, img_w, _ = image.shape input_h, input_w = input_shape image = cv2.resize(image, (input_w, input_h)) # When the channels of image is 1 if len(image.shape) == 2: image = np.expand_dims(image, axis=-1) image = np.concatenate([image, image, image], axis=-1) return img_id, image, np.array((img_h, img_w), np.float32) def _data_aug(image, box, is_training, image_size=(300, 300)): ih, iw, _ = image.shape h, w = image_size if not is_training: return _infer_data(image, image_size) # Random crop box = box.astype(np.float32) image, box = random_sample_crop(image, box) ih, iw, _ = image.shape # Resize image image = cv2.resize(image, (w, h)) # Flip image or not flip = _rand() < .5 if flip: image = cv2.flip(image, 1, dst=None) # When the channels of image is 1 if len(image.shape) == 2: image = np.expand_dims(image, axis=-1) image = np.concatenate([image, image, image], axis=-1) box[:, [0, 2]] = box[:, [0, 2]] / ih box[:, [1, 3]] = box[:, [1, 3]] / iw if flip: box[:, [1, 3]] = 1 - box[:, [3, 1]] box, label, num_match = ssd_bboxes_encode(box) return image, box, label, num_match return _data_aug(image, box, is_training, image_size=[300, 300]) ``` ### Creating a Dataset ```python from mindspore import Tensor from mindspore.dataset import MindDataset from mindspore.dataset.vision import Decode, HWC2CHW, Normalize, RandomColorAdjust def create_ssd_dataset(mindrecord_file, batch_size=32, device_num=1, rank=0, is_training=True, num_parallel_workers=1, use_multiprocessing=True): """Create SSD dataset with MindDataset.""" dataset = MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num, shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training) decode = Decode() dataset = dataset.map(operations=decode, input_columns=["image"]) change_swap_op = HWC2CHW() # Computed from random subset of ImageNet training images normalize_op = Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) color_adjust_op = RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) compose_map_func = (lambda img_id, image, annotation: preprocess_fn(img_id, image, annotation, is_training)) if is_training: output_columns = ["image", "box", "label", "num_match"] trans = [color_adjust_op, normalize_op, change_swap_op] else: output_columns = ["img_id", "image", "image_shape"] trans = [normalize_op, change_swap_op] dataset = dataset.map(operations=compose_map_func, input_columns=["img_id", "image", "annotation"], output_columns=output_columns, python_multiprocessing=use_multiprocessing, num_parallel_workers=num_parallel_workers) dataset = dataset.map(operations=trans, input_columns=["image"], python_multiprocessing=use_multiprocessing, num_parallel_workers=num_parallel_workers) dataset = dataset.batch(batch_size, drop_remainder=True) return dataset ``` ## Model Building The SSD network structure consists of the following parts: ![SSD-3](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/SSD_3.jpg) - VGG16 Base Layer - Extra Feature Layer - Detection Layer - NMS - Anchor ### Backbone Layer ![SSD-4](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_en/cv/images/SSD_4.png) After being preprocessed, the size of the input image is fixed at 300 x 300. The image passes through the backbone first. In this case, the first 13 convolutional layers of the VGG-16 network are used. Then, the fully-connected layers fc6 and fc7 of VGG-16 are respectively converted into 3 $\times$ 3 convolutional layer block 6 and 1 $\times$ 1 convolutional layer block 7, and features are further extracted. In block 6, a dilated convolution with 6 dilations is used, and padding of the dilated convolution is also 6. This is to increase a receptive field and keep the parameter quantity and the feature map size unchanged. ### Extra Feature Layer On the basis of VGG-16, the SSD further adds four deep convolutional layers to extract higher-layer semantic information: ![SSD-5](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/SSD_5.png) Blocks 8 to 11 are used to extract higher semantic information. The number of channels in block 8 is 512, and the number of channels in block 9, block 10, and block 11 is 256. From block 7 to block 11, sizes of the five convolutional output feature maps are 19 x 19, 10 x 10, 5 x 5, 3 x 3, and 1 x 1 in sequence. To reduce the number of parameters, 1 x 1 convolution is used to reduce the number of channels to half of the number of output channels at this layer, and then 3 x 3 convolution is used for feature extraction. ### Anchor The SSD uses the PriorBox to generate regions. The PriorBox with a fixed width and height is used as the prior region of interest, and a stage is used to complete classification and regression. A large number of dense PriorBoxes are designed to ensure that each region of the entire image is detected one by one. The PriorBox location is represented by the coordinates of the center point and the width and height of the box (cx,cy,w,h), and is converted into a percentage. PriorBox generation rule: The SSD uses six feature layers to detect objects. At different feature layers, the scale of the PriorBox is different. The scale of the lowest layer is 0.1, and the scale of the highest layer is 0.95. The calculation formulas for other layers are as follows: ![SSD-6](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/SSD_6.jpg) If the scale of a feature layer is fixed, PriorBox with different aspect ratios is set. The length and width of PriorBox are calculated as follows: ![SSD-7](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/SSD_7.jpg) When ratio is 1, a PriorBox (length-width ratio=1) of a specific scale is calculated based on the feature layer and the next feature layer. The calculation formula is as follows: ![SSD-8](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/SSD_8.jpg) PriorBox is generated for each point at each feature layer based on the preceding rules. (cx,cy) is determined by the current center point. Therefore, a large number of dense PriorBoxes are generated at each feature layer, as shown in the following figure. ![SSD-9](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/SSD_9.png) The SSD uses feature maps obtained by using six convolutional layers: the fourth, the seventh, the eighth, the ninth, the tenth, and the eleventh layers. Sizes of the six feature maps become smaller, and receptive fields corresponding to the six feature maps become larger. Each point on the six feature maps corresponds to 4, 6, 6, 6, 4, and 4 PriorBoxes respectively. Coordinates of a point in a feature map may be obtained based on a downsampling rate in the original image. Four or six PriorBoxes of different sizes are generated by using the coordinates as a center. Then, a prediction amount of a class and a location corresponding to each PriorBox is predicted by using a feature of the feature map. For example, the size of the feature map obtained by the eighth convolutional layer is 10 x 10 x 512, each point corresponds to six PriorBoxes, and there are 600 PriorBoxes in total. The MultiBox class is defined to generate multiple prediction boxes. ### Detection Layer ![SSD-10](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_en/cv/images/SSD_10.jpg) The SSD model has six prediction feature maps in total. For a prediction feature map whose size is m\*n and channel is p, it is assumed that each pixel of the prediction feature map generates k anchors, and each anchor corresponds to c classes and four regression offsets. A convolution operation is performed on the prediction feature map by using (4+c)k convolution kernels whose sizes are 3x3 and channel is p, to obtain an output feature map whose sizes are m\*n and channel is (4+c)m\*k. It contains the regression offset and probability scores of each anchor generated on the prediction feature map. Therefore, for a prediction feature map whose size is m\*n, a total of (4+c)k\*m\*n results are generated. The number of output channels of the cls branch is k\*class_num, and the number of output channels of the loc branch is k\*4. ```python from mindspore import nn def _make_layer(channels): in_channels = channels[0] layers = [] for out_channels in channels[1:]: layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3)) layers.append(nn.ReLU()) in_channels = out_channels return nn.SequentialCell(layers) class Vgg16(nn.Cell): """VGG16 module.""" def __init__(self): super(Vgg16, self).__init__() self.b1 = _make_layer([3, 64, 64]) self.b2 = _make_layer([64, 128, 128]) self.b3 = _make_layer([128, 256, 256, 256]) self.b4 = _make_layer([256, 512, 512, 512]) self.b5 = _make_layer([512, 512, 512, 512]) self.m1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME') self.m2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME') self.m3 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME') self.m4 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME') self.m5 = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode='SAME') def construct(self, x): # block1 x = self.b1(x) x = self.m1(x) # block2 x = self.b2(x) x = self.m2(x) # block3 x = self.b3(x) x = self.m3(x) # block4 x = self.b4(x) block4 = x x = self.m4(x) # block5 x = self.b5(x) x = self.m5(x) return block4, x ``` ```python import mindspore as ms import mindspore.nn as nn import mindspore.ops as ops def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0): in_channels = in_channel out_channels = in_channel depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', padding=pad, group=in_channels) conv = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, pad_mode='same', has_bias=True) bn = nn.BatchNorm2d(in_channel, eps=1e-3, momentum=0.97, gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) return nn.SequentialCell([depthwise_conv, bn, nn.ReLU6(), conv]) class FlattenConcat(nn.Cell): """FlattenConcat module.""" def __init__(self): super(FlattenConcat, self).__init__() self.num_ssd_boxes = 8732 def construct(self, inputs): output = () batch_size = ops.shape(inputs[0])[0] for x in inputs: x = ops.transpose(x, (0, 2, 3, 1)) output += (ops.reshape(x, (batch_size, -1)),) res = ops.concat(output, axis=1) return ops.reshape(res, (batch_size, self.num_ssd_boxes, -1)) class MultiBox(nn.Cell): """ Multibox conv layers. Each multibox layer contains class conf scores and localization predictions. """ def __init__(self): super(MultiBox, self).__init__() num_classes = 81 out_channels = [512, 1024, 512, 256, 256, 256] num_default = [4, 6, 6, 6, 4, 4] loc_layers = [] cls_layers = [] for k, out_channel in enumerate(out_channels): loc_layers += [_last_conv2d(out_channel, 4 * num_default[k], kernel_size=3, stride=1, pad_mod='same', pad=0)] cls_layers += [_last_conv2d(out_channel, num_classes * num_default[k], kernel_size=3, stride=1, pad_mod='same', pad=0)] self.multi_loc_layers = nn.CellList(loc_layers) self.multi_cls_layers = nn.CellList(cls_layers) self.flatten_concat = FlattenConcat() def construct(self, inputs): loc_outputs = () cls_outputs = () for i in range(len(self.multi_loc_layers)): loc_outputs += (self.multi_loc_layers[i](inputs[i]),) cls_outputs += (self.multi_cls_layers[i](inputs[i]),) return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs) class SSD300Vgg16(nn.Cell): """SSD300Vgg16 module.""" def __init__(self): super(SSD300Vgg16, self).__init__() # VGG16 backbone: block1~5 self.backbone = Vgg16() # SSD blocks: block6~7 self.b6_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6, pad_mode='pad') self.b6_2 = nn.Dropout(p=0.5) self.b7_1 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1) self.b7_2 = nn.Dropout(p=0.5) # Extra Feature Layers: block8~11 self.b8_1 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, padding=1, pad_mode='pad') self.b8_2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, pad_mode='valid') self.b9_1 = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1, padding=1, pad_mode='pad') self.b9_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, pad_mode='valid') self.b10_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1) self.b10_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, pad_mode='valid') self.b11_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1) self.b11_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, pad_mode='valid') # boxes self.multi_box = MultiBox() def construct(self, x): # VGG16 backbone: block1~5 block4, x = self.backbone(x) # SSD blocks: block6~7 x = self.b6_1(x) # 1024 x = self.b6_2(x) x = self.b7_1(x) # 1024 x = self.b7_2(x) block7 = x # Extra Feature Layers: block8~11 x = self.b8_1(x) # 256 x = self.b8_2(x) # 512 block8 = x x = self.b9_1(x) # 128 x = self.b9_2(x) # 256 block9 = x x = self.b10_1(x) # 128 x = self.b10_2(x) # 256 block10 = x x = self.b11_1(x) # 128 x = self.b11_2(x) # 256 block11 = x # boxes multi_feature = (block4, block7, block8, block9, block10, block11) pred_loc, pred_label = self.multi_box(multi_feature) if not self.training: pred_label = ops.sigmoid(pred_label) pred_loc = pred_loc.astype(ms.float32) pred_label = pred_label.astype(ms.float32) return pred_loc, pred_label ``` ## Loss Function The object function of the SSD algorithm is divided into two parts: calculating a confidence loss (conf) between a corresponding preselection box and a target category and a corresponding location loss (loc): ![SSD-11](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/SSD_11.jpg) In the preceding information:
N indicates the number of positive samples in the prior box.
c indicates the prediction value of class confidence.
l indicates the location prediction value of the corresponding bounding box of the prior box.
g indicates the location parameter of the ground truth.
α is used to adjust the ratio of confidence loss to location loss. The default value is **1**. ### Location Loss Function Smooth L1 Loss is used for all positive samples. The location information is encoded. ![SSD-12](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/SSD_12.jpg) ### Confidence Loss Function The confidence loss is the softmax loss on multi-class confidence (c). ![SSD-13](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_en/cv/images/SSD_13.jpg) ```python def class_loss(logits, label): """Calculate category losses.""" label = ops.one_hot(label, ops.shape(logits)[-1], Tensor(1.0, ms.float32), Tensor(0.0, ms.float32)) weight = ops.ones_like(logits) pos_weight = ops.ones_like(logits) sigmiod_cross_entropy = ops.binary_cross_entropy_with_logits(logits, label, weight.astype(ms.float32), pos_weight.astype(ms.float32)) sigmoid = ops.sigmoid(logits) label = label.astype(ms.float32) p_t = label * sigmoid + (1 - label) * (1 - sigmoid) modulating_factor = ops.pow(1 - p_t, 2.0) alpha_weight_factor = label * 0.75 + (1 - label) * (1 - 0.75) focal_loss = modulating_factor * alpha_weight_factor * sigmiod_cross_entropy return focal_loss ``` ## Metrics In SSD, non-maximum suppression (NMS) is not required during training. However, during inspection, for example, when an image is input and a box is required, the NMS needs to be used to filter out the prediction boxes that overlap with each other.
The NMS process is as follows: 1. Sort by confidence score. 2. Select the ratio bounding box with the highest confidence to add to the final output list and remove it from the bounding box list.
3. Calculate the region of all bounding boxes.
4. Calculate the IoU of the bounding box with the highest confidence and other candidate boxes.
5. Delete the bounding box where the IoU is greater than the threshold.
6. Repeat the preceding steps until the bounding box list is empty.
```python import json from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval def apply_eval(eval_param_dict): net = eval_param_dict["net"] net.set_train(False) ds = eval_param_dict["dataset"] anno_json = eval_param_dict["anno_json"] coco_metrics = COCOMetrics(anno_json=anno_json, classes=train_cls, num_classes=81, max_boxes=100, nms_threshold=0.6, min_score=0.1) for data in ds.create_dict_iterator(output_numpy=True, num_epochs=1): img_id = data['img_id'] img_np = data['image'] image_shape = data['image_shape'] output = net(Tensor(img_np)) for batch_idx in range(img_np.shape[0]): pred_batch = { "boxes": output[0].asnumpy()[batch_idx], "box_scores": output[1].asnumpy()[batch_idx], "img_id": int(np.squeeze(img_id[batch_idx])), "image_shape": image_shape[batch_idx] } coco_metrics.update(pred_batch) eval_metrics = coco_metrics.get_metrics() return eval_metrics def apply_nms(all_boxes, all_scores, thres, max_boxes): """Apply NMS to bboxes.""" y1 = all_boxes[:, 0] x1 = all_boxes[:, 1] y2 = all_boxes[:, 2] x2 = all_boxes[:, 3] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = all_scores.argsort()[::-1] keep = [] while order.size > 0: i = order[0] keep.append(i) if len(keep) >= max_boxes: break xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) w = np.maximum(0.0, xx2 - xx1 + 1) h = np.maximum(0.0, yy2 - yy1 + 1) inter = w * h ovr = inter / (areas[i] + areas[order[1:]] - inter) inds = np.where(ovr <= thres)[0] order = order[inds + 1] return keep class COCOMetrics: """Calculate mAP of predicted bboxes.""" def __init__(self, anno_json, classes, num_classes, min_score, nms_threshold, max_boxes): self.num_classes = num_classes self.classes = classes self.min_score = min_score self.nms_threshold = nms_threshold self.max_boxes = max_boxes self.val_cls_dict = {i: cls for i, cls in enumerate(classes)} self.coco_gt = COCO(anno_json) cat_ids = self.coco_gt.loadCats(self.coco_gt.getCatIds()) self.class_dict = {cat['name']: cat['id'] for cat in cat_ids} self.predictions = [] self.img_ids = [] def update(self, batch): pred_boxes = batch['boxes'] box_scores = batch['box_scores'] img_id = batch['img_id'] h, w = batch['image_shape'] final_boxes = [] final_label = [] final_score = [] self.img_ids.append(img_id) for c in range(1, self.num_classes): class_box_scores = box_scores[:, c] score_mask = class_box_scores > self.min_score class_box_scores = class_box_scores[score_mask] class_boxes = pred_boxes[score_mask] * [h, w, h, w] if score_mask.any(): nms_index = apply_nms(class_boxes, class_box_scores, self.nms_threshold, self.max_boxes) class_boxes = class_boxes[nms_index] class_box_scores = class_box_scores[nms_index] final_boxes += class_boxes.tolist() final_score += class_box_scores.tolist() final_label += [self.class_dict[self.val_cls_dict[c]]] * len(class_box_scores) for loc, label, score in zip(final_boxes, final_label, final_score): res = {} res['image_id'] = img_id res['bbox'] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]] res['score'] = score res['category_id'] = label self.predictions.append(res) def get_metrics(self): with open('predictions.json', 'w') as f: json.dump(self.predictions, f) coco_dt = self.coco_gt.loadRes('predictions.json') E = COCOeval(self.coco_gt, coco_dt, iouType='bbox') E.params.imgIds = self.img_ids E.evaluate() E.accumulate() E.summarize() return E.stats[0] class SsdInferWithDecoder(nn.Cell): """ SSD Infer wrapper to decode the bbox locations.""" def __init__(self, network, default_boxes, ckpt_path): super(SsdInferWithDecoder, self).__init__() param_dict = ms.load_checkpoint(ckpt_path) ms.load_param_into_net(network, param_dict) self.network = network self.default_boxes = default_boxes self.prior_scaling_xy = 0.1 self.prior_scaling_wh = 0.2 def construct(self, x): pred_loc, pred_label = self.network(x) default_bbox_xy = self.default_boxes[..., :2] default_bbox_wh = self.default_boxes[..., 2:] pred_xy = pred_loc[..., :2] * self.prior_scaling_xy * default_bbox_wh + default_bbox_xy pred_wh = ops.exp(pred_loc[..., 2:] * self.prior_scaling_wh) * default_bbox_wh pred_xy_0 = pred_xy - pred_wh / 2.0 pred_xy_1 = pred_xy + pred_wh / 2.0 pred_xy = ops.concat((pred_xy_0, pred_xy_1), -1) pred_xy = ops.maximum(pred_xy, 0) pred_xy = ops.minimum(pred_xy, 1) return pred_xy, pred_label ``` ## Training Process ### (1) Prior box matching During training, you need to determine the prior box to which the ground truth in the training image matches. The bounding box corresponding to the matched prior box is responsible for predicting the ground truth. The principles for matching the prior box of the SSD with the ground truth are as follows: 1. For each ground truth in the image, find the prior box with the largest IoU. The prior box matches the ground truth. In this way, each ground truth must match a prior box. Generally, a prior box that matches the ground truth is referred to as a positive sample. On the contrary, if a prior box does not match any ground truth, the prior box can only match the background, and is a negative sample. 2. For the remaining unmatched prior boxes, if the IoU of a ground truth is greater than a threshold (generally 0.5), the prior box is also matched with the ground truth. Although a ground truth can match multiple prior boxes, there are too few ground truths compared with prior boxes. Therefore, there are many negative samples compared with positive samples. To ensure that positive and negative samples are balanced, the SSD uses hard negative mining, that is, negative samples are sampled. During sampling, negative samples are sorted in descending order based on confidence loss (a smaller confidence of the prediction background indicates a larger loss). The top-k samples with a larger loss are selected as negative samples for training, to ensure that the ratio of positive samples to negative samples is close to 1:3. Notes: 1. Generally, a prior box that matches the ground truth is referred to as a positive sample. On the contrary, if a prior box does not match any ground truth, the prior box is referred to as a negative sample. 2. A ground truth can match multiple prior boxes, but each prior box can match only one ground truth. 3. If the IoUs of multiple ground truths and a prior box are all greater than the threshold, the prior box is matched only with the largest IoU. ![SSD-14](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_en/cv/images/SSD_14.jpg) As shown in the preceding figure, the basic idea of matching prior boxes and ground truth boxes during training is as follows: Each prior box is regressed to the ground truth box. The control of this process requires the help of the loss layer. The loss layer calculates the error between the actual value and the prediction value to guide the learning direction. ### (2) Loss function The loss function uses the weighted sum of the position loss function and the confidence loss function mentioned above. ### (3) Data augmentation Use the previously defined data augmentation mode to perform data augmentation on the created data augmentation mode. During model training, the number of epochs for model training is set to 60, and the training set and validation set are created using the create_ssd_dataset class. The value of **batch_size** is **5**, and the image size is adjusted to 300 x 300. The loss functions are the weighted sum of the position loss function and the confidence loss function. The optimizer is Momentum, and the initial learning rate is set to **0.001**. The callback functions LossMonitor and TimeMonitor are used to monitor the change of the loss value after each epoch ends during training and the running time of each epoch and step. The model is saved once every 10 epochs are trained. ```python import math import itertools as it from mindspore import set_seed class GeneratDefaultBoxes(): """ Generate Default boxes for SSD, follows the order of (W, H, archor_sizes). `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w]. `self.default_boxes_tlbr` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2]. """ def __init__(self): fk = 300 / np.array([8, 16, 32, 64, 100, 300]) scale_rate = (0.95 - 0.1) / (len([4, 6, 6, 6, 4, 4]) - 1) scales = [0.1 + scale_rate * i for i in range(len([4, 6, 6, 6, 4, 4]))] + [1.0] self.default_boxes = [] for idex, feature_size in enumerate([38, 19, 10, 5, 3, 1]): sk1 = scales[idex] sk2 = scales[idex + 1] sk3 = math.sqrt(sk1 * sk2) if idex == 0 and not [[2], [2, 3], [2, 3], [2, 3], [2], [2]][idex]: w, h = sk1 * math.sqrt(2), sk1 / math.sqrt(2) all_sizes = [(0.1, 0.1), (w, h), (h, w)] else: all_sizes = [(sk1, sk1)] for aspect_ratio in [[2], [2, 3], [2, 3], [2, 3], [2], [2]][idex]: w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio) all_sizes.append((w, h)) all_sizes.append((h, w)) all_sizes.append((sk3, sk3)) assert len(all_sizes) == [4, 6, 6, 6, 4, 4][idex] for i, j in it.product(range(feature_size), repeat=2): for w, h in all_sizes: cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex] self.default_boxes.append([cy, cx, h, w]) def to_tlbr(cy, cx, h, w): return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2 # For IoU calculation self.default_boxes_tlbr = np.array(tuple(to_tlbr(*i) for i in self.default_boxes), dtype='float32') self.default_boxes = np.array(self.default_boxes, dtype='float32') default_boxes_tlbr = GeneratDefaultBoxes().default_boxes_tlbr default_boxes = GeneratDefaultBoxes().default_boxes y1, x1, y2, x2 = np.split(default_boxes_tlbr[:, :4], 4, axis=-1) vol_anchors = (x2 - x1) * (y2 - y1) matching_threshold = 0.5 ``` ```python from mindspore.common.initializer import initializer, TruncatedNormal def init_net_param(network, initialize_mode='TruncatedNormal'): """Init the parameters in net.""" params = network.trainable_params() for p in params: if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: if initialize_mode == 'TruncatedNormal': p.set_data(initializer(TruncatedNormal(0.02), p.data.shape, p.data.dtype)) else: p.set_data(initialize_mode, p.data.shape, p.data.dtype) def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): """ generate learning rate array""" lr_each_step = [] total_steps = steps_per_epoch * total_epochs warmup_steps = steps_per_epoch * warmup_epochs for i in range(total_steps): if i < warmup_steps: lr = lr_init + (lr_max - lr_init) * i / warmup_steps else: lr = lr_end + (lr_max - lr_end) * (1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2. if lr < 0.0: lr = 0.0 lr_each_step.append(lr) current_step = global_step lr_each_step = np.array(lr_each_step).astype(np.float32) learning_rate = lr_each_step[current_step:] return learning_rate ``` ```python import time from mindspore.amp import DynamicLossScaler set_seed(1) # load data mindrecord_dir = "./datasets/MindRecord_COCO" mindrecord_file = "./datasets/MindRecord_COCO/ssd.mindrecord0" dataset = create_ssd_dataset(mindrecord_file, batch_size=5, rank=0, use_multiprocessing=True) dataset_size = dataset.get_dataset_size() image, get_loc, gt_label, num_matched_boxes = next(dataset.create_tuple_iterator()) # Network definition and initialization network = SSD300Vgg16() init_net_param(network) # Define the learning rate lr = Tensor(get_lr(global_step=0 * dataset_size, lr_init=0.001, lr_end=0.001 * 0.05, lr_max=0.05, warmup_epochs=2, total_epochs=60, steps_per_epoch=dataset_size)) # Define the optimizer opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), lr, 0.9, 0.00015, float(1024)) # Define the forward procedure def forward_fn(x, gt_loc, gt_label, num_matched_boxes): pred_loc, pred_label = network(x) mask = ops.less(0, gt_label).astype(ms.float32) num_matched_boxes = ops.sum(num_matched_boxes.astype(ms.float32)) # Positioning loss mask_loc = ops.tile(ops.expand_dims(mask, -1), (1, 1, 4)) smooth_l1 = nn.SmoothL1Loss()(pred_loc, gt_loc) * mask_loc loss_loc = ops.sum(ops.sum(smooth_l1, -1), -1) # Category loss loss_cls = class_loss(pred_label, gt_label) loss_cls = ops.sum(loss_cls, (1, 2)) return ops.sum((loss_cls + loss_loc) / num_matched_boxes) grad_fn = ms.value_and_grad(forward_fn, None, opt.parameters, has_aux=False) loss_scaler = DynamicLossScaler(1024, 2, 1000) # Gradient updates def train_step(x, gt_loc, gt_label, num_matched_boxes): loss, grads = grad_fn(x, gt_loc, gt_label, num_matched_boxes) opt(grads) return loss print("=================== Starting Training =====================") iterator = dataset.create_tuple_iterator(num_epochs=60) for epoch in range(60): network.set_train(True) begin_time = time.time() for step, (image, get_loc, gt_label, num_matched_boxes) in enumerate(iterator): loss = train_step(image, get_loc, gt_label, num_matched_boxes) end_time = time.time() times = end_time - begin_time print(f"Epoch:[{int(epoch + 1)}/{int(60)}], " f"loss:{loss} , " f"time:{times}s ") ms.save_checkpoint(network, "ssd-60_9.ckpt") print("=================== Training Success =====================") ``` ```text =================== Starting Training ===================== Epoch:[1/60], loss:1365.3849 , time:42.76231384277344s Epoch:[2/60], loss:1350.9009 , time:43.63900399208069s Epoch:[3/60], loss:1325.2102 , time:48.01434779167175s Epoch:[4/60], loss:1297.8125 , time:40.65014576911926s Epoch:[5/60], loss:1269.7281 , time:40.627623081207275s Epoch:[6/60], loss:1240.8068 , time:42.14572191238403s Epoch:[7/60], loss:1210.52 , time:41.091148853302s Epoch:[8/60], loss:1178.0127 , time:41.88719820976257s Epoch:[9/60], loss:1142.2338 , time:41.147764444351196s Epoch:[10/60], loss:1101.929 , time:42.21702218055725s Epoch:[11/60], loss:1055.7747 , time:40.66824555397034s Epoch:[12/60], loss:1002.66125 , time:40.70291781425476s Epoch:[13/60], loss:942.0149 , time:42.10250663757324s Epoch:[14/60], loss:874.245 , time:41.27074885368347s Epoch:[15/60], loss:801.06055 , time:40.62501621246338s Epoch:[16/60], loss:725.4527 , time:41.78050708770752s Epoch:[17/60], loss:651.15564 , time:40.619580030441284s Epoch:[18/60], loss:581.7435 , time:41.07759237289429s Epoch:[19/60], loss:519.85223 , time:41.74708104133606s Epoch:[20/60], loss:466.71866 , time:40.79696846008301s Epoch:[21/60], loss:422.35846 , time:40.40337634086609s Epoch:[22/60], loss:385.95758 , time:41.0706627368927s Epoch:[23/60], loss:356.3252 , time:41.02973508834839s Epoch:[24/60], loss:332.2302 , time:41.101938009262085s Epoch:[25/60], loss:312.56158 , time:40.12760329246521s Epoch:[26/60], loss:296.3943 , time:40.62085247039795s Epoch:[27/60], loss:282.99237 , time:42.20474720001221s Epoch:[28/60], loss:271.7844 , time:40.27843761444092s Epoch:[29/60], loss:262.32687 , time:40.6625394821167s Epoch:[30/60], loss:254.28302 , time:41.42288422584534s Epoch:[31/60], loss:247.38882 , time:40.49200940132141s Epoch:[32/60], loss:241.44067 , time:41.48827362060547s Epoch:[33/60], loss:236.28123 , time:41.1355299949646s Epoch:[34/60], loss:231.78201 , time:40.45781660079956s Epoch:[35/60], loss:227.84433 , time:40.92684364318848s Epoch:[36/60], loss:224.38614 , time:40.89856195449829s Epoch:[37/60], loss:221.34372 , time:41.585039138793945s Epoch:[38/60], loss:218.66156 , time:40.8972954750061s Epoch:[39/60], loss:216.29553 , time:42.22093486785889s Epoch:[40/60], loss:214.20854 , time:40.75188755989075s Epoch:[41/60], loss:212.36868 , time:41.51768183708191s Epoch:[42/60], loss:210.74985 , time:40.3460476398468s Epoch:[43/60], loss:209.32901 , time:40.65240502357483s Epoch:[44/60], loss:208.08626 , time:41.250218629837036s Epoch:[45/60], loss:207.00375 , time:40.334686040878296s Epoch:[46/60], loss:206.06656 , time:40.822086811065674s Epoch:[47/60], loss:205.2609 , time:40.492422103881836s Epoch:[48/60], loss:204.57387 , time:41.39555335044861s Epoch:[49/60], loss:203.9947 , time:40.29546666145325s Epoch:[50/60], loss:203.51189 , time:39.61115860939026s Epoch:[51/60], loss:203.11642 , time:41.232492446899414s Epoch:[52/60], loss:202.79791 , time:40.896180152893066s Epoch:[53/60], loss:202.54779 , time:40.62282419204712s Epoch:[54/60], loss:202.35779 , time:40.751235485076904s Epoch:[55/60], loss:202.2188 , time:41.790447473526s Epoch:[56/60], loss:202.12277 , time:41.371476888656616s Epoch:[57/60], loss:202.05978 , time:41.00389575958252s Epoch:[58/60], loss:202.02513 , time:40.384965658187866s Epoch:[59/60], loss:202.00772 , time:40.91265916824341s Epoch:[60/60], loss:201.9999 , time:41.31216502189636s =================== Training Success ===================== ``` ## Evaluation Customize the eval_net() class to evaluate the trained model and invoke the SsdInferWithDecoder class to return the predicted coordinates and labels. The average precision (AP) and average recall (AR) are then calculated for different IoU threshold, area, and maxDets settings. Use the COCOMetrics class to calculate mAP. The evaluation metrics of the model on the test set are as follows: ### AP and AR Explanations - TP: IoU > Number of detection boxes with the specified threshold (The same ground truth is calculated only once.) - FP: IoU ≤ Number of detection boxes with the specified threshold, or the number of redundant detection boxes with the same ground truth. - FN: The number of GTs that are not detected. ### AP and AR Formulas - Average precision (AP): ![SSD-15](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/SSD_15.jpg) The AP is the ratio of the correct prediction result of positive samples to the sum of the prediction result of the positive samples and the incorrect prediction result, and mainly reflects an error rate of a prediction result. - Average recall (AR): ![SSD-16](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/SSD_16.jpg) The AR is the ratio of the correct prediction result of positive samples to the sum of the correct prediction result of positive samples and the incorrect prediction result of positive samples. The AR mainly reflects the missing detection rate in the prediction result. ### Output Metrics for the Following Code Running Results - The first value is the mAP(mean average precision), that is, the average value of APs of each class. - The second value is the mAP value when IoU is set to 0.5, which is the evaluation standard of VOC. - The third value is the mAP value that is strictly evaluated, which can reflect the position accuracy of the algorithm box. The middle values are the mAP values of the object size. For the AR, check the mAR value when maxDets is 10/100 to reflect the detection rate. If the two values are close, it indicates that 100 boxes do not need to be detected for the dataset, which can improve the performance. ```python mindrecord_file = "./datasets/MindRecord_COCO/ssd_eval.mindrecord0" def ssd_eval(dataset_path, ckpt_path, anno_json): """SSD evaluation.""" batch_size = 1 ds = create_ssd_dataset(dataset_path, batch_size=batch_size, is_training=False, use_multiprocessing=False) network = SSD300Vgg16() print("Load Checkpoint!") net = SsdInferWithDecoder(network, Tensor(default_boxes), ckpt_path) net.set_train(False) total = ds.get_dataset_size() * batch_size print("\n========================================\n") print("total images num: ", total) eval_param_dict = {"net": net, "dataset": ds, "anno_json": anno_json} mAP = apply_eval(eval_param_dict) print("\n========================================\n") print(f"mAP: {mAP}") def eval_net(): print("Start Eval!") ssd_eval(mindrecord_file, "./ssd-60_9.ckpt", anno_json) eval_net() ``` ```text Start Eval! Load Checkpoint! ======================================== total images num: 9 loading annotations into memory... Done (t=0.00s) creating index... index created! Loading and preparing results... DONE (t=0.47s) creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.97s). Accumulating evaluation results... DONE (t=0.20s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.003 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.006 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.052 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.016 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.005 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.037 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.071 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.057 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.328 ======================================== mAP: 0.0025924737758294216 ``` ## Reference [1] Liu W, Anguelov D, Erhan D, et al. Ssd: Single shot multibox detector[C]//European conference on computer vision. Springer, Cham, 2016: 21-37.