mindspore.ops.Morph
- class mindspore.ops.Morph(fn, infer_shape, infer_dtype)[源代码]
Morph 算子用于对用户自定义函数 fn 进行封装,允许其被当做自定义算子使用。
Morph 算子的主要适用于静态图的分布式自动并行场景,通过在自定义函数 fn 中使用集合通信算子,实现自定义的并行计算逻辑,尤其适用于 fn 内存在动态Shape的场景。
Morph 算子作用于输入时,实际上是其内封装的自定义函数 fn 作用于输入。
Morph 算子与
mindspore.ops.Custom()
的主要区别在于,前者会在自动微分前被展开替换为用户自定义 fn,故无需实现反向函数。说明
本算子只支持图模式。
fn 必须满足图模式语法约束。
用户无需实现自定义反向函数。
用户自定义函数不支持 vararg、kwarg、kwonlyargs 和自由变量。
- 参数:
fn (Function) - MindSpore Function,用户自定义函数。
infer_shape (Function) - Mindspore Function,用户自定义 infer_shape 函数。
infer_dtype (Function) - Mindspore Function,用户自定义 infer_dtype 函数。
- 输入:
用户自定义 fn 的输入。
- 输出:
用户自定义 fn 的输出。
- 异常:
RuntimeError - 如果算子在非图模式下被使用。
样例:
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import context, nn, ops, Tensor, Parameter >>> >>> np_weight0 = np.array([1.0, 2.0, 3.0]) >>> np_weight1 = np.array([4.0, 5.0, 6.0]) >>> np_input_x = np.array([7.0, 8.0, 9.0]) >>> >>> def infer_dtype(args): ... return args >>> >>> def infer_shape(args): ... return args >>> >>> def mul_by(*args): ... def inner(x): ... return args[0] * x ... return inner >>> >>> NUMBER_100 = 100 >>> class MorphNet(nn.Cell): ... def __init__(self): ... super(MorphNet, self).__init__() ... self.weight0 = Parameter(Tensor(np_weight0, ms.float32), name="weight0") ... self.weight1 = Parameter(Tensor(np_weight1, ms.float32), name="weight1") ... self.mul_by_100 = ops.Morph(mul_by(NUMBER_100), infer_shape, infer_dtype) ... def construct(self, x): ... a = x * self.weight0 ... b = self.mul_by_100(a) ... out = b * self.weight1 ... return out >>> >>> context.set_context(mode=context.GRAPH_MODE) >>> input_x = Tensor(np_input_x, ms.float32) >>> net = MorphNet() >>> grad_op = ops.GradOperation(get_all=True, get_by_list=True) >>> grad_net = grad_op(net, net.trainable_params()) >>> bwd_out = grad_net(input_x) >>> x_grad = bwd_out[0][0].asnumpy() >>> weight0_grad = bwd_out[1][0].asnumpy() >>> weight1_grad = bwd_out[1][1].asnumpy() >>> print("x_grad", x_grad) >>> print("weight0_grad", weight0_grad) >>> print("weight1_grad", weight1_grad) x_grad [ 400. 1000. 1800.] weight0_grad [2800. 4000. 5400.] weight1_grad [ 700. 1600. 2700.]