mindspore.nn.WithEvalCell
- class mindspore.nn.WithEvalCell(network, loss_fn, add_cast_fp32=False)[source]
Wraps the forward network with the loss function.
It returns loss, forward output and label to calculate the metrics.
- Parameters
- Inputs:
data (Tensor) - Tensor of shape \((N, \ldots)\).
label (Tensor) - Tensor of shape \((N, \ldots)\).
- Outputs:
Tuple(Tensor), containing a scalar loss Tensor, a network output Tensor of shape \((N, \ldots)\) and a label Tensor of shape \((N, \ldots)\).
- Raises
TypeError – If add_cast_fp32 is not a bool.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore.nn as nn >>> # Define a forward network without loss function, taking LeNet5 as an example. >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() >>> eval_net = nn.WithEvalCell(net, loss_fn)