mindflow.pde.UnsteadyFlowWithLoss

View Source On Gitee
class mindflow.pde.UnsteadyFlowWithLoss(model, t_in=1, t_out=1, loss_fn='mse', data_format='NTCHW')[source]

Base class of unsteady user-defined data-driven problems.

Parameters
  • model (mindspore.nn.Cell) – A training or test model.

  • t_in (int) – Initial time steps. Default: 1.

  • t_out (int) – Output time steps. Default: 1.

  • loss_fn (Union[str, Cell]) – Loss function. Default: "mse".

  • data_format (str) – Data format. Default: "NTCHW".

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> from mindspore import Tensor
>>> import mindspore
>>> from mindflow.pde import UnsteadyFlowWithLoss
>>> from mindflow.cell import FNO2D
>>> from mindflow.core import RelativeRMSELoss
...
>>> model = FNO2D(in_channels=1, out_channels=1, resolution=64, modes=12)
>>> problem = UnsteadyFlowWithLoss(model, loss_fn=RelativeRMSELoss(), data_format='NHWTC')
>>> inputs = Tensor(np.random.randn(32, 64, 64, 1, 1), mindspore.float32)
>>> label = Tensor(np.random.randn(32, 64, 64, 1, 1), mindspore.float32)
>>> loss = problem.get_loss(inputs, label)
>>> print(loss)
31.999998
get_loss(inputs, labels)[source]

Compute the loss of training or test model.

Parameters
  • inputs (Tensor) – Dataset with data format is "NTCHW" or "NHWTC".

  • labels (Tensor) – True values of the samples.

Returns

float, loss value.

step(inputs)[source]

Support single or multiple time steps training.

Parameters

inputs (Tensor) – Input dataset with data format is "NTCHW" or "NHWTC".

Returns

List(Tensor), Dataset with data format is "NTCHW" or "NHWTC".