mindelec.architecture.MTLWeightedLossCell
- class mindelec.architecture.MTLWeightedLossCell(num_losses)[源代码]
MTL策略自动加权多任务损失。请参考 自动加权进行多任务学习 。
- 参数:
num_losses (int) - 多任务损失的数量,应为正整数。
- 输入:
input (tuple[Tensor]) - 输入数据。
- 输出:
Scalar。多任务学习自动加权计算出的损失。
- 支持平台:
Ascend
样例:
>>> import numpy as np >>> from mindelec.architecture import MTLWeightedLossCell >>> import mindspore >>> from mindspore import Tensor >>> net = MTLWeightedLossCell(num_losses=2) >>> input1 = Tensor(1.0, mindspore.float32) >>> input2 = Tensor(0.8, mindspore.float32) >>> output = net((input1, input2)) >>> print(output) 2.2862945