mindspore.ops.stop_gradient
- mindspore.ops.stop_gradient(value)[源代码]
用于消除某个值对梯度的影响,例如截断来自于函数输出的梯度传播。更多细节请参考 Stop Gradient 。
- 参数:
value (Any) - 需要被消除梯度影响的值。
- 返回:
一个与 value 相同的值。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import mindspore >>> def f1(x): ... return x ** 2 >>> x = 3.0 >>> f1(x) 9.0 >>> mindspore.ops.grad(f1)(mindspore.tensor(x)) Tensor(shape=[], dtype=Float32, value= 6) >>> >>> # The same function with stop_gradient, return a zero gradient because x is effectively treated as a constant. >>> def f2(x): ... return mindspore.ops.stop_gradient(x) ** 2 >>> f2(x) 9.0 >>> mindspore.ops.grad(f2)(mindspore.tensor(x)) Tensor(shape=[], dtype=Float32, value= 0)