Document feedback

Question document fragment

When a question document fragment contains a formula, it is displayed as a space.

Submission type
issue

It's a little complicated...

I'd like to ask someone.

Please select the submission type

Problem type
Specifications and Common Mistakes

- Specifications and Common Mistakes:

- Misspellings or punctuation mistakes,incorrect formulas, abnormal display.

- Incorrect links, empty cells, or wrong formats.

- Chinese characters in English context.

- Minor inconsistencies between the UI and descriptions.

- Low writing fluency that does not affect understanding.

- Incorrect version numbers, including software package names and version numbers on the UI.

Usability

- Usability:

- Incorrect or missing key steps.

- Missing main function descriptions, keyword explanation, necessary prerequisites, or precautions.

- Ambiguous descriptions, unclear reference, or contradictory context.

- Unclear logic, such as missing classifications, items, and steps.

Correctness

- Correctness:

- Technical principles, function descriptions, supported platforms, parameter types, or exceptions inconsistent with that of software implementation.

- Incorrect schematic or architecture diagrams.

- Incorrect commands or command parameters.

- Incorrect code.

- Commands inconsistent with the functions.

- Wrong screenshots.

- Sample code running error, or running results inconsistent with the expectation.

Risk Warnings

- Risk Warnings:

- Lack of risk warnings for operations that may damage the system or important data.

Content Compliance

- Content Compliance:

- Contents that may violate applicable laws and regulations or geo-cultural context-sensitive words and expressions.

- Copyright infringement.

Please select the type of question

Problem description

Describe the bug so that we can quickly locate the problem.

mindspore.boost

Boost provide auto accelerating for network, such as Less BN, Gradient Freeze, Gradient accumulation and so on.

Note

This feature is a beta feature, and we are still improving its functionality.

class mindspore.boost.AdaSum(rank, device_number, group_number, parameter_tuple)[source]

The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data parallel training of Deep Learning models.

Parameters
  • rank (int) – Rank number.

  • device_number (int) – Device number.

  • group_number (int) – Group number.

  • parameter_tuple (Tuple(Parameter)) – Tuple of parameters.

Inputs:
  • delta_weights (Tuple(Tensor)) - Tuple of gradients.

  • parameters (Tuple(Parameter)) - Tuple of current parameters.

  • old_parameters (Tuple(Parameter)) - Tuple of last parameters.

Outputs:
  • adasum_parameters (Tuple(Tensor)) - Tuple of parameters after adasum process.

class mindspore.boost.AutoBoost(level='O0', boost_config_dict='')[source]

Provide auto accelerating for network.

Parameters
  • level (str) – Boost config level. Default: "O0" .

  • boost_config_dict (dict) –

    User config hyperparameter dict, recommended config format:

    {
        "boost": {
            "mode": "auto",
            "less_bn": False,
            "grad_freeze": False,
            "adasum": False,
            "grad_accumulation": False,
            "dim_reduce": False,
            "loss_scale_group": False
        },
        "common": {
            "gradient_split_groups": [50, 100],
            "device_number": 8
        },
        "less_bn": {
            "fn_flag": True,
            "gc_flag": True
        },
        "grad_freeze": {
            "param_groups": 10,
            "freeze_type": 1,
            "freeze_p": 0.7,
            "total_steps": 65536
        }
        "dim_reduce": {
            "rho": 0.55,
            "gamma": 0.9,
            "alpha": 0.001,
            "sigma": 0.4,
            "n_components": 32,
            "pca_mat_path": None,
            "weight_load_dir": None,
            "timeout": 1800
        }
    }
    

    Default: "" .

    • boost:

      • mode (str): How to set the boost. Supports ["auto", "manual", "enable_all", "disable_all"]. Default: "auto" .

        • auto: Depend on the argument "boost_level" in class Model.

        • manual: Depend on "boost_config_dict".

        • enable_all: Set all boost functions true.

        • disable_all: Set all boost functions false.

      • less_bn (bool): Whether to apply less_bn function. Default: False .

      • grad_freeze: (bool): Whether to apply grad_freeze function. Default: False .

      • adasum (bool): Whether to apply adasum function. Default: False .

      • grad_accumulation (bool): Whether to apply grad_accumulation function. Default: False .

      • dim_reduce (bool): Whether to apply dim_reduce function. Default: False .

      • loss_scale_group (bool): Whether to apply loss_scale_group function. Default: False .

      If set dim_reduce true, other functions will be false. If set grad_freeze true and dim_reduce false, other functions will be false.

    • common:

      • gradient_split_groups (list): The gradient split point of this network. Default: [50, 100] .

      • device_number (int): Device number. Default: 8 .

    • less_bn:

      • fn_flag (bool): Whether changing fc to fn. Default: True .

      • gc_flag (bool): Whether to apply gc. Default: True .

    • grad_freeze:

      • param_groups (int): The number of parameter groups. Default: 10 .

      • freeze_type (int): Gradient freeze grouping strategy, select from [0, 1]. Default: 1 .

      • freeze_p (float): Gradient freezing probability. Default: 0.7 .

      • total_steps (int): Total training steps. Default: 65536 .

    • dim_reduce:

      The leading principles of dim_reduce:

      grad_k=pca_matgraddk=bkgrad_ksk=rhomdkdelta_loss=sigmagrad_k.Tsk

      Here:

      • pca_mat (array): Shape (kn), k is part of n_components, n is the size of weight.

      • bk (array): Shape (kk), is the symmetric positive definite matrix in Quasi-Newton method.

      we need to find the m satisfy:

      new_loss<old_loss+delta_loss

      Then, get delta_grad to update the weights for model:

      grad_k_proj=pca_mat.Tgrad_knew_grad_momentum=gammaold_grad_momentum+gradgrad_k_projdelta_grad=alphanew_grad_momentumpca_mat.Tsk
      • rho (float): Generally, it does not need to be modified. Default: 0.55 .

      • gamma (float): Generally, it does not need to be modified. Default: 0.9 .

      • alpha (float): Generally, it does not need to be modified. Default: 0.001 .

      • sigma (float): Generally, it does not need to be modified. Default: 0.4 .

      • n_components (int): PCA component. Default: 32 .

      • pca_mat_path (str): The path to load pca mat. Default: None .

      • weight_load_dir (str): The directory to load weight files saved as ckpt. Default: None .

      • timeout (int): Waiting time to load local pca mat. Default: 1800 (second) .

    User can load the config through the JSON file or use the dictionary directly. The unconfigured parameters will adopt the default values.

Raises

ValueError – The boost mode not in ["auto", "manual", "enable_all", "disable_all"].

Supported Platforms:

Ascend

Examples

>>> from mindspore.boost import AutoBoost
>>> #1) when configuring the dict directly:
>>> boost_config_dict = {"boost": {"mode": "auto"}}
>>> boost = AutoBoost("O1", boost_config_dict)
>>>
>>> #2) when loading the dict from a json file:
>>> import json
>>> boost_json = "/path/boost_config.json"
>>> with open(boost_json, 'r') as fp:
...     boost_config_dict = json.load(fp)
>>> boost = AutoBoost("O1", boost_config_dict)
network_auto_process_eval(network)[source]

Boost network eval.

Parameters

network (Cell) – The inference network.

network_auto_process_train(network, optimizer)[source]

Boost network train.

Parameters
  • network (Cell) – The training network.

  • optimizer (Cell) – Optimizer for updating the weights.

class mindspore.boost.BoostTrainOneStepCell(network, optimizer, sens=None)[source]

Boost Network training package class.

Wraps the network with an optimizer. The resulting Cell is trained with input '*inputs'. The backward graph will be created in the construct function to update the parameter, and different parallel modes are available for training.

Parameters
  • network (Cell) – The training network. The network only supports single output.

  • optimizer (Union[Cell]) – Optimizer for updating the weights.

  • sens (numbers.Number) – The scaling number to be filled as the input of backpropagation. Default: None , which is 1.0 .

Inputs:
  • *inputs (Tuple(Tensor)) - Tuple of input tensors with shape (N,).

Outputs:

Tensor, a tensor means the loss value, the shape of which is usually ().

  • loss(Tensor): A scalar Tensor.

  • overflow(Tensor): A scalar Tensor which type is bool.

  • loss scaling value(Tensor): A scalar Tensor.

Raises

TypeError – If sens is not a number.

Supported Platforms:

Ascend GPU CPU

Examples

>>> from mindspore import boost
>>> from mindspore import nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.5.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> #1) Using the WithLossCell existing provide
>>> loss_net = nn.WithLossCell(net, loss_fn)
>>> train_net = boost.BoostTrainOneStepCell(loss_net, optim)
>>>
>>> #2) Using user-defined WithLossCell
>>> class MyWithLossCell(nn.Cell):
...    def __init__(self, backbone, loss_fn):
...        super(MyWithLossCell, self).__init__(auto_prefix=False)
...        self._backbone = backbone
...        self._loss_fn = loss_fn
...
...    def construct(self, x, y, label):
...        out = self._backbone(x, y)
...        return self._loss_fn(out, label)
...
...    @property
...    def backbone_network(self):
...        return self._backbone
...
>>> loss_net = MyWithLossCell(net, loss_fn)
>>> train_net = boost.BoostTrainOneStepCell(loss_net, optim)
adasum_process(loss, grads)[source]

Adasum algorithm process.

Parameters
  • loss (Tensor) – Tensor with shape ().

  • grads (tuple(Tensor)) – Tuple of gradient tensors.

Returns

  • loss (Tensor) - Network loss, tensor with shape ().

check_adasum_enable()[source]

Check adasum enable.

Returns

  • enable_adasum (bool) - Check whether the Adasum algorithm is enabled.

check_dim_reduce_enable()[source]

Check dim_reduce enable.

Returns

  • enable_dim_reduce (bool) - Check whether the dimensionality reduction second-order training algorithm is enabled.

gradient_accumulation_process(loss, grads, sens, *inputs)[source]

Gradient accumulation algorithm process.

Parameters
  • loss (Tensor) – Tensor with shape ().

  • grads (tuple(Tensor)) – Tuple of gradient tensors.

  • sens (Tensor) – Tensor with shape ().

  • inputs (tuple(Tensor)) – Tuple of input tensors with shape (N,).

Returns

  • loss (Tensor) - Network loss, tensor with shape ().

gradient_freeze_process(*inputs)[source]

Gradient freeze algorithm process.

Parameters

inputs (tuple(Tensor)) – Tuple of input tensors with shape (N,).

Returns

  • loss (Tensor) - Network loss, tensor with shape ().

class mindspore.boost.BoostTrainOneStepWithLossScaleCell(network, optimizer, scale_sense)[source]

Boost Network training with loss scaling.

This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update Cell as args. The loss scale value can be updated in both host side or device side. The BoostTrainOneStepWithLossScaleCell will be compiled to be graph which takes *inputs as input data. The Tensor type of scale_sense is acting as loss scaling value. If you want to update it on host side, the value must be provided. If the Tensor type of scale_sense is not given, the loss scale update logic must be provide by Cell type of scale_sense.

Parameters
  • network (Cell) – The training network. The network only supports single output.

  • optimizer (Cell) – Optimizer for updating the weights.

  • scale_sense (Union[Tensor, Cell]) – If this value is Cell type, the loss scaling update logic cell.If this value is Tensor type, mindspore.nn.TrainOneStepWithLossScaleCell.set_sense_scale() can be called to update loss scale factor, Tensor with shape () or (1,).

Inputs:
  • *inputs (Tuple(Tensor)) - Tuple of input tensors with shape (N,).

Outputs:

Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value.

  • loss (Tensor) - Tensor with shape ().

  • overflow (Tensor) - Tensor with shape (), type is bool.

  • loss scaling value (Tensor) - Tensor with shape ()

Raises
  • TypeError – If scale_sense is neither Cell nor Tensor.

  • ValueError – If shape of scale_sense is neither (1,) nor ().

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> from mindspore import ops
>>> from mindspore.nn import WithLossCell
>>> from mindspore import dtype as mstype
>>> from mindspore import boost
>>>
>>> class Net(nn.Cell):
...     def __init__(self, in_features, out_features):
...         super(Net, self).__init__()
...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                 name='weight')
...         self.matmul = ops.MatMul()
...
...     def construct(self, x):
...         output = self.matmul(x, self.weight)
...         return output
...
>>> size, in_features, out_features = 16, 16, 10
>>> #1) when the type of scale_sense is Cell:
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = WithLossCell(net, loss)
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
>>> train_network = boost.BoostTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
>>> input = Tensor(np.ones([out_features, in_features]), mstype.float32)
>>> labels = Tensor(np.ones([out_features,]), mstype.float32)
>>> output = train_network(input, labels)
>>>
>>> #2) when the type of scale_sense is Tensor:
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = WithLossCell(net, loss)
>>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
>>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
>>> train_network = boost.BoostTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens)
>>> output = train_network(inputs, label)
class mindspore.boost.DimReduce(network, optimizer, weight, pca_mat_local, n_components, rho, gamma, alpha, sigma, rank, rank_size)[source]

The dimension reduce training, is a novel algorithm for accelerating convergence of Deep Learning models.

grad_k=pca_matgraddk=bkgrad_ksk=rhomdkdelta_loss=sigmagrad_k.Tsk

Here:

  • pca_mat (array): Shape (kn), k is part of n_components, n is the size of weight.

  • bk (array): Shape (kk), is the symmetric positive definite matrix in Quasi-Newton method.

we need to find the m satisfy:

new_loss<old_loss+delta_loss

Then, get delta_grad to update the weights for model:

grad_k_proj=pca_mat.Tgrad_knew_grad_momentum=gammaold_grad_momentum+gradgrad_k_projdelta_grad=alphanew_grad_momentumpca_mat.Tsk
Parameters
  • network (Cell) – The training network. The network only supports single output.

  • optimizer (Union[Cell]) – Optimizer for updating the weights.

  • weight (Tuple(Parameter)) – Tuple of parameters.

  • pca_mat_local (numpy.ndarray) – For PCA operation, k*n, k is part of n_components, n is the size of weight.

  • n_components (int) – PCA.components.

  • rho (float) – Coefficient.

  • gamma (float) – Coefficient.

  • alpha (float) – Coefficient.

  • sigma (float) – Coefficient.

  • rank (int) – Rank number.

  • rank_size (int) – Rank size.

Inputs:
  • loss (Tensor) - Tensor with shape ().

  • old_grad (Tuple(Tensor)) - Tuple of gradient tensors.

  • weight (Tuple(Tensor)) - Tuple of parameters.

  • weight_clone (Tuple(Tensor)) - clone of weight

  • *inputs (Tuple(Tensor)) - Tuple of input tensors with shape (N,).

Outputs:
  • loss (Tensor) - Tensor with shape ().

class mindspore.boost.FreezeOpt(opt, train_parameter_groups=None, train_strategy=None)[source]

Optimizer that supports gradients freezing training.

Parameters
  • opt (Cell) – non-freezing optimizer instance, such as 'Momentum', 'SGD'.

  • train_parameter_groups (Union[tuple, list]) – Groups of parameters for gradients freezing training. Default: None .

  • train_strategy (Union[tuple(int), list(int), Tensor]) – Strategy for gradients freezing training. Default: None .

Supported Platforms:

Ascend

class mindspore.boost.GradientAccumulation(max_accumulation_step, optimizer)[source]

After accumulating the gradients of multiple steps, call to optimize its update.

Parameters
  • max_accumulation_step (int) – Steps to accumulate gradients.

  • optimizer (Cell) – Optimizer used.

class mindspore.boost.GradientFreeze(param_groups, freeze_type, freeze_p, total_steps)[source]

Gradients freezing algorithm, freezing the gradients of some layers randomly, to improve network training performance. The number and probability of frozen layers can be configured by users.

Parameters
  • param_groups (Union[tuple, list]) – Groups of parameters for gradients freezing training.

  • freeze_type (int) – Strategy of gradients freezing training.

  • freeze_p (float) – probability of gradients freezing training.

  • total_steps (int) – Steps of the whole training.

Examples

>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> from mindspore import ops
>>> from mindspore.nn import WithLossCell
>>> from mindspore import dtype as mstype
>>> from mindspore import boost
>>>
>>> class Net(nn.Cell):
...    def __init__(self, in_features, out_features):
...        super(Net, self).__init__()
...        self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                name='weight')
...        self.matmul = ops.MatMul()
...
...    def construct(self, x):
...        output = self.matmul(x, self.weight)
...        return output
>>> size, in_features, out_features = 16, 16, 10
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = WithLossCell(net, loss)
>>> gradient_freeze_class = boost.GradientFreeze(10, 1, 0.5, 2000)
>>> network, optimizer = gradient_freeze_class.freeze_generate(net_with_loss, optimizer)
>>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
>>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
>>> output = network(inputs, label)
freeze_generate(network, optimizer)[source]

Generate freeze network and optimizer.

Parameters
  • network (Cell) – The training network.

  • optimizer (Cell) – Optimizer for updating the weights.

generate_freeze_index_sequence(parameter_groups_number, freeze_strategy, freeze_p, total_steps)[source]

Generate index sequence for gradient freezing training.

Parameters
  • parameter_groups_number (int) – The number of parameter groups.

  • freeze_strategy (int) – Gradient freeze grouping strategy, select from [0, 1].

  • freeze_p (float) – Gradient freezing probability.

  • total_steps (int) – Total training steps.

split_parameters_groups(net, freeze_para_groups_number)[source]

Split parameter groups for gradients freezing training.

Parameters
  • net (Cell) – The training network.

  • freeze_para_groups_number (int) – The number of gradient freeze groups.

class mindspore.boost.GroupLossScaleManager(init_loss_scale, loss_scale_groups)[source]

Enhanced hybrid precision algorithm supports multi-layer application of different loss scales and dynamic updating of loss scales.

Parameters
  • init_loss_scale (Number) – The initialized loss scale value.

  • loss_scale_groups (List) – The loss scale groups, which are divided from the param list.

Inputs:
  • x (Tensor) - The output of last operator.

  • layer1 (Int) - Current network layer value.

  • layer2 (Int) - Last network layer value.

Outputs:
  • out (Tensor) - A tensor with a group of loss scale tags that marks the loss scale group number of the current tensor.

Supported Platforms:

Ascend

Examples

>>> import mindspore as ms
>>> from mindspore import boost, nn
>>>
>>> class Net(nn.Cell):
...     def __init__(self, enhanced_amp, num_class=10, num_channel=1):
...         super(Net, self).__init__()
...         self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
...         self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
...         self.fc1 = nn.Dense(16*5*5, 120, weight_init='ones')
...         self.fc2 = nn.Dense(120, 84, weight_init='ones')
...         self.fc3 = nn.Dense(84, num_class, weight_init='ones')
...         self.relu = nn.ReLU()
...         self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
...         self.flatten = nn.Flatten()
...         self.enhanced_amp = enhanced_amp
...
...     def construct(self, x):
...         x = self.enhanced_amp(x, 0, 1)
...         x = self.max_pool2d(self.relu(self.conv1(x)))
...         x = self.max_pool2d(self.relu(self.conv2(x)))
...         x = self.flatten(x)
...         x = self.enhanced_amp(x, 1, 2)
...         x = self.relu(self.fc1(x))
...         x = self.relu(self.fc2(x))
...         x = self.fc3(x)
...         x = self.enhanced_amp(x, 2, 3)
...         return x
>>>
>>> loss_scale_manager = boost.GroupLossScaleManager(4096, [])
>>> net = Net(loss_scale_manager)
>>> param_group1 = []
>>> param_group2 = []
>>> for param in net.trainable_params():
...     if 'conv' in param.name:
...         param_group1.append(param)
...     else:
...         param_group2.append(param)
>>> loss_scale_manager.loss_scale_groups = [param_group1, param_group2]
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> boost_config_dict = {"boost": {"mode": "manual", "less_bn": False, "grad_freeze": False, "adasum": False,
...                      "grad_accumulation": False, "dim_reduce": False, "loss_scale_group": True}}
>>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim, metrics=None,
...                        loss_scale_manager=loss_scale_manager,
...                        boost_level="O1", boost_config_dict=boost_config_dict)
>>> # Create the dataset taking MNIST as an example. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.5.0/docs/mindspore/code/mnist.py
>>> dataset = create_dataset()
>>> model.train(2, dataset)
get_loss_scale()[source]

Get loss scale value.

Returns

bool, loss_scale value.

get_update_cell()[source]

Returns the instance of mindspore.boost.GroupLossScaleManager.

Returns

mindspore.boost.GroupLossScaleManager.

set_loss_scale_status(loss_scale_number, init_loss_scale)[source]

Generate dynamic loss scale tuple and set overflow status list.

Parameters
  • loss_scale_number (int) – The number of loss scale.

  • init_loss_scale (float) – The initialized loss scale.

update_loss_scale_status(layer, update_ratio)[source]

Update dynamic loss scale.

Parameters
  • layer (int) – Current layer.

  • update_ratio (float) – The ratio of loss scale update.

Outputs:

float, new loss scale.

class mindspore.boost.LessBN(network, fn_flag=False)[source]

Reduce the number of BN automatically to improve the network performance and ensure the network accuracy.

Parameters
  • network (Cell) – Network to be modified.

  • fn_flag (bool) – Replace FC with FN. Default: False .

Examples

>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> from mindspore import ops
>>> from mindspore.nn import WithLossCell
>>> from mindspore import dtype as mstype
>>> from mindspore import boost
>>>
>>> class Net(nn.Cell):
...    def __init__(self, in_features, out_features):
...        super(Net, self).__init__()
...        self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                name='weight')
...        self.matmul = ops.MatMul()
...
...    def construct(self, x):
...        output = self.matmul(x, self.weight)
...        return output
>>> size, in_features, out_features = 16, 16, 10
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = WithLossCell(net, loss)
>>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
>>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
>>> train_network = boost.LessBN(net_with_loss)
>>> output = train_network(inputs, label)
class mindspore.boost.OptimizerProcess(opt)[source]

Process optimizer for Boost. Currently, this class supports adding GC(grad centralization) tags and creating new optimizers.

Parameters

opt (Cell) – Optimizer used.

Examples

>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> from mindspore import ops
>>> from mindspore.boost import OptimizerProcess
>>>
>>> class Net(nn.Cell):
...     def __init__(self, in_features, out_features):
...         super(Net, self).__init__()
...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                 name='weight')
...         self.matmul = ops.MatMul()
...
...     def construct(self, x):
...         output = self.matmul(x, self.weight)
...         return output
...
>>> size, in_features, out_features = 16, 16, 10
>>> network = Net(in_features, out_features)
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> optimizer_process = OptimizerProcess(optimizer)
>>> optimizer_process.add_grad_centralization(network)
>>> optimizer = optimizer_process.generate_new_optimizer()
add_grad_centralization(network)[source]

Add gradient centralization.

Parameters

network (Cell) – The training network.

static build_gc_params_group(params_dict, parameters)[source]

Build the parameter's group with grad centralization.

Parameters
  • params_dict (dict) – The network's parameter dict.

  • parameters (list) – The network's parameter list.

static build_params_dict(network)[source]

Build the parameter's dict of the network.

Parameters

network (Cell) – The training network.

generate_new_optimizer()[source]

Generate new optimizer.

class mindspore.boost.ParameterProcess[source]

Process parameter for Boost. Currently, this class supports creating group parameters and automatically setting gradient segmentation point.

Examples

>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> from mindspore import ops
>>> from mindspore.boost import ParameterProcess
>>>
>>> class Net(nn.Cell):
...     def __init__(self, in_features, out_features):
...         super(Net, self).__init__()
...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                 name='weight')
...         self.weight2 = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                 name='weight2')
...         self.matmul = ops.MatMul()
...         self.matmul2 = ops.MatMul()
...
...     def construct(self, x):
...         output = self.matmul(x, self.weight)
...         output2 = self.matmul2(x, self.weight2)
...         return output + output2
...
>>> size, in_features, out_features = 16, 16, 10
>>> network = Net(in_features, out_features)
>>> new_parameter = network.trainable_params()[:1]
>>> group_params = ParameterProcess.generate_group_params(new_parameter, network.trainable_params())
assign_parameter_group(parameters, split_point=None)[source]

Assign parameter group.

Parameters
  • parameters (list) – The network's parameter list.

  • split_point (list) – The gradient split point of this network. Default: None.

static generate_group_params(parameters, origin_params)[source]

Generate group parameters.

Parameters
  • parameters (list) – The network's parameter list.

  • origin_params (list) – The network's origin parameter list.

mindspore.boost.freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None, max_accumulation_step=1)[source]

Generate freeze network and optimizer.

Parameters
  • reducer_flag (bool) – Reducer flag.

  • network (Cell) – The training network.

  • optimizer (Cell) – Optimizer for updating the weights.

  • sens (numbers.Number) – The scaling number.

  • grad (tuple(Tensor)) – Tuple of gradient tensors.

  • use_grad_accumulation (bool) – Use gradient accumulation flag.

  • mean (bool) – Gradients mean flag. Default: None .

  • degree (int) – Device number. Default: None .

  • max_accumulation_step (int) – Max accumulation steps. Default: 1 .

Examples

>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> from mindspore import ops
>>> from mindspore.boost.grad_freeze import freeze_cell, FreezeOpt
>>>
>>> class Net(nn.Cell):
...     def __init__(self, in_features, out_features):
...         super(Net, self).__init__()
...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                 name='weight')
...         self.matmul = ops.MatMul()
...
...     def construct(self, x):
...         output = self.matmul(x, self.weight)
...         return output
...
>>> in_features, out_features = 16, 10
>>> network = Net(in_features, out_features)
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> optimizer = FreezeOpt(optimizer)
>>> grad = ops.GradOperation(get_by_list=True, sens_param=True)
>>> freeze_nets = freeze_cell(False, network, optimizer, 1.0, grad, False, None, None, 1)