Source code for mindspore.nn.optim.tft_wrapper

# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""OptTFTWrapper"""
from __future__ import absolute_import

import os
from mindspore.common.tensor import Tensor
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.ops.operations.manually_defined._inner import TensorReport
from mindspore import ops, context
from mindspore.common.parameter import Parameter
import mindspore.common.dtype as mstype

[docs]class OptTFTWrapper(Optimizer): r""" Implements TFT optimizer wrapper, this wrapper is used to report status to MindIO TFT before optimizer updating. Note: This optimizer is depend on MindIO TFT feature. Currently only support ascend graph mode and sink_size must be less than 1. Args: opt (Optimizer): Must be sub-class of Optimizer. Inputs: - **gradients** (tuple[Tensor]) - The gradients of opt's `params`, the shape is the same as opt's `params`. Outputs: Tensor, result of executing optimizer 'opt'. Raises: TypeError: If the parameter opt is not an subclass of Optimizer. ValueError: If the platform is not Ascend graph mode, or customer doesn't switch on TFT feature. Supported Platforms: ``Ascend`` Examples: >>> import mindspore as ms >>> from mindspore import nn >>> >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.4.10/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> #1) All parameters use the same learning rate and weight decay >>> optim = nn.SGD(params=net.trainable_params()) >>> optim_wrapper = nn.OptTFTWrapper(optim) >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim) """ def __init__(self, opt, **kwargs): if not isinstance(opt, Optimizer): raise TypeError(f"For 'OptTFTWrapper', the argument 'opt' must be Optimizer type, " f"but got {type(opt)}.") super(OptTFTWrapper, self).__init__(opt.learning_rate, opt._parameters) # pylint: disable=W0212 tft_env = os.getenv("MS_ENABLE_TFT", "") if ("TTP:1" not in tft_env) and ("UCE:1" not in tft_env): raise ValueError("MindIO TFT regitster need custom switch on[MS_ENABLE_TFT='{TTP:1,UCE:1}']!") mode = context.get_context("mode") device_target = context.get_context("device_target") if device_target != "Ascend" or mode != context.GRAPH_MODE: raise ValueError("MindIO adataper only support on Ascend device with GRAPH Mode!") self.opt = opt self.report = TensorReport() self.depend = ops.Depend() self.allreduce_sum = ops.AllReduce() self.allreduce_sum.add_prim_attr("tft_report_before", True) self.tft_g_one_flag = Parameter(Tensor([1], dtype=mstype.int32)) self.param_rank = opt.param_rank self.optim_filter = opt.optim_filter self.loss_scale = opt.loss_scale self.dynamic_weight_decay = opt.dynamic_weight_decay self.grad_centralization = opt.grad_centralization self.dynamic_lr = opt.dynamic_lr self.global_step = opt.global_step self.is_group = opt.is_group self.is_group_lr = opt.is_group_lr self.is_group_params_ordered = opt.is_group_params_ordered self.use_parallel = opt.use_parallel if self.is_group: self.group_params = opt.group_params self.group_lr = opt.group_lr self.group_weight_decay = opt.group_weight_decay self.group_grad_centralization = opt.group_grad_centralization self.grad_centralization_flags = opt.grad_centralization_flags self.skip_auto_parallel_compile = opt.skip_auto_parallel_compile self.learning_rate = opt.learning_rate self.parameters = opt.parameters self.decay_flags = opt.decay_flags self.dynamic_decay_flags = opt.dynamic_decay_flags self.weight_decay = opt.weight_decay self.exec_weight_decay = opt.exec_weight_decay self.ps_parameters = opt.ps_parameters self.cache_enable = opt.cache_enable self.reciprocal_scale = opt.reciprocal_scale self.need_scale = opt.need_scale self.global_step_increase_tensor = opt.global_step_increase_tensor self.param_length = opt.param_length self.enable_tuple_broaden = opt.enable_tuple_broaden def construct(self, gradients): tft_g_one_flag = self.depend(self.tft_g_one_flag, gradients) self.tft_g_one_flag = self.allreduce_sum(tft_g_one_flag) grads = self.depend(gradients, self.report("tft_report", self.tft_g_one_flag)) opt_ret = self.opt(grads) return opt_ret