mindspore.mutable

mindspore.mutable(input_data)[源代码]

设置一个常量值为可变的。

当前除了Tensor,所有顶层网络的输入,例如标量、tuple、list和dict,都被当做常量值。常量值是不能求导的,而且在编译优化阶段会被常量折叠掉。 另外,当网络的输入是tuple[Tensor], list[Tensor]或Dict[Tensor]时,即使里面Tensor的shape和dtype没有发生变化,在多次调用同一个网络的时候,这个网络每次都会被重新编译,这是因为这些类型的输入被当做常量值处理了。

为解决以上的问题,我们提供了 mutable 接口去设置网络的常量输入为”可变的”。一个”可变的”输入意味着这个输入成为了像Tensor一样的变量。最重要的是,我们可以对其进行求导了。

参数:
  • input_data (Union[Tensor, tuple[Tensor], list[Tensor], dict[Tensor]]) - 要设置为可变的输入数据。

Warning

  • 这是一个实验特性,未来有可能被修改或删除。

  • 目前运行时暂时不支持处理标量数据流,所以我们目前只支持Tensor、tuple[Tensor]、list[Tensor]或dict[Tensor]作为输入,主要解决重复编译的问题。

  • 当前暂时只支持在网络外部使用该接口。

  • 当前该接口只在图模式下生效。

返回:

状态设置为可变的原输入数据。

异常:
  • TypeError - 如果 input_data 不是Tensor、tuple[Tensor]、list[Tensor]或dict[Tensor]的其中一种类型或者不是它们的嵌套结构。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> from mindspore.ops.composite import GradOperation
>>> from mindspore.common import mutable
>>> from mindspore.common import dtype as mstype
>>> from mindspore import Tensor
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.matmul = ops.MatMul()
...
...     def construct(self, z):
...         x = z[0]
...         y = z[1]
...         out = self.matmul(x, y)
...         return out
...
>>> class GradNetWrtX(nn.Cell):
...     def __init__(self, net):
...         super(GradNetWrtX, self).__init__()
...         self.net = net
...         self.grad_op = GradOperation()
...
...     def construct(self, z):
...         gradient_function = self.grad_op(self.net)
...         return gradient_function(z)
...
>>> z = mutable((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
...              Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
>>> output = GradNetWrtX(Net())(z)
>>> print(output)
(Tensor(shape=[2, 3], dtype=Float32, value=
[[ 1.41000009e+00, 1.60000002e+00, 6.59999943e+00],
 [ 1.41000009e+00, 1.60000002e+00, 6.59999943e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
[[ 1.70000005e+00, 1.70000005e+00, 1.70000005e+00],
 [ 1.89999998e+00, 1.89999998e+00, 1.89999998e+00],
 [ 1.50000000e+00, 1.50000000e+00, 1.50000000e+00]]))