mindflow.operators.batched_hessian

mindflow.operators.batched_hessian(model)[源代码]

计算网络模型的海森矩阵。

参数:
  • model (mindspore.nn.Cell) - 输入维度为in_channels输出维度为out_channels的网络模型。

返回:

Function,用于计算海森矩阵的Hessian实例。输入维度为:[batch_size,in_channels],输出维度为:[out_channels,in_channels,batch_size,in_channels]。

说明

要求MindSpore版本 >= 2.0.0调用如下接口: mindspore.jacrev

支持平台:

Ascend GPU CPU

样例:

>>> import numpy as np
>>> from mindspore import nn, ops, Tensor
>>> from mindspore import dtype as mstype
>>> from mindflow.operators import batched_hessian
>>> np.random.seed(123456)
>>> class Net(nn.Cell):
...     def __init__(self, cin=2, cout=1, hidden=10):
...         super().__init__()
...         self.fc1 = nn.Dense(cin, hidden)
...         self.fc2 = nn.Dense(hidden, hidden)
...         self.fcout = nn.Dense(hidden, cout)
...         self.act = ops.Tanh()
...     
...     def construct(self, x):
...         x = self.act(self.fc1(x))
...         x = self.act(self.fc2(x))
...         x = self.fcout(x)
...         return x
>>> model = Net()
>>> hessian = batched_hessian(model)
>>> inputs = np.random.random(size=(3, 2))
>>> res = hessian(Tensor(inputs, mstype.float32))
>>> print(res.shape)
(1, 2, 3, 2)