PeRCNN for 3D Reaction-Diffusion Equation
Overview
PDE equations occupy an important position in the modeling of physical systems. But many underlying PDEs have not yet been fully explored in epidemiology, meteorological science, fluid mechanics, and biology. However, for those known PDE equations, such as Naiver-Stokes equations, the exact numerical calculation of these equations requires huge computing power, which hinders the application of numerical simulation in large-scale systems. Recently, advances in machine learning provide a new way for PDE solution and inversion.
Recently, Huawei and Professor Sun Hao’s team from Renmin University of China proposed Physics-encoded Recurrent Convolutional Neural Network, PeRCNN(https://www.nature.com/articles/s42256-023-00685-7) based on Ascend platform and MindSpore. Compared with physical information neural network, ConvLSTM, PDE-NET and other methods, generalization and noise resistance of PeRCNN are significantly improved. The long-term prediction accuracy is improved by more than 10 times. This method has broad application prospects in aerospace, shipbuilding, weather forecasting and other fields. The results have been published in nature machine intelligence.
Problem Description
Reaction-diffusion equation is a partial derivative equation that is of great significance and has been broadly used in a variety of disciplines such as physics, chemistry and biology.
Governing Equation
In this research, RD equation is formulated as follow:
where,
In this case, we will simulate the flow dynamics in 100 time steps (dt=0.5s) in a \(\Omega \times \tau = {[-50,50]}^3 \times [0,500]\) physical domain. The initial condition of the problem would go through gaussian noise and periodic BC is adpoted.
Technology Path
MindSpore Flow solves the problem as follows:
Optimizer and One-step Training
Model Construction
Model training
Model Evaluation and Visualization.
[1]:
import os
import sys
import time
import numpy as np
[2]:
from mindspore import context, jit, nn, ops, save_checkpoint, set_seed
import mindspore.common.dtype as mstype
from mindflow.utils import load_yaml_config, print_log
from src import RecurrentCnn, post_process, Trainer, UpScaler, count_params
[3]:
set_seed(123456)
np.random.seed(123456)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=0)
Optimizer and One-step Training
[4]:
def train_stage(trainer, stage, config, ckpt_dir, use_ascend):
"""train stage"""
if use_ascend:
from mindspore.amp import DynamicLossScaler, all_finite
loss_scaler = DynamicLossScaler(2**10, 2, 100)
if 'milestone_num' in config.keys():
milestone = list([(config['epochs']//config['milestone_num'])*(i + 1)
for i in range(config['milestone_num'])])
learning_rate = config['learning_rate']
lr = float(config['learning_rate'])*np.array(list([config['gamma']
** i for i in range(config['milestone_num'])]))
learning_rate = nn.piecewise_constant_lr(milestone, list(lr))
else:
learning_rate = config['learning_rate']
if stage == 'pretrain':
params = trainer.upconv.trainable_params()
else:
params = trainer.upconv.trainable_params() + trainer.recurrent_cnn.trainable_params()
optimizer = nn.Adam(params, learning_rate=learning_rate)
def forward_fn():
if stage == 'pretrain':
loss = trainer.get_ic_loss()
else:
loss = trainer.get_loss()
if use_ascend:
loss = loss_scaler.scale(loss)
return loss
if stage == 'pretrain':
grad_fn = ops.value_and_grad(forward_fn, None, params, has_aux=False)
else:
grad_fn = ops.value_and_grad(forward_fn, None, params, has_aux=True)
@jit
def train_step():
loss, grads = grad_fn()
if use_ascend:
loss = loss_scaler.unscale(loss)
is_finite = all_finite(grads)
if is_finite:
grads = loss_scaler.unscale(grads)
loss = ops.depend(loss, optimizer(grads))
loss_scaler.adjust(is_finite)
else:
loss = ops.depend(loss, optimizer(grads))
return loss
best_loss = sys.maxsize
for epoch in range(1, 1 + config['epochs']):
time_beg = time.time()
trainer.upconv.set_train(True)
trainer.recurrent_cnn.set_train(True)
if stage == 'pretrain':
step_train_loss = train_step()
print_log(
f"epoch: {epoch} train loss: {step_train_loss} epoch time: {(time.time() - time_beg) :.3f} s")
else:
if epoch == 3800:
break
epoch_loss, loss_data, loss_ic, loss_phy, loss_valid = train_step()
print_log(f"epoch: {epoch} train loss: {epoch_loss} ic_loss: {loss_ic} data_loss: {loss_data} \
phy_loss: {loss_phy} valid_loss: {loss_valid} epoch time: {(time.time() - time_beg): .3f} s")
if epoch_loss < best_loss:
best_loss = epoch_loss
print_log('best loss', best_loss, 'save model')
save_checkpoint(trainer.upconv, os.path.join(ckpt_dir, "train_upconv.ckpt"))
save_checkpoint(trainer.recurrent_cnn,
os.path.join(ckpt_dir, "train_recurrent_cnn.ckpt"))
Model Construction
PeRCNN is composed of two networks which are UpSclaer for upscaling and recurrent CNN as a backbone.
[5]:
def train():
"""train"""
rd_config = load_yaml_config('./configs/percnn_3d_rd.yaml')
data_config = rd_config['data']
optim_config = rd_config['optimizer']
summary_config = rd_config['summary']
model_config = rd_config['model']
use_ascend = context.get_context(attr_key='device_target') == "Ascend"
print_log(f"use_ascend: {use_ascend}")
if use_ascend:
compute_dtype = mstype.float16
else:
compute_dtype = mstype.float32
upconv_config = model_config['upconv']
upconv = UpScaler(in_channels=upconv_config['in_channel'],
out_channels=upconv_config['out_channel'],
hidden_channels=upconv_config['hidden_channel'],
kernel_size=upconv_config['kernel_size'],
stride=upconv_config['stride'],
has_bais=True)
if use_ascend:
from mindspore.amp import auto_mixed_precision
auto_mixed_precision(upconv, 'O1')
rcnn_config = model_config['rcnn']
recurrent_cnn = RecurrentCnn(input_channels=rcnn_config['in_channel'],
hidden_channels=rcnn_config['hidden_channel'],
kernel_size=rcnn_config['kernel_size'],
stride=rcnn_config['stride'],
compute_dtype=compute_dtype)
percnn_trainer = Trainer(upconv=upconv,
recurrent_cnn=recurrent_cnn,
timesteps_for_train=data_config['rollout_steps'],
dx=data_config['dx'],
grid_size=data_config['grid_size'],
dt=data_config['dt'],
mu=data_config['mu'],
data_path=data_config['data_path'],
compute_dtype=compute_dtype)
total_params = int(count_params(upconv.trainable_params()) +
count_params(recurrent_cnn.trainable_params()))
print(f"There are {total_params} parameters")
ckpt_dir = summary_config["ckpt_dir"]
fig_path = summary_config["fig_save_path"]
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
train_stage(percnn_trainer, 'pretrain',
optim_config['pretrain'], ckpt_dir, use_ascend)
train_stage(percnn_trainer, 'finetune',
optim_config['finetune'], ckpt_dir, use_ascend)
output = percnn_trainer.get_output(100).asnumpy()
output = np.transpose(output, (1, 0, 2, 3, 4))[:, :-1:10]
print('output shape is ', output.shape)
for i in range(0, 10, 2):
post_process(output[0, i], fig_path, is_u=True, num=i)
Model Training
With MindSpore version >= 2.0.0, we can use the functional programming for training neural networks.
[1]:
train()
use_ascend: False
shape of uv is (3001, 2, 48, 48, 48)
shape of ic is (1, 2, 48, 48, 48)
shape of init_state_low is (1, 2, 24, 24, 24)
There are 10078 parameters
epoch: 1 train loss: 0.160835 epoch time: 5.545 s
epoch: 2 train loss: 104.36749 epoch time: 0.010 s
epoch: 3 train loss: 4.3207517 epoch time: 0.009 s
epoch: 4 train loss: 8.491383 epoch time: 0.009 s
epoch: 5 train loss: 23.683647 epoch time: 0.009 s
epoch: 6 train loss: 23.857117 epoch time: 0.010 s
epoch: 7 train loss: 16.037672 epoch time: 0.010 s
epoch: 8 train loss: 8.406443 epoch time: 0.009 s
epoch: 9 train loss: 3.527469 epoch time: 0.020 s
epoch: 10 train loss: 1.0823832 epoch time: 0.009 s
...
epoch: 9990 train loss: 8.7615306e-05 epoch time: 0.008 s
epoch: 9991 train loss: 8.76504e-05 epoch time: 0.008 s
epoch: 9992 train loss: 8.761823e-05 epoch time: 0.008 s
epoch: 9993 train loss: 8.7546505e-05 epoch time: 0.008 s
epoch: 9994 train loss: 8.7519744e-05 epoch time: 0.008 s
epoch: 9995 train loss: 8.753734e-05 epoch time: 0.008 s
epoch: 9996 train loss: 8.753101e-05 epoch time: 0.008 s
epoch: 9997 train loss: 8.748294e-05 epoch time: 0.008 s
epoch: 9998 train loss: 8.7443106e-05 epoch time: 0.008 s
epoch: 9999 train loss: 8.743979e-05 epoch time: 0.008 s
epoch: 10000 train loss: 8.744074e-05 epoch time: 0.008 s
epoch: 1 train loss: 61.754555 ic_loss: 8.7413886e-05 data_loss: 6.1754117 phy_loss: 2.6047118 valid_loss: 7.221066 truth_loss: 2.7125626 epoch time: 138.495 s
best loss 61.754555 save model
epoch: 2 train loss: 54.79151 ic_loss: 0.32984126 data_loss: 5.3142304 phy_loss: 52.50226 valid_loss: 6.812231 truth_loss: 2.7744124 epoch time: 1.342 s
best loss 54.79151 save model
epoch: 3 train loss: 46.904842 ic_loss: 0.12049961 data_loss: 4.6302347 phy_loss: 32.494545 valid_loss: 5.953037 truth_loss: 2.579268 epoch time: 1.262 s
best loss 46.904842 save model
epoch: 4 train loss: 40.674736 ic_loss: 0.031907484 data_loss: 4.05152 phy_loss: 11.360751 valid_loss: 5.08032 truth_loss: 2.3503494 epoch time: 1.233 s
best loss 40.674736 save model
epoch: 5 train loss: 36.910408 ic_loss: 0.10554239 data_loss: 3.6382694 phy_loss: 3.5776496 valid_loss: 4.4271708 truth_loss: 2.1671412 epoch time: 1.315 s
best loss 36.910408 save model
epoch: 6 train loss: 33.767193 ic_loss: 0.14396289 data_loss: 3.304738 phy_loss: 1.4308721 valid_loss: 3.954126 truth_loss: 2.0307255 epoch time: 1.322 s
best loss 33.767193 save model
epoch: 7 train loss: 30.495178 ic_loss: 0.09850004 data_loss: 3.0002677 phy_loss: 0.8241035 valid_loss: 3.586939 truth_loss: 1.9244627 epoch time: 1.178 s
best loss 30.495178 save model
epoch: 8 train loss: 27.448381 ic_loss: 0.03362463 data_loss: 2.728026 phy_loss: 0.6343211 valid_loss: 3.286183 truth_loss: 1.8369334 epoch time: 1.271 s
best loss 27.448381 save model
epoch: 9 train loss: 24.990078 ic_loss: 0.0024543565 data_loss: 2.4977806 phy_loss: 0.5740176 valid_loss: 3.0332325 truth_loss: 1.7619449 epoch time: 1.573 s
best loss 24.990078 save model
epoch: 10 train loss: 23.15583 ic_loss: 0.014634784 data_loss: 2.3082657 phy_loss: 0.5407104 valid_loss: 2.8156128 truth_loss: 1.6955423 epoch time: 1.351 s
best loss 23.15583 save model
...
epoch: 1640 train loss: 0.094870105 ic_loss: 0.0006555757 data_loss: 0.009159223 phy_loss: 0.000713372 valid_loss: 0.012182931 truth_loss: 0.16177362 epoch time: 1.289 s
best loss 0.094870105 save model
epoch: 1641 train loss: 0.09474868 ic_loss: 0.00065547746 data_loss: 0.00914713 phy_loss: 0.00071231654 valid_loss: 0.01216803 truth_loss: 0.16169967 epoch time: 1.259 s
best loss 0.09474868 save model
epoch: 1642 train loss: 0.09462735 ic_loss: 0.0006553787 data_loss: 0.009135046 phy_loss: 0.00071125705 valid_loss: 0.012153144 truth_loss: 0.16162594 epoch time: 1.310 s
best loss 0.09462735 save model
epoch: 1643 train loss: 0.094506115 ic_loss: 0.000655279 data_loss: 0.009122972 phy_loss: 0.00071020663 valid_loss: 0.01213827 truth_loss: 0.16155209 epoch time: 1.379 s
best loss 0.094506115 save model
epoch: 1644 train loss: 0.094384976 ic_loss: 0.0006551788 data_loss: 0.009110908 phy_loss: 0.0007091502 valid_loss: 0.012123411 truth_loss: 0.16147849 epoch time: 1.375 s
best loss 0.094384976 save model
epoch: 1645 train loss: 0.094263926 ic_loss: 0.0006550779 data_loss: 0.009098854 phy_loss: 0.0007081007 valid_loss: 0.012108564 truth_loss: 0.16140485 epoch time: 1.354 s
best loss 0.094263926 save model
epoch: 1646 train loss: 0.09414298 ic_loss: 0.0006549765 data_loss: 0.00908681 phy_loss: 0.00070705137 valid_loss: 0.012093734 truth_loss: 0.16133131 epoch time: 1.332 s
best loss 0.09414298 save model
epoch: 1647 train loss: 0.09402215 ic_loss: 0.0006548743 data_loss: 0.009074777 phy_loss: 0.0007060007 valid_loss: 0.012078916 truth_loss: 0.16125791 epoch time: 1.435 s
best loss 0.09402215 save model
epoch: 1648 train loss: 0.09390141 ic_loss: 0.0006547714 data_loss: 0.009062755 phy_loss: 0.00070495723 valid_loss: 0.012064112 truth_loss: 0.16118445 epoch time: 1.402 s
best loss 0.09390141 save model
epoch: 1649 train loss: 0.09378076 ic_loss: 0.00065466797 data_loss: 0.009050743 phy_loss: 0.0007039088 valid_loss: 0.012049323 truth_loss: 0.1611112 epoch time: 1.284 s
best loss 0.09378076 save model
epoch: 1650 train loss: 0.09366022 ic_loss: 0.00065456395 data_loss: 0.009038741 phy_loss: 0.00070286694 valid_loss: 0.0120345475 truth_loss: 0.16103792 epoch time: 1.502 s
best loss 0.09366022 save model
epoch: 1651 train loss: 0.093539774 ic_loss: 0.0006544591 data_loss: 0.009026748 phy_loss: 0.0007018241 valid_loss: 0.012019787 truth_loss: 0.16096477 epoch time: 1.274 s
best loss 0.093539774 save model
epoch: 1652 train loss: 0.093419425 ic_loss: 0.0006543536 data_loss: 0.009014766 phy_loss: 0.00070078264 valid_loss: 0.012005039 truth_loss: 0.16089168 epoch time: 1.456 s
best loss 0.093419425 save model
epoch: 1653 train loss: 0.09329918 ic_loss: 0.00065424765 data_loss: 0.0090027945 phy_loss: 0.00069974473 valid_loss: 0.011990305 truth_loss: 0.16081864 epoch time: 1.203 s
best loss 0.09329918 save model
epoch: 1654 train loss: 0.09317903 ic_loss: 0.00065414095 data_loss: 0.008990833 phy_loss: 0.000698706 valid_loss: 0.011975586 truth_loss: 0.16074573 epoch time: 1.285 s
best loss 0.09317903 save model
epoch: 1655 train loss: 0.09305898 ic_loss: 0.0006540336 data_loss: 0.008978881 phy_loss: 0.0006976697 valid_loss: 0.011960882 truth_loss: 0.16067289 epoch time: 1.159 s
best loss 0.09305898 save model
epoch: 1656 train loss: 0.092939034 ic_loss: 0.0006539258 data_loss: 0.00896694 phy_loss: 0.00069663546 valid_loss: 0.01194619 truth_loss: 0.16060013 epoch time: 1.274 s
best loss 0.092939034 save model
epoch: 1657 train loss: 0.092819184 ic_loss: 0.00065381714 data_loss: 0.00895501 phy_loss: 0.00069560105 valid_loss: 0.011931514 truth_loss: 0.16052744 epoch time: 1.174 s
best loss 0.092819184 save model
epoch: 1658 train loss: 0.09269943 ic_loss: 0.0006537079 data_loss: 0.008943089 phy_loss: 0.0006945693 valid_loss: 0.011916851 truth_loss: 0.1604548 epoch time: 1.296 s
best loss 0.09269943 save model
epoch: 1659 train loss: 0.092579775 ic_loss: 0.00065359805 data_loss: 0.008931179 phy_loss: 0.0006935386 valid_loss: 0.0119022 truth_loss: 0.16038223 epoch time: 1.426 s
best loss 0.092579775 save model
epoch: 1660 train loss: 0.09246021 ic_loss: 0.0006534874 data_loss: 0.008919277 phy_loss: 0.00069250836 valid_loss: 0.011887563 truth_loss: 0.16030973 epoch time: 1.389 s
best loss 0.09246021 save model
Model Evaluation and Visualization
After training, all data points in the flow field can be inferred. And related results can be visualized.