mindspore.nn.OptTFTWrapper
- class mindspore.nn.OptTFTWrapper(opt, **kwargs)[source]
Implements TFT optimizer wrapper, this wrapper is used to report status to MindIO TFT before optimizer updating.
Note
This optimizer is depend on MindIO TFT feature. Currently only support ascend graph mode and sink_size must be less than 1.
- Parameters
opt (Optimizer) – Must be sub-class of Optimizer.
- Inputs:
gradients (tuple[Tensor]) - The gradients of opt's params, the shape is the same as opt's params.
- Outputs:
Tensor, result of executing optimizer 'opt'.
- Raises
TypeError – If the parameter opt is not an subclass of Optimizer.
ValueError – If the platform is not Ascend graph mode, or customer doesn't switch on TFT feature.
- Supported Platforms:
Ascend
Examples
>>> import mindspore as ms >>> from mindspore import nn >>> >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> #1) All parameters use the same learning rate and weight decay >>> optim = nn.SGD(params=net.trainable_params()) >>> optim_wrapper = nn.OptTFTWrapper(optim) >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim)