mindspore.mutable

查看源文件
mindspore.mutable(input_data, dynamic_len=False)[源代码]

设置一个常量值为可变的。

当前除了Tensor,所有顶层网络的输入,例如标量、tuple、list和dict,都被当做常量值。常量值是不能求导的,而且在编译优化阶段会被常量折叠掉。 另外,当网络的输入是tuple[Tensor]、list[Tensor]或Dict[Tensor]时,即使里面Tensor的shape和dtype没有发生变化,在多次调用同一个网络的时候,这个网络每次都会被重新编译,这是因为这些类型的输入被当做常量值处理了。

为解决以上的问题,我们提供了 mutable 接口去设置网络的常量输入为”可变的”。一个”可变的”输入意味着这个输入成为了像Tensor一样的变量,最重要的是,我们可以对其进行求导了。

input_data 是tuple或者list并且 dynamic_len 是False的情况下,mutable 的返回值是一个固定长度的tuple或者list,且其中的每一个元素都是可变的。当 dyanmic_len 被设置为True的时候,返回的tuple或者list长度是动态的。

如果一个动态长度的tuple或者list被作为网络的输入并且这个网络被重复调用,且每一次的输入的tuple或者list长度都不一致,这个网络也不需要被重新编译。

参数:
  • input_data (Union[Tensor, scalar, tuple, list, dict]) - 要设置为可变的输入数据。如果 input_data 是list,tuple或者dict, 其内部元素的类型也需要是这些有效类型中的一个。

  • dynamic_len (bool) - 是否要将整个序列设置为动态长度的。在图编译内,如果 dynamic_len 被设置为 True , 那么 input_data 必须为tuple或者list, 并且其中的元素必须有相同的类型以及形状。默认值: False

警告

这是一个实验性API,后续可能修改或删除。 dynamic_len 是实验性质的参数,暂不支持 dynamic_lenTrue

说明

当前该接口只在图模式下生效。

返回:

状态设置为可变的原输入数据。

异常:
  • TypeError - 如果 input_data 不是Tensor, scalar, tuple, list 或dict的其中一种类型或者不是它们的嵌套结构。

  • TypeError - 如果 dynamic_len 被设置为 True 并且 input_data 不是tuple或者list。

  • ValueError - 如果 dynamic_len 被设置为 Trueinput_data 是tuple或者list的情况下,其中的元素的形状或者类型不一致。

支持平台:

Ascend GPU CPU

样例:

>>> from mindspore import mutable, nn, ops, Tensor, context
>>> from mindspore import dtype as mstype
>>> context.set_context(mode=context.GRAPH_MODE)
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.matmul = ops.MatMul()
...
...     def construct(self, z):
...         x = z[0]
...         y = z[1]
...         out = self.matmul(x, y)
...         return out
...
>>> class GradNetWrtX(nn.Cell):
...     def __init__(self, net):
...         super(GradNetWrtX, self).__init__()
...         self.net = net
...         self.grad_op = ops.GradOperation()
...
...     def construct(self, z):
...         gradient_function = self.grad_op(self.net)
...         return gradient_function(z)
...
>>> z = mutable((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
...              Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
>>> output = GradNetWrtX(Net())(z)
>>> print(output)
(Tensor(shape=[2, 3], dtype=Float32, value=
[[ 1.41000009e+00,  1.60000002e+00,  6.59999943e+00],
 [ 1.41000009e+00,  1.60000002e+00,  6.59999943e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
[[ 1.70000005e+00,  1.70000005e+00,  1.70000005e+00],
 [ 1.89999998e+00,  1.89999998e+00,  1.89999998e+00],
 [ 1.50000000e+00,  1.50000000e+00,  1.50000000e+00]]))