mindspore_xai.explanation

Predefined Attribution explainers.

class mindspore_xai.explanation.Deconvolution(network)[source]

Deconvolution explanation.

Deconvolution method is a modified version of Gradient method. For the original ReLU operation in the network to be explained, Deconvolution modifies the propagation rule from directly backpropagating gradients to backprpagating positive gradients.

Note

The parsed network will be set to eval mode through network.set_grad(False) and network.set_train(False). If you want to train the network afterwards, please reset it back to training mode through the opposite operations. To use Deconvolution, the ReLU operations in the network must be implemented with mindspore.nn.Cell object rather than mindspore.ops.Operations.ReLU. Otherwise, the results will not be correct.

Parameters

network (Cell) – The black-box model to be explained.

Inputs:
  • inputs (Tensor) - The input data to be explained, a 4D tensor of shape \((N, C, H, W)\).

  • targets (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. If it is a 1D tensor, its length should be the same as inputs.

Outputs:

Tensor, a 4D tensor of shape \((N, 1, H, W)\).

Raises
  • TypeError – Be raised for any argument or input type problem.

  • ValueError – Be raised for any input value problem.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import context
>>> from mindspore_xai.explanation import Deconvolution
>>>
>>> context.set_context(mode=context.PYNATIVE_MODE)
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
>>> net = LeNet5(10, num_channel=3)
>>> deconvolution = Deconvolution(net)
>>> # parse data and the target label to be explained and get the saliency map
>>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
>>> label = 5
>>> saliency = deconvolution(inputs, label)
>>> print(saliency.shape)
(1, 1, 32, 32)
class mindspore_xai.explanation.GradCAM(network, layer='')[source]

Provides GradCAM explanation method.

GradCAM generates saliency map at intermediate layer. The attribution is obtained as:

\[ \begin{align}\begin{aligned}\alpha_k^c = \frac{1}{Z} \sum_i \sum_j \frac{\partial{y^c}}{\partial{A_{i,j}^k}}\\attribution = ReLU(\sum_k \alpha_k^c A^k)\end{aligned}\end{align} \]

For more details, please refer to the original paper: GradCAM.

Note

The parsed network will be set to eval mode through network.set_grad(False) and network.set_train(False). If you want to train the network afterwards, please reset it back to training mode through the opposite operations.

Parameters
  • network (Cell) – The black-box model to be explained.

  • layer (str, optional) – The layer name to generate the explanation, usually chosen as the last convolutional layer for better practice. If it is ‘’, the explanation will be generated at the input layer. Default: ‘’.

Inputs:
  • inputs (Tensor) - The input data to be explained, a 4D tensor of shape \((N, C, H, W)\).

  • targets (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. If it is a 1D tensor, its length should be the same as inputs.

Outputs:

Tensor, a 4D tensor of shape \((N, 1, H, W)\), saliency maps.

Raises
  • TypeError – Be raised for any argument or input type problem.

  • ValueError – Be raised for any input value problem.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import context
>>> from mindspore_xai.explanation import GradCAM
>>>
>>> context.set_context(mode=context.PYNATIVE_MODE)
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
>>> net = LeNet5(10, num_channel=3)
>>> # specify a layer name to generate explanation, usually the layer can be set as the last conv layer.
>>> layer_name = 'conv2'
>>> # init GradCAM with a trained network and specify the layer to obtain attribution
>>> gradcam = GradCAM(net, layer=layer_name)
>>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
>>> label = 5
>>> saliency = gradcam(inputs, label)
>>> print(saliency.shape)
(1, 1, 32, 32)
class mindspore_xai.explanation.Gradient(network)[source]

Provides Gradient explanation method.

Gradient is the simplest attribution method which uses the naive gradients of outputs w.r.t inputs as the explanation.

\[attribution = \frac{\partial{y}}{\partial{x}}\]

Note

The parsed network will be set to eval mode through network.set_grad(False) and network.set_train(False). If you want to train the network afterwards, please reset it back to training mode through the opposite operations.

Parameters

network (Cell) – The black-box model to be explained.

Inputs:
  • inputs (Tensor) - The input data to be explained, a 4D tensor of shape \((N, C, H, W)\).

  • targets (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. If it is a 1D tensor, its length should be the same as inputs.

Outputs:

Tensor, a 4D tensor of shape \((N, 1, H, W)\), saliency maps.

Raises
  • TypeError – Be raised for any argument type problem.

  • ValueError – Be raised for any input value problem.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import context
>>> from mindspore_xai.explanation import Gradient
>>>
>>> context.set_context(mode=context.PYNATIVE_MODE)
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
>>> net = LeNet5(10, num_channel=3)
>>> gradient = Gradient(net)
>>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
>>> label = 5
>>> saliency = gradient(inputs, label)
>>> print(saliency.shape)
(1, 1, 32, 32)
class mindspore_xai.explanation.GuidedBackprop(network)[source]

Guided-Backpropagation explanation.

Guided-Backpropagation method is an extension of Gradient method. On top of the original ReLU operation in the network to be explained, Guided-Backpropagation introduces another ReLU operation to filter out the negative gradients during backpropagation.

Note

The parsed network will be set to eval mode through network.set_grad(False) and network.set_train(False). If you want to train the network afterwards, please reset it back to training mode through the opposite operations. To use GuidedBackprop, the ReLU operations in the network must be implemented with mindspore.nn.Cell object rather than mindspore.ops.Operations.ReLU. Otherwise, the results will not be correct.

Parameters

network (Cell) – The black-box model to be explained.

Inputs:
  • inputs (Tensor) - The input data to be explained, a 4D tensor of shape \((N, C, H, W)\).

  • targets (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. If it is a 1D tensor, its length should be the same as inputs.

Outputs:

Tensor, a 4D tensor of shape \((N, 1, H, W)\), saliency maps.

Raises
  • TypeError – Be raised for any argument or input type problem.

  • ValueError – Be raised for any input value problem.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import context
>>> from mindspore_xai.explanation import GuidedBackprop
>>>
>>> context.set_context(mode=context.PYNATIVE_MODE)
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
>>> net = LeNet5(10, num_channel=3)
>>> gbp = GuidedBackprop(net)
>>> # feed data and the target label to be explained and get the saliency map
>>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
>>> label = 5
>>> saliency = gbp(inputs, label)
>>> print(saliency.shape)
(1, 1, 32, 32)
class mindspore_xai.explanation.Occlusion(network, activation_fn, perturbation_per_eval=32)[source]

Occlusion uses a sliding window to replace the pixels with a reference value (e.g. constant value), and computes the output difference w.r.t the original output. The output difference caused by perturbed pixels are assigned as feature importance to those pixels. For pixels involved in multiple sliding windows, the feature importance is the averaged differences from multiple sliding windows.

For more details, please refer to the original paper via: https://arxiv.org/abs/1311.2901.

Parameters
  • network (Cell) – The black-box model to be explained.

  • activation_fn (Cell) – The activation layer that transforms logits to prediction probabilities. For single label classification tasks, nn.Softmax is usually applied. As for multi-label classification tasks,`nn.Sigmoid` is usually be applied. Users can also pass their own customized activation_fn as long as when combining this function with network, the final output is the probability of the input.

  • perturbation_per_eval (int, optional) – Number of perturbations for each inference during inferring the perturbed samples. Within the memory capacity, usually the larger this number is, the faster the explanation is obtained. Default: 32.

Inputs:
  • inputs (Tensor) - The input data to be explained, a 4D tensor of shape \((N, C, H, W)\).

  • targets (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. If it is a 1D tensor, its length should be the same as inputs.

Outputs:

Tensor, a 4D tensor of shape \((N, 1, H, W)\), saliency maps.

Raises
  • TypeError – Be raised for any argument or input type problem.

  • ValueError – Be raised for any input value problem.

Supported Platforms:

Ascend GPU

Example

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import context
>>> from mindspore.explainer.explanation import Occlusion
>>>
>>> context.set_context(mode=context.PYNATIVE_MODE)
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
>>> net = LeNet5(10, num_channel=3)
>>> # initialize Occlusion explainer with the pretrained model and activation function
>>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities
>>> occlusion = Occlusion(net, activation_fn=activation_fn)
>>> input_x = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
>>> label = ms.Tensor([1], ms.int32)
>>> saliency = occlusion(input_x, label)
>>> print(saliency.shape)
(1, 1, 32, 32)
class mindspore_xai.explanation.OoDNet(underlying, num_classes)[source]

Out of distribution network.

OoDNet takes a underlying classifier and outputs the out of distribution scores of samples.

Note

A training of OoDNet is required with the classifier’s training dataset inorder to give the correct OoD scores.

Parameters
  • underlying (Cell) – The underlying classifier, it must has the ‘num_features’ (int) and ‘output_features’ (bool) attributes, please check the example code for the details.

  • num_classes (int) – The number of classes for the classifier.

Returns

Tensor, classification logits (if set_train(True) was called) or

OOD scores (if set_train(False) was called). In the shape of [batch_size, num_classes].

Raises
  • TypeError – Be raised for any argument or input type problem.

  • ValueError – Be raised for any input value problem.

  • AttributeError – Be raised for underlying is missing any required attribute.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import context, nn
>>> from mindspore_xai.explanation import OoDNet
>>> from mindspore.common.initializer import Normal
>>>
>>>
>>> class MyLeNet5(nn.Cell):
>>>
>>>    def __init__(self, num_class, num_channel):
>>>        super(MyLeNet5, self).__init__()
>>>
>>>        # must add the following 2 attributes to your model
>>>        self.num_features = 84 # no. of features, int
>>>        self.output_features = False # output features flag, bool
>>>
>>>        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
>>>        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
>>>        self.relu = nn.ReLU()
>>>        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
>>>        self.flatten = nn.Flatten()
>>>        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
>>>        self.fc2 = nn.Dense(120, self.num_features, weight_init=Normal(0.02))
>>>        self.fc3 = nn.Dense(self.num_features, num_class, weight_init=Normal(0.02))
>>>
>>>    def construct(self, x):
>>>        x = self.conv1(x)
>>>        x = self.relu(x)
>>>        x = self.max_pool2d(x)
>>>        x = self.conv2(x)
>>>        x = self.relu(x)
>>>        x = self.max_pool2d(x)
>>>        x = self.flatten(x)
>>>        x = self.relu(self.fc1(x))
>>>        x = self.relu(self.fc2(x))
>>>
>>>        # return the features tensor if output_features is True
>>>        if self.output_features:
>>>            return x
>>>
>>>        x = self.fc3(x)
>>>        return x
>>>
>>> context.set_context(mode=context.PYNATIVE_MODE)
>>> # prepare trained classifier
>>> net = MyLeNet5(10, num_channel=3)
>>> param_dict = load_checkpoint('mylenet5.ckpt')
>>> load_param_into_net(net, param_dict)
>>> # prepare train_dataset and your OoD network
>>> train_dataset = create_dataset_cifar10("/path/to/cifar/dataset")
>>> ood_net = OoDNet(net, 10)
>>> ood_net.train(train_dataset, nn.SoftmaxCrossEntropyWithLogits())
>>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
>>> ood_map = ood_net(inputs)
>>> print(ood_map.shape)
(1, 10)
construct(x)[source]

Forward inferences the classification logits or OOD scores.

Returns

Tensor, logits of softmax with temperature (if set_train(True) was called) or

OOD scores (if set_train(False) was called). In the shape of [batch_size, num_classes].

get_train_parameters(train_underlying=False)[source]

Get the training parameters.

Returns

list[Parameter], parameters.

property num_classes

Get the number of classes.

Returns

int, the number of classes.

prepare_train(learning_rate=0.1, momentum=0.9, weight_decay=0.0001, lr_base_factor=0.1, lr_epoch_denom=30, train_underlying=False)[source]

Creates necessities for training.

Parameters
  • learning_rate (float) – The optimizer learning rate.

  • momentum (float) – The optimizer momentum.

  • weight_decay (float) – The optimizer weight decay.

  • lr_base_factor (float) – The base scaling factor of learning rate scheduler.

  • lr_epoch_denom (int) – The epoch denominator of learning rate scheduler.

  • train_underlying (bool) – True to train the underlying classifier as well.

Returns

  • Optimizer, optimizer.

  • LearningRateScheduler, learning rate scheduler.

set_train(mode=True)[source]

Set training mode.

Parameters

mode (bool) – It is in training mode.

train(dataset, loss_fn, callbacks=None, epoch=90, optimizer=None, scheduler=None, **kwargs)[source]

Trains this OOD net.

Parameters
  • dataset (Dataset) – The training dataset, expecting (data, one-hot label) items.

  • loss_fn (Cell) – The loss function, if the classifier’s activation function is nn.Softmax(), then use nn.SoftmaxCrossEntropyWithLogits(), if the activation function is nn.Sigmod(), then use nn.BCEWithLogitsLoss().

  • callbacks (Callback, optional) – The train callbacks.

  • epoch (int, optional) – The number of epochs to be trained. Default: 90.

  • optimizer (Optimizer, optional) – The optimizer. The one from prepare_train() will be used if which is set to None.

  • scheduler (LearningRateScheduler, optional) – The learning rate scheduler. The one from prepare_train() will be used if which is set to None.

  • **kwargs (any, optional) – Keyword arguments for prepare_train().

property underlying

Get the underlying classifier.

Returns

nn.Cell, the underlying classifier.

class mindspore_xai.explanation.RISE(network, activation_fn, perturbation_per_eval=32)[source]

RISE: Randomized Input Sampling for Explanation of Black-box Model.

RISE is a perturbation-based method that generates attribution maps by sampling on multiple random binary masks. The original image is randomly masked, and then fed into the black-box model to get predictions. The final attribution map is the weighted sum of these random masks, with the weights being the corresponding output on the node of interest:

\[attribution = \sum_{i}f_c(I\odot M_i) M_i\]

For more details, please refer to the original paper via: RISE.

Parameters
  • network (Cell) – The black-box model to be explained.

  • activation_fn (Cell) – The activation layer that transforms logits to prediction probabilities. For single label classification tasks, nn.Softmax is usually applied. As for multi-label classification tasks, nn.Sigmoid is usually be applied. Users can also pass their own customized activation_fn as long as when combining this function with network, the final output is the probability of the input.

  • perturbation_per_eval (int, optional) – Number of perturbations for each inference during inferring the perturbed samples. Within the memory capacity, usually the larger this number is, the faster the explanation is obtained. Default: 32.

Inputs:
  • inputs (Tensor) - The input data to be explained, a 4D tensor of shape \((N, C, H, W)\).

  • targets (Tensor, int) - The labels of interest to be explained. When targets is an integer, all of the inputs will generates attribution map w.r.t this integer. When targets is a tensor, it should be of shape \((N, l)\) (l being the number of labels for each sample) or \((N,)\) \(()\).

Outputs:

Tensor, a 4D tensor of shape \((N, l, H, W)\) when targets is a tensor of shape (N, l), otherwise a tensor of shape (N, 1, H, w), saliency maps.

Raises
  • TypeError – Be raised for any argument or input type problem.

  • ValueError – Be raised for any input value problem.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import context
>>> from mindspore_xai.explanation import RISE
>>>
>>> context.set_context(mode=context.PYNATIVE_MODE)
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
>>> net = LeNet5(10, num_channel=3)
>>> # initialize RISE explainer with the pretrained model and activation function
>>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities
>>> rise = RISE(net, activation_fn=activation_fn)
>>> # given an instance of RISE, saliency map can be generate
>>> inputs = ms.Tensor(np.random.rand(2, 3, 32, 32), ms.float32)
>>> # when `targets` is an integer
>>> targets = 5
>>> saliency = rise(inputs, targets)
>>> print(saliency.shape)
(2, 1, 32, 32)
>>> # `targets` can also be a 2D tensor
>>> targets = ms.Tensor([[5], [1]], ms.int32)
>>> saliency = rise(inputs, targets)
>>> print(saliency.shape)
(2, 1, 32, 32)
class mindspore_xai.explanation.RISEPlus(ood_net, network, activation_fn, perturbation_per_eval=32)[source]

RISEPlus is a perturbation-based method that generates attribution maps by sampling on multiple random binary masks. An OoD detector is adopted to produce an ‘inlier score’, estimating the probability that a sample is generated from the distribution. Then the inlier score is aggregated to the weighted sum of the random masks, with the weights being the corresponding output on the node of interest:

\[attribution = \sum_{i}s_if_c(I\odot M_i) M_i\]

For more details, please refer to the original paper: Resisting Out-of-Distribution Samples for Perturbation-based XAI.

Parameters
  • ood_net (OoDNet) – The OoD network for generating inlier score.

  • network (Cell) – The black-box model to be explained.

  • activation_fn (Cell) – The activation layer that transforms logits to prediction probabilities. For single label classification tasks, nn.Softmax is usually applied. As for multi-label classification tasks, nn.Sigmoid is usually be applied. Users can also pass their own customized activation_fn as long as when combining this function with network, the final output is the probability of the input.

  • perturbation_per_eval (int, optional) – Number of perturbations for each inference during inferring the perturbed samples. Within the memory capacity, usually the larger this number is, the faster the explanation is obtained. Default: 32.

Inputs:
  • inputs (Tensor) - The input data to be explained, a 4D tensor of shape \((N, C, H, W)\).

  • targets (Tensor, int) - The labels of interest to be explained. When targets is an integer, all of the inputs will generates attribution map w.r.t this integer. When targets is a tensor, it should be of shape \((N, l)\) (l being the number of labels for each sample) or \((N,)\) \(()\).

Outputs:

Tensor, a 4D tensor of shape \((N, l, H, W)\) when targets is a tensor of shape (N, l), otherwise a tensor of shape (N, 1, H, w), saliency maps.

Raises
  • TypeError – Be raised for any argument or input type problem.

  • ValueError – Be raised for any input value problem.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn, context, load_checkpoint, load_param_into_net
>>> from mindspore_xai.explanation import RISEPlus, OoDNet
>>>
>>>
>>> class MyLeNet5(nn.Cell):
>>>
>>>    def __init__(self, num_class, num_channel):
>>>        super(MyLeNet5, self).__init__()
>>>
>>>        # must add the following 2 attributes to your model
>>>        self.num_features = 84 # no. of features, int
>>>        self.output_features = False # output features flag, bool
>>>
>>>        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
>>>        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
>>>        self.relu = nn.ReLU()
>>>        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
>>>        self.flatten = nn.Flatten()
>>>        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
>>>        self.fc2 = nn.Dense(120, self.num_features, weight_init=Normal(0.02))
>>>        self.fc3 = nn.Dense(self.num_features, num_class, weight_init=Normal(0.02))
>>>
>>>    def construct(self, x):
>>>        x = self.conv1(x)
>>>        x = self.relu(x)
>>>        x = self.max_pool2d(x)
>>>        x = self.conv2(x)
>>>        x = self.relu(x)
>>>        x = self.max_pool2d(x)
>>>        x = self.flatten(x)
>>>        x = self.relu(self.fc1(x))
>>>        x = self.relu(self.fc2(x))
>>>
>>>        # return the features tensor if output_features is True
>>>        if self.output_features:
>>>            return x
>>>
>>>        x = self.fc3(x)
>>>        return x
>>>
>>> context.set_context(mode=context.PYNATIVE_MODE)
>>> # prepare trained classifier
>>> net = MyLeNet5(10, num_channel=3)
>>> param_dict = load_checkpoint('mylenet5.ckpt')
>>> load_param_into_net(net, param_dict)
>>> # prepare train_dataset and your OoD network
>>> train_dataset = create_dataset_cifar10("/path/to/cifar/dataset")
>>> ood_net = OoDNet(net, 10)
>>> ood_net.train(train_dataset, nn.SoftmaxCrossEntropyWithLogits())
>>> # initialize RISEPlus explainer with the pretrained model and activation function
>>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities
>>> riseplus = RISEPlus(ood_net, net, activation_fn=activation_fn)
>>> # given an instance of RISEPlus, saliency map can be generate
>>> inputs = ms.Tensor(np.random.rand(2, 3, 32, 32), ms.float32)
>>> # when `targets` is an integer
>>> targets = 5
>>> saliency = riseplus(inputs, targets)
>>> print(saliency.shape)
(2, 1, 32, 32)