mindspore.nn.WithEvalCell

View Source On Gitee
class mindspore.nn.WithEvalCell(network, loss_fn, add_cast_fp32=False)[source]

Wraps the forward network with the loss function.

It returns loss, forward output and label to calculate the metrics.

Parameters
  • network (Cell) – The forward network.

  • loss_fn (Cell) – The loss function.

  • add_cast_fp32 (bool) – Whether to adjust the data type to float32. Default: False .

Inputs:
  • data (Tensor) - Tensor of shape \((N, \ldots)\).

  • label (Tensor) - Tensor of shape \((N, \ldots)\).

Outputs:

Tuple(Tensor), containing a scalar loss Tensor, a network output Tensor of shape \((N, \ldots)\) and a label Tensor of shape \((N, \ldots)\).

Raises

TypeError – If add_cast_fp32 is not a bool.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore.nn as nn
>>> # Define a forward network without loss function, taking LeNet5 as an example.
>>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> eval_net = nn.WithEvalCell(net, loss_fn)