# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base Class of Defense.
"""
from abc import abstractmethod
from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_pair_numpy_param, \
check_int_positive
LOGGER = LogUtil.get_instance()
TAG = 'Defense'
[docs]class Defense:
"""
The abstract base class for all defense classes defending adversarial
examples.
Args:
network (Cell): A MindSpore-style deep learning model to be defensed.
"""
def __init__(self, network):
self._network = network
[docs] @abstractmethod
def defense(self, inputs, labels):
"""
Defense model with samples.
Args:
inputs (numpy.ndarray): Samples based on which adversarial
examples are generated.
labels (numpy.ndarray): Labels of input samples.
Raises:
NotImplementedError: It is an abstract method.
"""
msg = 'The function defense() is an abstract function in class ' \
'`Defense` and should be implemented in child class.'
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
[docs] def batch_defense(self, inputs, labels, batch_size=32, epochs=5):
"""
Defense model with samples in batch.
Args:
inputs (numpy.ndarray): Samples based on which adversarial
examples are generated.
labels (numpy.ndarray): Labels of input samples.
batch_size (int): Number of samples in one batch.
epochs (int): Number of epochs.
Returns:
numpy.ndarray, loss of batch_defense operation.
Raises:
ValueError: If batch_size is 0.
"""
inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels',
labels)
x_len = len(inputs)
batch_size = check_int_positive('batch_size', batch_size)
iters_per_epoch = int(x_len / batch_size)
loss = None
for _ in range(epochs):
for step in range(iters_per_epoch):
x_batch = inputs[step*batch_size:(step + 1)*batch_size]
y_batch = labels[step*batch_size:(step + 1)*batch_size]
loss = self.defense(x_batch, y_batch)
return loss