mindspore.nn.OptTFTWrapper

查看源文件
class mindspore.nn.OptTFTWrapper(opt)[源代码]

实现TFT优化器封装器。该封装器将在优化器更新前向MindIO TFT上报状态。

说明

该优化器依赖于MindIO TFT特性。当前只支持Ascend后端的图模式,并且sink_size的配置必须小于等于1。

参数:
  • opt (Optimizer) - 该参数必须为Optimizer的子类。

输入:
  • gradients (tuple[Tensor]) - 参数opt的 params 的梯度,shape与opt的 params shape 相同。

输出:

Tensor,优化器opt执行返回的结果。

异常:
  • TypeError - 如果opt不是Optimizer的子类。

  • ValueError - 如果不是运行在Ascend后端的图模式,或者用户不开启TFT特性。

支持平台:

Ascend

样例:

>>> import mindspore as ms
>>> from mindspore import nn
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.SGD(params=net.trainable_params())
>>> optim_wrapper = nn.OptTFTWrapper(optim)
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim)