基于三维谱神经算子的纳维斯托克斯方程求解
概述
计算流体力学是21世纪流体力学领域的重要技术之一,其通过使用数值方法在计算机中对流体力学的控制方程进行求解,从而实现流动的分析、预测和控制。传统的有限元法(finite element method,FEM)和有限差分法(finite difference method,FDM)常用于复杂的仿真流程(物理建模、网格划分、数值离散、迭代求解等)和较高的计算成本,往往效率低下。因此,借助AI提升流体仿真效率是十分必要的。
近年来,随着神经网络的迅猛发展,为科学计算提供了新的范式。经典的神经网络是在有限维度的空间进行映射,只能学习与特定离散化相关的解。与经典神经网络不同,傅里叶神经算子(Fourier Neural Operator,FNO)是一种能够学习无限维函数空间映射的新型深度学习架构。该架构可直接学习从任意函数参数到解的映射,用于解决一类偏微分方程的求解问题,具有更强的泛化能力。更多信息可参考Fourier Neural Operator for Parametric Partial Differential Equations。
谱神经算子(Spectral Neural Operator,SNO)是利用多项式将计算变换到频谱空间(Chebyshev,Legendre等)的类似FNO的架构。与FNO相比, SNO的特点是由混淆误差引起的系统偏差较小。其中最重要的好处之一是SNO的基的选择更为宽泛,因此可以在其中找到一组最方便表示的多项式。例如,针对问题的对称性或针对时间间隔来选取适应的基。此外,当输入定义在在非结构化网格上时,基于正交多项式的神经算子相比其他谱算子或更有竞争力。
更多信息可参考, “Spectral Neural Operators”. arXiv preprint arXiv:2205.10573 (2022).
本教程描述如何使用3D SNO求解Navier-Stokes方程。
问题描述
我们的目标是学习算子将涡量的前10个步骤映射到完整的轨迹[10,T]:
技术路径
MindFlow求解该问题的具体流程如下:
创建数据集。
构建模型。
优化器与损失函数。
模型训练。
Spectral Neural Operator
下图显示了谱神经算子的架构,它由编码器、多层谱卷积层(谱空间的线性变换)和解码器组成。要计算频谱卷积的正向和逆多项式变换矩阵,应在相应的Gauss正交节点(Chebyshev网格等)对输入进行插值。通过卷积编码层将插值后的输入提升到更高维度的通道。其结果将经过多层谱卷积层,每个层对其截断的谱表示应用线性卷积。SNO层的输出通过卷积解码器投影回目标维度,最后插值回原始节点。
SNO层执行以下操作:将多项式变换
[1]:
import os
import time
import numpy as np
from mindspore import nn, ops, jit, data_sink, context, Tensor
from mindspore.common import set_seed
from mindspore import dtype as mstype
下述src
包可以在applications/data_driven/navier_stokes/sno3d/src下载。
[2]:
from mindflow import get_warmup_cosine_annealing_lr, load_yaml_config
from mindflow.utils import print_log
from mindflow.cell import SNO3D, get_poly_transform
from src import calculate_l2_error, UnitGaussianNormalizer, create_training_dataset, load_interp_data, visual
set_seed(0)
np.random.seed(0)
[3]:
# set context for training: using graph mode for high performance training with GPU acceleration
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', device_id=0)
use_ascend = context.get_context(attr_key='device_target') == "Ascend"
config = load_yaml_config('./configs/sno3d.yaml')
data_params = config["data"]
model_params = config["model"]
optimizer_params = config["optimizer"]
summary_params = config["summary"]
grid_size = data_params["resolution"]
input_timestep = model_params["input_timestep"]
output_timestep = model_params["extrapolations"]
创建数据集
训练与测试数据下载: data_driven/navier_stokes/dataset .
本案例根据Zongyi Li在 Fourier Neural Operator for Parametric Partial Differential Equations 一文中对数据集的设置生成训练数据集与测试数据集。具体设置如下:
基于周期性边界,生成满足如下分布的初始条件
外力项设置为:
采用Crank-Nicolson方法生成数据,时间步长设置为1e-4,最终数据以每 t = 1 个时间单位记录解。所有数据均在256×256的网格上生成,并被下采样至64×64网格。本案例选取粘度系数
[4]:
load_interp_data(data_params, 'train')
train_loader = create_training_dataset(data_params, shuffle=True)
test_data = load_interp_data(data_params, 'test')
test_a = Tensor(test_data['a'], mstype.float32)
test_u = Tensor(test_data['u'], mstype.float32)
test_u_unif = np.load(os.path.join(data_params['root_dir'], 'test/test_u.npy'))
train_a = Tensor(np.load(os.path.join(
data_params["root_dir"], "train/train_a_interp.npy")), mstype.float32)
train_u = Tensor(np.load(os.path.join(
data_params["root_dir"], "train/train_u_interp.npy")), mstype.float32)
train a, u: (1000, 10, 64, 64) (1000, 40, 64, 64)
test a, u: (200, 10, 64, 64) (200, 40, 64, 64)
构建模型
网络由1个Encoding layer、多个Spectral layer和Decoding block组成:
编码卷积在情况下对应
SNO3D.encoder
,将输入数据 映射到高维;在这种情况下,SNO层序列对应于
SNO3D.sno_kernel
。使用多项式变换的输入矩阵(三个变量中每个变量的正反转换)来实现时空域和频域之间的转换;解码层对应
SNO3D.decoder
,由两个卷积组成。解码器用于获得最终预测。
[5]:
n_modes = model_params['modes']
poly_type = data_params['poly_type']
transform_data = get_poly_transform(grid_size, n_modes, poly_type)
transform = Tensor(transform_data["analysis"], mstype.float32)
inv_transform = Tensor(transform_data["synthesis"], mstype.float32)
transform_t_axis = get_poly_transform(output_timestep, n_modes, poly_type)
transform_t = Tensor(transform_t_axis["analysis"], mstype.float32)
inv_transform_t = Tensor(transform_t_axis["synthesis"], mstype.float32)
transforms = [[transform, inv_transform]] * 2 + [[transform_t, inv_transform_t]]
[6]:
if use_ascend:
compute_type = mstype.float16
else:
compute_type = mstype.float32
# prepare model
model = SNO3D(in_channels=model_params['in_channels'],
out_channels=model_params['out_channels'],
hidden_channels=model_params['hidden_channels'],
num_sno_layers=model_params['sno_layers'],
transforms=transforms,
kernel_size=model_params['kernel_size'],
compute_dtype=compute_type)
model_params_list = []
for k, v in model_params.items():
model_params_list.append(f"{k}-{v}")
model_name = "_".join(model_params_list)
total = 0
for param in model.get_parameters():
print_log(param.shape)
total += param.size
print_log(f"Total Parameters:{total}")
(64, 10, 1, 1, 1)
(64, 64, 7, 7, 7)
(64, 64, 1, 1, 1)
(64, 64, 7, 7, 7)
(64, 64, 1, 1, 1)
(64, 64, 7, 7, 7)
(64, 64, 1, 1, 1)
(64, 64, 7, 7, 7)
(64, 64, 1, 1, 1)
(64, 64, 7, 7, 7)
(64, 64, 1, 1, 1)
(64, 64, 1, 1, 1)
(1, 64, 1, 1, 1)
Total Parameters:7049920
优化器与损失函数
[ ]:
lr = get_warmup_cosine_annealing_lr(lr_init=optimizer_params["learning_rate"],
last_epoch=optimizer_params["epochs"],
steps_per_epoch=train_loader.get_dataset_size(),
warmup_epochs=optimizer_params["warmup_epochs"])
steps_per_epoch = train_loader.get_dataset_size()
optimizer = nn.AdamWeightDecay(model.trainable_params(),
learning_rate=Tensor(lr),
eps=float(optimizer_params['eps']),
weight_decay=optimizer_params['weight_decay'])
loss_fn = nn.RMSELoss() #LpLoss()
a_normalizer = UnitGaussianNormalizer(train_a)
u_normalizer = UnitGaussianNormalizer(train_u)
训练函数
使用MindSpore>= 2.0.0的版本,可以使用函数式编程范式训练神经网络,单步训练函数使用jit装饰。数据下沉函数data_sink,传入单步训练函数和训练数据集。
[8]:
def forward_fn(data, label):
bs = data.shape[0]
data = a_normalizer.encode(data)
data = data.reshape(bs, input_timestep, grid_size, grid_size, 1).repeat(output_timestep, axis=-1)
logits = model(data).reshape(bs, output_timestep, grid_size, grid_size)
logits = u_normalizer.decode(logits)
loss = loss_fn(logits.reshape(bs, -1), label.reshape(bs, -1))
if use_ascend:
loss = loss_scaler.scale(loss)
return loss
grad_fn = ops.value_and_grad(
forward_fn, None, optimizer.parameters, has_aux=False)
from mindspore.amp import DynamicLossScaler, auto_mixed_precision, all_finite
if use_ascend:
loss_scaler = DynamicLossScaler(1024, 2, 100)
auto_mixed_precision(model, 'O2')
else:
loss_scaler = None
@jit
def train_step(data, label):
loss, grads = grad_fn(data, label)
if use_ascend:
loss = loss_scaler.unscale(loss)
is_finite = all_finite(grads)
if is_finite:
grads = loss_scaler.unscale(grads)
loss = ops.depend(loss, optimizer(grads))
loss_scaler.adjust(is_finite)
else:
loss = ops.depend(loss, optimizer(grads))
return loss
sink_process = data_sink(train_step, train_loader, sink_size=200)
模型训练
使用MindSpore >= 2.0.0的版本,可以使用函数式编程范式训练神经网络。
[11]:
summary_dir = os.path.join(summary_params["root_dir"], model_name)
ckpt_dir = os.path.join(summary_dir, summary_params["ckpt_dir"])
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
[9]:
def train():
for epoch in range(1, 1 + optimizer_params["epochs"]):
local_time_beg = time.time()
model.set_train(True)
cur_loss = sink_process()
local_time_end = time.time()
epoch_seconds = local_time_end - local_time_beg
step_seconds = (epoch_seconds/200)
print_log(
f"epoch: {epoch} train loss: {cur_loss} epoch time: {epoch_seconds:.3f}s step time: {step_seconds:5.3f}ms")
if epoch % summary_params['test_interval'] == 0:
model.set_train(False)
calculate_l2_error(model, test_a, test_u, data_params, a_normalizer, u_normalizer)
[11]:
train()
epoch: 1 train loss: 0.90118235 epoch time: 44.718s step time: 0.224s
epoch: 2 train loss: 0.91254395 epoch time: 40.240s step time: 0.201s
epoch: 3 train loss: 0.9374327 epoch time: 40.302s step time: 0.202s
epoch: 4 train loss: 0.85217404 epoch time: 40.482s step time: 0.202s
epoch: 5 train loss: 0.6309165 epoch time: 40.590s step time: 0.203s
epoch: 6 train loss: 0.4290015 epoch time: 40.576s step time: 0.203s
epoch: 7 train loss: 0.34428337 epoch time: 40.536s step time: 0.203s
epoch: 8 train loss: 0.34126174 epoch time: 40.564s step time: 0.203s
epoch: 9 train loss: 0.27420813 epoch time: 40.571s step time: 0.203s
epoch: 10 train loss: 0.2711888 epoch time: 40.554s step time: 0.203s
================================Start Evaluation================================
Error on Gauss grid: 0.28268588, on regular grid: 0.2779018060781111
predict total time: 26.014892578125 s
=================================End Evaluation=================================
epoch: 11 train loss: 0.2603902 epoch time: 40.542s step time: 0.203s
epoch: 12 train loss: 0.24578454 epoch time: 40.570s step time: 0.203s
epoch: 13 train loss: 0.23497193 epoch time: 40.543s step time: 0.203s
epoch: 14 train loss: 0.210803 epoch time: 40.543s step time: 0.203s
epoch: 15 train loss: 0.24416743 epoch time: 40.507s step time: 0.203s
epoch: 16 train loss: 0.2085956 epoch time: 40.520s step time: 0.203s
epoch: 17 train loss: 0.22456339 epoch time: 40.507s step time: 0.203s
epoch: 18 train loss: 0.20356481 epoch time: 40.494s step time: 0.202s
epoch: 19 train loss: 0.1977826 epoch time: 40.486s step time: 0.202s
epoch: 20 train loss: 0.21421571 epoch time: 40.487s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.21850872, on regular grid: 0.21519988102613533
predict total time: 24.13991665840149 s
=================================End Evaluation=================================
epoch: 21 train loss: 0.19105345 epoch time: 40.517s step time: 0.203s
epoch: 22 train loss: 0.1900783 epoch time: 40.514s step time: 0.203s
epoch: 23 train loss: 0.19938461 epoch time: 40.525s step time: 0.203s
epoch: 24 train loss: 0.17807631 epoch time: 40.475s step time: 0.202s
epoch: 25 train loss: 0.23215973 epoch time: 40.487s step time: 0.202s
epoch: 26 train loss: 0.16794981 epoch time: 40.480s step time: 0.202s
epoch: 27 train loss: 0.17212906 epoch time: 40.480s step time: 0.202s
epoch: 28 train loss: 0.18129097 epoch time: 40.507s step time: 0.203s
epoch: 29 train loss: 0.17482412 epoch time: 40.477s step time: 0.202s
epoch: 30 train loss: 0.16695607 epoch time: 40.476s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.21386118, on regular grid: 0.2106647958068735
predict total time: 24.241203784942627 s
=================================End Evaluation=================================
epoch: 31 train loss: 0.18177168 epoch time: 40.473s step time: 0.202s
epoch: 32 train loss: 0.21024966 epoch time: 40.516s step time: 0.203s
epoch: 33 train loss: 0.17173253 epoch time: 40.502s step time: 0.203s
epoch: 34 train loss: 0.16217099 epoch time: 40.476s step time: 0.202s
epoch: 35 train loss: 0.16301228 epoch time: 40.499s step time: 0.202s
epoch: 36 train loss: 0.18293448 epoch time: 40.498s step time: 0.202s
epoch: 37 train loss: 0.18147346 epoch time: 40.441s step time: 0.202s
epoch: 38 train loss: 0.16941778 epoch time: 40.447s step time: 0.202s
epoch: 39 train loss: 0.16393727 epoch time: 40.508s step time: 0.203s
epoch: 40 train loss: 0.14487892 epoch time: 40.456s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.17054155, on regular grid: 0.1669973841447345
predict total time: 24.262221574783325 s
=================================End Evaluation=================================
epoch: 41 train loss: 0.15602723 epoch time: 40.482s step time: 0.202s
epoch: 42 train loss: 0.1580032 epoch time: 40.502s step time: 0.203s
epoch: 43 train loss: 0.14684558 epoch time: 40.464s step time: 0.202s
epoch: 44 train loss: 0.1525133 epoch time: 40.450s step time: 0.202s
epoch: 45 train loss: 0.15542132 epoch time: 40.483s step time: 0.202s
epoch: 46 train loss: 0.14850396 epoch time: 40.461s step time: 0.202s
epoch: 47 train loss: 0.15148017 epoch time: 40.470s step time: 0.202s
epoch: 48 train loss: 0.1460498 epoch time: 40.457s step time: 0.202s
epoch: 49 train loss: 0.14232638 epoch time: 40.450s step time: 0.202s
epoch: 50 train loss: 0.14340377 epoch time: 40.448s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.15915471, on regular grid: 0.15595411313723867
predict total time: 21.78568935394287 s
=================================End Evaluation=================================
epoch: 51 train loss: 0.14372692 epoch time: 40.469s step time: 0.202s
epoch: 52 train loss: 0.14164849 epoch time: 40.487s step time: 0.202s
epoch: 53 train loss: 0.14629523 epoch time: 40.512s step time: 0.203s
epoch: 54 train loss: 0.1396117 epoch time: 40.464s step time: 0.202s
epoch: 55 train loss: 0.13634394 epoch time: 40.459s step time: 0.202s
epoch: 56 train loss: 0.13366798 epoch time: 40.463s step time: 0.202s
epoch: 57 train loss: 0.13632345 epoch time: 40.457s step time: 0.202s
epoch: 58 train loss: 0.13450852 epoch time: 40.474s step time: 0.202s
epoch: 59 train loss: 0.12455033 epoch time: 40.435s step time: 0.202s
epoch: 60 train loss: 0.1306016 epoch time: 40.483s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.15231597, on regular grid: 0.1492942070119114
predict total time: 23.996315956115723 s
=================================End Evaluation=================================
epoch: 61 train loss: 0.13203391 epoch time: 40.448s step time: 0.202s
epoch: 62 train loss: 0.13594246 epoch time: 40.470s step time: 0.202s
epoch: 63 train loss: 0.13565734 epoch time: 40.466s step time: 0.202s
epoch: 64 train loss: 0.12305962 epoch time: 40.435s step time: 0.202s
epoch: 65 train loss: 0.13006279 epoch time: 40.452s step time: 0.202s
epoch: 66 train loss: 0.12222704 epoch time: 40.474s step time: 0.202s
epoch: 67 train loss: 0.123683415 epoch time: 40.440s step time: 0.202s
epoch: 68 train loss: 0.120612934 epoch time: 40.453s step time: 0.202s
epoch: 69 train loss: 0.115140736 epoch time: 40.462s step time: 0.202s
epoch: 70 train loss: 0.12193731 epoch time: 40.420s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.14335541, on regular grid: 0.14026955797649515
predict total time: 23.51957106590271 s
=================================End Evaluation=================================
epoch: 71 train loss: 0.12505008 epoch time: 40.453s step time: 0.202s
epoch: 72 train loss: 0.12200938 epoch time: 40.484s step time: 0.202s
epoch: 73 train loss: 0.11936474 epoch time: 40.440s step time: 0.202s
epoch: 74 train loss: 0.12116067 epoch time: 40.480s step time: 0.202s
epoch: 75 train loss: 0.11600651 epoch time: 40.430s step time: 0.202s
epoch: 76 train loss: 0.11403544 epoch time: 40.447s step time: 0.202s
epoch: 77 train loss: 0.117489025 epoch time: 40.445s step time: 0.202s
epoch: 78 train loss: 0.10970513 epoch time: 40.479s step time: 0.202s
epoch: 79 train loss: 0.10635782 epoch time: 40.448s step time: 0.202s
epoch: 80 train loss: 0.11854948 epoch time: 40.459s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.13526691, on regular grid: 0.1325411629391487
predict total time: 22.508509397506714 s
=================================End Evaluation=================================
epoch: 81 train loss: 0.11022251 epoch time: 40.444s step time: 0.202s
epoch: 82 train loss: 0.108954415 epoch time: 40.509s step time: 0.203s
epoch: 83 train loss: 0.113180526 epoch time: 40.459s step time: 0.202s
epoch: 84 train loss: 0.106218904 epoch time: 40.431s step time: 0.202s
epoch: 85 train loss: 0.10933072 epoch time: 40.429s step time: 0.202s
epoch: 86 train loss: 0.10805362 epoch time: 40.442s step time: 0.202s
epoch: 87 train loss: 0.10749279 epoch time: 40.423s step time: 0.202s
epoch: 88 train loss: 0.112811126 epoch time: 40.471s step time: 0.202s
epoch: 89 train loss: 0.1098047 epoch time: 40.443s step time: 0.202s
epoch: 90 train loss: 0.110777 epoch time: 40.446s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.13124052, on regular grid: 0.1286052267835557
predict total time: 23.993195056915283 s
=================================End Evaluation=================================
epoch: 91 train loss: 0.10114018 epoch time: 40.435s step time: 0.202s
epoch: 92 train loss: 0.10804214 epoch time: 40.477s step time: 0.202s
epoch: 93 train loss: 0.103131406 epoch time: 40.461s step time: 0.202s
epoch: 94 train loss: 0.1079015 epoch time: 40.453s step time: 0.202s
epoch: 95 train loss: 0.10340427 epoch time: 40.445s step time: 0.202s
epoch: 96 train loss: 0.10799302 epoch time: 40.426s step time: 0.202s
epoch: 97 train loss: 0.1010814 epoch time: 40.448s step time: 0.202s
epoch: 98 train loss: 0.10470774 epoch time: 40.441s step time: 0.202s
epoch: 99 train loss: 0.105584204 epoch time: 40.422s step time: 0.202s
epoch: 100 train loss: 0.10661688 epoch time: 40.453s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.12790823, on regular grid: 0.12531590522074978
predict total time: 24.04036521911621 s
=================================End Evaluation=================================
epoch: 101 train loss: 0.09820426 epoch time: 40.432s step time: 0.202s
epoch: 102 train loss: 0.10459576 epoch time: 40.495s step time: 0.202s
epoch: 103 train loss: 0.100737445 epoch time: 40.475s step time: 0.202s
epoch: 104 train loss: 0.104481824 epoch time: 40.439s step time: 0.202s
epoch: 105 train loss: 0.10380473 epoch time: 40.432s step time: 0.202s
epoch: 106 train loss: 0.10476779 epoch time: 40.475s step time: 0.202s
epoch: 107 train loss: 0.1067871 epoch time: 40.444s step time: 0.202s
epoch: 108 train loss: 0.10912545 epoch time: 40.452s step time: 0.202s
epoch: 109 train loss: 0.09430095 epoch time: 40.442s step time: 0.202s
epoch: 110 train loss: 0.100769304 epoch time: 40.430s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.12614353, on regular grid: 0.12363198251396713
predict total time: 23.740464687347412 s
=================================End Evaluation=================================
epoch: 111 train loss: 0.12305746 epoch time: 40.437s step time: 0.202s
epoch: 112 train loss: 0.10381921 epoch time: 40.470s step time: 0.202s
epoch: 113 train loss: 0.107983574 epoch time: 40.429s step time: 0.202s
epoch: 114 train loss: 0.10639244 epoch time: 40.435s step time: 0.202s
epoch: 115 train loss: 0.098030716 epoch time: 40.430s step time: 0.202s
epoch: 116 train loss: 0.104712404 epoch time: 40.422s step time: 0.202s
epoch: 117 train loss: 0.10629137 epoch time: 40.435s step time: 0.202s
epoch: 118 train loss: 0.107867606 epoch time: 40.446s step time: 0.202s
epoch: 119 train loss: 0.11190843 epoch time: 40.422s step time: 0.202s
epoch: 120 train loss: 0.10280066 epoch time: 40.433s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.12565085, on regular grid: 0.12316060496613422
predict total time: 24.39193344116211 s
=================================End Evaluation=================================
[ ]:
visual(model, test_a, data_params, a_normalizer, u_normalizer)
[13]:
from IPython.display import Image, display
with open('images/input.gif', 'rb') as f:
display(Image(data=f.read(), format='png', embed=True))

[4]:
with open('images/result.gif', 'rb') as f:
display(Image(data=f.read(), format='png', embed=True))
