mindspore.ops.TensorDump

查看源文件
class mindspore.ops.TensorDump[源代码]

将Tensor保存为numpy格式的npy文件。

文件名会按照执行顺序自动添加前缀。例如,如果 filea,第一次保存的文件名为 0_a.npy,第二次为 1_a.npy

警告

  • 如果在短时间内保存大量数据,可能会导致设备端内存溢出。可以考虑对数据进行切片,以减小数据规模。

  • 由于数据保存是异步处理的,当数据量过大或主进程退出过快时,可能出现数据丢失的问题,需要主动控制主进程销毁时间,例如使用sleep。

输入:
  • file (str) - 要保存的文件路径。

  • input_x (Tensor) - 任意维度的Tensor。

异常:
  • TypeError - 如果 file 不是str。

  • TypeError - 如果 input_x 不是Tensor。

支持平台:

Ascend

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> import time
>>> from mindspore import nn, Tensor, ops
>>> ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.dump = ops.TensorDump()
...
...     def construct(self, x):
...         x += 1.
...         self.dump('add', x)
...         x /= 2.
...         self.dump('div', x)
...         x *= 5.
...         self.dump('mul', x)
...         return x
...
>>> x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32)
>>> input_x = Tensor(x)
>>> net = Net()
>>> out = net(input_x)
>>> time.sleep(0.5)
>>> add = np.load('0_add.npy')
>>> print(add)
[[2. 3. 4. 5.]
 [6. 7. 8. 9.]]