# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Radar map.
"""
from math import pi
import numpy as np
import matplotlib.pyplot as plt
from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_param_type, check_numpy_param, \
check_param_multi_types, check_equal_length
LOGGER = LogUtil.get_instance()
TAG = 'RadarMetric'
[docs]class RadarMetric:
"""
Radar chart to show the robustness of a model by multiple metrics.
Args:
metrics_name (Union[tuple, list]): An array of names of metrics to show.
metrics_data (numpy.ndarray): The (normalized) values of each metrics of
multiple radar curves, like [[0.5, 0.8, ...], [0.2,0.6,...], ...].
Each set of values corresponds to one radar curve.
labels (Union[tuple, list]): Legends of all radar curves.
title (str): Title of the chart.
scale (str): Scalar to adjust axis ticks, such as 'hide', 'norm',
'sparse' or 'dense'. Default: 'hide'.
Raises:
ValueError: If scale not in ['hide', 'norm', 'sparse', 'dense'].
Examples:
>>> from mindarmour.adv_robustness.evaluations import RadarMetric
>>> metrics_name = ['MR', 'ACAC', 'ASS', 'NTE', 'ACTC']
>>> def_metrics = [0.9, 0.85, 0.6, 0.7, 0.8]
>>> raw_metrics = [0.5, 0.3, 0.55, 0.65, 0.7]
>>> metrics_data = np.array([def_metrics, raw_metrics])
>>> metrics_labels = ['before', 'after']
>>> rm = RadarMetric(metrics_name,
... metrics_data,
... metrics_labels,
... title='',
... scale='sparse')
>>> #rm.show()
"""
def __init__(self, metrics_name, metrics_data, labels, title, scale='hide'):
self._metrics_name = check_param_multi_types('metrics_name',
metrics_name,
[tuple, list])
self._metrics_data = check_numpy_param('metrics_data', metrics_data)
self._labels = check_param_multi_types('labels', labels, (tuple, list))
_, _ = check_equal_length('metrics_name', metrics_name,
'metrics_data', self._metrics_data[0])
_, _ = check_equal_length('labels', labels, 'metrics_data', metrics_data)
self._title = check_param_type('title', title, str)
if scale in ['hide', 'norm', 'sparse', 'dense']:
self._scale = scale
else:
msg = "scale must be in ['hide', 'norm', 'sparse', 'dense'], but " \
"got {}".format(scale)
LOGGER.error(TAG, msg)
raise ValueError(msg)
self._nb_var = len(metrics_name)
# divide the plot / number of variable
self._angles = [n / self._nb_var*2.0*pi for n in
range(self._nb_var)]
self._angles += self._angles[:1]
# add one more point
data = [self._metrics_data, self._metrics_data[:, [0]]]
self._metrics_data = np.concatenate(data, axis=1)
[docs] def show(self):
"""
Show the radar chart.
"""
# Initialise the spider plot
plt.clf()
axis_pic = plt.subplot(111, polar=True)
axis_pic.spines['polar'].set_visible(False)
axis_pic.set_yticklabels([])
# If you want the first axis to be on top:
axis_pic.set_theta_offset(pi / 2)
axis_pic.set_theta_direction(-1)
# Draw one axe per variable + add labels labels yet
plt.xticks(self._angles[:-1], self._metrics_name)
# Draw y labels
axis_pic.set_rlabel_position(0)
if self._scale == 'hide':
plt.yticks([0.0], color="grey", size=7)
elif self._scale == 'norm':
plt.yticks([0.2, 0.4, 0.6, 0.8],
["0.2", "0.4", "0.6", "0.8"],
color="grey", size=7)
elif self._scale == 'sparse':
plt.yticks([0.5], ["0.5"], color="grey", size=7)
elif self._scale == 'dense':
ticks = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
labels = ["0.1", "0.2", "0.3", "0.4", "0.5", "0.6",
"0.7", "0.8", "0.9"]
plt.yticks(ticks, labels, color="grey", size=7)
else:
# default
plt.yticks([0.0], color="grey", size=7)
plt.ylim(0, 1)
# plot border
axis_pic.plot(self._angles, [1]*(self._nb_var + 1), color='grey',
linewidth=1, linestyle='solid')
for i in range(len(self._labels)):
axis_pic.plot(self._angles, self._metrics_data[i], linewidth=1,
linestyle='solid', label=self._labels[i])
axis_pic.fill(self._angles, self._metrics_data[i], alpha=0.1)
# Add legend
plt.legend(loc='upper right', bbox_to_anchor=(0., 0.))
plt.title(self._title, y=1.1, color='g')
plt.show()