# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Visualization for detection/segmentation dataset.
"""
import os
import sys
import importlib
import numpy as np
from mindspore import log as logger
[docs]def imshow_det_bbox(image,
bboxes,
labels,
segm=None,
class_names=None,
score_threshold=0,
bbox_color=(0, 255, 0),
text_color=(203, 192, 255),
mask_color=(128, 0, 128),
thickness=2,
font_size=0.8,
show=True,
win_name="win",
wait_time=2000,
out_file=None
):
"""Draw an image with given bboxes and class labels (with scores).
Args:
image (ndarray): The image to be displayed, shaped (C, H, W) or (H, W, C), formatted RGB.
bboxes (ndarray): Bounding boxes (with scores), shaped (N, 4) or (N, 5),
data should be ordered with (N, x, y, w, h).
labels (ndarray): Labels of bboxes, shaped (N, 1).
segm (ndarray): The segmentation masks of image in M classes, shaped (M, H, W) (Default=None).
class_names (list[str], tuple[str], dict): Names of each class to map label to class name
(Default=None, only display label).
score_threshold (float): Minimum score of bboxes to be shown (Default=0).
bbox_color (tuple(int)): Color of bbox lines.
The tuple of color should be in BGR order (Default=(0, 255 ,0), means 'green').
text_color (tuple(int)): Color of texts.
The tuple of color should be in BGR order (Default=(203, 192, 255), means 'pink').
mask_color (tuple(int)): Color of mask.
The tuple of color should be in BGR order (Default=(128, 0, 128), means 'purple').
thickness (int): Thickness of lines (Default=2).
font_size (int, float): Font size of texts (Default=0.8).
show (bool): Whether to show the image (Default=True).
win_name (str): The window name (Default="win").
wait_time (int): Value of waitKey param (Default=2000, means display interval is 2000ms).
out_file (str, optional): The filename to write the imagee (Default=None). File extension name
is required to indicate the image compression type, e.g. 'jpg', 'png'.
Returns:
ndarray: The image with bboxes drawn on it.
"""
try:
cv2 = importlib.import_module("cv2")
except ModuleNotFoundError:
raise ImportError("import cv2 failed, seems you have to run `pip install opencv-python`.")
# validation
assert isinstance(image, np.ndarray) and image.ndim == 3 and (image.shape[0] == 3 or image.shape[2] == 3),\
"image must be a ndarray in (H, W, C) or (C, H, W) format."
if bboxes is not None:
assert isinstance(bboxes, np.ndarray) and bboxes.ndim == 2 and (bboxes.shape[1] == 4 or bboxes.shape[1] == 5), \
"bboxes must be a ndarray in (N, 4) or (N, 5) format."
assert isinstance(labels, np.ndarray) and labels.ndim == 2 and labels.shape[1] == 1 and \
labels.shape[0] == bboxes.shape[0], "labels must be a ndarray in (N, 1) format and has same N with bboxes."
if segm is not None:
assert isinstance(segm, np.ndarray) and segm.ndim == 3, "segm must be a ndarray in (M, H, W) format."
H, W = (image.shape[0], image.shape[1]) if image.shape[2] == 3 else (image.shape[1], image.shape[2])
assert H == segm.shape[1] and W == segm.shape[2], "segm must has same height and width with image."
if bboxes is not None:
assert bboxes.shape[0] <= segm.shape[0], "number of segm masks must not be less than the number of bboxes."
assert isinstance(class_names, (tuple, list, dict)), "class_names must be a list, tuple or dict."
assert isinstance(bbox_color, tuple) and len(bbox_color) == 3, \
"bbox_color must be a three tuple, formatted (B, G, R)."
assert isinstance(text_color, tuple) and len(text_color) == 3, \
"text_color must be a three tuple, formatted (B, G, R)."
assert isinstance(mask_color, tuple) and len(mask_color) == 3, \
"mask_color must be a three tuple, formatted (B, G, R)."
assert isinstance(thickness, int), "thickness must be a int."
assert thickness >= 0, "thickness must be larger than or equal to zero."
assert isinstance(font_size, (int, float)), "font_size must be a int or float."
assert font_size >= 0, "font_size must be larger than or equal to zero."
assert isinstance(show, bool), "show must be a bool."
assert isinstance(win_name, str), "win_name must be a str."
assert isinstance(wait_time, int), "wait_time must be a int."
assert wait_time >= 0, "wait_time must be larger than or equal to zero."
if out_file is not None:
assert isinstance(out_file, str), "out_file must be a str."
if score_threshold > 0:
assert bboxes.shape[1] == 5
if not show:
assert out_file is not None
# image
if image.shape[0] == 3:
image = image.transpose((1, 2, 0))
draw_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
if bboxes is not None:
bbox_num = bboxes.shape[0]
for i in range(bbox_num):
draw_bbox = bboxes[i]
if len(draw_bbox) > 4:
if draw_bbox[4] < score_threshold:
continue
# bbox
x1, y1 = int(draw_bbox[0]), int(draw_bbox[1])
x2, y2 = int(draw_bbox[0]+draw_bbox[2]), int(draw_bbox[1]+draw_bbox[3])
cv2.rectangle(draw_image, (x1, y1), (x2, y2), bbox_color, thickness)
# label
try:
draw_label = str(class_names[labels[i][0]]) if class_names is not None else f'class {labels[i][0]}'
except (IndexError, KeyError):
draw_label = f'class {labels[i][0]}'
if len(draw_bbox) > 4:
draw_label += f'|{draw_bbox[-1]:.02f}'
cv2.putText(draw_image, draw_label, (x1, y2), cv2.FONT_HERSHEY_SIMPLEX, font_size, text_color, thickness)
if segm is not None:
mask = segm[i].astype(bool)
draw_image[mask] = draw_image[mask] * 0.5 + np.array(mask_color) * 0.5
else:
if segm is not None:
segm_num = segm.shape[0]
for i in range(segm_num):
mask = segm[i].astype(bool)
draw_image[mask] = draw_image[mask] * 0.5 + np.array(mask_color) * 0.5
if show:
cv2.imshow(win_name, draw_image)
if cv2.waitKey(wait_time) == 27:
sys.exit()
if out_file:
logger.info("Saving image file with name: " + out_file + "...")
cv2.imwrite(out_file, draw_image)
os.chmod(out_file, 0o600)
return draw_image