sciai.common.lbfgs_train

查看源文件
sciai.common.lbfgs_train(loss_net, input_data, lbfgs_iter)[源代码]

L-BFGS训练函数,目前只能在PYNATIVE模式下运行。

参数:
  • loss_net (Cell) - 返回loss作为目标函数的网络。

  • input_data (Union[Tensor, tuple[Tensor]]) - loss_net的输入数据。

  • lbfgs_iter (int) - l-bfgs训练过程的迭代次数。

支持平台:

GPU CPU Ascend

样例:

>>> import mindspore as ms
>>> from mindspore import nn, ops
>>> from sciai.common import lbfgs_train
>>> ms.set_seed(1234)
>>> class Net1In1Out(nn.Cell):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.dense1 = nn.Dense(2, 1)
>>>     def construct(self, x):
>>>         return self.dense1(x).abs().sum()
>>> net = Net1In1Out()
>>> x = ops.ones((3, 2), ms.float32)
>>> lbfgs_train(net, (x,), 1000)
>>> loss = net(x)
>>> print(loss)
5.9944578e-06