mindspore.nn.probability.transforms.TransformToBNN

class mindspore.nn.probability.transforms.TransformToBNN(trainable_dnn, dnn_factor=1, bnn_factor=1)[source]

Transform Deep Neural Network (DNN) model to Bayesian Neural Network (BNN) model.

Parameters
  • trainable_dnn (Cell) – A trainable DNN model (backbone) wrapped by TrainOneStepCell.

  • dnn_factor ((int, float) – The coefficient of backbone’s loss, which is computed by loss function. Default: 1.

  • bnn_factor (int, float) – The coefficient of KL loss, which is KL divergence of Bayesian layer. Default: 1.

Supported Platforms:

Ascend GPU

Examples

>>> from mindspore.nn.probability import bnn_layers
>>>
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
...         self.bn = nn.BatchNorm2d(64)
...         self.relu = nn.ReLU()
...         self.flatten = nn.Flatten()
...         self.fc = nn.Dense(64*224*224, 12) # padding=0
...
...     def construct(self, x):
...         x = self.conv(x)
...         x = self.bn(x)
...         x = self.relu(x)
...         x = self.flatten(x)
...         out = self.fc(x)
...         return out
>>>
>>> net = Net()
>>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
>>> optim = nn.AdamWeightDecay(params=net.trainable_params(), learning_rate=0.0001)
>>> net_with_loss = nn.WithLossCell(net, criterion)
>>> train_network = nn.TrainOneStepCell(net_with_loss, optim)
>>> bnn_transformer = TransformToBNN(train_network, 60000, 0.0001)
transform_to_bnn_layer(dnn_layer_type, bnn_layer_type, get_args=None, add_args=None)[source]

Transform a specific type of layers in DNN model to corresponding BNN layer.

Parameters
  • dnn_layer_type (Cell) – The type of DNN layer to be transformed to BNN layer. The optional values are nn.Dense and nn.Conv2d.

  • bnn_layer_type (Cell) – The type of BNN layer to be transformed to. The optional values are DenseReparam and ConvReparam.

  • get_args – The arguments gotten from the DNN layer. Default: None.

  • add_args (dict) – The new arguments added to BNN layer. Note that the arguments in add_args must not duplicate arguments in get_args. Default: None.

Returns

Cell, a trainable model wrapped by TrainOneStepCell, whose specific type of layer is transformed to the corresponding bayesian layer.

Supported Platforms:

Ascend GPU

Examples

>>> import mindspore.nn as nn
>>> from mindspore.nn.probability import bnn_layers
>>> net = LeNet()
>>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
>>> optim = nn.AdamWeightDecay(params=net.trainable_params(), learning_rate=0.0001)
>>> net_with_loss = nn.WithLossCell(net, criterion)
>>> train_network = nn.TrainOneStepCell(net_with_loss, optim)
>>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1)
>>> train_bnn_network = bnn_transformer.transform_to_bnn_layer(nn.Dense, bnn_layers.DenseReparam)
transform_to_bnn_model(get_dense_args=lambda dp: ..., get_conv_args=lambda dp: ..., add_dense_args=None, add_conv_args=None)[source]

Transform the whole DNN model to BNN model, and wrap BNN model by TrainOneStepCell.

Parameters
  • get_dense_args – The arguments gotten from the DNN full connection layer. Default: lambda dp: {“in_channels”: dp.in_channels, “out_channels”: dp.out_channels, “has_bias”: dp.has_bias}.

  • get_conv_args – The arguments gotten from the DNN convolutional layer. Default: lambda dp: {“in_channels”: dp.in_channels, “out_channels”: dp.out_channels, “pad_mode”: dp.pad_mode, “kernel_size”: dp.kernel_size, “stride”: dp.stride, “has_bias”: dp.has_bias}.

  • add_dense_args (dict) – The new arguments added to BNN full connection layer. Note that the arguments in add_dense_args must not duplicate arguments in get_dense_args. Default: None.

  • add_conv_args (dict) – The new arguments added to BNN convolutional layer. Note that the arguments in add_conv_args must not duplicate arguments in get_conv_args. Default: None.

Returns

Cell, a trainable BNN model wrapped by TrainOneStepCell.

Supported Platforms:

Ascend GPU

Examples

>>> net = Net()
>>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
>>> optim = nn.AdamWeightDecay(params=net.trainable_params(), learning_rate=0.0001)
>>> net_with_loss = nn.WithLossCell(net, criterion)
>>> train_network = nn.TrainOneStepCell(net_with_loss, optim)
>>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1)
>>> train_bnn_network = bnn_transformer.transform_to_bnn_model()