# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DeepFool Attack.
"""
import numpy as np
from mindspore import Tensor
from mindspore.nn import Cell
from mindarmour.attacks.attack import Attack
from mindarmour.utils.logger import LogUtil
from mindarmour.utils.util import GradWrap
from mindarmour.utils.util import jacobian_matrix
from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \
check_value_positive, check_int_positive, check_norm_level, \
check_param_multi_types, check_param_type
LOGGER = LogUtil.get_instance()
TAG = 'DeepFool'
[docs]class DeepFool(Attack):
"""
DeepFool is an untargeted & iterative attack achieved by moving the benign
sample to the nearest classification boundary and crossing the boundary.
Reference: `DeepFool: a simple and accurate method to fool deep neural
networks <https://arxiv.org/abs/1511.04599>`_
Args:
network (Cell): Target model.
num_classes (int): Number of labels of model output, which should be
greater than zero.
max_iters (int): Max iterations, which should be
greater than zero. Default: 50.
overshoot (float): Overshoot parameter. Default: 0.02.
norm_level (int): Order of the vector norm. Possible values: np.inf
or 2. Default: 2.
bounds (tuple): Upper and lower bounds of data range. In form of (clip_min,
clip_max). Default: None.
sparse (bool): If True, input labels are sparse-coded. If False,
input labels are onehot-coded. Default: True.
Examples:
>>> attack = DeepFool(network)
"""
def __init__(self, network, num_classes, max_iters=50, overshoot=0.02,
norm_level=2, bounds=None, sparse=True):
super(DeepFool, self).__init__()
self._network = check_model('network', network, Cell)
self._max_iters = check_int_positive('max_iters', max_iters)
self._overshoot = check_value_positive('overshoot', overshoot)
self._norm_level = check_norm_level(norm_level)
self._num_classes = check_int_positive('num_classes', num_classes)
self._net_grad = GradWrap(self._network)
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
self._sparse = check_param_type('sparse', sparse, bool)
for b in self._bounds:
_ = check_param_multi_types('bound', b, [int, float])
[docs] def generate(self, inputs, labels):
"""
Generate adversarial examples based on input samples and original labels.
Args:
inputs (numpy.ndarray): Input samples.
labels (numpy.ndarray): Original labels.
Returns:
numpy.ndarray, adversarial examples.
Raises:
NotImplementedError: If norm_level is not in [2, np.inf, '2', 'inf'].
Examples:
>>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2])
"""
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if not self._sparse:
labels = np.argmax(labels, axis=1)
inputs_dtype = inputs.dtype
iteration = 0
origin_labels = labels
cur_labels = origin_labels.copy()
weight = np.squeeze(np.zeros(inputs.shape[1:]))
r_tot = np.zeros(inputs.shape)
x_origin = inputs
while np.any(cur_labels == origin_labels) and iteration < self._max_iters:
preds = self._network(Tensor(inputs)).asnumpy()
grads = jacobian_matrix(self._net_grad, inputs, self._num_classes)
for idx in range(inputs.shape[0]):
diff_w = np.inf
label = origin_labels[idx]
if cur_labels[idx] != label:
continue
for k in range(self._num_classes):
if k == label:
continue
w_k = grads[k, idx, ...] - grads[label, idx, ...]
f_k = preds[idx, k] - preds[idx, label]
if self._norm_level == 2 or self._norm_level == '2':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8)
elif self._norm_level == np.inf \
or self._norm_level == 'inf':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8)
else:
msg = 'ord {} is not available.' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
if diff_w_k < diff_w:
diff_w = diff_w_k
weight = w_k
if self._norm_level == 2 or self._norm_level == '2':
r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8)
elif self._norm_level == np.inf or self._norm_level == 'inf':
r_i = diff_w*np.sign(weight) \
/ (np.linalg.norm(weight, ord=1) + 1e-8)
else:
msg = 'ord {} is not available in normalization.' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
r_tot[idx, ...] = r_tot[idx, ...] + r_i
if self._bounds is not None:
clip_min, clip_max = self._bounds
inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max
- clip_min)
inputs = np.clip(inputs, clip_min, clip_max)
else:
inputs = x_origin + (1 + self._overshoot)*r_tot
cur_labels = np.argmax(
self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(),
axis=1)
iteration += 1
inputs = inputs.astype(inputs_dtype)
del preds, grads
return inputs