mindarmour.reliability
Reliability methods of MindArmour.
- class mindarmour.reliability.ConceptDriftCheckTimeSeries(window_size=100, rolling_window=10, step=10, threshold_index=1.5, need_label=False)[source]
ConceptDriftCheckTimeSeries is used for example series distribution change detection. For details, please check Implementing the Concept Drift Detection Application of Time Series Data.
- Parameters
window_size (int) – Size of a concept window, no less than 10. If given the input data, window_size belongs to [10, 1/3*len(input data)]. If the data is periodic, usually window_size equals 2-5 periods, such as, for monthly/weekly data, the data volume of 30/7 days is a period. Default:
100
.rolling_window (int) – Smoothing window size, belongs to [1, window_size]. Default:
10
.step (int) – The jump length of the sliding window, belongs to [1, window_size]. Default:
10
.threshold_index (float) – The threshold index, \((-\infty, +\infty)\). Default:
1.5
.need_label (bool) – If
True
, concept drift labels are needed. Default:False
.
Examples
>>> from mindarmour import ConceptDriftCheckTimeSeries >>> concept = ConceptDriftCheckTimeSeries(window_size=100, rolling_window=10, ... step=10, threshold_index=1.5, need_label=False) >>> data_example = 5*np.random.rand(1000) >>> data_example[200: 800] = 20*np.random.rand(600) >>> score, threshold, concept_drift_location = concept.concept_check(data_example)
- concept_check(data)[source]
Find concept drift locations in a data series.
- Parameters
data (numpy.ndarray) – Input data. The shape of data could be \((n,1)\) or \((n,m)\). Note that each column (m columns) is one data series.
- Returns
numpy.ndarray, the concept drift score of the example series.
float, the threshold to judge concept drift.
list, the location of the concept drift.
- class mindarmour.reliability.FaultInjector(model, fi_type=None, fi_mode=None, fi_size=None)[source]
Fault injection module simulates various fault scenarios for deep neural networks and evaluates performance and reliability of the model.
For details, please check Implementing the Model Fault Injection and Evaluation.
- Parameters
model (Model) – The model need to be evaluated.
fi_type (list) – The type of the fault injection which include
bitflips_random
(flip randomly),bitflips_designated
(flip the key bit),random
,zeros
,nan
,inf
,anti_activation
,precision_loss
etc.fi_mode (list) – The mode of fault injection. Fault inject on just
single layer
orall layers
.fi_size (list) – The number of fault injection. It mean that how many values need to be injected.
Examples
>>> from mindspore import Model >>> import mindspore.ops.operations as P >>> from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self._softmax = P.Softmax() ... self._Dense = nn.Dense(10,10) ... self._squeeze = P.Squeeze(1) ... def construct(self, inputs): ... out = self._softmax(inputs) ... out = self._Dense(out) ... return self._squeeze(out) >>> def dataset_generator(): ... batch_size = 16 ... batches = 1 ... data = np.random.randn(batches * batch_size,1,10).astype(np.float32) ... label = np.random.randint(0,10, batches * batch_size).astype(np.int32) ... for i in range(batches): ... yield data[i*batch_size:(i+1)*batch_size], label[i*batch_size:(i+1)*batch_size] >>> net = Net() >>> model = Model(net) >>> ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) >>> fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', ... 'nan', 'inf', 'anti_activation', 'precision_loss'] >>> fi_mode = ['single_layer', 'all_layer'] >>> fi_size = [1] >>> fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) >>> fi.kick_off() >>> fi.metrics()
- kick_off(ds_data, ds_label, iter_times=100)[source]
Startup and return final results after Fault Injection.
- Parameters
ds_data (np.ndarray) – Input data for testing. The evaluation is based on this data.
ds_label (np.ndarray) – The label of data, corresponding to the data.
iter_times (int) – The number of evaluations, which will determine the batch size.
- Returns
list, the result of fault injection.
- class mindarmour.reliability.OodDetector(model, ds_train)[source]
The abstract class of the out-of-distribution detector.
- Parameters
model (Model) – The training model.
ds_train (numpy.ndarray) – The training dataset.
- get_optimal_threshold(label, ds_eval)[source]
Get the optimal threshold. Try to find an optimal threshold value to detect OOD examples. The optimal threshold is calculated by a labeled dateset ds_eval.
- Parameters
label (numpy.ndarray) – The label whether an image is in-distribution and out-of-distribution.
ds_eval (numpy.ndarray) – The testing dataset to help find the threshold.
- Returns
float, the optimal threshold.
- ood_predict(threshold, ds_test)[source]
The out-of-distribution detection. This function aims to detect whether images, regarded as ds_test, are OOD examples or not. If the prediction score of one image is larger than threshold, this image is out-of-distribution.
- Parameters
threshold (float) – the threshold to judge ood data. One can set value by experience or use function get_optimal_threshold.
ds_test (numpy.ndarray) – The testing dataset.
- Returns
numpy.ndarray, the detection result.
0
means the data is not ood,1
means the data is ood.
- class mindarmour.reliability.OodDetectorFeatureCluster(model, ds_train, n_cluster, layer)[source]
Train the OOD detector. Extract the training data features, and obtain the clustering centers. The distance between the testing data features and the clustering centers determines whether an image is an out-of-distribution(OOD) image or not.
For details, please check Implementing the Concept Drift Detection Application of Image Data.
- Parameters
model (Model) – The training model.
ds_train (numpy.ndarray) – The training dataset.
n_cluster (int) – The cluster number. Belonging to [2,100]. Usually, n_cluster equals to the class number of the training dataset. If the OOD detector performs poor in the testing dataset, we can increase the value of n_cluster appropriately.
layer (str) – The name of the feature layer. layer (str) is represented by ‘name[:Tensor]’, where ‘name’ is given by users when training the model. Please see more details about how to name the model layer in ‘README.md’.
Examples
>>> from mindspore import Model >>> from mindspore.ops import TensorSummary >>> import mindspore.ops.operations as P >>> from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self._softmax = P.Softmax() ... self._Dense = nn.Dense(10,10) ... self._squeeze = P.Squeeze(1) ... self._summary = TensorSummary() ... def construct(self, inputs): ... out = self._softmax(inputs) ... out = self._Dense(out) ... self._summary('output', out) ... return self._squeeze(out) >>> net = Net() >>> model = Model(net) >>> batch_size = 16 >>> batches = 1 >>> ds_train = np.random.randn(batches * batch_size, 1, 10).astype(np.float32) >>> ds_eval = np.random.randn(batches * batch_size, 1, 10).astype(np.float32) >>> detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]') >>> num = int(len(ds_eval) / 2) >>> ood_label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) >>> optimal_threshold = detector.get_optimal_threshold(ood_label, ds_eval)
- get_optimal_threshold(label, ds_eval)[source]
Get the optimal threshold. Try to find an optimal threshold value to detect OOD examples. The optimal threshold is calculated by a labeled dateset ds_eval.
- Parameters
label (numpy.ndarray) – The label whether an image is in-distribution and out-of-distribution.
ds_eval (numpy.ndarray) – The testing dataset to help find the threshold.
- Returns
float, the optimal threshold.
- ood_predict(threshold, ds_test)[source]
The out-of-distribution detection. This function aims to detect whether images, regarded as ds_test, are OOD examples or not. If the prediction score of one image is larger than threshold, this image is out-of-distribution.
- Parameters
threshold (float) – the threshold to judge ood data. One can set value by experience or use function get_optimal_threshold.
ds_test (numpy.ndarray) – The testing dataset.
- Returns
numpy.ndarray, the detection result.
0
means the data is not ood,1
means the data is ood.