mindspore.ops.stop_gradient

查看源文件
mindspore.ops.stop_gradient(value)[源代码]

用于消除某个值对梯度的影响,例如截断来自于函数输出的梯度传播。更多细节请参考 Stop Gradient

参数:
  • value (Any) - 需要被消除梯度影响的值。

返回:

一个与 value 相同的值。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> def f1(x):
...     return x ** 2
>>> x = 3.0
>>> f1(x)
9.0
>>> mindspore.ops.grad(f1)(mindspore.tensor(x))
Tensor(shape=[], dtype=Float32, value= 6)
>>>
>>> # The same function with stop_gradient, return a zero gradient because x is effectively treated as a constant.
>>> def f2(x):
...     return mindspore.ops.stop_gradient(x) ** 2
>>> f2(x)
9.0
>>> mindspore.ops.grad(f2)(mindspore.tensor(x))
Tensor(shape=[], dtype=Float32, value= 0)