mindspore.ops.Assign

查看源文件
class mindspore.ops.Assign[源代码]

为网络参数赋值。

更多细节请参考 mindspore.ops.assign()

输入:
  • variable (Parameter) - 待赋值的网络参数,shape: \((N,*)\) ,其中 \(*\) 表示任何数量的附加维度。其秩应小于8。

  • value (Tensor) - 被赋给网络参数的值,和 variable 有相同的shape。

输出:

Tensor,shape和dtype与 variable 相同。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> from mindspore import Tensor, ops
>>> value = Tensor([2.0], mindspore.float32)
>>> variable = mindspore.Parameter(Tensor([1.0], mindspore.float32), name="variable")
>>> assign = ops.Assign()
>>> x = assign(variable, value)
>>> print(variable.asnumpy())
[2.]