mindspore.ops.Morph

查看源文件
class mindspore.ops.Morph(fn, infer_shape, infer_dtype)[源代码]

Morph 算子用于对用户自定义函数 fn 进行封装,允许其被当做自定义算子使用。

Morph 算子的主要适用于静态图的分布式自动并行场景,通过在自定义函数 fn 中使用集合通信算子,实现自定义的并行计算逻辑,尤其适用于 fn 内存在动态Shape的场景。

Morph 算子作用于输入时,实际上是其内封装的自定义函数 fn 作用于输入。

Morph 算子与 mindspore.ops.Custom() 的主要区别在于,前者会在自动微分前被展开替换为用户自定义 fn,故无需实现反向函数。

说明

  • 本算子只支持图模式。

  • fn 必须满足图模式语法约束。

  • 用户无需实现自定义反向函数。

  • 用户自定义函数不支持 varargkwargkwonlyargs 和自由变量。

参数:
  • fn (Function) - MindSpore Function,用户自定义函数。

  • infer_shape (Function) - Mindspore Function,用户自定义 infer_shape 函数。

  • infer_dtype (Function) - Mindspore Function,用户自定义 infer_dtype 函数。

输入:

用户自定义 fn 的输入。

输出:

用户自定义 fn 的输出。

异常:
  • RuntimeError - 如果算子在非图模式下被使用。

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import context, nn, ops, Tensor, Parameter
>>>
>>> np_weight0 = np.array([1.0, 2.0, 3.0])
>>> np_weight1 = np.array([4.0, 5.0, 6.0])
>>> np_input_x = np.array([7.0, 8.0, 9.0])
>>>
>>> def infer_dtype(args):
...     return args
>>>
>>> def infer_shape(args):
...     return args
>>>
>>> def mul_by(*args):
...     def inner(x):
...         return args[0] * x
...     return inner
>>>
>>> NUMBER_100 = 100
>>> class MorphNet(nn.Cell):
...     def __init__(self):
...         super(MorphNet, self).__init__()
...         self.weight0 = Parameter(Tensor(np_weight0, ms.float32), name="weight0")
...         self.weight1 = Parameter(Tensor(np_weight1, ms.float32), name="weight1")
...         self.mul_by_100 = ops.Morph(mul_by(NUMBER_100), infer_shape, infer_dtype)
...     def construct(self, x):
...         a = x * self.weight0
...         b = self.mul_by_100(a)
...         out = b * self.weight1
...         return out
>>>
>>> context.set_context(mode=context.GRAPH_MODE)
>>> input_x = Tensor(np_input_x, ms.float32)
>>> net = MorphNet()
>>> grad_op = ops.GradOperation(get_all=True, get_by_list=True)
>>> grad_net = grad_op(net, net.trainable_params())
>>> bwd_out = grad_net(input_x)
>>> x_grad = bwd_out[0][0].asnumpy()
>>> weight0_grad = bwd_out[1][0].asnumpy()
>>> weight1_grad = bwd_out[1][1].asnumpy()
>>> print("x_grad", x_grad)
>>> print("weight0_grad", weight0_grad)
>>> print("weight1_grad", weight1_grad)
x_grad [ 400. 1000. 1800.]
weight0_grad [2800. 4000. 5400.]
weight1_grad [ 700. 1600. 2700.]