基于二维谱神经算子的纳维斯托克斯方程求解

下载Notebook下载样例代码查看源文件

概述

计算流体力学是21世纪流体力学领域的重要技术之一,其通过使用数值方法在计算机中对流体力学的控制方程进行求解,从而实现流动的分析、预测和控制。传统的有限元法(finite element method,FEM)和有限差分法(finite difference method,FDM)常用于复杂的仿真流程(物理建模、网格划分、数值离散、迭代求解等)和较高的计算成本,往往效率低下。因此,借助AI提升流体仿真效率是十分必要的。

近年来,随着神经网络的迅猛发展,为科学计算提供了新的范式。经典的神经网络是在有限维度的空间进行映射,只能学习与特定离散化相关的解。与经典神经网络不同,傅里叶神经算子(Fourier Neural Operator,FNO)是一种能够学习无限维函数空间映射的新型深度学习架构。该架构可直接学习从任意函数参数到解的映射,用于解决一类偏微分方程的求解问题,具有更强的泛化能力。更多信息可参考Fourier Neural Operator for Parametric Partial Differential Equations

谱神经算子(Spectral Neural Operator,SNO)是利用多项式将计算变换到频谱空间(Chebyshev,Legendre等)的类似FNO的架构。与FNO相比, SNO的特点是由混淆误差引起的系统偏差较小。其中最重要的好处之一是SNO的基的选择更为宽泛,因此可以在其中找到一组最方便表示的多项式。例如,针对问题的对称性或针对时间间隔来选取适应的基。此外,当输入定义在在非结构化网格上时,基于正交多项式的神经算子相比其他谱算子或更有竞争力。

更多信息可参考, “Spectral Neural Operators”. arXiv preprint arXiv:2205.10573 (2022).

本案例教程介绍了利用频谱神经算子求解Navier-Stokes方程的方法。

纳维-斯托克斯方程(Navier-Stokes equation)

纳维-斯托克斯方程(Navier-Stokes equation)是计算流体力学领域的经典方程,是一组描述流体动量守恒的偏微分方程,简称N-S方程。它在二维不可压缩流动中的涡量形式如下:

tw(x,t)+u(x,t)w(x,t)=νΔw(x,t)+f(x),x(0,1)2,t(0,T]
u(x,t)=0,x(0,1)2,t[0,T]
w(x,0)=w0(x),x(0,1)2

其中u表示速度场,w=×u表示涡量,w0(x)表示初始条件,ν表示粘度系数,f(x)为外力合力项。

问题描述

本案例利用Spectral Neural Operator学习某一个时刻对应涡量到下一时刻涡量的映射,实现二维不可压缩N-S方程的求解:

wtw(,t+1)

技术路径

MindFlow求解该问题的具体流程如下:

  1. 创建数据集。

  2. 构建模型。

  3. 优化器与损失函数。

  4. 模型训练。

Spectral Neural Operator

U-SNO修改

下图显示了谱神经算子的架构,它由编码器、多层谱卷积层(谱空间的线性变换)和解码器组成。要计算频谱卷积的正向和逆多项式变换矩阵,应在相应的Gauss正交节点(Chebyshev网格等)对输入进行插值。通过卷积编码层将插值后的输入提升到更高维度的通道。其结果将经过多层谱卷积层,每个层对其截断的谱表示应用线性卷积。SNO层的输出通过卷积解码器投影回目标维度,最后插值回原始节点。

SNO层执行以下操作:将多项式变换A应用于光谱空间(Chebyshev,Legendre等)操作;多项式低阶模态上的线性卷积L操作,高阶模态上的过滤操作;而后,应用逆变换 S=A1(回到物理空间)。然后添加输入层的线性卷积 W操作,并应用非线性激活层。

U-SNO是基于SNO的强化修改,其中,一系列修改过的SNO卷积层被布置在主序列之后。在修改后的U-SNO层中,UNet体系结构(具有自定义的步骤数)被用作跳过块来代替线性的W

SNO网络结构

[1]:
import os
import time
import numpy as np

import mindspore
from mindspore import nn, context, ops, Tensor, jit, set_seed, save_checkpoint
import mindspore.common.dtype as mstype

下述src包可以在applications/data_driven/navier_stokes/sno2d/src下载。

[2]:
from mindflow.cell import SNO2D, get_poly_transform
from mindflow.utils import load_yaml_config, print_log
from mindflow.pde import UnsteadyFlowWithLoss
from src import create_training_dataset, load_interp_data, calculate_l2_error
from mindflow.loss import RelativeRMSELoss
from mindflow.common import get_warmup_cosine_annealing_lr

set_seed(0)
np.random.seed(0)
[3]:
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', device_id=0)
use_ascend = context.get_context(attr_key='device_target') == "Ascend"
config = load_yaml_config('./configs/sno2d.yaml')

data_params = config["data"]
model_params = config["model"]
optimizer_params = config["optimizer"]
summary_params = config["summary"]

创建数据集

训练与测试数据下载: data_driven/navier_stokes/dataset .

本案例根据Zongyi Li在 Fourier Neural Operator for Parametric Partial Differential Equations 一文中对数据集的设置生成训练数据集与测试数据集。具体设置如下:

基于周期性边界,生成满足如下分布的初始条件w0(x)

w0μ,μ=N(0,73/2(Δ+49I)2.5)

外力项设置为:

f(x)=0.1(sin(2π(x1+x2))+cos(2π(x1+x2)))

采用Crank-Nicolson方法生成数据,时间步长设置为1e-4,最终数据以每 t = 1 个时间单位记录解。所有数据均在256×256的网格上生成,并被下采样至64×64网格。本案例选取粘度系数ν=1e5,训练集样本量为19000个,测试集样本量为3800个。

[4]:
poly_type = data_params['poly_type']
load_interp_data(data_params, dataset_type='train')
train_dataset = create_training_dataset(data_params, shuffle=True)

test_data = load_interp_data(data_params, dataset_type='test')
test_input = test_data['test_inputs']
test_label = test_data['test_labels']

batch_size = data_params['batch_size']
resolution = data_params['resolution']

构建模型

网络由1个Encoding layer、多个Spectral layer和Decoding block组成:

  • 编码卷积在情况下对应SNO2D.encoder,将输入数据x映射到高维;

  • 在这种情况下,SNO层序列对应于SNO2D.sno_kernel。使用多项式变换的输入矩阵(两个空间变量各自的正反转换)来实现时空域和频域之间的转换;这里,它由两个子序列组成,分别带有SNO层和U-SNO层。

  • 解码层对应SNO2D.decoder,由两个卷积组成。解码器用于获得最终预测。

[5]:
n_modes = model_params['modes']

transform_data = get_poly_transform(resolution, n_modes, poly_type)

transform = Tensor(transform_data["analysis"], mstype.float32)
inv_transform = Tensor(transform_data["synthesis"], mstype.float32)

model = SNO2D(in_channels=model_params['in_channels'],
              out_channels=model_params['out_channels'],
              hidden_channels=model_params['hidden_channels'],
              num_sno_layers=model_params['sno_layers'],
              kernel_size=model_params['kernel_size'],
              transforms=[[transform, inv_transform]]*2,
              num_usno_layers=model_params['usno_layers'],
              num_unet_strides=model_params['unet_strides'],
              compute_dtype=mstype.float32)

total = 0
for param in model.get_parameters():
    print_log(param.shape)
    total += param.size
print_log(f"Total Parameters:{total}")
(64, 1, 1, 1)
(64, 64, 5, 5)
(64, 64, 1, 1)
(64, 64, 5, 5)
(64, 64, 1, 1)
(64, 64, 5, 5)
(64, 64, 1, 1)
(64, 64, 5, 5)
(64, 64, 3, 3)
(64, 64, 3, 3)
(128, 64, 3, 3)
(128, 128, 3, 3)
(256, 128, 3, 3)
(256, 256, 3, 3)
(256, 128, 2, 2)
(128, 256, 3, 3)
(128, 128, 3, 3)
(128, 64, 2, 2)
(64, 128, 3, 3)
(64, 64, 3, 3)
(64, 128, 3, 3)
(64, 64, 1, 1)
(1, 64, 1, 1)
Total Parameters:2396288

优化器与损失函数

[10]:
steps_per_epoch = train_dataset.get_dataset_size()
grad_clip_norm = optimizer_params['grad_clip_norm']
[8]:
lr = get_warmup_cosine_annealing_lr(lr_init=optimizer_params['learning_rate'],
                                    last_epoch=optimizer_params["epochs"],
                                    steps_per_epoch=steps_per_epoch,
                                    warmup_epochs=1)

optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=Tensor(lr),
                               weight_decay=optimizer_params['weight_decay'])
problem = UnsteadyFlowWithLoss(model, loss_fn=RelativeRMSELoss(), data_format="NTCHW")

模型训练

使用MindSpore >= 2.0.0的版本,可以使用函数式编程范式训练神经网络。

[11]:
def train():
    def forward_fn(train_inputs, train_label):
        loss = problem.get_loss(train_inputs, train_label)
        return loss

    grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)

    @jit
    def train_step(train_inputs, train_label):
        loss, grads = grad_fn(train_inputs, train_label)
        grads = ops.clip_by_global_norm(grads, grad_clip_norm)
        loss = ops.depend(loss, optimizer(grads))
        return loss

    sink_process = mindspore.data_sink(train_step, train_dataset, sink_size=1)
    ckpt_dir = os.path.join(model_params["root_dir"], summary_params["ckpt_dir"])
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    for epoch in range(1, 1 + optimizer_params["epochs"]):
        local_time_beg = time.time()
        model.set_train(True)
        for _ in range(steps_per_epoch):
            cur_loss = sink_process()

        local_time_end = time.time()
        epoch_seconds = local_time_end - local_time_beg
        step_seconds = (epoch_seconds/steps_per_epoch)*1000
        print_log(f"epoch: {epoch} train loss: {cur_loss} "
                  f"epoch time: {epoch_seconds:.3f}s step time: {step_seconds:5.3f}ms")

        model.set_train(False)
        if epoch % summary_params["save_ckpt_interval"] == 0:
            save_checkpoint(model, os.path.join(ckpt_dir, f"{model_params['name']}_epoch{epoch}"))

        if epoch % summary_params['test_interval'] == 0:
            calculate_l2_error(model, test_input, test_label, data_params)

[12]:
train()
epoch: 1 train loss: 1.9672374 epoch time: 34.144s step time: 34.144ms
epoch: 2 train loss: 1.8687398 epoch time: 28.038s step time: 28.038ms
epoch: 3 train loss: 1.6240175 epoch time: 28.094s step time: 28.094ms
epoch: 4 train loss: 1.812437 epoch time: 28.001s step time: 28.001ms
epoch: 5 train loss: 1.6048276 epoch time: 28.006s step time: 28.006ms
epoch: 6 train loss: 1.3349447 epoch time: 28.045s step time: 28.045ms
epoch: 7 train loss: 1.445535 epoch time: 28.084s step time: 28.084ms
epoch: 8 train loss: 1.287163 epoch time: 28.050s step time: 28.050ms
epoch: 9 train loss: 1.2205887 epoch time: 28.079s step time: 28.079ms
epoch: 10 train loss: 1.1622387 epoch time: 28.048s step time: 28.048ms
================================Start Evaluation================================
on Gauss grid: 0.2202785452026874, on regular grid: 0.21447483365566075
=================================End Evaluation=================================
predict total time: 7.394038677215576 s
epoch: 11 train loss: 0.98966134 epoch time: 28.090s step time: 28.090ms
epoch: 12 train loss: 0.9963242 epoch time: 28.080s step time: 28.080ms
epoch: 13 train loss: 1.0154707 epoch time: 28.125s step time: 28.125ms
epoch: 14 train loss: 1.029425 epoch time: 28.087s step time: 28.087ms
epoch: 15 train loss: 1.0535842 epoch time: 28.069s step time: 28.069ms
epoch: 16 train loss: 1.0508957 epoch time: 28.217s step time: 28.217ms
epoch: 17 train loss: 0.73175216 epoch time: 28.239s step time: 28.239ms
epoch: 18 train loss: 0.7978346 epoch time: 28.060s step time: 28.060ms
epoch: 19 train loss: 1.2525742 epoch time: 28.057s step time: 28.057ms
epoch: 20 train loss: 1.0816319 epoch time: 28.052s step time: 28.052ms
================================Start Evaluation================================
on Gauss grid: 0.17742644541244953, on regular grid: 0.17132807192601202
=================================End Evaluation=================================
predict total time: 7.578975677490234 s
epoch: 21 train loss: 0.9601194 epoch time: 28.033s step time: 28.033ms
epoch: 22 train loss: 1.0366433 epoch time: 28.100s step time: 28.100ms
epoch: 23 train loss: 0.9956419 epoch time: 28.061s step time: 28.061ms
epoch: 24 train loss: 1.0766693 epoch time: 28.125s step time: 28.125ms
epoch: 25 train loss: 0.9773072 epoch time: 28.022s step time: 28.022ms
epoch: 26 train loss: 0.65455425 epoch time: 28.086s step time: 28.086ms
epoch: 27 train loss: 0.71299446 epoch time: 28.006s step time: 28.006ms
epoch: 28 train loss: 1.0231717 epoch time: 28.170s step time: 28.170ms
epoch: 29 train loss: 0.8839726 epoch time: 28.143s step time: 28.143ms
epoch: 30 train loss: 0.90894026 epoch time: 28.124s step time: 28.124ms
================================Start Evaluation================================
on Gauss grid: 0.16749235310871155, on regular grid: 0.169489491779019
=================================End Evaluation=================================
predict total time: 7.71979022026062 s
epoch: 31 train loss: 0.9652164 epoch time: 28.092s step time: 28.092ms
epoch: 32 train loss: 0.6686845 epoch time: 28.096s step time: 28.096ms
epoch: 33 train loss: 0.8932849 epoch time: 28.107s step time: 28.107ms
epoch: 34 train loss: 0.7517134 epoch time: 28.208s step time: 28.208ms
epoch: 35 train loss: 0.825667 epoch time: 28.188s step time: 28.188ms
epoch: 36 train loss: 0.74803126 epoch time: 28.128s step time: 28.128ms
epoch: 37 train loss: 0.8695539 epoch time: 28.032s step time: 28.032ms
epoch: 38 train loss: 0.686597 epoch time: 28.025s step time: 28.025ms
epoch: 39 train loss: 0.9947252 epoch time: 28.032s step time: 28.032ms
epoch: 40 train loss: 0.8597307 epoch time: 28.046s step time: 28.046ms
================================Start Evaluation================================
on Gauss grid: 0.12830503433849663, on regular grid: 0.13030632202877585
=================================End Evaluation=================================
predict total time: 7.54561448097229 s
epoch: 41 train loss: 0.5904021 epoch time: 28.101s step time: 28.101ms
epoch: 42 train loss: 0.6276789 epoch time: 28.145s step time: 28.145ms
epoch: 43 train loss: 0.62192535 epoch time: 28.092s step time: 28.092ms
epoch: 44 train loss: 0.6407144 epoch time: 28.059s step time: 28.059ms
epoch: 45 train loss: 0.60519314 epoch time: 28.014s step time: 28.014ms
epoch: 46 train loss: 1.0048012 epoch time: 28.078s step time: 28.078ms
epoch: 47 train loss: 0.5551628 epoch time: 28.087s step time: 28.087ms
epoch: 48 train loss: 0.8461705 epoch time: 28.101s step time: 28.101ms
epoch: 49 train loss: 0.7118721 epoch time: 28.077s step time: 28.077ms
epoch: 50 train loss: 0.55335164 epoch time: 28.170s step time: 28.170ms
================================Start Evaluation================================
on Gauss grid: 0.08227695803437382, on regular grid: 0.08470196191734738
=================================End Evaluation=================================
predict total time: 7.394194602966309 s
epoch: 51 train loss: 0.636775 epoch time: 28.049s step time: 28.049ms
epoch: 52 train loss: 0.5920238 epoch time: 28.095s step time: 28.095ms
epoch: 53 train loss: 0.58135617 epoch time: 28.278s step time: 28.278ms
epoch: 54 train loss: 0.7213563 epoch time: 28.203s step time: 28.203ms
epoch: 55 train loss: 0.71770614 epoch time: 28.166s step time: 28.166ms
epoch: 56 train loss: 0.48096988 epoch time: 28.130s step time: 28.130ms
epoch: 57 train loss: 0.5998644 epoch time: 28.143s step time: 28.143ms
epoch: 58 train loss: 0.6089008 epoch time: 28.111s step time: 28.111ms
epoch: 59 train loss: 0.595509 epoch time: 28.200s step time: 28.200ms
epoch: 60 train loss: 0.6066635 epoch time: 28.149s step time: 28.149ms
================================Start Evaluation================================
on Gauss grid: 0.08370403416315093, on regular grid: 0.08586561499600351
=================================End Evaluation=================================
predict total time: 7.493133306503296 s
epoch: 61 train loss: 0.5519717 epoch time: 28.119s step time: 28.119ms
epoch: 62 train loss: 0.4908938 epoch time: 28.166s step time: 28.166ms
epoch: 63 train loss: 0.43803358 epoch time: 28.126s step time: 28.126ms
epoch: 64 train loss: 0.47794145 epoch time: 28.171s step time: 28.171ms
epoch: 65 train loss: 0.504622 epoch time: 28.176s step time: 28.176ms
epoch: 66 train loss: 0.44892752 epoch time: 28.074s step time: 28.074ms
epoch: 67 train loss: 0.6695643 epoch time: 28.069s step time: 28.069ms
epoch: 68 train loss: 0.5254482 epoch time: 28.147s step time: 28.147ms
epoch: 69 train loss: 0.43325588 epoch time: 28.253s step time: 28.253ms
epoch: 70 train loss: 0.4950175 epoch time: 28.150s step time: 28.150ms
================================Start Evaluation================================
on Gauss grid: 0.07004086356284096, on regular grid: 0.07265735937107769
=================================End Evaluation=================================
predict total time: 7.431047439575195 s
epoch: 71 train loss: 0.48058861 epoch time: 28.090s step time: 28.090ms
epoch: 72 train loss: 0.48115337 epoch time: 28.087s step time: 28.087ms
epoch: 73 train loss: 0.5245213 epoch time: 28.215s step time: 28.215ms
epoch: 74 train loss: 0.40916815 epoch time: 28.153s step time: 28.153ms
epoch: 75 train loss: 0.48107946 epoch time: 28.155s step time: 28.155ms
epoch: 76 train loss: 0.4762331 epoch time: 28.062s step time: 28.062ms
epoch: 77 train loss: 0.5066639 epoch time: 28.141s step time: 28.141ms
epoch: 78 train loss: 0.43607965 epoch time: 28.142s step time: 28.142ms
epoch: 79 train loss: 0.49439412 epoch time: 28.142s step time: 28.142ms
epoch: 80 train loss: 0.45099196 epoch time: 28.150s step time: 28.150ms
================================Start Evaluation================================
on Gauss grid: 0.053801001163199545, on regular grid: 0.059015438375345855
=================================End Evaluation=================================
predict total time: 7.7433998584747314 s
epoch: 81 train loss: 0.66613305 epoch time: 28.178s step time: 28.178ms
epoch: 82 train loss: 0.3882894 epoch time: 28.167s step time: 28.167ms
epoch: 83 train loss: 0.5185521 epoch time: 28.212s step time: 28.212ms
epoch: 84 train loss: 0.49510124 epoch time: 28.142s step time: 28.142ms
epoch: 85 train loss: 0.46369594 epoch time: 28.168s step time: 28.168ms
epoch: 86 train loss: 0.37444192 epoch time: 28.185s step time: 28.185ms
epoch: 87 train loss: 0.38335305 epoch time: 27.993s step time: 27.993ms
epoch: 88 train loss: 0.523732 epoch time: 27.984s step time: 27.984ms
epoch: 89 train loss: 0.46601093 epoch time: 28.099s step time: 28.099ms
epoch: 90 train loss: 0.46671164 epoch time: 28.167s step time: 28.167ms
================================Start Evaluation================================
on Gauss grid: 0.05075095896422863, on regular grid: 0.05621649529775642
=================================End Evaluation=================================
predict total time: 7.390618801116943 s
[ ]:
from src import visual
visual(model, test_input, data_params)
[19]:
from IPython.display import Image, display
display(Image(filename='images/result.jpg', format='jpg', embed=True))
../_images/data_driven_navier_stokes_SNO2D_21_0.jpg
[18]:
with open('images/result.gif', 'rb') as f:
    display(Image(data=f.read(), format='png', embed=True))
../_images/data_driven_navier_stokes_SNO2D_22_0.png