mindarmour

MindArmour, a tool box of MindSpore to enhance model trustworthiness and achieve privacy-preserving machine learning.

class mindarmour.Attack[source]

The abstract base class for all attack classes creating adversarial examples. The adversarial examples are generated by adding adversarial noises to the original sample.

batch_generate(inputs, labels, batch_size=64)[source]

Generate adversarial examples in batch, based on input samples and their labels.

Parameters
  • inputs (Union[numpy.ndarray, tuple]) – Samples based on which adversarial examples are generated.

  • labels (Union[numpy.ndarray, tuple]) – Original/target labels. For each input if it has more than one label, it is wrapped in a tuple.

  • batch_size (int) – The number of samples in one batch. Default: 64.

Returns

numpy.ndarray, generated adversarial examples

abstract generate(inputs, labels)[source]

Generate adversarial examples based on normal samples and their labels.

Parameters
  • inputs (Union[numpy.ndarray, tuple]) – Samples based on which adversarial examples are generated.

  • labels (Union[numpy.ndarray, tuple]) – Original/target labels. For each input if it has more than one label, it is wrapped in a tuple.

Raises

NotImplementedError – It is an abstract method.

class mindarmour.BlackModel[source]

The abstract class which treats the target model as a black box. The model should be defined by users.

is_adversarial(data, label, is_targeted)[source]

Check if input sample is adversarial example or not.

Parameters
  • data (numpy.ndarray) – The input sample to be check, typically some maliciously perturbed examples.

  • label (numpy.ndarray) – For targeted attacks, label is intended label of perturbed example. For untargeted attacks, label is original label of corresponding unperturbed sample.

  • is_targeted (bool) – For targeted/untargeted attacks, select True/False.

Returns

bool.
  • If True, the input sample is adversarial.

  • If False, the input sample is not adversarial.

abstract predict(inputs)[source]

Predict using the user specified model. The shape of predict results should be (m, n), where n represents the number of classes this model classifies.

Parameters

inputs (numpy.ndarray) – The input samples to be predicted.

Raises

NotImplementedError – It is an abstract method.

class mindarmour.ConceptDriftCheckTimeSeries(window_size=100, rolling_window=10, step=10, threshold_index=1.5, need_label=False)[source]

ConceptDriftCheckTimeSeries is used for example series distribution change detection. For details, please check Tutorial.

Parameters
  • window_size (int) – Size of a concept window, no less than 10. If given the input data, window_size belongs to [10, 1/3*len(input data)]. If the data is periodic, usually window_size equals 2-5 periods, such as, for monthly/weekly data, the data volume of 30/7 days is a period. Default: 100.

  • rolling_window (int) – Smoothing window size, belongs to [1, window_size]. Default:10.

  • step (int) – The jump length of the sliding window, belongs to [1, window_size]. Default:10.

  • threshold_index (float) – The threshold index, \((-\infty, +\infty)\). Default: 1.5.

  • need_label (bool) – False or True. If need_label=True, concept drift labels are needed. Default: False.

Examples

>>> from mindarmour import ConceptDriftCheckTimeSeries
>>> concept = ConceptDriftCheckTimeSeries(window_size=100, rolling_window=10,
...                                       step=10, threshold_index=1.5, need_label=False)
>>> data_example = 5*np.random.rand(1000)
>>> data_example[200: 800] = 20*np.random.rand(600)
>>> score, threshold, concept_drift_location = concept.concept_check(data_example)
concept_check(data)[source]

Find concept drift locations in a data series.

Parameters

data (numpy.ndarray) – Input data. The shape of data could be (n,1) or (n,m). Note that each column (m columns) is one data series.

Returns

  • numpy.ndarray, the concept drift score of the example series.

  • float, the threshold to judge concept drift.

  • list, the location of the concept drift.

class mindarmour.DPModel(micro_batches=2, norm_bound=1.0, noise_mech=None, clip_mech=None, **kwargs)[source]

DPModel is used for constructing a model for differential privacy training. This class is overload mindspore.train.model.Model.

For details, please check Tutorial.

Parameters
  • micro_batches (int) – The number of small batches split from an original batch. Default: 2.

  • norm_bound (float) – Use to clip the bound, if set 1, will return the original data. Default: 1.0.

  • noise_mech (Mechanisms) – The object can generate the different type of noise. Default: None.

  • clip_mech (Mechanisms) – The object is used to update the adaptive clip. Default: None.

Raises
  • ValueError – If DPOptimizer and noise_mech are both None or not None.

  • ValueError – If noise_mech or DPOtimizer’s mech method is adaptive while clip_mech is not None.

class mindarmour.Defense(network)[source]

The abstract base class for all defense classes defending adversarial examples.

Parameters

network (Cell) – A MindSpore-style deep learning model to be defensed.

batch_defense(inputs, labels, batch_size=32, epochs=5)[source]

Defense model with samples in batch.

Parameters
  • inputs (numpy.ndarray) – Samples based on which adversarial examples are generated.

  • labels (numpy.ndarray) – Labels of input samples.

  • batch_size (int) – Number of samples in one batch. Default: 32.

  • epochs (int) – Number of epochs. Default: 5.

Returns

numpy.ndarray, loss of batch_defense operation.

Raises

ValueError – If batch_size is 0.

abstract defense(inputs, labels)[source]

Defense model with samples.

Parameters
  • inputs (numpy.ndarray) – Samples based on which adversarial examples are generated.

  • labels (numpy.ndarray) – Labels of input samples.

Raises

NotImplementedError – It is an abstract method.

class mindarmour.Detector[source]

The abstract base class for all adversarial example detectors.

abstract detect(inputs)[source]

Detect adversarial examples from input samples.

Parameters

inputs (Union[numpy.ndarray, list, tuple]) – The input samples to be detected.

Raises

NotImplementedError – It is an abstract method.

abstract detect_diff(inputs)[source]

Calculate the difference between the input samples and de-noised samples.

Parameters

inputs (Union[numpy.ndarray, list, tuple]) – The input samples to be detected.

Raises

NotImplementedError – It is an abstract method.

abstract fit(inputs, labels=None)[source]

Fit a threshold and refuse adversarial examples whose difference from their denoised versions are larger than the threshold. The threshold is determined by a certain false positive rate when applying to normal samples.

Parameters
  • inputs (numpy.ndarray) – The input samples to calculate the threshold.

  • labels (numpy.ndarray) – Labels of training data. Default: None.

Raises

NotImplementedError – It is an abstract method.

abstract transform(inputs)[source]

Filter adversarial noises in input samples.

Parameters

inputs (Union[numpy.ndarray, list, tuple]) – The input samples to be transformed.

Raises

NotImplementedError – It is an abstract method.

class mindarmour.Fuzzer(target_model)[source]

Fuzzing test framework for deep neural networks.

Reference: DeepHunter: A Coverage-Guided Fuzz Testing Framework for Deep Neural Networks

Parameters

target_model (Model) – Target fuzz model.

Examples

>>> from mindspore.common.initializer import TruncatedNormal
>>> from mindspore.ops import operations as P
>>> from mindspore.train import Model
>>> from mindspore.ops import TensorSummary
>>> from mindarmour.fuzz_testing import Fuzzer
>>> from mindarmour.fuzz_testing import KMultisectionNeuronCoverage
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.conv1 = nn.Conv2d(1, 6, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid")
...         self.conv2 = nn.Conv2d(6, 16, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid")
...         self.fc1 = nn.Dense(16 * 5 * 5, 120, TruncatedNormal(0.02), TruncatedNormal(0.02))
...         self.fc2 = nn.Dense(120, 84, TruncatedNormal(0.02), TruncatedNormal(0.02))
...         self.fc3 = nn.Dense(84, 10, TruncatedNormal(0.02), TruncatedNormal(0.02))
...         self.relu = nn.ReLU()
...         self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
...         self.reshape = P.Reshape()
...         self.summary = TensorSummary()
...
...     def construct(self, x):
...         x = self.conv1(x)
...         x = self.relu(x)
...         self.summary('conv1', x)
...         x = self.max_pool2d(x)
...         x = self.conv2(x)
...         x = self.relu(x)
...         self.summary('conv2', x)
...         x = self.max_pool2d(x)
...         x = self.reshape(x, (-1, 16 * 5 * 5))
...         x = self.fc1(x)
...         x = self.relu(x)
...         self.summary('fc1', x)
...         x = self.fc2(x)
...         x = self.relu(x)
...         self.summary('fc2', x)
...         x = self.fc3(x)
...         self.summary('fc3', x)
...         return x
>>> net = Net()
>>> model = Model(net)
>>> mutate_config = [{'method': 'GaussianBlur',
...                   'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}},
...                  {'method': 'MotionBlur',
...                   'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300],
...                   'auto_param': [True]}},
...                  {'method': 'UniformNoise',
...                   'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}},
...                  {'method': 'GaussianNoise',
...                   'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}},
...                  {'method': 'Contrast',
...                   'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}},
...                  {'method': 'Rotate',
...                   'params': {'angle': [20, 90], 'auto_param': [False, True]}},
...                  {'method': 'FGSM',
...                   'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}]
>>> batch_size = 8
>>> num_classe = 10
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
>>> test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
>>> test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32)
>>> test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32)
>>> initial_seeds = []
>>> # make initial seeds
>>> for img, label in zip(test_images, test_labels):
...     initial_seeds.append([img, label])
>>> initial_seeds = initial_seeds[:10]
>>> nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100, incremental=True)
>>> model_fuzz_test = Fuzzer(model)
>>> samples, gt_labels, preds, strategies, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds,
...                                                                          nc, max_iters=100)
fuzzing(mutate_config, initial_seeds, coverage, evaluate=True, max_iters=10000, mutate_num_per_seed=20)[source]

Fuzzing tests for deep neural networks.

Parameters
  • mutate_config (list) – Mutate configs. The format is [{‘method’: ‘GaussianBlur’, ‘params’: {‘ksize’: [1, 2, 3, 5], ‘auto_param’: [True, False]}}, {‘method’: ‘UniformNoise’, ‘params’: {‘factor’: [0.1, 0.2, 0.3], ‘auto_param’: [False, True]}}, {‘method’: ‘GaussianNoise’, ‘params’: {‘factor’: [0.1, 0.2, 0.3], ‘auto_param’: [False, True]}}, {‘method’: ‘Contrast’, ‘params’: {‘alpha’: [0.5, 1, 1.5], ‘beta’: [-10, 0, 10], ‘auto_param’: [False, True]}}, {‘method’: ‘Rotate’, ‘params’: {‘angle’: [20, 90], ‘auto_param’: [False, True]}}, {‘method’: ‘FGSM’, ‘params’: {‘eps’: [0.3, 0.2, 0.4], ‘alpha’: [0.1], ‘bounds’: [(0, 1)]}}] …]. The supported methods list is in self._strategies, and the params of each method must within the range of optional parameters. Supported methods are grouped in two types: Firstly, natural robustness methods include: ‘Translate’, ‘Scale’, ‘Shear’, ‘Rotate’, ‘Perspective’, ‘Curve’, ‘GaussianBlur’, ‘MotionBlur’, ‘GradientBlur’, ‘Contrast’, ‘GradientLuminance’, ‘UniformNoise’, ‘GaussianNoise’, ‘SaltAndPepperNoise’, ‘NaturalNoise’. Secondly, attack methods include: ‘FGSM’, ‘PGD’ and ‘MDIIM’. ‘FGSM’, ‘PGD’ and ‘MDIIM’. are abbreviations of FastGradientSignMethod, ProjectedGradientDescent and MomentumDiverseInputIterativeMethod. mutate_config must have method in [‘Contrast’, ‘GradientLuminance’, ‘GaussianBlur’, ‘MotionBlur’, ‘GradientBlur’, ‘UniformNoise’, ‘GaussianNoise’, ‘SaltAndPepperNoise’, ‘NaturalNoise’]. The way of setting parameters for first and second type methods can be seen in ‘mindarmour/natural_robustness/transform/image’. For third type methods, the optional parameters refer to self._attack_param_checklists.

  • initial_seeds (list[list]) – Initial seeds used to generate mutated samples. The format of initial seeds is [[image_data, label], […], …] and the label must be one-hot.

  • coverage (CoverageMetrics) – Class of neuron coverage metrics.

  • evaluate (bool) – return evaluate report or not. Default: True.

  • max_iters (int) – Max number of select a seed to mutate. Default: 10000.

  • mutate_num_per_seed (int) – The number of mutate times for a seed. Default: 20.

Returns

  • list, mutated samples in fuzz_testing.

  • list, ground truth labels of mutated samples.

  • list, preds of mutated samples.

  • list, strategies of mutated samples.

  • dict, metrics report of fuzzer.

Raises
  • ValueError – Coverage must be subclass of CoverageMetrics.

  • ValueError – If initial seeds is empty.

  • ValueError – If element of seed is not two in initial seeds.

class mindarmour.ImageInversionAttack(network, input_shape, input_bound, loss_weights=(1, 0.2, 5))[source]

An attack method used to reconstruct images by inverting their deep representations.

References: Aravindh Mahendran, Andrea Vedaldi. Understanding Deep Image Representations by Inverting Them. 2014.

Parameters
  • network (Cell) – The network used to infer images’ deep representations.

  • input_shape (tuple) – Data shape of single network input, which should be in accordance with the given network. The format of shape should be (channel, image_width, image_height).

  • input_bound (Union[tuple, list]) – The pixel range of original images, which should be like [minimum_pixel, maximum_pixel] or (minimum_pixel, maximum_pixel).

  • loss_weights (Union[list, tuple]) – Weights of three sub-loss in InversionLoss, which can be adjusted to obtain better results. Default: (1, 0.2, 5).

Raises
  • TypeError – If the type of network is not Cell.

  • ValueError – If any value of input_shape is not positive int.

  • ValueError – If any value of loss_weights is not positive value.

Examples

>>> import mindspore.ops.operations as P
>>> from mindspore.nn import Cell
>>> from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack
>>> class Net(Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self._softmax = P.Softmax()
...         self._reduce = P.ReduceSum()
...         self._squeeze = P.Squeeze(1)
...     def construct(self, inputs):
...         out = self._softmax(inputs)
...         out = self._reduce(out, 2)
...         return self._squeeze(out)
>>> net = Net()
>>> original_images = np.random.random((2,1,10,10)).astype(np.float32)
>>> target_features =  np.random.random((2,10)).astype(np.float32)
>>> inversion_attack = ImageInversionAttack(net,
...                                         input_shape=(1, 10, 10),
...                                         input_bound=(0, 1),
...                                         loss_weights=[1, 0.2, 5])
>>> inversion_images = inversion_attack.generate(target_features, iters=10)
>>> evaluate_result = inversion_attack.evaluate(original_images, inversion_images)
evaluate(original_images, inversion_images, labels=None, new_network=None)[source]

Evaluate the quality of inverted images by three index: the average L2 distance and SSIM value between original images and inversion images, and the average of inverted images’ confidence on true labels of inverted inferred by a new trained network.

Parameters
  • original_images (numpy.ndarray) – Original images, whose shape should be (img_num, channels, img_width, img_height).

  • inversion_images (numpy.ndarray) – Inversion images, whose shape should be (img_num, channels, img_width, img_height).

  • labels (numpy.ndarray) – Ground truth labels of original images. Default: None.

  • new_network (Cell) – A network whose structure contains all parts of self._network, but loaded with different checkpoint file. Default: None.

Returns

  • float, l2 distance.

  • float, average ssim value.

  • Union[float, None], average confidence. It would be None if labels or new_network is None.

generate(target_features, iters=100)[source]

Reconstruct images based on target_features.

Parameters
  • target_features (numpy.ndarray) – Deep representations of original images. The first dimension of target_features should be img_num. It should be noted that the shape of target_features should be (1, dim2, dim3, …) if img_num equals 1.

  • iters (int) – iteration times of inversion attack, which should be positive integers. Default: 100.

Returns

numpy.ndarray, reconstructed images, which are expected to be similar to original images.

Raises
  • TypeError – If the type of target_features is not numpy.ndarray.

  • ValueError – If any value of iters is not positive int.Z

class mindarmour.MembershipInference(model, n_jobs=- 1)[source]

Proposed by Shokri, Stronati, Song and Shmatikov, membership inference is a grey-box attack for inferring user’s privacy data. It requires loss or logits results of the training samples. (Privacy refers to some sensitive attributes of a single user).

For details, please refer to the Tutorial.

References: Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov. Membership Inference Attacks against Machine Learning Models. 2017.

Parameters
  • model (Model) – Target model.

  • n_jobs (int) – Number of jobs run in parallel. -1 means using all processors, otherwise the value of n_jobs must be a positive integer.

Raises
  • TypeError – If type of model is not mindspore.train.Model.

  • TypeError – If type of n_jobs is not int.

  • ValueError – The value of n_jobs is neither -1 nor a positive integer.

Examples

>>> import mindspore.ops.operations as P
>>> from mindspore.nn import Cell
>>> from mindspore import Model
>>> from mindarmour.privacy.evaluation import MembershipInference
>>> def dataset_generator():
...     batch_size = 16
...     batches = 1
...     data =  np.random.randn(batches * batch_size,1,10).astype(np.float32)
...     label =  np.random.randint(0,10, batches * batch_size).astype(np.int32)
...     for i in range(batches):
...         yield data[i*batch_size:(i+1)*batch_size], label[i*batch_size:(i+1)*batch_size]
>>> class Net(Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self._softmax = P.Softmax()
...         self._Dense = nn.Dense(10,10)
...         self._squeeze = P.Squeeze(1)
...     def construct(self, inputs):
...         out = self._softmax(inputs)
...         out = self._Dense(out)
...         return self._squeeze(out)
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
>>> opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(network=net, loss_fn=loss, optimizer=opt)
>>> inference_model = MembershipInference(model, 2)
>>> config = [{
...     "method": "KNN",
...     "params": {"n_neighbors": [3, 5, 7],}
...     }]
>>> ds_train = ds.GeneratorDataset(dataset_generator, ["image", "label"])
>>> ds_test = ds.GeneratorDataset(dataset_generator, ["image", "label"])
>>> inference_model.train(ds_train, ds_test, config)
>>> metrics = ["precision", "accuracy", "recall"]
>>> eval_train = ds.GeneratorDataset(dataset_generator, ["image", "label"])
>>> eval_test = ds.GeneratorDataset(dataset_generator, ["image", "label"])
>>> result = inference_model.eval(eval_train. eval_test, metrics)
>>> print(result)
eval(dataset_train, dataset_test, metrics)[source]

Evaluate the different privacy of the target model. Evaluation indicators shall be specified by metrics.

Parameters
  • dataset_train (mindspore.dataset) – The training dataset for the target model.

  • dataset_test (mindspore.dataset) – The test dataset for the target model.

  • metrics (Union[list, tuple]) – Evaluation indicators. The value of metrics must be in [“precision”, “accuracy”, “recall”]. Default: [“precision”].

Returns

list, each element contains an evaluation indicator for the attack model.

train(dataset_train, dataset_test, attack_config)[source]

Depending on the configuration, use the input dataset to train the attack model. Save the attack model to self._attack_list.

Parameters
  • dataset_train (mindspore.dataset) – The training dataset for the target model.

  • dataset_test (mindspore.dataset) – The test set for the target model.

  • attack_config (Union[list, tuple]) – Parameter setting for the attack model. The format is [{“method”: “knn”, “params”: {“n_neighbors”: [3, 5, 7]}}, {“method”: “lr”, “params”: {“C”: np.logspace(-4, 2, 10)}}]. The support methods are knn, lr, mlp and rf, and the params of each method must within the range of changeable parameters. Tips of params implement can be found below: KNN, LR, RF, MLP.

Raises
  • KeyError – If any config in attack_config doesn’t have keys {“method”, “params”}.

  • NameError – If the method(case insensitive) in attack_config is not in [“lr”, “knn”, “rf”, “mlp”].