Source code for mindspore.dataset.utils.browse_dataset

# Copyright 2021 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Visualization for detection/segmentation dataset.
import os
import sys
import importlib
import numpy as np
from mindspore import log as logger

[docs]def imshow_det_bbox(image, bboxes, labels, segm=None, class_names=None, score_threshold=0, bbox_color=(0, 255, 0), text_color=(203, 192, 255), mask_color=(128, 0, 128), thickness=2, font_size=0.8, show=True, win_name="win", wait_time=2000, out_file=None ): """Draw an image with given bboxes and class labels (with scores). Args: image (ndarray): The image to be displayed, shaped (C, H, W) or (H, W, C), formatted RGB. bboxes (ndarray): Bounding boxes (with scores), shaped (N, 4) or (N, 5), data should be ordered with (N, x, y, w, h). labels (ndarray): Labels of bboxes, shaped (N, 1). segm (ndarray): The segmentation masks of image in M classes, shaped (M, H, W) (Default=None). class_names (list[str], tuple[str], dict): Names of each class to map label to class name (Default=None, only display label). score_threshold (float): Minimum score of bboxes to be shown (Default=0). bbox_color (tuple(int)): Color of bbox lines. The tuple of color should be in BGR order (Default=(0, 255 ,0), means 'green'). text_color (tuple(int)): Color of texts. The tuple of color should be in BGR order (Default=(203, 192, 255), means 'pink'). mask_color (tuple(int)): Color of mask. The tuple of color should be in BGR order (Default=(128, 0, 128), means 'purple'). thickness (int): Thickness of lines (Default=2). font_size (int, float): Font size of texts (Default=0.8). show (bool): Whether to show the image (Default=True). win_name (str): The window name (Default="win"). wait_time (int): Value of waitKey param (Default=2000, means display interval is 2000ms). out_file (str, optional): The filename to write the imagee (Default=None). File extension name is required to indicate the image compression type, e.g. 'jpg', 'png'. Returns: ndarray: The image with bboxes drawn on it. """ try: cv2 = importlib.import_module("cv2") except ModuleNotFoundError: raise ImportError("import cv2 failed, seems you have to run `pip install opencv-python`.") # validation assert isinstance(image, np.ndarray) and image.ndim == 3 and (image.shape[0] == 3 or image.shape[2] == 3),\ "image must be a ndarray in (H, W, C) or (C, H, W) format." if bboxes is not None: assert isinstance(bboxes, np.ndarray) and bboxes.ndim == 2 and (bboxes.shape[1] == 4 or bboxes.shape[1] == 5), \ "bboxes must be a ndarray in (N, 4) or (N, 5) format." assert isinstance(labels, np.ndarray) and labels.ndim == 2 and labels.shape[1] == 1 and \ labels.shape[0] == bboxes.shape[0], "labels must be a ndarray in (N, 1) format and has same N with bboxes." if segm is not None: assert isinstance(segm, np.ndarray) and segm.ndim == 3, "segm must be a ndarray in (M, H, W) format." H, W = (image.shape[0], image.shape[1]) if image.shape[2] == 3 else (image.shape[1], image.shape[2]) assert H == segm.shape[1] and W == segm.shape[2], "segm must has same height and width with image." if bboxes is not None: assert bboxes.shape[0] <= segm.shape[0], "number of segm masks must not be less than the number of bboxes." assert isinstance(class_names, (tuple, list, dict)), "class_names must be a list, tuple or dict." assert isinstance(bbox_color, tuple) and len(bbox_color) == 3, \ "bbox_color must be a three tuple, formatted (B, G, R)." assert isinstance(text_color, tuple) and len(text_color) == 3, \ "text_color must be a three tuple, formatted (B, G, R)." assert isinstance(mask_color, tuple) and len(mask_color) == 3, \ "mask_color must be a three tuple, formatted (B, G, R)." assert isinstance(thickness, int), "thickness must be a int." assert thickness >= 0, "thickness must be larger than or equal to zero." assert isinstance(font_size, (int, float)), "font_size must be a int or float." assert font_size >= 0, "font_size must be larger than or equal to zero." assert isinstance(show, bool), "show must be a bool." assert isinstance(win_name, str), "win_name must be a str." assert isinstance(wait_time, int), "wait_time must be a int." assert wait_time >= 0, "wait_time must be larger than or equal to zero." if out_file is not None: assert isinstance(out_file, str), "out_file must be a str." if score_threshold > 0: assert bboxes.shape[1] == 5 if not show: assert out_file is not None # image if image.shape[0] == 3: image = image.transpose((1, 2, 0)) draw_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) if bboxes is not None: bbox_num = bboxes.shape[0] for i in range(bbox_num): draw_bbox = bboxes[i] if len(draw_bbox) > 4: if draw_bbox[4] < score_threshold: continue # bbox x1, y1 = int(draw_bbox[0]), int(draw_bbox[1]) x2, y2 = int(draw_bbox[0]+draw_bbox[2]), int(draw_bbox[1]+draw_bbox[3]) cv2.rectangle(draw_image, (x1, y1), (x2, y2), bbox_color, thickness) # label try: draw_label = str(class_names[labels[i][0]]) if class_names is not None else f'class {labels[i][0]}' except (IndexError, KeyError): draw_label = f'class {labels[i][0]}' if len(draw_bbox) > 4: draw_label += f'|{draw_bbox[-1]:.02f}' cv2.putText(draw_image, draw_label, (x1, y2), cv2.FONT_HERSHEY_SIMPLEX, font_size, text_color, thickness) if segm is not None: mask = segm[i].astype(bool) draw_image[mask] = draw_image[mask] * 0.5 + np.array(mask_color) * 0.5 else: if segm is not None: segm_num = segm.shape[0] for i in range(segm_num): mask = segm[i].astype(bool) draw_image[mask] = draw_image[mask] * 0.5 + np.array(mask_color) * 0.5 if show: cv2.imshow(win_name, draw_image) if cv2.waitKey(wait_time) == 27: sys.exit() if out_file:"Saving image file with name: " + out_file + "...") cv2.imwrite(out_file, draw_image) os.chmod(out_file, 0o600) return draw_image