Using CV Explainers
What are CV Explainers
Explainers are algorithms explaining the decisions made by AI models. MindSpore XAI currently provides 7 explainers for image classification scenario. Saliency maps (or heatmaps) are the outputs, their brightness represents the importance of the corresponding regions on the original image.
A saliency map overlay on top of the original image:
There are 2 categories of explainers: gradient-based and perturbation-based. The gradient-based explainers rely on the backpropagation method to compute the pixel importance while the perturbation-based explainers exploit random perturbations on the original images.
Explainer |
Category |
PYNATIVE_MODE |
GRAPH_MODE |
---|---|---|---|
Gradient |
gradient |
Supported |
Supported |
GradCAM |
gradient |
Supported |
|
GuidedBackprop |
gradient |
Supported |
|
Deconvolution |
gradient |
Supported |
|
Occlusion |
perturbation |
Supported |
Supported |
RISE |
perturbation |
Supported |
Supported |
RISEPlus |
perturbation |
Supported |
Preparations
Downloading Data Package and Model
First of all, we have to download the data package and put it underneath the xai/examples/
directory of a local XAI source package:
wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/xai_examples_data.tar.gz
tar -xf xai_examples_data.tar.gz
git clone https://gitee.com/mindspore/xai
mv xai_examples_data xai/examples/
xai/examples/
files:
xai/examples/
├── xai_examples_data/
│ ├── ckpt/
│ │ ├── resent50.ckpt
│ ├── train/
│ └── test/
├── common/
│ ├── dataset.py
│ └── resnet.py
├── using_cv_explainers.py
├── using_rise_plus.py
└── using_cv_benchmarks.py
xai_examples_data/
: The extracted data package.xai_examples_data/ckpt/resent50.ckpt
: ResNet50 checkpoint file.xai_examples_data/test
: Test dataset.xai_examples_data/train
: Training dataset.common/dataset.py
: Dataset loader.common/resnet.py
: ResNet model definitions.using_cv_explainers.py
: Example of using explainers.using_rise_plus.py
: Example of using RISEPlus explainer. The usage is different from that of other interpreters.using_cv_benchmarks.py
: Example of using benchmarks.
Preparing Python Environment
The complete code of the tutorial below is using_cv_explainers.py.
After downloading the example packet, we have to load a trained classifier and an image to be reasoned and interpreted:
# have to change the current directory to xai/examples/ first
from mindspore import load_checkpoint, load_param_into_net, set_context, PYNATIVE_MODE
from common.resnet import resnet50
from common.dataset import load_image_tensor
# 20 classes
num_classes = 20
# load the trained classifier
net = resnet50(num_classes)
param_dict = load_checkpoint("xai_examples_data/ckpt/resnet50.ckpt")
load_param_into_net(net, param_dict)
# [1, 3, 224, 224] Tensor
boat_image = load_image_tensor("xai_examples_data/test/boat.jpg")
Using GradCAM
GradCAM
is a typical and effective gradient based explainer:
from PIL import Image
import mindspore as ms
from mindspore import Tensor
from mindspore_xai.explainer import GradCAM
from mindspore_xai.visual.cv import saliency_to_image
# only PYNATIVE_MODE is supported
set_context(mode=PYNATIVE_MODE)
# usually specify the last convolutional layer
grad_cam = GradCAM(net, layer="layer4")
# 3 is the class id of 'boat'
saliency = grad_cam(boat_image, targets=3, show=False)
# convert the saliency map to a PIL.Image.Image object
orig_img = Image.open("xai_examples_data/test/boat.jpg")
saliency_to_image(saliency, orig_img)
The returned saliency
is a 1x1x224x224 tensor for an 1xCx224x224 image tensor, which stores all pixel importances (range:[0.0, 1.0]) to the classification decision of ‘boat’. Users may specify any class to be explained.
Batch Explanation
For gradient based explainers, batch explanation is usually more efficient. Other explainers may also batch the evaluations:
from common.dataset import load_dataset
test_ds = load_dataset('xai_examples_data/test').batch(4)
for images, labels in test_ds:
saliencies = grad_cam(images, targets=ms.Tensor([3, 3, 3, 3], dtype=ms.int32))
# other custom operations ...
The returned saliencies
is a 4x1x224x224 tensor for a 4xCx224x224 batched image tensor.
Using Other Explainers
The ways of using other explainers are very similar to GradCAM
, except RISEPlus
.
Using RISEPlus
The complete code of the tutorial below is using_rise_plus.py.
RISEPlus
is based on RISE
with an introduction of Out-of-Distribution(OoD) detector. It solves the degeneration problem of RISE
on samples that the classifier had never seem the similar in training.
First, we need to train an OoD detector(OoDNet
) with the classifier training dataset:
# have to change the current directory to xai/examples/ first
from mindspore import set_context, save_checkpoint, load_checkpoint, load_param_into_net, PYNATIVE_MODE
from mindspore.nn import Softmax, SoftmaxCrossEntropyWithLogits
from mindspore_xai.tool.cv import OoDNet
from mindspore_xai.explainer import RISEPlus
from common.dataset import load_dataset, load_image_tensor
from common.resnet import resnet50
# only PYNATIVE_MODE is supported
set_context(mode=PYNATIVE_MODE)
num_classes = 20
# classifier training dataset
train_ds = load_dataset('xai_examples_data/train').batch(4)
# load the trained classifier
net = resnet50(num_classes)
param_dict = load_checkpoint('xai_examples_data/ckpt/resnet50.ckpt')
load_param_into_net(net, param_dict)
ood_net = OoDNet(underlying=net, num_classes=num_classes)
# use SoftmaxCrossEntropyWithLogits as loss function if the activation function of
# the classifier is Softmax, use BCEWithLogitsLoss if the activation function is Sigmoid
ood_net.train(train_ds, loss_fn=SoftmaxCrossEntropyWithLogits())
save_checkpoint(ood_net, 'ood_net.ckpt')
The classifier for OoDNet
must be a subclass of nn.Cell
, in __init__()
which must:
defines an
int
member attribute namednum_features
as the number of feature values to be returned by the feature layer.defines a
bool
member attribute namedoutput_features
withFalse
as initial value,OoDNet
tells the classifier to return the feature tensor inconstruct()
by settingoutput_features
toTrue
.
A LeNet5 example of underlying classifier:
from mindspore import nn
from mindspore.common.initializer import Normal
class MyLeNet5(nn.Cell):
def __init__(self, num_class, num_channel):
super(MyLeNet5, self).__init__()
# must add the following 2 attributes to your model
self.num_features = 84 # no. of features, int
self.output_features = False # output features flag, bool
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, self.num_features, weight_init=Normal(0.02))
self.fc3 = nn.Dense(self.num_features, num_class, weight_init=Normal(0.02))
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
# return the features tensor if output_features is True
if self.output_features:
return x
x = self.fc3(x)
return x
Now we can use RISEPlus
with the trained OoDNet
:
from PIL import Image
from mindspore_xai.visual.cv import saliency_to_image
# create a new classifier as the underlying when loading OoDNet from a checkpoint
ood_net = OoDNet(underlying=resnet50(num_classes), num_classes=num_classes)
param_dict = load_checkpoint('ood_net.ckpt')
load_param_into_net(ood_net, param_dict)
rise_plus = RISEPlus(ood_net=ood_net, network=net, activation_fn=Softmax())
boat_image = load_image_tensor("xai_examples_data/test/boat.jpg")
saliency = rise_plus(boat_image, targets=3, show=False)
orig_img = Image.open("xai_examples_data/test/boat.jpg")
saliency_to_image(saliency, orig_img)
The returned saliency
is an 1x1x224x224 tensor for an 1xCx224x224 image tensor.