Source code for mindspore_xai.benchmark.robustness

# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Robustness."""

import numpy as np

import mindspore as ms
import mindspore.nn as nn
from mindspore.train._utils import check_value_type
from mindspore import log

from mindspore_xai.explainer.perturb.replacement import RandomPerturb
from .metric import LabelSensitiveMetric


[docs]class Robustness(LabelSensitiveMetric): """ Robustness perturbs the inputs by adding random noise and choose the maximum sensitivity as evaluation score from the perturbations. Args: num_labels (int): Number of classes in the dataset. activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as when combining this function with network, the final output is the probability of the input. Raises: TypeError: Be raised for any argument type problem. Supported Platforms: ``Ascend`` ``GPU`` """ def __init__(self, num_labels, activation_fn): super().__init__(num_labels) check_value_type("activation_fn", activation_fn, nn.Cell) self._perturb = RandomPerturb() self._num_perturbations = 10 # number of perturbations used in evaluation self._threshold = 0.1 # threshold to generate perturbation self._activation_fn = activation_fn
[docs] def evaluate(self, explainer, inputs, targets, saliency=None): """ Evaluate robustness on the explainer. Note: Currently only single sample (:math:`N=1`) at each call is supported. Args: explainer (Explainer): The explainer to be evaluated, see `mindspore_xai.explainer`. inputs (Tensor): A data sample, a 4D tensor of shape :math:`(N, C, H, W)`. targets (Tensor, int): The label of interest. It should be a 1D or scalar tensor, or an integer. If `targets` is a 1D tensor, its length should be :math:`N`. saliency (Tensor, optional): The saliency map to be evaluated, a 4D tensor of shape :math:`(N, 1, H, W)`. If it is None, the parsed `explainer` will generate the saliency map with `inputs` and `targets` and continue the evaluation. Default: ``None``. Returns: numpy.ndarray, 1D array of shape :math:`(N,)`, result of robustness evaluated on `explainer`. Raises: TypeError: Be raised for any argument type problem. ValueError: Be raised if :math:`N` is not 1. Examples: >>> import numpy as np >>> import mindspore as ms >>> from mindspore import nn, set_context, PYNATIVE_MODE >>> from mindspore_xai.explainer import Gradient >>> from mindspore_xai.benchmark import Robustness >>> >>> set_context(mode=PYNATIVE_MODE) >>> # Initialize a Robustness benchmarker passing num_labels of the dataset. >>> num_labels = 10 >>> activation_fn = nn.Softmax() >>> robustness = Robustness(num_labels, activation_fn) >>> >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py >>> net = LeNet5(10, num_channel=3) >>> # prepare your explainer to be evaluated, e.g., Gradient. >>> gradient = Gradient(net) >>> input_x = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32) >>> target_label = ms.Tensor([0], ms.int32) >>> # robustness is a Robustness instance >>> res = robustness.evaluate(gradient, input_x, target_label) >>> print(res.shape) (1,) """ self._check_evaluate_param(explainer, inputs, targets, saliency) if inputs.shape[0] > 1: raise ValueError('Robustness only support a sample each time, but receive {}'.format(inputs.shape[0])) if isinstance(targets, int): targets = ms.Tensor([targets], ms.int32) if saliency is None: saliency = explainer(inputs, targets, show=False) saliency = saliency.asnumpy() norm = np.sqrt(np.sum(np.square(saliency), axis=tuple(range(1, len(saliency.shape))))) if (norm == 0).any(): log.warning('Get saliency norm equals 0, robustness return NaN for zero-norm saliency currently.') norm[norm == 0] = np.nan full_network = nn.SequentialCell([explainer.network, self._activation_fn]) original_outputs = full_network(inputs).asnumpy() sensitivities = [] inputs = inputs.asnumpy() for _ in range(self._num_perturbations): perturbations = [] for j, sample in enumerate(inputs): perturbation_on_single_sample = self._perturb_with_threshold(full_network, np.expand_dims(sample, axis=0), original_outputs[j]) perturbations.append(perturbation_on_single_sample) perturbations = np.vstack(perturbations) perturbations = explainer(ms.Tensor(perturbations, ms.float32), targets, show=False).asnumpy() sensitivity = np.sqrt(np.sum((perturbations - saliency) ** 2, axis=tuple(range(1, len(saliency.shape))))) sensitivities.append(sensitivity) sensitivities = np.stack(sensitivities, axis=-1) sensitivity = np.max(sensitivities, axis=1) / norm return 1 / np.exp(sensitivity)
def _perturb_with_threshold(self, network: nn.Cell, sample: np.ndarray, original_output: np.ndarray) -> np.ndarray: """ Generate the perturbation until the L2-distance between original_output and perturbation_output is lower than the given self._threshold or until the attempt reaches the max_attempt_time. """ # the maximum time attempt to get a perturbation with perturb_error low than self._threshold max_attempt_time = 3 perturbation = None for _ in range(max_attempt_time): perturbation = self._perturb(sample) perturbation_output = self._activation_fn(network(ms.Tensor(sample, ms.float32))).asnumpy() perturb_error = np.linalg.norm(original_output - perturbation_output) if perturb_error <= self._threshold: return perturbation return perturbation