mindarmour.privacy.diff_privacy
This module provides Differential Privacy feature to protect user privacy.
- class mindarmour.privacy.diff_privacy.AdaClippingWithGaussianRandom(decay_policy='Linear', learning_rate=0.001, target_unclipped_quantile=0.9, fraction_stddev=0.01, seed=0)[source]
Adaptive clipping. If decay_policy is ‘Linear’, the update formula \(norm bound = norm bound - learning rate*(beta - target unclipped quantile)\). If decay_policy is ‘Geometric’, the update formula is \(norm bound = norm bound*exp(-learning rate*(empirical fraction - target unclipped quantile))\). where beta is the empirical fraction of samples with the value at most target_unclipped_quantile.
- Parameters
decay_policy (str) – Decay policy of adaptive clipping, decay_policy must be in [‘Linear’, ‘Geometric’]. Default: Linear.
learning_rate (float) – Learning rate of update norm clip. Default: 0.001.
target_unclipped_quantile (float) – Target quantile of norm clip. Default: 0.9.
fraction_stddev (float) – The stddev of Gaussian normal which used in empirical_fraction, the formula is empirical_fraction + N(0, fraction_stddev). Default: 0.01.
seed (int) – Original random seed, if seed=0 random normal will use secure random number. IF seed!=0 random normal will generate values using given seed. Default: 0.
- Returns
Tensor, undated norm clip .
Examples
>>> decay_policy = 'Linear' >>> beta = Tensor(0.5, mstype.float32) >>> norm_bound = Tensor(1.0, mstype.float32) >>> beta_stddev = 0.01 >>> learning_rate = 0.001 >>> target_unclipped_quantile = 0.9 >>> ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, >>> learning_rate=learning_rate, >>> target_unclipped_quantile=target_unclipped_quantile, >>> fraction_stddev=beta_stddev) >>> next_norm_bound = ada_clip(beta, norm_bound)
- class mindarmour.privacy.diff_privacy.ClipMechanismsFactory[source]
Factory class of clip mechanisms
- static create(mech_name, decay_policy='Linear', learning_rate=0.001, target_unclipped_quantile=0.9, fraction_stddev=0.01, seed=0)[source]
- Parameters
mech_name (str) – Clip noise generated strategy, support ‘Gaussian’ now.
decay_policy (str) – Decay policy of adaptive clipping, decay_policy must be in [‘Linear’, ‘Geometric’]. Default: Linear.
learning_rate (float) – Learning rate of update norm clip. Default: 0.001.
target_unclipped_quantile (float) – Target quantile of norm clip. Default: 0.9.
fraction_stddev (float) – The stddev of Gaussian normal which used in empirical_fraction, the formula is \(empirical fraction + N(0, fraction sstddev)\). Default: 0.01.
seed (int) – Original random seed, if seed=0 random normal will use secure random number. IF seed!=0 random normal will generate values using given seed. Default: 0.
- Raises
NameError – mech_name must be in [‘Gaussian’].
- Returns
Mechanisms, class of noise generated Mechanism.
Examples
>>> decay_policy = 'Linear' >>> beta = Tensor(0.5, mstype.float32) >>> norm_bound = Tensor(1.0, mstype.float32) >>> beta_stddev = 0.1 >>> learning_rate = 0.1 >>> target_unclipped_quantile = 0.3 >>> clip_mechanism = ClipMechanismsFactory() >>> ada_clip = clip_mechanism.create('Gaussian', >>> decay_policy=decay_policy, >>> learning_rate=learning_rate, >>> target_unclipped_quantile=target_unclipped_quantile, >>> fraction_stddev=beta_stddev) >>> next_norm_bound = ada_clip(beta, norm_bound)
- class mindarmour.privacy.diff_privacy.DPModel(micro_batches=2, norm_bound=1.0, noise_mech=None, clip_mech=None, **kwargs)[source]
This class is overload mindspore.train.model.Model.
- Parameters
micro_batches (int) – The number of small batches split from an original batch. Default: 2.
norm_bound (float) – Use to clip the bound, if set 1, will return the original data. Default: 1.0.
noise_mech (Mechanisms) – The object can generate the different type of noise. Default: None.
clip_mech (Mechanisms) – The object is used to update the adaptive clip. Default: None.
- Raises
ValueError – If DPOptimizer and noise_mecn are both None or not None.
ValueError – If noise_mech or DPOtimizer’s mech method is adaptive while clip_mech is not None.
Examples
>>> norm_bound = 1.0 >>> initial_noise_multiplier = 0.01 >>> network = LeNet5() >>> batch_size = 32 >>> batches = 128 >>> epochs = 1 >>> micro_batches = 2 >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) >>> factory_opt = DPOptimizerClassFactory(micro_batches=micro_batches) >>> factory_opt.set_mechanisms('Gaussian', >>> norm_bound=norm_bound, >>> initial_noise_multiplier=initial_noise_multiplier) >>> net_opt = factory_opt.create('Momentum')(network.trainable_params(), >>> learning_rate=0.1, momentum=0.9) >>> clip_mech = ClipMechanismsFactory().create('Gaussian', >>> decay_policy='Linear', >>> learning_rate=0.01, >>> target_unclipped_quantile=0.9, >>> fraction_stddev=0.01) >>> model = DPModel(micro_batches=micro_batches, >>> norm_bound=norm_bound, >>> clip_mech=clip_mech, >>> noise_mech=None, >>> network=network, >>> loss_fn=loss, >>> optimizer=net_opt, >>> metrics=None) >>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), >>> ['data', 'label']) >>> model.train(epochs, ms_ds, dataset_sink_mode=False)
- class mindarmour.privacy.diff_privacy.DPOptimizerClassFactory(micro_batches=2)[source]
Factory class of Optimizer.
- Parameters
micro_batches (int) – The number of small batches split from an original batch. Default: 2.
- Returns
Optimizer, Optimizer class.
Examples
>>> GaussianSGD = DPOptimizerClassFactory(micro_batches=2) >>> GaussianSGD.set_mechanisms('Gaussian', norm_bound=1.0, initial_noise_multiplier=1.5) >>> net_opt = GaussianSGD.create('Momentum')(params=network.trainable_params(), >>> learning_rate=0.001, >>> momentum=0.9)
- class mindarmour.privacy.diff_privacy.NoiseAdaGaussianRandom(norm_bound=1.0, initial_noise_multiplier=1.0, seed=0, noise_decay_rate=6e-06, decay_policy='Exp')[source]
Adaptive Gaussian noise generated mechanism. Noise would be decayed with training. Decay mode could be ‘Time’ mode, ‘Step’ mode, ‘Exp’ mode. self._noise_multiplier will be update during the model.train, using _MechanismsParamsUpdater.
- Parameters
norm_bound (float) – Clipping bound for the l2 norm of the gradients. Default: 1.0.
initial_noise_multiplier (float) – Ratio of the standard deviation of Gaussian noise divided by the norm_bound, which will be used to calculate privacy spent. Default: 1.0.
seed (int) – Original random seed, if seed=0 random normal will use secure random number. IF seed!=0 random normal will generate values using given seed. Default: 0.
noise_decay_rate (float) – Hyper parameter for controlling the noise decay. Default: 6e-6.
decay_policy (str) – Noise decay strategy include ‘Step’, ‘Time’, ‘Exp’. Default: ‘Exp’.
- Returns
Tensor, generated noise with shape like given gradients.
Examples
>>> gradients = Tensor([0.2, 0.9], mstype.float32) >>> norm_bound = 1.0 >>> initial_noise_multiplier = 1.5 >>> seed = 0 >>> noise_decay_rate = 6e-4 >>> decay_policy = "Exp" >>> net = NoiseAdaGaussianRandom(norm_bound, initial_noise_multiplier, seed, noise_decay_rate, decay_policy) >>> res = net(gradients) >>> print(res)
- class mindarmour.privacy.diff_privacy.NoiseGaussianRandom(norm_bound=1.0, initial_noise_multiplier=1.0, seed=0, decay_policy=None)[source]
Gaussian noise generated mechanism.
- Parameters
norm_bound (float) – Clipping bound for the l2 norm of the gradients. Default: 1.0.
initial_noise_multiplier (float) – Ratio of the standard deviation of Gaussian noise divided by the norm_bound, which will be used to calculate privacy spent. Default: 1.0.
seed (int) – Original random seed, if seed=0 random normal will use secure random number. IF seed!=0 random normal will generate values using given seed. Default: 0.
decay_policy (str) – Mechanisms parameters update policy. Default: None.
- Returns
Tensor, generated noise with shape like given gradients.
Examples
>>> gradients = Tensor([0.2, 0.9], mstype.float32) >>> norm_bound = 0.5 >>> initial_noise_multiplier = 1.5 >>> seed = 0 >>> decay_policy = None >>> net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier, seed, decay_policy) >>> res = net(gradients) >>> print(res)
- class mindarmour.privacy.diff_privacy.NoiseMechanismsFactory[source]
Factory class of noise mechanisms
- static create(mech_name, norm_bound=1.0, initial_noise_multiplier=1.0, seed=0, noise_decay_rate=6e-06, decay_policy=None)[source]
- Parameters
mech_name (str) – Noise generated strategy, could be ‘Gaussian’ or ‘AdaGaussian’. Noise would be decayed with ‘AdaGaussian’ mechanism while be constant with ‘Gaussian’ mechanism.
norm_bound (float) – Clipping bound for the l2 norm of the gradients. Default: 1.0.
initial_noise_multiplier (float) – Ratio of the standard deviation of Gaussian noise divided by the norm_bound, which will be used to calculate privacy spent. Default: 1.0.
seed (int) – Original random seed, if seed=0 random normal will use secure random number. IF seed!=0 random normal will generate values using given seed. Default: 0.
noise_decay_rate (float) – Hyper parameter for controlling the noise decay. Default: 6e-6.
decay_policy (str) – Mechanisms parameters update policy. Default: None, no parameters need update. Default: None.
- Raises
NameError – mech_name must be in [‘Gaussian’, ‘AdaGaussian’].
- Returns
Mechanisms, class of noise generated Mechanism.
Examples
>>> norm_bound = 1.0 >>> initial_noise_multiplier = 0.01 >>> network = LeNet5() >>> batch_size = 32 >>> batches = 128 >>> epochs = 1 >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) >>> noise_mech = NoiseMechanismsFactory().create('Gaussian', >>> norm_bound=norm_bound, >>> initial_noise_multiplier=initial_noise_multiplier) >>> clip_mech = ClipMechanismsFactory().create('Gaussian', >>> decay_policy='Linear', >>> learning_rate=0.01, >>> target_unclipped_quantile=0.9, >>> fraction_stddev=0.01) >>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, >>> momentum=0.9) >>> model = DPModel(micro_batches=2, >>> clip_mech=clip_mech, >>> norm_bound=norm_bound, >>> noise_mech=noise_mech, >>> network=network, >>> loss_fn=loss, >>> optimizer=net_opt, >>> metrics=None) >>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), >>> ['data', 'label']) >>> model.train(epochs, ms_ds, dataset_sink_mode=False)
- class mindarmour.privacy.diff_privacy.PrivacyMonitorFactory[source]
Factory class of DP training’s privacy monitor.
- static create(policy, *args, **kwargs)[source]
Create a privacy monitor class.
- Parameters
policy (str) – Monitor policy, ‘rdp’ and ‘zcdp’ are supported by now. If policy is ‘rdp’, the monitor will compute the privacy budget of DP training based on Renyi differential privacy theory; If policy is ‘zcdp’, the monitor will compute the privacy budget of DP training based on zero-concentrated differential privacy theory. It’s worth noting that ‘zcdp’ is not suitable for subsampling nosie mechanism.
args (Union[int, float, numpy.ndarray, list, str]) – Parameters used for creating a privacy monitor.
kwargs (Union[int, float, numpy.ndarray, list, str]) – Keyword parameters used for creating a privacy monitor.
- Returns
Callback, a privacy monitor.
Examples
>>> rdp = PrivacyMonitorFactory.create(policy='rdp', >>> num_samples=60000, batch_size=32)
- class mindarmour.privacy.diff_privacy.RDPMonitor(num_samples, batch_size, initial_noise_multiplier=1.5, max_eps=10.0, target_delta=0.001, max_delta=None, target_eps=None, orders=None, noise_decay_mode='Time', noise_decay_rate=0.0006, per_print_times=50, dataset_sink_mode=False)[source]
Compute the privacy budget of DP training based on Renyi differential privacy (RDP) theory. According to the reference below, if a randomized mechanism is said to have ε’-Renyi differential privacy of order α, it also satisfies conventional differential privacy (ε, δ) as below:
\[(ε'+\frac{log(1/δ)}{α-1}, δ)\]Reference: Rényi Differential Privacy of the Sampled Gaussian Mechanism
- Parameters
num_samples (int) – The total number of samples in training data sets.
batch_size (int) – The number of samples in a batch while training.
initial_noise_multiplier (Union[float, int]) – Ratio of the standard deviation of Gaussian noise divided by the norm_bound, which will be used to calculate privacy spent. Default: 1.5.
max_eps (Union[float, int, None]) – The maximum acceptable epsilon budget for DP training, which is used for estimating the max training epochs. ‘None’ means there is no limit to epsilon budget. Default: 10.0.
target_delta (Union[float, int, None]) – Target delta budget for DP training. If target_delta is set to be δ, then the privacy budget δ would be fixed during the whole training process. Default: 1e-3.
max_delta (Union[float, int, None]) – The maximum acceptable delta budget for DP training, which is used for estimating the max training epochs. Max_delta must be less than 1 and suggested to be less than 1e-3, otherwise overflow would be encountered. ‘None’ means there is no limit to delta budget. Default: None.
target_eps (Union[float, int, None]) – Target epsilon budget for DP training. If target_eps is set to be ε, then the privacy budget ε would be fixed during the whole training process. Default: None.
orders (Union[None, list[int, float]]) – Finite orders used for computing rdp, which must be greater than 1. The computation result of privacy budget would be different for various orders. In order to obtain a tighter (smaller) privacy budget estimation, a list of orders could be tried. Default: None.
noise_decay_mode (Union[None, str]) – Decay mode of adding noise while training, which can be None, ‘Time’, ‘Step’ or ‘Exp’. Default: ‘Time’.
noise_decay_rate (float) – Decay rate of noise while training. Default: 6e-4.
per_print_times (int) – The interval steps of computing and printing the privacy budget. Default: 50.
dataset_sink_mode (bool) – If True, all training data would be passed to device(Ascend) one-time. If False, training data would be passed to device after each step training. Default: False.
Examples
>>> network = Net() >>> net_loss = nn.SoftmaxCrossEntropyWithLogits() >>> epochs = 2 >>> norm_clip = 1.0 >>> initial_noise_multiplier = 1.5 >>> mech = NoiseMechanismsFactory().create('AdaGaussian', >>> norm_bound=norm_clip, initial_noise_multiplier=initial_noise_multiplier) >>> net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) >>> model = DPModel(micro_batches=2, norm_clip=norm_clip, >>> mech=mech, network=network, loss_fn=loss, optimizer=net_opt, metrics=None) >>> rdp = PrivacyMonitorFactory.create(policy='rdp', >>> num_samples=60000, batch_size=256, >>> initial_noise_multiplier=initial_noise_multiplier) >>> model.train(epochs, ds, callbacks=[rdp], dataset_sink_mode=False)
- max_epoch_suggest()[source]
Estimate the maximum training epochs to satisfy the predefined privacy budget.
- Returns
int, the recommended maximum training epochs.
Examples
>>> rdp = PrivacyMonitorFactory.create(policy='rdp', >>> num_samples=60000, batch_size=32) >>> suggest_epoch = rdp.max_epoch_suggest()
- step_end(run_context)[source]
Compute privacy budget after each training step.
- Parameters
run_context (RunContext) – Include some information of the model.
- class mindarmour.privacy.diff_privacy.ZCDPMonitor(num_samples, batch_size, initial_noise_multiplier=1.5, max_eps=10.0, target_delta=0.001, noise_decay_mode='Time', noise_decay_rate=0.0006, per_print_times=50, dataset_sink_mode=False)[source]
Compute the privacy budget of DP training based on zero-concentrated differential privacy theory (zcdp). According to the reference below, if a randomized mechanism is said to have ρ-zCDP, it also satisfies conventional differential privacy (ε, δ) as below:
\[(ρ+2\sqrt{ρlog(1/δ)}, δ)\]It should be noted that ZCDPMonitor is not suitable for subsampling noise mechanisms(such as NoiseAdaGaussianRandom and NoiseGaussianRandom). The matching noise mechanism of ZCDP will be developed in the future. Reference: Concentrated Differentially Private Gradient Descent with Adaptive per-Iteration Privacy Budget
- Parameters
num_samples (int) – The total number of samples in training data sets.
batch_size (int) – The number of samples in a batch while training.
initial_noise_multiplier (Union[float, int]) – Ratio of the standard deviation of Gaussian noise divided by the norm_bound, which will be used to calculate privacy spent. Default: 1.5.
max_eps (Union[float, int]) – The maximum acceptable epsilon budget for DP training, which is used for estimating the max training epochs. Default: 10.0.
target_delta (Union[float, int]) – Target delta budget for DP training. If target_delta is set to be δ, then the privacy budget δ would be fixed during the whole training process. Default: 1e-3.
noise_decay_mode (Union[None, str]) – Decay mode of adding noise while training, which can be None, ‘Time’, ‘Step’ or ‘Exp’. Default: ‘Time’.
noise_decay_rate (float) – Decay rate of noise while training. Default: 6e-4.
per_print_times (int) – The interval steps of computing and printing the privacy budget. Default: 50.
dataset_sink_mode (bool) – If True, all training data would be passed to device(Ascend) one-time. If False, training data would be passed to device after each step training. Default: False.
Examples
>>> network = Net() >>> net_loss = nn.SoftmaxCrossEntropyWithLogits() >>> epochs = 2 >>> norm_clip = 1.0 >>> initial_noise_multiplier = 1.5 >>> mech = NoiseMechanismsFactory().create('AdaGaussian', >>> norm_bound=norm_clip, initial_noise_multiplier=initial_noise_multiplier) >>> net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) >>> model = DPModel(micro_batches=2, norm_clip=norm_clip, >>> mech=mech, network=network, loss_fn=loss, optimizer=net_opt, metrics=None) >>> zcdp = PrivacyMonitorFactory.create(policy='zcdp', >>> num_samples=60000, batch_size=256, >>> initial_noise_multiplier=initial_noise_multiplier) >>> model.train(epochs, ds, callbacks=[zcdp], dataset_sink_mode=False)
- max_epoch_suggest()[source]
Estimate the maximum training epochs to satisfy the predefined privacy budget.
- Returns
int, the recommended maximum training epochs.
Examples
>>> zcdp = PrivacyMonitorFactory.create(policy='zcdp', >>> num_samples=60000, batch_size=32) >>> suggest_epoch = zcdp.max_epoch_suggest()
- step_end(run_context)[source]
Compute privacy budget after each training step.
- Parameters
run_context (RunContext) – Include some information of the model.