mindspore.ops.stop_gradient
- mindspore.ops.stop_gradient(value)[源代码]
用于消除某个值对梯度的影响,例如截断来自于函数输出的梯度传播。更多细节请参考 Stop Gradient 。
- 参数:
value (Any) - 需要被消除梯度影响的值。
- 返回:
一个与 value 相同的值。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> from mindspore import ops >>> from mindspore import Tensor >>> from mindspore import dtype as mstype >>> def net(x, y): ... out1 = ops.MatMul()(x, y) ... out2 = ops.MatMul()(x, y) ... out2 = ops.stop_gradient(out2) ... return out1, out2 ... >>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32) >>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32) >>> grad_fn = ops.grad(net) >>> output = grad_fn(x, y) >>> print(output) [[1.4100001 1.6 6.5999994] [1.4100001 1.6 6.5999994]]