sciai.common.lbfgs_train
- sciai.common.lbfgs_train(loss_net, input_data, lbfgs_iter)[source]
L-BFGS training function, which can only run on PYNATIVE mode currently.
- Parameters
- Supported Platforms:
GPU
CPU
Ascend
Examples
>>> import mindspore as ms >>> from mindspore import nn, ops >>> from sciai.common import lbfgs_train >>> ms.set_seed(1234) >>> class Net1In1Out(nn.Cell): >>> def __init__(self): >>> super().__init__() >>> self.dense1 = nn.Dense(2, 1) >>> def construct(self, x): >>> return self.dense1(x).abs().sum() >>> net = Net1In1Out() >>> x = ops.ones((3, 2), ms.float32) >>> lbfgs_train(net, (x,), 1000) >>> loss = net(x) >>> print(loss) 5.9944578e-06