Solve Navier-Stokes equation based on Spectral Neural Operator

DownloadNotebookDownloadCodeView Source On Gitee

Overview

Computational fluid dynamics is one of the most important techniques in the field of fluid mechanics in the 21st century. The flow analysis, prediction and control can be realized by solving the governing equations of fluid mechanics by numerical method. Traditional finite element method (FEM) and finite difference method (FDM) are inefficient because of the complex simulation process (physical modeling, meshing, numerical discretization, iterative solution, etc.) and high computing costs. Therefore, it is necessary to improve the efficiency of fluid simulation with AI.

Machine learning methods provide a new paradigm for scientific computing by providing a fast solver similar to traditional methods. Classical neural networks learn mappings between finite dimensional spaces and can only learn solutions related to a specific discretization. Different from traditional neural networks, Fourier Neural Operator (FNO) is a new deep learning architecture that can learn mappings between infinite-dimensional function spaces. It directly learns mappings from arbitrary function parameters to solutions to solve a class of partial differential equations. Therefore, it has a stronger generalization capability. More information can be found in the paper, Fourier Neural Operator for Parametric Partial Differential Equations.

Spectral Neural Operator (SNO) is the FNO-like architecture using polynomial transformation to spectral space (Chebyshev, Legendre, etc.)instead of Fourier. SNO is characterized by smaller systematic bias caused by aliasing errors compared to FNO. One of the most important benefits of the architecture is that SNO supports multiple choice of basis, so it is possible to find a set of polynomials which is the most convenient for representation, e.g., respect symmetries of the problem or fit well into the solution interval. Besides, the neural operators, based on orthogonal polynomials, may be competitive with other spectral operators in case when the input function is defined on unstructured grid.

More information can be found in the paper, “Spectral Neural Operators”. arXiv preprint arXiv:2205.10573 (2022).

This tutorial describes how to solve the Navier-Stokes equation using Spectral neural operator.

Problem Description

We aim to solve two-dimensional incompressible N-S equation by learning the operator mapping from each time step to the next time step:

wtw(,t+1)

Technology Path

MindFlow solves the problem as follows:

  1. Training Dataset Construction.

  2. Model Construction.

  3. Optimizer and Loss Function.

  4. Model Training.

Spectral Neural Operator

U-SNO modification

The following figure shows the architecture of the Spectral Neural Operator, which consists of encoder, multiple spectral convolution layers (linear transformation in space of coefficients in polynomial basis) and decoder. To compute forward and inverse polynomial transformation matrices for spectral convolutions, the input should be interpolated at the respective Gauss quadrature nodes (Chebyshev grid, etc.). The interpolated input is lifted to a higher dimension channel space by a convolutional Encoder layer. The result comes to the input of a sequence of spectral (SNO) layers, each of which applies a linear convolution to its truncated spectral representation. The output of SNO layers is projected back to the target dimension by a convolutional Decoder, and finally interpolated back to the original nodes.

The spectral (SNO) layer performs the following operations: applies the polynomial transformation A to spectral space (Chebyshev, Legendre, etc.); a linear convolution L on the lower polynomial modes and filters out the higher modes; then applies the inverse conversion S=A1 (back to the physical space). Then a linear convolution W of input is added, and nonlinear activation is applied.

U-SNO is the SNO modification, where a sequence of modified spectral convolution layers comes after the main sequence. In the modified U-SNO layer, UNet architecture (with custom number of steps) is used as a skip block instead of linear W.

Spectral Neural Operator model structure

[1]:
import os
import time
import numpy as np

import mindspore
from mindspore import nn, context, ops, Tensor, jit, set_seed, save_checkpoint
import mindspore.common.dtype as mstype

The following src pacakage can be downloaded in applications/data_driven/navier_stokes/sno2d/src.

[2]:
from mindflow.cell import SNO2D, get_poly_transform
from mindflow.utils import load_yaml_config, print_log
from mindflow.pde import UnsteadyFlowWithLoss
from src import create_training_dataset, load_interp_data, calculate_l2_error
from mindflow.loss import RelativeRMSELoss
from mindflow.common import get_warmup_cosine_annealing_lr

set_seed(0)
np.random.seed(0)
[3]:
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', device_id=0)
use_ascend = context.get_context(attr_key='device_target') == "Ascend"
config = load_yaml_config('./configs/sno2d.yaml')

data_params = config["data"]
model_params = config["model"]
optimizer_params = config["optimizer"]
summary_params = config["summary"]

Training Dataset Construction

Download the training and test dataset: data_driven/navier_stokes/dataset .

In this case, training data sets and test data sets are generated according to Zongyi Li’s data set in Fourier Neural Operator for Parametric Partial Differential Equations . The settings are as follows:

The initial condition w0(x) is generated according to periodic boundary conditions:

w0μ,μ=N(0,73/2(Δ+49I)2.5)

The forcing function is defined as:

f(x)=0.1(sin(2π(x1+x2))+cos(2π(x1+x2)))

We use a time-step of 1e-4 for the Crank–Nicolson scheme in the data-generated process where we record the solution every t = 1 time units. All data are generated on a 256 × 256 grid and are downsampled to 64 × 64. In this case, the viscosity coefficient ν=1e5, the number of samples in the training set is 19000, and the number of samples in the test set is 3800.

[4]:
poly_type = data_params['poly_type']
load_interp_data(data_params, dataset_type='train')
train_dataset = create_training_dataset(data_params, shuffle=True)

test_data = load_interp_data(data_params, dataset_type='test')
test_input = test_data['test_inputs']
test_label = test_data['test_labels']

batch_size = data_params['batch_size']
resolution = data_params['resolution']

Model Construction

The network is composed of 1 encoding layer, multiple spectral layers and decoding block:

  • The encoding convolution corresponds to the SNO2D.encoder in the case, and maps the input data x to the high dimension;

  • A sequence of SNO layers corresponds to the SNO2D.sno_kernel in the case. Input matrices of polynomial transformations(forward and inverse conversions for each of two spatial variables) are used to realize the transition between space-time domain and frequency domain; Here, it consists of two subsequences, with SNO layers and U-SNO layers, respectively.

  • The decoding layer corresponds to SNO2D.decoder and consists of two convolutions.The decoder is used to obtain the final prediction.

[5]:
n_modes = model_params['modes']

transform_data = get_poly_transform(resolution, n_modes, poly_type)

transform = Tensor(transform_data["analysis"], mstype.float32)
inv_transform = Tensor(transform_data["synthesis"], mstype.float32)

model = SNO2D(in_channels=model_params['in_channels'],
              out_channels=model_params['out_channels'],
              hidden_channels=model_params['hidden_channels'],
              num_sno_layers=model_params['sno_layers'],
              kernel_size=model_params['kernel_size'],
              transforms=[[transform, inv_transform]]*2,
              num_usno_layers=model_params['usno_layers'],
              num_unet_strides=model_params['unet_strides'],
              compute_dtype=mstype.float32)

total = 0
for param in model.get_parameters():
    print_log(param.shape)
    total += param.size
print_log(f"Total Parameters:{total}")
(64, 1, 1, 1)
(64, 64, 5, 5)
(64, 64, 1, 1)
(64, 64, 5, 5)
(64, 64, 1, 1)
(64, 64, 5, 5)
(64, 64, 1, 1)
(64, 64, 5, 5)
(64, 64, 3, 3)
(64, 64, 3, 3)
(128, 64, 3, 3)
(128, 128, 3, 3)
(256, 128, 3, 3)
(256, 256, 3, 3)
(256, 128, 2, 2)
(128, 256, 3, 3)
(128, 128, 3, 3)
(128, 64, 2, 2)
(64, 128, 3, 3)
(64, 64, 3, 3)
(64, 128, 3, 3)
(64, 64, 1, 1)
(1, 64, 1, 1)
Total Parameters:2396288

Optimizer and Loss Function

[10]:
steps_per_epoch = train_dataset.get_dataset_size()
grad_clip_norm = optimizer_params['grad_clip_norm']
[8]:
lr = get_warmup_cosine_annealing_lr(lr_init=optimizer_params['learning_rate'],
                                    last_epoch=optimizer_params["epochs"],
                                    steps_per_epoch=steps_per_epoch,
                                    warmup_epochs=1)

optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=Tensor(lr),
                               weight_decay=optimizer_params['weight_decay'])
problem = UnsteadyFlowWithLoss(model, loss_fn=RelativeRMSELoss(), data_format="NTCHW")

Model Training

With MindSpore version >= 2.0.0, we can use the functional programming for training neural networks.

[11]:
def train():
    def forward_fn(train_inputs, train_label):
        loss = problem.get_loss(train_inputs, train_label)
        return loss

    grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)

    @jit
    def train_step(train_inputs, train_label):
        loss, grads = grad_fn(train_inputs, train_label)
        grads = ops.clip_by_global_norm(grads, grad_clip_norm)
        loss = ops.depend(loss, optimizer(grads))
        return loss

    sink_process = mindspore.data_sink(train_step, train_dataset, sink_size=1)
    ckpt_dir = os.path.join(model_params["root_dir"], summary_params["ckpt_dir"])
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    for epoch in range(1, 1 + optimizer_params["epochs"]):
        local_time_beg = time.time()
        model.set_train(True)
        for _ in range(steps_per_epoch):
            cur_loss = sink_process()

        local_time_end = time.time()
        epoch_seconds = local_time_end - local_time_beg
        step_seconds = (epoch_seconds/steps_per_epoch)*1000
        print_log(f"epoch: {epoch} train loss: {cur_loss} "
                  f"epoch time: {epoch_seconds:.3f}s step time: {step_seconds:5.3f}ms")

        model.set_train(False)
        if epoch % summary_params["save_ckpt_interval"] == 0:
            save_checkpoint(model, os.path.join(ckpt_dir, f"{model_params['name']}_epoch{epoch}"))

        if epoch % summary_params['test_interval'] == 0:
            calculate_l2_error(model, test_input, test_label, data_params)

[12]:
train()
epoch: 1 train loss: 1.9672374 epoch time: 34.144s step time: 34.144ms
epoch: 2 train loss: 1.8687398 epoch time: 28.038s step time: 28.038ms
epoch: 3 train loss: 1.6240175 epoch time: 28.094s step time: 28.094ms
epoch: 4 train loss: 1.812437 epoch time: 28.001s step time: 28.001ms
epoch: 5 train loss: 1.6048276 epoch time: 28.006s step time: 28.006ms
epoch: 6 train loss: 1.3349447 epoch time: 28.045s step time: 28.045ms
epoch: 7 train loss: 1.445535 epoch time: 28.084s step time: 28.084ms
epoch: 8 train loss: 1.287163 epoch time: 28.050s step time: 28.050ms
epoch: 9 train loss: 1.2205887 epoch time: 28.079s step time: 28.079ms
epoch: 10 train loss: 1.1622387 epoch time: 28.048s step time: 28.048ms
================================Start Evaluation================================
on Gauss grid: 0.2202785452026874, on regular grid: 0.21447483365566075
=================================End Evaluation=================================
predict total time: 7.394038677215576 s
epoch: 11 train loss: 0.98966134 epoch time: 28.090s step time: 28.090ms
epoch: 12 train loss: 0.9963242 epoch time: 28.080s step time: 28.080ms
epoch: 13 train loss: 1.0154707 epoch time: 28.125s step time: 28.125ms
epoch: 14 train loss: 1.029425 epoch time: 28.087s step time: 28.087ms
epoch: 15 train loss: 1.0535842 epoch time: 28.069s step time: 28.069ms
epoch: 16 train loss: 1.0508957 epoch time: 28.217s step time: 28.217ms
epoch: 17 train loss: 0.73175216 epoch time: 28.239s step time: 28.239ms
epoch: 18 train loss: 0.7978346 epoch time: 28.060s step time: 28.060ms
epoch: 19 train loss: 1.2525742 epoch time: 28.057s step time: 28.057ms
epoch: 20 train loss: 1.0816319 epoch time: 28.052s step time: 28.052ms
================================Start Evaluation================================
on Gauss grid: 0.17742644541244953, on regular grid: 0.17132807192601202
=================================End Evaluation=================================
predict total time: 7.578975677490234 s
epoch: 21 train loss: 0.9601194 epoch time: 28.033s step time: 28.033ms
epoch: 22 train loss: 1.0366433 epoch time: 28.100s step time: 28.100ms
epoch: 23 train loss: 0.9956419 epoch time: 28.061s step time: 28.061ms
epoch: 24 train loss: 1.0766693 epoch time: 28.125s step time: 28.125ms
epoch: 25 train loss: 0.9773072 epoch time: 28.022s step time: 28.022ms
epoch: 26 train loss: 0.65455425 epoch time: 28.086s step time: 28.086ms
epoch: 27 train loss: 0.71299446 epoch time: 28.006s step time: 28.006ms
epoch: 28 train loss: 1.0231717 epoch time: 28.170s step time: 28.170ms
epoch: 29 train loss: 0.8839726 epoch time: 28.143s step time: 28.143ms
epoch: 30 train loss: 0.90894026 epoch time: 28.124s step time: 28.124ms
================================Start Evaluation================================
on Gauss grid: 0.16749235310871155, on regular grid: 0.169489491779019
=================================End Evaluation=================================
predict total time: 7.71979022026062 s
epoch: 31 train loss: 0.9652164 epoch time: 28.092s step time: 28.092ms
epoch: 32 train loss: 0.6686845 epoch time: 28.096s step time: 28.096ms
epoch: 33 train loss: 0.8932849 epoch time: 28.107s step time: 28.107ms
epoch: 34 train loss: 0.7517134 epoch time: 28.208s step time: 28.208ms
epoch: 35 train loss: 0.825667 epoch time: 28.188s step time: 28.188ms
epoch: 36 train loss: 0.74803126 epoch time: 28.128s step time: 28.128ms
epoch: 37 train loss: 0.8695539 epoch time: 28.032s step time: 28.032ms
epoch: 38 train loss: 0.686597 epoch time: 28.025s step time: 28.025ms
epoch: 39 train loss: 0.9947252 epoch time: 28.032s step time: 28.032ms
epoch: 40 train loss: 0.8597307 epoch time: 28.046s step time: 28.046ms
================================Start Evaluation================================
on Gauss grid: 0.12830503433849663, on regular grid: 0.13030632202877585
=================================End Evaluation=================================
predict total time: 7.54561448097229 s
epoch: 41 train loss: 0.5904021 epoch time: 28.101s step time: 28.101ms
epoch: 42 train loss: 0.6276789 epoch time: 28.145s step time: 28.145ms
epoch: 43 train loss: 0.62192535 epoch time: 28.092s step time: 28.092ms
epoch: 44 train loss: 0.6407144 epoch time: 28.059s step time: 28.059ms
epoch: 45 train loss: 0.60519314 epoch time: 28.014s step time: 28.014ms
epoch: 46 train loss: 1.0048012 epoch time: 28.078s step time: 28.078ms
epoch: 47 train loss: 0.5551628 epoch time: 28.087s step time: 28.087ms
epoch: 48 train loss: 0.8461705 epoch time: 28.101s step time: 28.101ms
epoch: 49 train loss: 0.7118721 epoch time: 28.077s step time: 28.077ms
epoch: 50 train loss: 0.55335164 epoch time: 28.170s step time: 28.170ms
================================Start Evaluation================================
on Gauss grid: 0.08227695803437382, on regular grid: 0.08470196191734738
=================================End Evaluation=================================
predict total time: 7.394194602966309 s
epoch: 51 train loss: 0.636775 epoch time: 28.049s step time: 28.049ms
epoch: 52 train loss: 0.5920238 epoch time: 28.095s step time: 28.095ms
epoch: 53 train loss: 0.58135617 epoch time: 28.278s step time: 28.278ms
epoch: 54 train loss: 0.7213563 epoch time: 28.203s step time: 28.203ms
epoch: 55 train loss: 0.71770614 epoch time: 28.166s step time: 28.166ms
epoch: 56 train loss: 0.48096988 epoch time: 28.130s step time: 28.130ms
epoch: 57 train loss: 0.5998644 epoch time: 28.143s step time: 28.143ms
epoch: 58 train loss: 0.6089008 epoch time: 28.111s step time: 28.111ms
epoch: 59 train loss: 0.595509 epoch time: 28.200s step time: 28.200ms
epoch: 60 train loss: 0.6066635 epoch time: 28.149s step time: 28.149ms
================================Start Evaluation================================
on Gauss grid: 0.08370403416315093, on regular grid: 0.08586561499600351
=================================End Evaluation=================================
predict total time: 7.493133306503296 s
epoch: 61 train loss: 0.5519717 epoch time: 28.119s step time: 28.119ms
epoch: 62 train loss: 0.4908938 epoch time: 28.166s step time: 28.166ms
epoch: 63 train loss: 0.43803358 epoch time: 28.126s step time: 28.126ms
epoch: 64 train loss: 0.47794145 epoch time: 28.171s step time: 28.171ms
epoch: 65 train loss: 0.504622 epoch time: 28.176s step time: 28.176ms
epoch: 66 train loss: 0.44892752 epoch time: 28.074s step time: 28.074ms
epoch: 67 train loss: 0.6695643 epoch time: 28.069s step time: 28.069ms
epoch: 68 train loss: 0.5254482 epoch time: 28.147s step time: 28.147ms
epoch: 69 train loss: 0.43325588 epoch time: 28.253s step time: 28.253ms
epoch: 70 train loss: 0.4950175 epoch time: 28.150s step time: 28.150ms
================================Start Evaluation================================
on Gauss grid: 0.07004086356284096, on regular grid: 0.07265735937107769
=================================End Evaluation=================================
predict total time: 7.431047439575195 s
epoch: 71 train loss: 0.48058861 epoch time: 28.090s step time: 28.090ms
epoch: 72 train loss: 0.48115337 epoch time: 28.087s step time: 28.087ms
epoch: 73 train loss: 0.5245213 epoch time: 28.215s step time: 28.215ms
epoch: 74 train loss: 0.40916815 epoch time: 28.153s step time: 28.153ms
epoch: 75 train loss: 0.48107946 epoch time: 28.155s step time: 28.155ms
epoch: 76 train loss: 0.4762331 epoch time: 28.062s step time: 28.062ms
epoch: 77 train loss: 0.5066639 epoch time: 28.141s step time: 28.141ms
epoch: 78 train loss: 0.43607965 epoch time: 28.142s step time: 28.142ms
epoch: 79 train loss: 0.49439412 epoch time: 28.142s step time: 28.142ms
epoch: 80 train loss: 0.45099196 epoch time: 28.150s step time: 28.150ms
================================Start Evaluation================================
on Gauss grid: 0.053801001163199545, on regular grid: 0.059015438375345855
=================================End Evaluation=================================
predict total time: 7.7433998584747314 s
epoch: 81 train loss: 0.66613305 epoch time: 28.178s step time: 28.178ms
epoch: 82 train loss: 0.3882894 epoch time: 28.167s step time: 28.167ms
epoch: 83 train loss: 0.5185521 epoch time: 28.212s step time: 28.212ms
epoch: 84 train loss: 0.49510124 epoch time: 28.142s step time: 28.142ms
epoch: 85 train loss: 0.46369594 epoch time: 28.168s step time: 28.168ms
epoch: 86 train loss: 0.37444192 epoch time: 28.185s step time: 28.185ms
epoch: 87 train loss: 0.38335305 epoch time: 27.993s step time: 27.993ms
epoch: 88 train loss: 0.523732 epoch time: 27.984s step time: 27.984ms
epoch: 89 train loss: 0.46601093 epoch time: 28.099s step time: 28.099ms
epoch: 90 train loss: 0.46671164 epoch time: 28.167s step time: 28.167ms
================================Start Evaluation================================
on Gauss grid: 0.05075095896422863, on regular grid: 0.05621649529775642
=================================End Evaluation=================================
predict total time: 7.390618801116943 s
[ ]:
from src import visual
visual(model, test_input, data_params)
[19]:
from IPython.display import Image, display
display(Image(filename='images/result.jpg', format='jpg', embed=True))
../_images/data_driven_navier_stokes_SNO2D_21_0.jpg
[18]:
with open('images/result.gif', 'rb') as f:
    display(Image(data=f.read(), format='png', embed=True))
../_images/data_driven_navier_stokes_SNO2D_22_0.png