sciai.common.lbfgs_train
- sciai.common.lbfgs_train(loss_net, input_data, lbfgs_iter)[源代码]
L-BFGS训练函数,目前只能在PYNATIVE模式下运行。
- 参数:
loss_net (Cell) - 返回loss作为目标函数的网络。
input_data (Union[Tensor, tuple[Tensor]]) - loss_net的输入数据。
lbfgs_iter (int) - l-bfgs训练过程的迭代次数。
- 支持平台:
GPU
CPU
Ascend
样例:
>>> import mindspore as ms >>> from mindspore import nn, ops >>> from sciai.common import lbfgs_train >>> ms.set_seed(1234) >>> class Net1In1Out(nn.Cell): >>> def __init__(self): >>> super().__init__() >>> self.dense1 = nn.Dense(2, 1) >>> def construct(self, x): >>> return self.dense1(x).abs().sum() >>> net = Net1In1Out() >>> x = ops.ones((3, 2), ms.float32) >>> lbfgs_train(net, (x,), 1000) >>> loss = net(x) >>> print(loss) 5.9944578e-06