Source code for mindspore_xai.explainer.perturb.occlusion

# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Occlusion explainer."""

from typing import Tuple
import numpy as np
import mindspore as ms
import mindspore.nn as nn

from mindspore_xai.common.utils import abs_max, unify_targets
from .ablation import Ablation
from .perturbation import PerturbationAttribution
from .replacement import Constant


def _generate_patches(array, window_size: Tuple, strides: Tuple):
    """Generate patches from image w.r.t given window_size and strides."""
    window_strides = array.strides
    slices = tuple(slice(None, None, stride) for stride in strides)
    indexing_strides = array[slices].strides
    win_indices_shape = (np.array(array.shape) - np.array(window_size)) // np.array(strides) + 1

    patches_shape = tuple(win_indices_shape) + window_size
    strides_in_memory = indexing_strides + window_strides
    patches = np.lib.stride_tricks.as_strided(array, shape=patches_shape, strides=strides_in_memory, writeable=False)
    patches = patches.reshape((-1,) + window_size)
    return patches


[docs]class Occlusion(PerturbationAttribution): r""" Provides Occlusion explanation method. Occlusion uses a sliding window to replace the pixels with a reference value (e.g. constant value), and computes the output difference w.r.t the original output. The output difference caused by perturbed pixels are assigned as feature importance to those pixels. For pixels involved in multiple sliding windows, the feature importance is the averaged differences from multiple sliding windows. For more details, please refer to the original paper via: `Visualizing and Understanding Convolutional Networks <https://arxiv.org/abs/1311.2901>`_. Note: Currently only single sample (:math:`N=1`) at each call is supported. Args: network (Cell): The black-box model to be explained. activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as when combining this function with network, the final output is the probability of the input. perturbation_per_eval (int, optional): Number of perturbations for each inference during inferring the perturbed samples. Within the memory capacity, usually the larger this number is, the faster the explanation is obtained. Default: ``32``. Inputs: - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. - **targets** (Tensor, int, tuple, list) - The label of interest. It should be a 1D or scalar tensor, or an integer, or a tuple/list of integers. If it is a 1D tensor, tuple or list, its length should be :math:`N`. - **ret** (str, optional): The return object type. ``'tensor'`` means returns a Tensor object, ``'image'`` means return a PIL.Image.Image list. Default: ``'tensor'``. - **show** (bool, optional): Show the saliency images, ``None`` means automatically show the saliency images if it is running on JupyterLab. Default: ``None``. Outputs: Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`, saliency maps. Or list[list[PIL.Image.Image]], the normalized saliency images if `ret` was set to ``'image'``. Raises: TypeError: Be raised for any argument or input type problem. ValueError: Be raised for any input value problem. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import numpy as np >>> import mindspore as ms >>> from mindspore_xai.explainer import Occlusion >>> from mindspore import set_context, PYNATIVE_MODE >>> >>> set_context(mode=PYNATIVE_MODE) >>> # The detail of LeNet5 is shown in models.official.cv.lenet.src.lenet.py >>> net = LeNet5(10, num_channel=3) >>> # initialize Occlusion explainer with the pretrained model and activation function >>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities >>> occlusion = Occlusion(net, activation_fn=activation_fn) >>> input_x = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32) >>> label = ms.Tensor([1], ms.int32) >>> saliency = occlusion(input_x, label) >>> print(saliency.shape) (1, 1, 32, 32) """ def __init__(self, network, activation_fn, perturbation_per_eval=32): super().__init__(network, activation_fn, perturbation_per_eval) self._ablation = Ablation(perturb_mode='Deletion') self._aggregation_fn = abs_max self._get_replacement = Constant(base_value=0.0) self._num_sample_per_dim = 32 # specify the number of perturbations each dimension. def __call__(self, inputs, targets, ret='tensor', show=None): """Call function for 'Occlusion'.""" self._verify_data(inputs, targets) self._verify_other_args(ret, show) inputs = inputs.asnumpy() targets = unify_targets(inputs.shape[0], targets).asnumpy() batch_size = inputs.shape[0] window_size, strides = self._get_window_size_and_strides(inputs) full_network = nn.SequentialCell([self._network, self._activation_fn]) original_outputs = full_network(ms.Tensor(inputs, ms.float32)).asnumpy()[np.arange(batch_size), targets] masks = Occlusion._generate_masks(inputs, window_size, strides) saliency = self._perturbate(batch_size, full_network, (original_outputs, masks, inputs, targets)) return self._postproc_saliency(saliency, ret, show) def _perturbate(self, batch_size, full_network, data): """Perform perturbations.""" original_outputs, masks, inputs, targets = data total_attribution = np.zeros_like(inputs) weights = np.ones_like(inputs) num_perturbations = masks.shape[1] reference = self._get_replacement(inputs) count = 0 while count < num_perturbations: ith_masks = masks[:, count:min(count+self._perturbation_per_eval, num_perturbations)] actual_num_eval = ith_masks.shape[1] num_samples = batch_size * actual_num_eval occluded_inputs = self._ablation(inputs, reference, ith_masks) occluded_inputs = occluded_inputs.reshape((-1, *inputs.shape[1:])) targets_repeat = np.repeat(targets, repeats=actual_num_eval, axis=0) occluded_outputs = full_network( ms.Tensor(occluded_inputs, ms.float32)).asnumpy()[np.arange(num_samples), targets_repeat] original_outputs_repeat = np.repeat(original_outputs, repeats=actual_num_eval, axis=0) outputs_diff = original_outputs_repeat - occluded_outputs total_attribution += ( outputs_diff.reshape(ith_masks.shape[:2] + (1,) * (len(masks.shape) - 2)) * ith_masks).sum(axis=1) weights += ith_masks.sum(axis=1) count += actual_num_eval attribution = self._aggregation_fn(ms.Tensor(total_attribution / weights, ms.float32)) return attribution def _get_window_size_and_strides(self, inputs): """ Return window_size and strides. # If spatial size of input data is smaller than self._num_sample_per_dim, window_size and strides will set to # `(C, 3, 3)` and `(C, 1, 1)` separately. Otherwise, the window_size and strides will generated adaptively to match self._num_sample_per_dim. """ window_size = tuple( [inputs.shape[1]] + [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 3 for x in inputs.shape[2:]]) strides = tuple( [inputs.shape[1]] + [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 1 for x in inputs.shape[2:]]) return window_size, strides @staticmethod def _generate_masks(inputs, window_size, strides): """Generate masks to perturb contiguous regions.""" total_dim = np.prod(inputs.shape[1:]).item() template = np.arange(total_dim).reshape(inputs.shape[1:]) indices = _generate_patches(template, window_size, strides) num_perturbations = indices.shape[0] indices = indices.reshape(num_perturbations, -1) mask = np.zeros((num_perturbations, total_dim), dtype=bool) for i in range(num_perturbations): mask[i, indices[i]] = True mask = mask.reshape((num_perturbations,) + inputs.shape[1:]) masks = np.tile(mask, reps=(inputs.shape[0],) + (1,) * len(mask.shape)) return masks