mindspore.recompute
- mindspore.recompute(block, *args, **kwargs)[source]
This function is used to reduce memory, when run block, rather than storing the intermediate activation computed in forward pass, we will recompute it in backward pass.
Note
Recompute function only support block which inherited from Cell object.
This function interface now only support pynative mode. you can use Cell.recompute interface in graph mode.
When use recompute function, block object should not decorated by @jit.
- Parameters
- Returns
Same as return type of block.
- Raises
TypeError – If block is not Cell object.
AssertionError – If execute mode is not PYNATIVE_MODE.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import numpy as np >>> import mindspore.nn as nn >>> from mindspore import ops >>> from mindspore import Tensor, recompute >>> class MyCell(nn.Cell): ... def __init__(self): ... super(MyCell, self).__init__(auto_prefix=False) ... self.conv = nn.Conv2d(2, 2, 2, has_bias=False, weight_init='ones') ... self.relu = ops.ReLU() ... ... def construct(self, x): ... y = recompute(self.conv, x) ... return self.relu(y) >>> inputs = Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 2) >>> my_net = MyCell() >>> grad = ops.grad(my_net)(inputs) >>> print(grad) [[[[2. 4.] [4. 8.]] [[2. 4.] [4. 8.]]] [[[2. 4.] [4. 8.]] [[2. 4.] [4. 8.]]]]