mindspore.ops.stop_gradient

mindspore.ops.stop_gradient(value)[源代码]

用于消除某个值对梯度的影响,例如截断来自于函数输出的梯度传播。更多细节请参考 Stop Gradient

参数:
  • value (Any) - 需要被消除梯度影响的值。

返回:

一个与 value 相同的值。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore.ops as ops
>>> from mindspore import Tensor
>>> from mindspore import dtype as mstype
>>> def net(x, y):
...     out1 = ops.MatMul()(x, y)
...     out2 = ops.MatMul()(x, y)
...     out2 = ops.stop_gradient(out2)
...     return out1, out2
...
>>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
>>> grad_fn = ops.grad(net)
>>> output = grad_fn(x, y)
>>> print(output)
[[1.4100001 1.6       6.5999994]
 [1.4100001 1.6       6.5999994]]