mindspore.mutable
- mindspore.mutable(input_data)[源代码]
设置一个常量值为可变的。
当前除了Tensor,所有顶层网络的输入,例如标量、tuple、list和dict,都被当做常量值。常量值是不能求导的,而且在编译优化阶段会被常量折叠掉。 另外,当网络的输入是tuple[Tensor], list[Tensor]或Dict[Tensor]时,即使里面Tensor的shape和dtype没有发生变化,在多次调用同一个网络的时候,这个网络每次都会被重新编译,这是因为这些类型的输入被当做常量值处理了。
为解决以上的问题,我们提供了 mutable 接口去设置网络的常量输入为”可变的”。一个”可变的”输入意味着这个输入成为了像Tensor一样的变量。最重要的是,我们可以对其进行求导了。
- 参数:
input_data (Union[Tensor, tuple[Tensor], list[Tensor], dict[Tensor]]) - 要设置为可变的输入数据。
Warning
这是一个实验特性,未来有可能被修改或删除。
目前运行时暂时不支持处理标量数据流,所以我们目前只支持Tensor、tuple[Tensor]、list[Tensor]或dict[Tensor]作为输入,主要解决重复编译的问题。
当前暂时只支持在网络外部使用该接口。
当前该接口只在图模式下生效。
- 返回:
状态设置为可变的原输入数据。
- 异常:
TypeError - 如果 input_data 不是Tensor、tuple[Tensor]、list[Tensor]或dict[Tensor]的其中一种类型或者不是它们的嵌套结构。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import mindspore.nn as nn >>> import mindspore.ops as ops >>> from mindspore.ops.composite import GradOperation >>> from mindspore.common import mutable >>> from mindspore.common import dtype as mstype >>> from mindspore import Tensor >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.matmul = ops.MatMul() ... ... def construct(self, z): ... x = z[0] ... y = z[1] ... out = self.matmul(x, y) ... return out ... >>> class GradNetWrtX(nn.Cell): ... def __init__(self, net): ... super(GradNetWrtX, self).__init__() ... self.net = net ... self.grad_op = GradOperation() ... ... def construct(self, z): ... gradient_function = self.grad_op(self.net) ... return gradient_function(z) ... >>> z = mutable((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), ... Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32))) >>> output = GradNetWrtX(Net())(z) >>> print(output) (Tensor(shape=[2, 3], dtype=Float32, value= [[ 1.41000009e+00, 1.60000002e+00, 6.59999943e+00], [ 1.41000009e+00, 1.60000002e+00, 6.59999943e+00]]), Tensor(shape=[3, 3], dtype=Float32, value= [[ 1.70000005e+00, 1.70000005e+00, 1.70000005e+00], [ 1.89999998e+00, 1.89999998e+00, 1.89999998e+00], [ 1.50000000e+00, 1.50000000e+00, 1.50000000e+00]]))