mindspore.nn.OptTFTWrapper

View Source On Gitee
class mindspore.nn.OptTFTWrapper(opt, **kwargs)[source]

Implements TFT optimizer wrapper, this wrapper is used to report status to MindIO TFT before optimizer updating.

Note

This optimizer is depend on MindIO TFT feature. Currently only support ascend graph mode and sink_size must be less than 1.

Parameters

opt (Optimizer) – Must be sub-class of Optimizer.

Inputs:
  • gradients (tuple[Tensor]) - The gradients of opt's params, the shape is the same as opt's params.

Outputs:

Tensor, result of executing optimizer 'opt'.

Raises
  • TypeError – If the parameter opt is not an subclass of Optimizer.

  • ValueError – If the platform is not Ascend graph mode, or customer doesn't switch on TFT feature.

Supported Platforms:

Ascend

Examples

>>> import mindspore as ms
>>> from mindspore import nn
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.SGD(params=net.trainable_params())
>>> optim_wrapper = nn.OptTFTWrapper(optim)
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim)