FNO for 3D Navier-Stokes
Overview
Computational fluid dynamics is one of the most important techniques in the field of fluid mechanics in the 21st century. The flow analysis, prediction and control can be realized by solving the governing equations of fluid mechanics by numerical method. Traditional finite element method (FEM) and finite difference method (FDM) are inefficient because of the complex simulation process (physical modeling, meshing, numerical discretization, iterative solution, etc.) and high computing costs. Therefore, it is necessary to improve the efficiency of fluid simulation with AI.
Machine learning methods provide a new paradigm for scientific computing by providing a fast solver similar to traditional methods. Classical neural networks learn mappings between finite dimensional spaces and can only learn solutions related to specific discretizations. Different from traditional neural networks, Fourier Neural Operator (FNO) is a new deep learning architecture that can learn mappings between infinite-dimensional function spaces. It directly learns mappings from arbitrary function parameters to solutions to solve a class of partial differential equations. Therefore, it has a stronger generalization capability. More information can be found in the paper, Fourier Neural Operator for Parametric Partial Differential Equations.
This tutorial describes how to solve the Navier-Stokes equation using 3-d Fourier neural operator.
Navier-Stokes equation
Navier-Stokes equation is a classical equation in computational fluid dynamics. It is a set of partial differential equations describing the conservation of fluid momentum, called N-S equation for short. Its vorticity form in two-dimensional incompressible flows is as follows:
where \(u\) is the velocity field, \(w=\nabla \times u\) is the vorticity, \(w_0(x)\) is the initial vorticity, \(\nu\) is the viscosity coefficient, \(f(x)\) is the forcing function.
Problem Description
We aim to solve two-dimensional incompressible N-S equation by learning the operator mapping from each time step to the next time step:
Technology Path
MindSpore Flow solves the problem as follows:
Training Dataset Construction.
Model Construction.
Optimizer and Loss Function.
Model Training.
Fourier Neural Operator
The following figure shows the architecture of the Fourier Neural Operator model. In the figure, \(w_0(x)\) represents the initial vorticity. The input vector is lifted to higher dimension channel space by the lifting layer. Then the mapping result is used as the input of the Fourier layer to perform nonlinear transformation of the frequency domain information. Finally, the decoding layer maps the transformation result to the final prediction result \(w_1(x)\).
The Fourier Neural Operator consists of the lifting Layer, Fourier Layers, and the decoding Layer.
Fourier layers: Start from input V. On top: apply the Fourier transform \(\mathcal{F}\); a linear transform R on the lower Fourier modes and filters out the higher modes; then apply the inverse Fourier transform \(\mathcal{F}^{-1}\). On the bottom: apply a local linear transform W. Finally, the Fourier Layer output vector is obtained through the activation function.
[1]:
import os
import time
import numpy as np
from mindspore import nn, ops, jit, data_sink, save_checkpoint, context, Tensor, ops
from mindspore import set_seed
from mindspore import dtype as mstype
The following src
pacakage can be downloaded in applications/data_driven/navier_stokes/fno3d/src.
[2]:
from mindflow import get_warmup_cosine_annealing_lr, load_yaml_config
from mindflow.cell.neural_operators.fno import FNO3D
from src import LpLoss, UnitGaussianNormalizer, create_training_dataset
set_seed(0)
np.random.seed(0)
[3]:
# set context for training: using graph mode for high performance training with GPU acceleration
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', device_id=0)
use_ascend = context.get_context(attr_key='device_target') == "Ascend"
config = load_yaml_config('./configs/fno3d.yaml')
data_params = config["data"]
model_params = config["model"]
optimizer_params = config["optimizer"]
sub = model_params["sub"]
grid_size = model_params["input_resolution"] // sub
input_timestep = model_params["input_timestep"]
output_timestep = model_params["output_timestep"]
Training Dataset Construction
Download the training and test dataset: data_driven/navier_stokes/dataset .
In this case, training data sets and test data sets are generated according to Zongyi Li’s data set in Fourier Neural Operator for Parametric Partial Differential Equations . The settings are as follows:
The initial condition \(w_0(x)\) is generated according to periodic boundary conditions:
The forcing function is defined as:
We use a time-step of 1e-4 for the Crank–Nicolson scheme in the data-generated process where we record the solution every t = 1 time units. All data are generated on a 256 × 256 grid and are downsampled to 64 × 64. In this case, the viscosity coefficient \(\nu=1e-4\), the number of samples in the training set is 1000, and the number of samples in the test set is 200.
[4]:
train_a = Tensor(np.load(os.path.join(
data_params["path"], "train_a.npy")), mstype.float32)
train_u = Tensor(np.load(os.path.join(
data_params["path"], "train_u.npy")), mstype.float32)
test_a = Tensor(np.load(os.path.join(
data_params["path"], "test_a.npy")), mstype.float32)
test_u = Tensor(np.load(os.path.join(
data_params["path"], "test_u.npy")), mstype.float32)
train_loader = create_training_dataset(data_params,
shuffle=True)
Data preparation finished
Model Construction
The network is composed of 1 lifting layer, multiple Fourier layers and 1 decoding layer:
The Lifting layer corresponds to the
FNO3D.fc0
in the case, and maps the output data \(x\) to the high dimension;Multi-layer Fourier Layer corresponds to the
FNO3D.fno_seq
in the case. Discrete Fourier transform is used to realize the conversion between time domain and frequency domain;The Decoding layer corresponds to
FNO3D.fc1
andFNO3D.fc2
in the case to obtain the final predictive value.
[5]:
if use_ascend:
compute_type = mstype.float16
else:
compute_type = mstype.float32
# prepare model
model = FNO3D(in_channels=model_params["in_channels"],
out_channels=model_params["out_channels"],
n_modes=model_params["modes"],
resolutions=[model_params["input_resolution"],
model_params["input_resolution"], output_timestep],
hidden_channels=model_params["width"],
n_layers=model_params["depth"],
projection_channels=4*model_params["width"],
fno_compute_dtype=compute_type
)
model_params_list = []
for k, v in model_params.items():
model_params_list.append(f"{k}-{v}")
model_name = "_".join(model_params_list)
Optimizer and Loss Function
[6]:
lr = get_warmup_cosine_annealing_lr(lr_init=optimizer_params["initial_lr"],
last_epoch=optimizer_params["train_epochs"],
steps_per_epoch=train_loader.get_dataset_size(),
warmup_epochs=optimizer_params["warmup_epochs"])
optimizer = nn.optim.Adam(model.trainable_params(),
learning_rate=Tensor(lr), weight_decay=optimizer_params['weight_decay'])
loss_fn = LpLoss()
a_normalizer = UnitGaussianNormalizer(train_a)
y_normalizer = UnitGaussianNormalizer(train_u)
[1]:
def calculate_l2_error(model, inputs, labels):
"""
Evaluate the model respect to input data and label.
Args:
model (Cell): list of expressions node can by identified by mindspore.
inputs (Tensor): the input data of network.
labels (Tensor): the true output value of given inputs.
"""
print("================================Start Evaluation================================")
time_beg = time.time()
rms_error = 0.0
for i in range(labels.shape[0]):
label = labels[i:i + 1]
test_batch = inputs[i:i + 1]
test_batch = a_normalizer.encode(test_batch)
label = y_normalizer.encode(label)
test_batch = test_batch.reshape(
1, grid_size, grid_size, 1, input_timestep).repeat(output_timestep, axis=3)
prediction = model(test_batch).reshape(
1, grid_size, grid_size, output_timestep)
prediction = y_normalizer.decode(prediction)
label = y_normalizer.decode(label)
rms_error_step = loss_fn(prediction.reshape(
1, -1), label.reshape(1, -1))
rms_error += rms_error_step
rms_error = rms_error / labels.shape[0]
print("mean rms_error:", rms_error)
print("predict total time: {} s".format(time.time() - time_beg))
print("=================================End Evaluation=================================")
Training Function
With MindSpore>= 2.0.0, you can train neural networks using functional programming paradigms, and single-step training functions are decorated with jit. The data_sink function is used to transfer the step-by-step training function and training dataset.
[2]:
def forward_fn(data, label):
bs = data.shape[0]
data = a_normalizer.encode(data)
label = y_normalizer.encode(label)
data = data.reshape(bs, grid_size, grid_size, 1, input_timestep).repeat(
output_timestep, axis=3)
logits = model(data).reshape(bs, grid_size, grid_size, output_timestep)
logits = y_normalizer.decode(logits)
label = y_normalizer.decode(label)
loss = loss_fn(logits.reshape(bs, -1), label.reshape(bs, -1))
return loss
grad_fn = ops.value_and_grad(
forward_fn, None, optimizer.parameters, has_aux=False)
@jit
def train_step(data, label):
loss, grads = grad_fn(data, label)
loss = ops.depend(loss, optimizer(grads))
return loss
sink_process = data_sink(train_step, train_loader, sink_size=100)
Model Training
With MindSpore version >= 2.0.0, we can use the functional programming for training neural networks.
[7]:
def train():
def forward_fn(data, label):
bs = data.shape[0]
data = a_normalizer.encode(data)
label = y_normalizer.encode(label)
data = data.reshape(bs, grid_size, grid_size, 1, input_timestep).repeat(
output_timestep, axis=3)
logits = model(data).reshape(bs, grid_size, grid_size, output_timestep)
logits = y_normalizer.decode(logits)
label = y_normalizer.decode(label)
loss = loss_fn(logits.reshape(bs, -1), label.reshape(bs, -1))
return loss
grad_fn = ops.value_and_grad(
forward_fn, None, optimizer.parameters, has_aux=False)
@jit
def train_step(data, label):
loss, grads = grad_fn(data, label)
loss = ops.depend(loss, optimizer(grads))
return loss
sink_process = data_sink(train_step, train_loader, sink_size=100)
summary_dir = os.path.join(config["summary_dir"], model_name)
ckpt_dir = os.path.join(summary_dir, "ckpt")
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
model.set_train()
for step in range(1, 1 + optimizer_params["train_epochs"]):
local_time_beg = time.time()
cur_loss = sink_process()
print(
f"epoch: {step} train loss: {cur_loss} epoch time: {time.time() - local_time_beg:.2f}s")
if step % 10 == 0:
print(f"loss: {cur_loss.asnumpy():>7f}")
print("step: {}, time elapsed: {}ms".format(
step, (time.time() - local_time_beg) * 1000))
calculate_l2_error(model, test_a, test_u)
save_checkpoint(model, os.path.join(
ckpt_dir, model_params["name"]))
[8]:
train()
pid: 1993
2023-02-01 12:14:12.2323
use_ascend: False
device_id: 2
Data preparation finished
steps_per_epoch: 1000
epoch: 1 train loss: 1.7631323 epoch time: 50.41s
epoch: 2 train loss: 1.9283392 epoch time: 36.59s
epoch: 3 train loss: 1.4265916 epoch time: 35.09s
epoch: 4 train loss: 1.8609437 epoch time: 34.41s
epoch: 5 train loss: 1.5222052 epoch time: 34.60s
epoch: 6 train loss: 1.3424721 epoch time: 33.85s
epoch: 7 train loss: 1.607729 epoch time: 33.11s
epoch: 8 train loss: 1.3308442 epoch time: 33.05s
epoch: 9 train loss: 1.3169765 epoch time: 33.90s
epoch: 10 train loss: 1.4149593 epoch time: 33.91s
...
predict total time: 15.179609298706055 s
epoch: 141 train loss: 0.777328 epoch time: 32.55s
epoch: 142 train loss: 0.7008966 epoch time: 32.52s
epoch: 143 train loss: 0.72377646 epoch time: 32.57s
epoch: 144 train loss: 0.72175145 epoch time: 32.44s
epoch: 145 train loss: 0.6235678 epoch time: 32.46s
epoch: 146 train loss: 0.9351083 epoch time: 32.45s
epoch: 147 train loss: 0.9283789 epoch time: 32.47s
epoch: 148 train loss: 0.7655642 epoch time: 32.60s
epoch: 149 train loss: 0.7233772 epoch time: 32.65s
epoch: 150 train loss: 0.86825275 epoch time: 32.59s
================================Start Evaluation================================
mean rel_rmse_error: 0.07437102290522307
=================================End Evaluation=================================
predict total time: 15.212349653244019 s