Source code for mindvision.utils.charts

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The chart API of paper experiment part."""

from typing import Optional, Dict, List, Tuple
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import ImageGrid

from mindvision.check_param import Validator

__all__ = [
    "topn_accuracy_chart",
    "accuracy_on_dataset_chart_v1",
    "accuracy_on_dataset_chart_v2",
    "accuracy_on_dataset_chart_v3",
    "accuracy_model_size_chart",
    "accuracy_model_flops_chart",
    "pos_embedding_cosine_chart",
]

font_format = {
    'family': 'Arial',
    'size': 12
}

# color offers 20 types
color = [
    'darkred',
    'darkgrey',
    'royalblue',
    'pink',
    'forestgreen',
    'steelblue',
    'orange',
    'black',
    'darkorange',
    'slategrey',
    'lightpink',
    'rosybrown',
    'goldenrod',
    'mediumturquoise',
    'mediumpurple',
    'slategray',
    'saddlebrown',
    'lawngreen',
    'purple',
    'teal']

# marker offers 20 types
marker = ['s', '*', 'o', '8', 'v',
          '<', '>', 'h', 'x', '^',
          's', 'P', 'd', 'P', 'D',
          '1', '2', '3', '4', '.']


[docs]def topn_accuracy_chart(accuracy_data: Dict, save_path: str = './', ylim: Optional[List] = None, figsize: Optional[Tuple] = None, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None): """ Accuracy charts, xlabel can be network models or iteration numbers, ylabel is the accuracy. Args: accuracy_data(dict): The accuracy data of models on different AI frame. save_path(str): The save path of line chart. Default: './'. ylim (list, optional): The range of y coordinate. Default: None. figsize (tuple, optional): The size of figure. Default: None. title(str, optional): The title of graph. Default: None. xlabel(str, optional): The Label of x coordinate. Default: None. ylabel(str, optional): The Label of y coordinate. Default: None. Examples: >>> accuracy_data = {'MindSpore': {'Resnet18': 70.078, 'Resnet34': 73.72, 'Resnet50': 76.6}, ... 'Pytorch': {'Resnet18': 69.758, 'Resnet34': 73.31, 'Resnet50': 76.13}} >>> topn_accuracy_chart(accuracy_data=accuracy_data) """ plt.figure(figsize=figsize) for index, (label, data) in enumerate(accuracy_data.items()): if index >= len(color): raise ValueError(f'The number of labels exceeds {len(color)}.') x, y = [], [] for key, value in data.items(): x.append(key) y.append(value) plt.plot(x, y, color=color[index], marker=marker[index], label=label) rotation = 90 if len(x[0]) > 8 else None plt.legend(fontsize=font_format['size']) plt.xlabel(xlabel, fontdict=font_format) plt.ylabel(ylabel, fontdict=font_format) plt.xticks(fontsize=font_format['size'] - 2, rotation=rotation) plt.yticks(fontsize=font_format['size'] - 2, rotation=rotation) plt.title(title, fontdict=font_format) plt.grid(linestyle='--') plt.ylim(ylim) plt.savefig(save_path, bbox_inches="tight")
[docs]def accuracy_on_dataset_chart_v1(accuracy_data: Dict, save_path: str = './', ylim: Optional[List] = None, figsize: Optional[Tuple] = None, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None ): """ The function is used to plot the accuracy range between models on different dataset. Args: accuracy_data (dict): The data of line chart models. save_path (str): Path to save the chart. Default: './'. ylim (list, optional): The range of y coordinate. Default: None. figsize (tuple, optional): The size of figure. Default: None. title (str, optional): The title of chart. Default: None. xlabel (str, optional): The Label of x coordinate. Default: None. ylabel (str, optional): The Label of y coordinate. Default: None. Examples: >>> accuracy_data = { ... 'ResNet50': { ... 'accuracy': { ... 'ImageNet': 76.8, ... 'ImageNet21K': 80.2, ... 'JFT-300M': 79.2, ... }, ... 'marker_size': 4 ... }, ... 'ResNet152': { ... 'accuracy': { ... 'ImageNet': 81.2, ... 'ImageNet21K': 85.5, ... 'JFT-300M': 87.8, ... }, ... 'marker_size': 6 ... } ... } >>> accuracy_on_dataset_chart_v1(accuracy_data=accuracy_data) """ plt.figure(figsize=figsize) len_of_models = len(list(accuracy_data.keys())) Validator.check_equal_int(len_of_models, 2, 'The Number of line chart models') shade_x, shade_y = [], [] for model, data in accuracy_data.items(): line_chart_x = [] line_chart_y = [] for dataset, accuracy in data['accuracy'].items(): line_chart_x.append(dataset) line_chart_y.append(accuracy) plt.plot( line_chart_x, line_chart_y, c='darkgrey', marker='s', markersize=data['marker_size'], label=model, alpha=0.5) shade_x.append(line_chart_x) shade_y.append(line_chart_y) assert shade_x[0] == shade_x[1] plt.fill_between( shade_x[0], shade_y[0], shade_y[1], facecolor='darkgrey', alpha=0.1) y_major_locator = plt.MultipleLocator(5) y_minor_locator = plt.MultipleLocator(1) ax = plt.gca() ax.yaxis.set_major_locator(y_major_locator) ax.yaxis.set_minor_locator(y_minor_locator) plt.legend(fontsize=font_format['size']) plt.xlabel(xlabel, fontdict=font_format) plt.ylabel(ylabel, fontdict=font_format) plt.xticks(fontsize=font_format['size'] - 2) plt.yticks(fontsize=font_format['size'] - 2) plt.title(title, fontdict=font_format) plt.grid(which='minor', linestyle=':', alpha=0.3) plt.grid(which='major', linestyle='-', linewidth=1.0, alpha=0.5) plt.ylim(ylim) plt.savefig(save_path, bbox_inches="tight")
[docs]def accuracy_on_dataset_chart_v2(accuracy_data: Dict, save_path: str = './', ylim: Optional[List] = None, figsize: Optional[Tuple] = None, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None ): """ The function is used to plot the accuracy comparison between models on different dataset. Args: accuracy_data (dict): The data of scatter models. save_path (str): Path to save the chart. Default: './'. ylim (list, optional): The range of y coordinate. Default: None. figsize (tuple, optional): The size of figure. Default: None. title (str, optional): The title of chart. Default: None. xlabel (str, optional): The Label of x coordinate. Default: None. ylabel (str, optional): The Label of y coordinate. Default: None. Examples: >>> accuracy_data = { ... 'ViT-B_32': { ... 'accuracy': { ... 'ImageNet': 73.38, ... 'ImageNet21K': 81.28, ... 'JFT-300M': 80.73, ... }, ... 'marker_size': 45 ... } ... } >>> accuracy_on_dataset_chart_v2(accuracy_data=accuracy_data) """ plt.figure(figsize=figsize) for index, (model, data) in enumerate(accuracy_data.items()): scatter_x, scatter_y = [], [] for dataset, accuracy in data['accuracy'].items(): scatter_x.append(dataset) scatter_y.append(accuracy) plt.scatter(scatter_x, scatter_y, c=color[index + 2], s=data['marker_size'], marker='o', label=model, alpha=0.5) y_major_locator = plt.MultipleLocator(5) y_minor_locator = plt.MultipleLocator(1) ax = plt.gca() ax.yaxis.set_major_locator(y_major_locator) ax.yaxis.set_minor_locator(y_minor_locator) plt.legend(fontsize=font_format['size']) plt.xlabel(xlabel, fontdict=font_format) plt.ylabel(ylabel, fontdict=font_format) plt.xticks(fontsize=font_format['size'] - 2) plt.yticks(fontsize=font_format['size'] - 2) plt.title(title, fontdict=font_format) plt.grid(which='minor', linestyle=':', alpha=0.3) plt.grid(which='major', linestyle='-', linewidth=1.0, alpha=0.5) plt.ylim(ylim) plt.savefig(save_path, bbox_inches="tight")
[docs]def accuracy_on_dataset_chart_v3(line_models_data: Dict, scatter_models_data: Dict, save_path: str = './', ylim: Optional[List] = None, figsize: Optional[Tuple] = None, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None ): """ The function is used to plot the accuracy comparison between architectures on different dataset. Args: line_models_data (dict): The data of line chart models. scatter_models_data (dict): The data of scatter models. save_path (str): Path to save the chart. Default: './'. ylim (list, optional): The range of y coordinate. Default: None. figsize (tuple, optional): The size of figure. Default: None. title (str, optional): The title of chart. Default: None. xlabel (str, optional): The Label of x coordinate. Default: None. ylabel (str, optional): The Label of y coordinate. Default: None. Examples: >>> line_models_data = { ... 'ResNet50': { ... 'accuracy': { ... 'ImageNet': 76.8, ... 'ImageNet21K': 80.2, ... 'JFT-300M': 79.2, ... }, ... 'marker_size': 4 ... }, ... 'ResNet152': { ... 'accuracy': { ... 'ImageNet': 81.2, ... 'ImageNet21K': 85.5, ... 'JFT-300M': 87.8, ... }, ... 'marker_size': 6 ... } ... } >>> scatter_models_data = { ... 'ViT-B_32': { ... 'accuracy': { ... 'ImageNet': 73.38, ... 'ImageNet21K': 81.28, ... 'JFT-300M': 80.73, ... }, ... 'marker_size': 45 ... } ... } >>> accuracy_on_dataset_chart_v3(line_models_data=line_models_data, ... scatter_models_data=scatter_models_data) """ plt.figure(figsize=figsize) len_of_models = len(list(line_models_data.keys())) Validator.check_equal_int(len_of_models, 2, 'The Number of line chart models') shade_x, shade_y = [], [] for model, data in line_models_data.items(): line_chart_x = [] line_chart_y = [] for dataset, accuracy in data['accuracy'].items(): line_chart_x.append(dataset) line_chart_y.append(accuracy) plt.plot( line_chart_x, line_chart_y, c='darkgrey', marker='s', markersize=data['marker_size'], label=model, alpha=0.5) shade_x.append(line_chart_x) shade_y.append(line_chart_y) for index, (model, data) in enumerate(scatter_models_data.items()): scatter_x, scatter_y = [], [] for dataset, accuracy in data['accuracy'].items(): scatter_x.append(dataset) scatter_y.append(accuracy) plt.scatter(scatter_x, scatter_y, c=color[index + 2], s=data['marker_size'], marker='o', label=model, alpha=0.5) assert shade_x[0] == shade_x[1] plt.fill_between( shade_x[0], shade_y[0], shade_y[1], facecolor='darkgrey', alpha=0.1) y_major_locator = plt.MultipleLocator(5) y_minor_locator = plt.MultipleLocator(1) ax = plt.gca() ax.yaxis.set_major_locator(y_major_locator) ax.yaxis.set_minor_locator(y_minor_locator) plt.legend(fontsize=font_format['size'], ncol=2) plt.xlabel(xlabel, fontdict=font_format) plt.ylabel(ylabel, fontdict=font_format) plt.xticks(fontsize=font_format['size'] - 2) plt.yticks(fontsize=font_format['size'] - 2) plt.title(title, fontdict=font_format) plt.grid(which='minor', linestyle=':', alpha=0.3) plt.grid(which='major', linestyle='-', linewidth=1.0, alpha=0.5) plt.ylim(ylim) plt.savefig(save_path, bbox_inches="tight")
[docs]def accuracy_model_size_chart(accuracy_data: Dict, size_unit: str, save_path: str = './', ylim: Optional[List] = None, figsize: Optional[Tuple] = None, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None ): """ The function is used to plot the accuracy comparison between models on different pre-trained dataset size. Args: accuracy_data (dict): The accuracy of model on different pre-trained dataset size. size_unit: Units for dataset size. save_path (str): Path to save the chart. Default: './'. ylim (list, optional): The range of y coordinate. Default: None. figsize (tuple, optional): The size of figure. Default: None. title (str, optional): The title of chart. Default: None. xlabel (str, optional): The Label of x coordinate. Default: None. ylabel (str, optional): The Label of y coordinate. Default: None. Examples: >>> accuracy_data = { ... 'ViT-b_32': { ... 10: 37, ... 30: 41, ... 100: 41.5 ... }, ... 'ViT-B_32': { ... 10: 38, ... 30: 53, ... 100: 54 ... } ... } >>> accuracy_model_size_chart(accuracy_data=accuracy_data, size_unit='M') """ plt.figure(figsize=figsize) for index, (model, data) in enumerate(accuracy_data.items()): x, y = [], [] for size, accuracy in data.items(): x.append(size) y.append(accuracy) plt.plot(x, y, c=color[index], marker='o', markersize=8, label=model) y_major_locator = plt.MultipleLocator(10) y_minor_locator = plt.MultipleLocator(2) ax = plt.gca() ax.yaxis.set_major_locator(y_major_locator) ax.yaxis.set_minor_locator(y_minor_locator) ax.yaxis.grid(which='major', linestyle='-', linewidth=1.0, alpha=0.5) ax.yaxis.grid(which='minor', linestyle=':', alpha=0.3) plt.xscale('log') plt.legend(fontsize=font_format['size'], ncol=3, loc='lower right') plt.xlabel(xlabel, fontdict=font_format) plt.ylabel(ylabel, fontdict=font_format) x_ticks_label = [] for i in x: i = str(i) + size_unit x_ticks_label.append(i) plt.xticks(x, x_ticks_label, fontsize=font_format['size'] - 2) plt.yticks(fontsize=font_format['size'] - 2) plt.title(title, fontdict=font_format) plt.ylim(ylim) plt.savefig(save_path, bbox_inches="tight")
[docs]def accuracy_model_flops_chart(accuracy_data: Dict, save_path: str = './', ylim: Optional[List] = None, figsize: Optional[Tuple] = None, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None ): """ The function is used to plot the accuracy comparison between architectures on different pre-trained compute. Args: accuracy_data (dict): The accuracy of model on different pre-trained compute. save_path (str): Path to save the chart. Default: './'. ylim (list, optional): The range of y coordinate. Default: None. figsize (tuple, optional): The size of figure. Default: None. title (str, optional): The title of chart. Default:None. xlabel (str, optional): The Label of x coordinate. Default: None. ylabel (str, optional): The Label of y coordinate. Default: None. Examples: >>> accuracy_data = { ... 'Transform(ViT)': { ... "vit-B_32_7": {55: 80.73}, ... "ViT-B_16_7": {224: 84.15}, ... "ViT-L_32_7": {196: 84.37}, ... "ViT-L_16_7": {783: 86.30} ... } ... } >>> accuracy_model_flops_chart(accuracy_data=accuracy_data, ylim=[75, 90]) """ plt.figure(figsize=figsize) for index, (architecture, model_data) in enumerate(accuracy_data.items()): x, y = [], [] for _, data in model_data.items(): for flops, accuracy in data.items(): x.append(flops) y.append(accuracy) plt.scatter(x, y, c=color[index], s=80, marker=marker[index], label=architecture) y_major_locator = plt.MultipleLocator(5) y_minor_locator = plt.MultipleLocator(1) ax = plt.gca() ax.yaxis.set_major_locator(y_major_locator) ax.yaxis.set_minor_locator(y_minor_locator) plt.xscale('log') plt.legend(fontsize=font_format['size'], loc='lower right') plt.xlabel(xlabel, fontdict=font_format) plt.ylabel(ylabel, fontdict=font_format) plt.xticks(fontsize=font_format['size'] - 2) plt.yticks(fontsize=font_format['size'] - 2) plt.grid(which='minor', linestyle=':', alpha=0.3) plt.grid(which='major', linestyle='-', linewidth=1.0, alpha=0.5) plt.title(title, fontdict=font_format) plt.ylim(ylim) plt.savefig(save_path, bbox_inches="tight")
[docs]def pos_embedding_cosine_chart(pos_embedding: np.ndarray, save_path: str = './', title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None, colorbar_label: Optional[str] = None ): """ The function is used to plot the cosine similarity of position embedding. Args: pos_embedding (ndarray): The data of position embedding. save_path (str): Path to save the chart. Default: './'. title (str, optional): The title of chart. Default: None. xlabel (str, optional): The Label of x coordinate. Default: None. ylabel (str, optional): The Label of y coordinate. Default: None. colorbar_label (str, optional): The Label of colorbar. Default: None. Examples: >>> pos_embedding = np.random.randn(1, 50, 768) >>> pos_embedding_cosine_chart(pos_embedding=pos_embedding) """ def cosine_similarity(x, y): x_dot_y = x.dot(y) l2_norm = np.linalg.norm(x, ord=2) * np.linalg.norm(x, ord=2) sim = x_dot_y / l2_norm return sim pos_embedding = pos_embedding.squeeze() rm_cls_pos_embedding = pos_embedding[1:, :] num_pos = rm_cls_pos_embedding.shape[0] len_side = int(np.sqrt(rm_cls_pos_embedding.shape[0])) cos = np.zeros((num_pos, num_pos)) for i in range(num_pos): for j in range(num_pos): cos[i, j] = cosine_similarity(rm_cls_pos_embedding[i, :], rm_cls_pos_embedding[j, :]) cos = cos.reshape((num_pos, len_side, len_side)) fig = plt.figure(figsize=(len_side, len_side)) grid = ImageGrid(fig, 111, nrows_ncols=(len_side, len_side), share_all=True, axes_pad=0.2, cbar_location="right", cbar_mode="single", cbar_pad=0.2, ) i = 0 for ax in grid: image = ax.imshow(cos[i, :, :], vmin=-1, vmax=1) ax.set_xticks([]) ax.set_yticks([]) ax.set_xlabel(int(i % len_side + 1), fontsize='xx-large') ax.set_ylabel(int(i / len_side + 1), fontsize='xx-large') i += 1 cb = plt.colorbar(image, cax=ax.cax, ticks=[-1, 1]) cb.set_label(colorbar_label, fontsize='xx-large') fig.suptitle(title, fontsize='xx-large') fig.supxlabel(xlabel, fontsize='xx-large') fig.supylabel(ylabel, fontsize='xx-large') plt.savefig(save_path, bbox_inches="tight")