mindspore.nn.WithEvalCell
- class mindspore.nn.WithEvalCell(network, loss_fn, add_cast_fp32=False)[源代码]
封装前向网络和损失函数。 返回用于计算评估指标的损失函数值、前向输出和标签。
- 参数:
network (Cell) - 前向网络。
loss_fn (Cell) - 损失函数。
add_cast_fp32 (bool) - 是否将数据类型调整为float32。默认值:
False
。
- 输入:
data (Tensor) - shape为 \((N, \ldots)\) 的Tensor。
label (Tensor) - shape为 \((N, \ldots)\) 的Tensor。
- 输出:
Tuple(Tensor),包括标量损失函数、shape为 \((N, \ldots)\) 的网络输出和shape为 \((N, \ldots)\) 的标签。
- 异常:
TypeError - add_cast_fp32 不是bool。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import mindspore.nn as nn >>> # Define a forward network without loss function, taking LeNet5 as an example. >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() >>> eval_net = nn.WithEvalCell(net, loss_fn)