Raynold-averaged Navier-Stokes
Overview
The Raynold-averaged Navier-Stokes equation is a classic numerical simulation case in the fields of fluid mechanics and meteorology. It is used to study the flow behavior of air or fluid over a periodic hilly terrain. This problem aims to explore the influence of hilly terrain on atmospheric or fluid motion, leading to a deeper understanding of meteorological phenomena, terrain effects, and fluid characteristics over complex terrain. This project utilizes the Reynolds-averaged model to simulate turbulent flow over a two-dimensional periodic hilly terrain.
Reynolds-Averaged Model
The Reynolds-Averaged Navier-Stokes equations (RANS) are a commonly used numerical simulation approach in fluid mechanics to study the averaged behavior of fluids under different Reynolds numbers. Named after the British scientist Osborne Reynolds, this model involves time-averaging of flow field variables and provides an engineering-oriented approach to deal with turbulent flows. The Reynolds-averaged model is based on Reynolds decomposition, which separates flow field variables into mean and fluctuating components. By time-averaging the Reynolds equations, the unsteady fluctuating terms are eliminated, resulting in time-averaged equations describing the macroscopic flow. Taking the two-dimensional Reynolds-averaged momentum and continuity equations as examples:
Reynolds-Averaged Momentum Equation
Continuity Equation
Here, \(\overline{u}\) and \(\overline{v}\) represent the time-averaged velocity components in the x and y directions, \(\overline{p}\) is the time-averaged pressure, \(\rho\) is fluid density, \(\nu\) is the kinematic viscosity, and \(u\) and \(v\) are the velocity components in the x and y directions.
Model Solution Introduction
The core idea of the RANS-PINNs (Reynolds-Averaged Navier-Stokes - Physics-Informed Neural Networks) method is to combine physical equations with neural networks to achieve simulation results that possess both the accuracy of traditional RANS models and the flexibility of neural networks. In this approach, the Reynolds-averaged equations for mean flow, along with an isotropic eddy viscosity model for turbulence, are combined to form an accurate baseline solution. Then, the remaining turbulent fluctuation part is modeled using Physics-Informed Neural Networks (PINNs), further enhancing the simulation accuracy.
The structure of the RANS-PINNs model is depicted below:
Preparation
Import the required libraries for training. The src folder includes functions for dataset processing, network models, and loss calculation.
Training is conducted using the graph mode (GRAPH) of the MindSpore framework, and it takes place on the GPU (by default) or Ascend (single card).
[1]:
import os
import time
import numpy as np
import mindspore
from mindspore import context, nn, ops, jit, set_seed, load_checkpoint, load_param_into_net, data_sink
from mindspore.amp import all_finite
from mindflow.cell import FCSequential
from mindflow.utils import load_yaml_config
from src import create_train_dataset, create_test_dataset, calculate_l2_error, NavierStokesRANS
from eval import predict
set_seed(0)
np.random.seed(0)
context.set_context(mode=context.PYNATIVE_MODE,
device_target="GPU")
use_ascend = context.get_context(attr_key='device_target') == "Ascend"
Load Parameters
Import the configuration parameters for the dataset, model, and optimizer from the rans.yaml file.
[2]:
# load configurations
config = load_yaml_config('./configs/rans.yaml')
data_params = config["data"]
model_params = config["model"]
optim_params = config["optimizer"]
summary_params = config["summary"]
Dataset Construction
Source: Numerical simulation flow field data around a two-dimensional cylinder, provided by Associate Professor Yu Jian’s team at the School of Aeronautic Science and Engineering, Beihang University.
Data Description: The data is in numpy’s npy format with dimensions [300, 700, 10]. The first two dimensions represent the length and width of the flow field, and the last dimension includes variables (x, y, u, v, p, uu, uv, vv, rho, nu), totaling 10 variables. Among these, x, y, u, v, p represent the x-coordinate, y-coordinate, x-direction velocity, y-direction velocity, and pressure of the flow field, respectively. uu, uv, vv are Reynolds-averaged statistical quantities, while rho is fluid density and nu is kinematic viscosity.
Dataset Download Link: periodic_hill.npy
[3]:
# create training dataset
# create training dataset
dataset = create_train_dataset(data_params["data_path"], data_params["batch_size"])
# create test dataset
inputs, label = create_test_dataset(data_params["data_path"])
Model Initialization
Initialize the RANS-PINNs model based on the configuration in rans.yaml. Use the Mean Squared Error (MSE) loss function and the Adam optimizer.
[4]:
model = FCSequential(in_channels=model_params["in_channels"],
out_channels=model_params["out_channels"],
layers=model_params["layers"],
neurons=model_params["neurons"],
residual=model_params["residual"],
act='tanh')
if summary_params["load_ckpt"]:
param_dict = load_checkpoint(summary_params["load_ckpt_path"])
load_param_into_net(model, param_dict)
if not os.path.exists(os.path.abspath(summary_params['ckpt_path'])):
os.makedirs(os.path.abspath(summary_params['ckpt_path']))
params = model.trainable_params()
optimizer = nn.Adam(params, optim_params["initial_lr"], weight_decay=optim_params["weight_decay"])
problem = NavierStokesRANS(model)
if use_ascend:
from mindspore.amp import DynamicLossScaler, auto_mixed_precision
loss_scaler = DynamicLossScaler(1024, 2, 100)
auto_mixed_precision(model, 'O3')
momentum_x: u(x, y)*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 0.000178571426658891*Derivative(u(x, y), (x, 2)) - 0.000178571426658891*Derivative(u(x, y), (y, 2)) + Derivative(uu(x, y), x) + Derivative(uv(x, y), y)
Item numbers of current derivative formula nodes: 7
momentum_y: u(x, y)*Derivative(v(x, y), x) + v(x, y)*Derivative(v(x, y), y) + 1.0*Derivative(p(x, y), y) + Derivative(uv(x, y), x) - 0.000178571426658891*Derivative(v(x, y), (x, 2)) - 0.000178571426658891*Derivative(v(x, y), (y, 2)) + Derivative(vv(x, y), y)
Item numbers of current derivative formula nodes: 7
continuty: Derivative(u(x, y), x) + Derivative(v(x, y), y)
Item numbers of current derivative formula nodes: 2
bc_u: u(x, y)
Item numbers of current derivative formula nodes: 1
bc_v: v(x, y)
Item numbers of current derivative formula nodes: 1
bc_p: p(x, y)
Item numbers of current derivative formula nodes: 1
bc_uu: uu(x, y)
Item numbers of current derivative formula nodes: 1
bc_uv: uv(x, y)
Item numbers of current derivative formula nodes: 1
bc_vv: vv(x, y)
Item numbers of current derivative formula nodes: 1
Model Training
For versions of MindSpore >= 2.0.0, you can use the functional programming paradigm to train neural networks.
[5]:
def forward_fn(pde_data, data, label):
loss = problem.get_loss(pde_data, data, label)
if use_ascend:
loss = loss_scaler.scale(loss)
return loss
grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)
@jit
def train_step(pde_data, data, label):
loss, grads = grad_fn(pde_data, data, label)
if use_ascend:
loss = loss_scaler.unscale(loss)
is_finite = all_finite(grads)
if is_finite:
grads = loss_scaler.unscale(grads)
loss = ops.depend(loss, optimizer(grads))
loss_scaler.adjust(is_finite)
else:
loss = ops.depend(loss, optimizer(grads))
return loss
epochs = optim_params["train_epochs"]
sink_process = data_sink(train_step, dataset, sink_size=1)
train_data_size = dataset.get_dataset_size()
for epoch in range(1, 1 + epochs):
# train
time_beg = time.time()
model.set_train(True)
for _ in range(train_data_size + 1):
step_train_loss = sink_process()
print(f"epoch: {epoch} train loss: {step_train_loss} epoch time: {(time.time() - time_beg)*1000 :.3f}ms")
model.set_train(False)
if epoch % summary_params["eval_interval_epochs"] == 0:
# eval
calculate_l2_error(model, inputs, label, config)
predict(model=model, epochs=epoch, input_data=inputs, label=label, path=summary_params["visual_dir"])
if epoch % summary_params["save_checkpoint_epochs"] == 0:
ckpt_name = "rans_{}.ckpt".format(epoch + 1)
mindspore.save_checkpoint(model, os.path.join(summary_params['ckpt_path'], ckpt_name))
epoch: 1 train loss: 0.033210676 epoch time: 21279.999ms
epoch: 2 train loss: 0.019967956 epoch time: 11001.454ms
epoch: 3 train loss: 0.015202466 epoch time: 11049.534ms
epoch: 4 train loss: 0.009431531 epoch time: 10979.578ms
epoch: 5 train loss: 0.009564591 epoch time: 11857.952ms
predict total time: 361.42492294311523 ms
l2_error, U: 0.3499122378307982 , V: 1.089610520680924 , P: 1.0590148771220198
l2_error, uu: 0.6619816139038208 , uv: 0.9806737880811025 , vv: 1.223253942721496 , Total: 0.3788639206858165
==================================================================================================
epoch: 6 train loss: 0.0080219805 epoch time: 10980.343ms
epoch: 7 train loss: 0.007290244 epoch time: 11141.353ms
epoch: 8 train loss: 0.0072537386 epoch time: 11535.102ms
epoch: 9 train loss: 0.007020033 epoch time: 11041.171ms
epoch: 10 train loss: 0.0072951056 epoch time: 11033.113ms
predict total time: 45.89080810546875 ms
l2_error, U: 0.2574625213886651 , V: 1.0159654927310178 , P: 1.08665077365793
l2_error, uu: 0.6712817201442959 , uv: 1.6285996210166078 , vv: 1.6174848943769466 , Total: 0.2994041993242163
==================================================================================================
epoch: 11 train loss: 0.006911595 epoch time: 11269.898ms
epoch: 12 train loss: 0.0064922348 epoch time: 11014.546ms
epoch: 13 train loss: 0.012375369 epoch time: 10856.192ms
epoch: 14 train loss: 0.0063738413 epoch time: 11219.892ms
epoch: 15 train loss: 0.006205684 epoch time: 11509.733ms
predict total time: 1419.1265106201172 ms
l2_error, U: 0.26029930447820726 , V: 1.0100483948680088 , P: 1.1317783698512909
l2_error, uu: 0.6231199513484501 , uv: 1.097468251696328 , vv: 1.2687142671208649 , Total: 0.301384468926242
==================================================================================================
epoch: 16 train loss: 0.00825448 epoch time: 11118.031ms
epoch: 17 train loss: 0.0061626835 epoch time: 11953.393ms
epoch: 18 train loss: 0.0073482464 epoch time: 11729.854ms
epoch: 19 train loss: 0.0059430953 epoch time: 11183.294ms
epoch: 20 train loss: 0.006461049 epoch time: 11480.535ms
predict total time: 328.2887935638428 ms
l2_error, U: 0.2893996640103185 , V: 1.0164172238860398 , P: 1.118747335999008
l2_error, uu: 0.6171527683696496 , uv: 1.1570214426333394 , vv: 1.5968321768424096 , Total: 0.3270872725014816
==================================================================================================
...
epoch: 496 train loss: 0.001080659 epoch time: 11671.701ms
epoch: 497 train loss: 0.0007907547 epoch time: 11653.532ms
epoch: 498 train loss: 0.0015688213 epoch time: 11612.691ms
epoch: 499 train loss: 0.00085494306 epoch time: 11429.596ms
epoch: 500 train loss: 0.0026226037 epoch time: 11708.611ms
predict total time: 43.506622314453125 ms
l2_error, U: 0.16019161506598686 , V: 0.561610130067435 , P: 0.4730013943213571
l2_error, uu: 1.0206032668202991 , uv: 0.812573326422638 , vv: 1.5239299913682682 , Total: 0.18547458639343734
Visualization of Prediction Results
Below is a comparison between the predicted results of the RANS-PINNs model and the ground truth:
The images display the distribution of lateral velocity and vertical velocity at different positions within the flow field. The lower image shows the ground truth, while the upper image displays the predicted values.
The following is a cross-velocity profile of the RANS-PINNs model:
where the blue line is the true value and the orange dashed line is the predicted value.