mindspore.nn.WithEvalCell
- class mindspore.nn.WithEvalCell(network, loss_fn, add_cast_fp32=False)[源代码]
封装前向网络和损失函数,返回用于计算评估指标的损失函数值、前向输出和标签。
参数:
network (Cell) - 前向网络。
loss_fn (Cell) - 损失函数。
add_cast_fp32 (bool):是否将数据类型调整为float32。默认值:False。
输入:
data (Tensor) - shape为 \((N, \ldots)\) 的Tensor。
label (Tensor) - shape为 \((N, \ldots)\) 的Tensor。
输出:
Tuple(Tensor),包括标量损失函数、shape为 \((N, \ldots)\) 的网络输出和shape为 \((N, \ldots)\) 的标签。
异常:
TypeError: add_cast_fp32 不是bool。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> # Forward network without loss function >>> net = Net() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() >>> eval_net = nn.WithEvalCell(net, loss_fn)