PeRCNN for 3D Reaction-Diffusion Equation

DownloadNotebookDownloadCodeViewSource

Overview

PDE equations occupy an important position in the modeling of physical systems. But many underlying PDEs have not yet been fully explored in epidemiology, meteorological science, fluid mechanics, and biology. However, for those known PDE equations, such as Naiver-Stokes equations, the exact numerical calculation of these equations requires huge computing power, which hinders the application of numerical simulation in large-scale systems. Recently, advances in machine learning provide a new way for PDE solution and inversion.

Recently, Huawei and Professor Sun Hao’s team from Renmin University of China proposed Physics-encoded Recurrent Convolutional Neural Network, PeRCNN based on Ascend platform and MindSpore. Compared with physical information neural network, ConvLSTM, PDE-NET and other methods, generalization and noise resistance of PeRCNN are significantly improved. The long-term prediction accuracy is improved by more than 10 times. This method has broad application prospects in aerospace, shipbuilding, weather forecasting and other fields. The results have been published in nature machine intelligence.

Problem Description

Reaction-diffusion equation is a partial derivative equation that is of great significance and has been broadly used in a variety of disciplines such as physics, chemistry and biology.

Governing Equation

In this research, RD equation is formulated as follow:

\[u_t = \mu_u \Delta u - u{v*2} + F(1-v)\]
\[v_t = \mu_v \Delta v + u{v*2} + (F+\kappa)v\]

where,

\[\mu_v = 0.1, \mu_u = 0.2, F = 0.025, \kappa = 0.055\]

In this case, we will simulate the flow dynamics in 100 time steps (dt=0.5s) in a $ \Omega `:nbsphinx-math:times :nbsphinx-math:tau = {[-50,50]}^3 :nbsphinx-math:times [0,500] `$ physical domain. The initial condition of the problem would go through gaussian noise and periodic BC is adpoted.

Technology Path

MindSpore Flow solves the problem as follows:

  1. Optimizer and One-step Training

  2. Model Construction

  3. Model training

  4. Model Evaluation and Visualization.

[1]:
import os
import sys
import time

import numpy as np

[2]:
from mindspore import context, jit, nn, ops, save_checkpoint, set_seed
import mindspore.common.dtype as mstype
from mindflow.utils import load_yaml_config, print_log
from src import RecurrentCnn, post_process, Trainer, UpScaler, count_params

[3]:
set_seed(123456)
np.random.seed(123456)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=0)

Optimizer and One-step Training

[4]:
def train_stage(trainer, stage, config, ckpt_dir, use_ascend):
    """train stage"""
    if use_ascend:
        from mindspore.amp import DynamicLossScaler, all_finite
        loss_scaler = DynamicLossScaler(2**10, 2, 100)

    if 'milestone_num' in config.keys():
        milestone = list([(config['epochs']//config['milestone_num'])*(i + 1)
                          for i in range(config['milestone_num'])])
        learning_rate = config['learning_rate']
        lr = float(config['learning_rate'])*np.array(list([config['gamma']
                                                           ** i for i in range(config['milestone_num'])]))
        learning_rate = nn.piecewise_constant_lr(milestone, list(lr))
    else:
        learning_rate = config['learning_rate']

    if stage == 'pretrain':
        params = trainer.upconv.trainable_params()
    else:
        params = trainer.upconv.trainable_params() + trainer.recurrent_cnn.trainable_params()

    optimizer = nn.Adam(params, learning_rate=learning_rate)

    def forward_fn():
        if stage == 'pretrain':
            loss = trainer.get_ic_loss()
        else:
            loss = trainer.get_loss()
        if use_ascend:
            loss = loss_scaler.scale(loss)
        return loss

    if stage == 'pretrain':
        grad_fn = ops.value_and_grad(forward_fn, None, params, has_aux=False)
    else:
        grad_fn = ops.value_and_grad(forward_fn, None, params, has_aux=True)

    @jit
    def train_step():
        loss, grads = grad_fn()
        if use_ascend:
            loss = loss_scaler.unscale(loss)
            is_finite = all_finite(grads)
            if is_finite:
                grads = loss_scaler.unscale(grads)
                loss = ops.depend(loss, optimizer(grads))
            loss_scaler.adjust(is_finite)
        else:
            loss = ops.depend(loss, optimizer(grads))
        return loss

    best_loss = sys.maxsize
    for epoch in range(1, 1 + config['epochs']):
        time_beg = time.time()
        trainer.upconv.set_train(True)
        trainer.recurrent_cnn.set_train(True)
        if stage == 'pretrain':
            step_train_loss = train_step()
            print_log(
                f"epoch: {epoch} train loss: {step_train_loss} epoch time: {(time.time() - time_beg) :.3f} s")
        else:
            if epoch == 3800:
                break
            epoch_loss, loss_data, loss_ic, loss_phy, loss_valid = train_step()
            print_log(f"epoch: {epoch} train loss: {epoch_loss} ic_loss: {loss_ic} data_loss: {loss_data} \
 phy_loss: {loss_phy}  valid_loss: {loss_valid} epoch time: {(time.time() - time_beg): .3f} s")
            if epoch_loss < best_loss:
                best_loss = epoch_loss
                print_log('best loss', best_loss, 'save model')
                save_checkpoint(trainer.upconv, os.path.join(ckpt_dir, "train_upconv.ckpt"))
                save_checkpoint(trainer.recurrent_cnn,
                                os.path.join(ckpt_dir, "train_recurrent_cnn.ckpt"))

Model Construction

PeRCNN is composed of two networks which are UpSclaer for upscaling and recurrent CNN as a backbone.

[5]:
def train():
    """train"""
    rd_config = load_yaml_config('./configs/percnn_3d_rd.yaml')
    data_config = rd_config['data']
    optim_config = rd_config['optimizer']
    summary_config = rd_config['summary']
    model_config = rd_config['model']

    use_ascend = context.get_context(attr_key='device_target') == "Ascend"
    print_log(f"use_ascend: {use_ascend}")

    if use_ascend:
        compute_dtype = mstype.float16
    else:
        compute_dtype = mstype.float32

    upconv_config = model_config['upconv']
    upconv = UpScaler(in_channels=upconv_config['in_channel'],
                      out_channels=upconv_config['out_channel'],
                      hidden_channels=upconv_config['hidden_channel'],
                      kernel_size=upconv_config['kernel_size'],
                      stride=upconv_config['stride'],
                      has_bais=True)

    if use_ascend:
        from mindspore.amp import auto_mixed_precision
        auto_mixed_precision(upconv, 'O1')

    rcnn_config = model_config['rcnn']
    recurrent_cnn = RecurrentCnn(input_channels=rcnn_config['in_channel'],
                                 hidden_channels=rcnn_config['hidden_channel'],
                                 kernel_size=rcnn_config['kernel_size'],
                                 stride=rcnn_config['stride'],
                                 compute_dtype=compute_dtype)

    percnn_trainer = Trainer(upconv=upconv,
                             recurrent_cnn=recurrent_cnn,
                             timesteps_for_train=data_config['rollout_steps'],
                             dx=data_config['dx'],
                             grid_size=data_config['grid_size'],
                             dt=data_config['dt'],
                             mu=data_config['mu'],
                             data_path=data_config['data_path'],
                             compute_dtype=compute_dtype)

    total_params = int(count_params(upconv.trainable_params()) +
                       count_params(recurrent_cnn.trainable_params()))
    print(f"There are {total_params} parameters")

    ckpt_dir = summary_config["ckpt_dir"]
    fig_path = summary_config["fig_save_path"]
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    train_stage(percnn_trainer, 'pretrain',
                optim_config['pretrain'], ckpt_dir, use_ascend)
    train_stage(percnn_trainer, 'finetune',
                optim_config['finetune'], ckpt_dir, use_ascend)

    output = percnn_trainer.get_output(100).asnumpy()
    output = np.transpose(output, (1, 0, 2, 3, 4))[:, :-1:10]

    print('output shape is ', output.shape)
    for i in range(0, 10, 2):
        post_process(output[0, i], fig_path, is_u=True, num=i)

Model Training

With MindSpore version >= 2.0.0, we can use the functional programming for training neural networks.

[ ]:
train()
use_ascend: False
shape of uv is  (3001, 2, 48, 48, 48)
shape of ic is  (1, 2, 48, 48, 48)
shape of init_state_low is  (1, 2, 24, 24, 24)
There are 10078 parameters
epoch: 1 train loss: 0.160835 epoch time: 5.545 s
epoch: 2 train loss: 104.36749 epoch time: 0.010 s
epoch: 3 train loss: 4.3207517 epoch time: 0.009 s
epoch: 4 train loss: 8.491383 epoch time: 0.009 s
epoch: 5 train loss: 23.683647 epoch time: 0.009 s
epoch: 6 train loss: 23.857117 epoch time: 0.010 s
epoch: 7 train loss: 16.037672 epoch time: 0.010 s
epoch: 8 train loss: 8.406443 epoch time: 0.009 s
epoch: 9 train loss: 3.527469 epoch time: 0.020 s
epoch: 10 train loss: 1.0823832 epoch time: 0.009 s
...
epoch: 9990 train loss: 8.7615306e-05 epoch time: 0.008 s
epoch: 9991 train loss: 8.76504e-05 epoch time: 0.008 s
epoch: 9992 train loss: 8.761823e-05 epoch time: 0.008 s
epoch: 9993 train loss: 8.7546505e-05 epoch time: 0.008 s
epoch: 9994 train loss: 8.7519744e-05 epoch time: 0.008 s
epoch: 9995 train loss: 8.753734e-05 epoch time: 0.008 s
epoch: 9996 train loss: 8.753101e-05 epoch time: 0.008 s
epoch: 9997 train loss: 8.748294e-05 epoch time: 0.008 s
epoch: 9998 train loss: 8.7443106e-05 epoch time: 0.008 s
epoch: 9999 train loss: 8.743979e-05 epoch time: 0.008 s
epoch: 10000 train loss: 8.744074e-05 epoch time: 0.008 s
epoch: 1 train loss: 61.754555 ic_loss: 8.7413886e-05 data_loss: 6.1754117  phy_loss: 2.6047118  valid_loss: 7.221066 truth_loss: 2.7125626 epoch time:  138.495 s
best loss 61.754555 save model
epoch: 2 train loss: 54.79151 ic_loss: 0.32984126 data_loss: 5.3142304  phy_loss: 52.50226  valid_loss: 6.812231 truth_loss: 2.7744124 epoch time:  1.342 s
best loss 54.79151 save model
epoch: 3 train loss: 46.904842 ic_loss: 0.12049961 data_loss: 4.6302347  phy_loss: 32.494545  valid_loss: 5.953037 truth_loss: 2.579268 epoch time:  1.262 s
best loss 46.904842 save model
epoch: 4 train loss: 40.674736 ic_loss: 0.031907484 data_loss: 4.05152  phy_loss: 11.360751  valid_loss: 5.08032 truth_loss: 2.3503494 epoch time:  1.233 s
best loss 40.674736 save model
epoch: 5 train loss: 36.910408 ic_loss: 0.10554239 data_loss: 3.6382694  phy_loss: 3.5776496  valid_loss: 4.4271708 truth_loss: 2.1671412 epoch time:  1.315 s
best loss 36.910408 save model
epoch: 6 train loss: 33.767193 ic_loss: 0.14396289 data_loss: 3.304738  phy_loss: 1.4308721  valid_loss: 3.954126 truth_loss: 2.0307255 epoch time:  1.322 s
best loss 33.767193 save model
epoch: 7 train loss: 30.495178 ic_loss: 0.09850004 data_loss: 3.0002677  phy_loss: 0.8241035  valid_loss: 3.586939 truth_loss: 1.9244627 epoch time:  1.178 s
best loss 30.495178 save model
epoch: 8 train loss: 27.448381 ic_loss: 0.03362463 data_loss: 2.728026  phy_loss: 0.6343211  valid_loss: 3.286183 truth_loss: 1.8369334 epoch time:  1.271 s
best loss 27.448381 save model
epoch: 9 train loss: 24.990078 ic_loss: 0.0024543565 data_loss: 2.4977806  phy_loss: 0.5740176  valid_loss: 3.0332325 truth_loss: 1.7619449 epoch time:  1.573 s
best loss 24.990078 save model
epoch: 10 train loss: 23.15583 ic_loss: 0.014634784 data_loss: 2.3082657  phy_loss: 0.5407104  valid_loss: 2.8156128 truth_loss: 1.6955423 epoch time:  1.351 s
best loss 23.15583 save model
...
epoch: 1657 train loss: 0.092819184 ic_loss: 0.00065381714 data_loss: 0.00895501  phy_loss: 0.00069560105  valid_loss: 0.011931514 truth_loss: 0.16052744 epoch time:  1.174 s
best loss 0.092819184 save model
epoch: 1658 train loss: 0.09269943 ic_loss: 0.0006537079 data_loss: 0.008943089  phy_loss: 0.0006945693  valid_loss: 0.011916851 truth_loss: 0.1604548 epoch time:  1.296 s
best loss 0.09269943 save model
epoch: 1659 train loss: 0.092579775 ic_loss: 0.00065359805 data_loss: 0.008931179  phy_loss: 0.0006935386  valid_loss: 0.0119022 truth_loss: 0.16038223 epoch time:  1.426 s
best loss 0.092579775 save model
epoch: 1660 train loss: 0.09246021 ic_loss: 0.0006534874 data_loss: 0.008919277  phy_loss: 0.00069250836  valid_loss: 0.011887563 truth_loss: 0.16030973 epoch time:  1.389 s
best loss 0.09246021 save model

Model Evaluation and Visualization

After training, all data points in the flow field can be inferred. And related results can be visualized.

image0