sciai.common.lbfgs_train

View Source On Gitee
sciai.common.lbfgs_train(loss_net, input_data, lbfgs_iter)[source]

L-BFGS training function, which can only run on PYNATIVE mode currently.

Parameters
  • loss_net (Cell) – Network which returns loss as objective function.

  • input_data (Union[Tensor, tuple[Tensor]]) – Input data of the loss_net.

  • lbfgs_iter (int) – Number of iterations of the l-bfgs training process.

Supported Platforms:

GPU CPU Ascend

Examples

>>> import mindspore as ms
>>> from mindspore import nn, ops
>>> from sciai.common import lbfgs_train
>>> ms.set_seed(1234)
>>> class Net1In1Out(nn.Cell):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.dense1 = nn.Dense(2, 1)
>>>     def construct(self, x):
>>>         return self.dense1(x).abs().sum()
>>> net = Net1In1Out()
>>> x = ops.ones((3, 2), ms.float32)
>>> lbfgs_train(net, (x,), 1000)
>>> loss = net(x)
>>> print(loss)
5.9944578e-06