# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Gradient explainer."""
from copy import deepcopy
from mindspore import Tensor
from mindspore.train._utils import check_value_type
from mindspore_xai.common.utils import abs_max, unify_inputs, unify_targets
from mindspore_xai.common.attribution import Attribution
from .backprop_utils import get_bp_weights, GradNet
[docs]class Gradient(Attribution):
r"""
Provides Gradient explanation method.
Gradient is the simplest attribution method which uses the naive gradients of outputs w.r.t inputs as the
explanation.
.. math::
attribution = \frac{\partial{y}}{\partial{x}}
Note:
The parsed `network` will be set to eval mode through `network.set_grad(False)` and `network.set_train(False)`.
If you want to train the `network` afterwards, please reset it back to training mode through the opposite
operations.
Args:
network (Cell): The black-box model to be explained.
Inputs:
- **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
- **targets** (Tensor, int, tuple, list) - The label of interest. It should be a 1D or scalar tensor, or an
integer, or a tuple/list of integers. If it is a 1D tensor, tuple or list, its length should be :math:`N`.
- **ret** (str, optional): The return object type. ``'tensor'`` means returns a Tensor object, ``'image'``
means return a PIL.Image.Image list. Default: ``'tensor'``.
- **show** (bool, optional): Show the saliency images, ``None`` means automatically show the saliency images
if it is running on JupyterLab. Default: ``None``.
Outputs:
Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`, saliency maps. Or list[list[PIL.Image.Image]], the
normalized saliency images if `ret` was set to ``'image'``.
Raises:
TypeError: Be raised for any argument or input type problem.
ValueError: Be raised for any input value problem.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import set_context, PYNATIVE_MODE
>>> from mindspore_xai.explainer import Gradient
>>>
>>> set_context(mode=PYNATIVE_MODE)
>>> # The detail of LeNet5 is shown in models.official.cv.lenet.src.lenet.py
>>> net = LeNet5(10, num_channel=3)
>>> gradient = Gradient(net)
>>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
>>> label = 5
>>> saliency = gradient(inputs, label)
>>> print(saliency.shape)
(1, 1, 32, 32)
"""
def __init__(self, network):
super(Gradient, self).__init__(network)
self._backward_model = deepcopy(network)
self._backward_model.set_train(False)
self._backward_model.set_grad(False)
self._grad_net = GradNet(self._backward_model)
self._aggregation_fn = abs_max
self._num_classes = None
def __call__(self, inputs, targets, ret='tensor', show=None):
"""Call function for `Gradient`."""
self._verify_data(inputs, targets)
self._verify_other_args(ret, show)
inputs = unify_inputs(inputs)
targets = unify_targets(inputs[0].shape[0], targets)
weights = self._get_bp_weights(inputs, targets)
gradient = self._grad_net(*inputs, weights)
saliency = self._aggregation_fn(gradient)
return self._postproc_saliency(saliency, ret, show)
def _get_bp_weights(self, unified_inputs, unified_targets):
if self._num_classes is None:
output = self._backward_model(*unified_inputs)
self._num_classes = output.shape[-1]
return get_bp_weights(self._num_classes, unified_targets)
@staticmethod
def _verify_data(inputs, targets):
"""
Verify the validity of the parsed inputs.
Args:
inputs (Tensor): The inputs to be explained.
targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer.
If it is a 1D tensor, its length should be the same as `inputs`.
"""
check_value_type('inputs', inputs, Tensor)
if len(inputs.shape) != 4:
raise ValueError(f'Argument inputs must be 4D Tensor. But got {len(inputs.shape)}D Tensor.')
check_value_type('targets', targets, (Tensor, int, tuple, list))
if isinstance(targets, Tensor):
if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != len(inputs)):
raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, '
'it should have the same length as inputs.')