mindelec.architecture.MTLWeightedLossCell

class mindelec.architecture.MTLWeightedLossCell(num_losses)[源代码]

MTL策略自动加权多任务损失。请参考 自动加权进行多任务学习

参数:
  • num_losses (int) - 多任务损失的数量,应为正整数。

输入:
  • input (tuple[Tensor]) - 输入数据。

输出:

Scalar。多任务学习自动加权计算出的损失。

支持平台:

Ascend

样例:

>>> import numpy as np
>>> from mindelec.architecture import MTLWeightedLossCell
>>> import mindspore
>>> from mindspore import Tensor
>>> net = MTLWeightedLossCell(num_losses=2)
>>> input1 = Tensor(1.0, mindspore.float32)
>>> input2 = Tensor(0.8, mindspore.float32)
>>> output = net((input1, input2))
>>> print(output)
2.2862945