{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "89a1fe73",
   "metadata": {},
   "source": [
    "# PeRCNN for 3D Reaction-Diffusion Equation\n",
    "\n",
    "[![DownloadNotebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_notebook_en.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/mindflow/en/data_mechanism_fusion/mindspore_percnn3d.ipynb) [![DownloadCode](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_download_code_en.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/mindflow/en/data_mechanism_fusion/mindspore_percnn3d.py) [![ViewSource](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.2/docs/mindflow/docs/source_en/data_mechanism_fusion/percnn3d.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8627dd49",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "PDE equations occupy an important position in the modeling of physical systems. But many underlying PDEs have not yet been fully explored in epidemiology, meteorological science, fluid mechanics, and biology. However, for those known PDE equations, such as Naiver-Stokes equations, the exact numerical calculation of these equations requires huge computing power, which hinders the application of numerical simulation in large-scale systems. Recently, advances in machine learning provide a new way for PDE solution and inversion.\n",
    "\n",
    "Recently, Huawei and Professor Sun Hao's team from Renmin University of China proposed Physics-encoded Recurrent Convolutional Neural Network, [PeRCNN](https://www.nature.com/articles/s42256-023-00685-7) based on Ascend platform and MindSpore. Compared with physical information neural network, ConvLSTM, PDE-NET and other methods, generalization and noise resistance of PeRCNN are significantly improved. The long-term prediction accuracy is improved by more than 10 times. This method has broad application prospects in aerospace, shipbuilding, weather forecasting and other fields. The results have been published in nature machine intelligence."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed20f19c",
   "metadata": {},
   "source": [
    "## Problem Description"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23b71838",
   "metadata": {},
   "source": [
    "Reaction-diffusion equation is a partial derivative equation that is of great significance and has been broadly used in a variety of disciplines such as physics, chemistry and biology."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4174fdee",
   "metadata": {},
   "source": [
    "## Governing Equation\n",
    "\n",
    "In this research, RD equation is formulated as follow:\n",
    "\n",
    "$$\n",
    "u_t = \\mu_u \\Delta u - u{v*2} + F(1-v)\n",
    "$$\n",
    "\n",
    "$$\n",
    "v_t = \\mu_v \\Delta v + u{v*2} + (F+\\kappa)v\n",
    "$$\n",
    "\n",
    "where,\n",
    "\n",
    "$$\n",
    "\\mu_v = 0.1, \\mu_u = 0.2, F = 0.025, \\kappa = 0.055\n",
    "$$\n",
    "\n",
    "In this case, we will simulate the flow dynamics in 100 time steps (dt=0.5s) in a $ \\Omega \\times \\tau = {[-50,50]}^3 \\times [0,500] $ physical domain. The initial condition of the problem would go through gaussian noise and periodic BC is adpoted."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd99f3dc",
   "metadata": {},
   "source": [
    "## Technology Path\n",
    "\n",
    "MindSpore Flow solves the problem as follows:\n",
    "\n",
    "1. Optimizer and One-step Training\n",
    "2. Model Construction\n",
    "3. Model training\n",
    "4. Model Evaluation and Visualization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "27fdf4af",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import time\n",
    "\n",
    "import numpy as np\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c65a682d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import context, jit, nn, ops, save_checkpoint, set_seed\n",
    "import mindspore.common.dtype as mstype\n",
    "from mindflow.utils import load_yaml_config, print_log\n",
    "from src import RecurrentCnn, post_process, Trainer, UpScaler, count_params\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2b0edc43",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(123456)\n",
    "np.random.seed(123456)\n",
    "context.set_context(mode=context.GRAPH_MODE, device_target=\"GPU\", device_id=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44f12211",
   "metadata": {},
   "source": [
    "## Optimizer and One-step Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fe6ae474",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_stage(trainer, stage, config, ckpt_dir, use_ascend):\n",
    "    \"\"\"train stage\"\"\"\n",
    "    if use_ascend:\n",
    "        from mindspore.amp import DynamicLossScaler, all_finite\n",
    "        loss_scaler = DynamicLossScaler(2**10, 2, 100)\n",
    "\n",
    "    if 'milestone_num' in config.keys():\n",
    "        milestone = list([(config['epochs']//config['milestone_num'])*(i + 1)\n",
    "                          for i in range(config['milestone_num'])])\n",
    "        learning_rate = config['learning_rate']\n",
    "        lr = float(config['learning_rate'])*np.array(list([config['gamma']\n",
    "                                                           ** i for i in range(config['milestone_num'])]))\n",
    "        learning_rate = nn.piecewise_constant_lr(milestone, list(lr))\n",
    "    else:\n",
    "        learning_rate = config['learning_rate']\n",
    "\n",
    "    if stage == 'pretrain':\n",
    "        params = trainer.upconv.trainable_params()\n",
    "    else:\n",
    "        params = trainer.upconv.trainable_params() + trainer.recurrent_cnn.trainable_params()\n",
    "\n",
    "    optimizer = nn.Adam(params, learning_rate=learning_rate)\n",
    "\n",
    "    def forward_fn():\n",
    "        if stage == 'pretrain':\n",
    "            loss = trainer.get_ic_loss()\n",
    "        else:\n",
    "            loss = trainer.get_loss()\n",
    "        if use_ascend:\n",
    "            loss = loss_scaler.scale(loss)\n",
    "        return loss\n",
    "\n",
    "    if stage == 'pretrain':\n",
    "        grad_fn = ops.value_and_grad(forward_fn, None, params, has_aux=False)\n",
    "    else:\n",
    "        grad_fn = ops.value_and_grad(forward_fn, None, params, has_aux=True)\n",
    "\n",
    "    @jit\n",
    "    def train_step():\n",
    "        loss, grads = grad_fn()\n",
    "        if use_ascend:\n",
    "            loss = loss_scaler.unscale(loss)\n",
    "            is_finite = all_finite(grads)\n",
    "            if is_finite:\n",
    "                grads = loss_scaler.unscale(grads)\n",
    "                loss = ops.depend(loss, optimizer(grads))\n",
    "            loss_scaler.adjust(is_finite)\n",
    "        else:\n",
    "            loss = ops.depend(loss, optimizer(grads))\n",
    "        return loss\n",
    "\n",
    "    best_loss = sys.maxsize\n",
    "    for epoch in range(1, 1 + config['epochs']):\n",
    "        time_beg = time.time()\n",
    "        trainer.upconv.set_train(True)\n",
    "        trainer.recurrent_cnn.set_train(True)\n",
    "        if stage == 'pretrain':\n",
    "            step_train_loss = train_step()\n",
    "            print_log(\n",
    "                f\"epoch: {epoch} train loss: {step_train_loss} epoch time: {(time.time() - time_beg) :.3f} s\")\n",
    "        else:\n",
    "            if epoch == 3800:\n",
    "                break\n",
    "            epoch_loss, loss_data, loss_ic, loss_phy, loss_valid = train_step()\n",
    "            print_log(f\"epoch: {epoch} train loss: {epoch_loss} ic_loss: {loss_ic} data_loss: {loss_data} \\\n",
    " phy_loss: {loss_phy}  valid_loss: {loss_valid} epoch time: {(time.time() - time_beg): .3f} s\")\n",
    "            if epoch_loss < best_loss:\n",
    "                best_loss = epoch_loss\n",
    "                print_log('best loss', best_loss, 'save model')\n",
    "                save_checkpoint(trainer.upconv, os.path.join(ckpt_dir, \"train_upconv.ckpt\"))\n",
    "                save_checkpoint(trainer.recurrent_cnn,\n",
    "                                os.path.join(ckpt_dir, \"train_recurrent_cnn.ckpt\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "933c4152",
   "metadata": {},
   "source": [
    "## Model Construction\n",
    "\n",
    "PeRCNN is composed of two networks which are UpSclaer for upscaling and recurrent CNN as a backbone."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d41734e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    \"\"\"train\"\"\"\n",
    "    rd_config = load_yaml_config('./configs/percnn_3d_rd.yaml')\n",
    "    data_config = rd_config['data']\n",
    "    optim_config = rd_config['optimizer']\n",
    "    summary_config = rd_config['summary']\n",
    "    model_config = rd_config['model']\n",
    "\n",
    "    use_ascend = context.get_context(attr_key='device_target') == \"Ascend\"\n",
    "    print_log(f\"use_ascend: {use_ascend}\")\n",
    "\n",
    "    if use_ascend:\n",
    "        compute_dtype = mstype.float16\n",
    "    else:\n",
    "        compute_dtype = mstype.float32\n",
    "\n",
    "    upconv_config = model_config['upconv']\n",
    "    upconv = UpScaler(in_channels=upconv_config['in_channel'],\n",
    "                      out_channels=upconv_config['out_channel'],\n",
    "                      hidden_channels=upconv_config['hidden_channel'],\n",
    "                      kernel_size=upconv_config['kernel_size'],\n",
    "                      stride=upconv_config['stride'],\n",
    "                      has_bais=True)\n",
    "\n",
    "    if use_ascend:\n",
    "        from mindspore.amp import auto_mixed_precision\n",
    "        auto_mixed_precision(upconv, 'O1')\n",
    "\n",
    "    rcnn_config = model_config['rcnn']\n",
    "    recurrent_cnn = RecurrentCnn(input_channels=rcnn_config['in_channel'],\n",
    "                                 hidden_channels=rcnn_config['hidden_channel'],\n",
    "                                 kernel_size=rcnn_config['kernel_size'],\n",
    "                                 stride=rcnn_config['stride'],\n",
    "                                 compute_dtype=compute_dtype)\n",
    "\n",
    "    percnn_trainer = Trainer(upconv=upconv,\n",
    "                             recurrent_cnn=recurrent_cnn,\n",
    "                             timesteps_for_train=data_config['rollout_steps'],\n",
    "                             dx=data_config['dx'],\n",
    "                             grid_size=data_config['grid_size'],\n",
    "                             dt=data_config['dt'],\n",
    "                             mu=data_config['mu'],\n",
    "                             data_path=data_config['data_path'],\n",
    "                             compute_dtype=compute_dtype)\n",
    "\n",
    "    total_params = int(count_params(upconv.trainable_params()) +\n",
    "                       count_params(recurrent_cnn.trainable_params()))\n",
    "    print(f\"There are {total_params} parameters\")\n",
    "\n",
    "    ckpt_dir = summary_config[\"ckpt_dir\"]\n",
    "    fig_path = summary_config[\"fig_save_path\"]\n",
    "    if not os.path.exists(ckpt_dir):\n",
    "        os.makedirs(ckpt_dir)\n",
    "\n",
    "    train_stage(percnn_trainer, 'pretrain',\n",
    "                optim_config['pretrain'], ckpt_dir, use_ascend)\n",
    "    train_stage(percnn_trainer, 'finetune',\n",
    "                optim_config['finetune'], ckpt_dir, use_ascend)\n",
    "\n",
    "    output = percnn_trainer.get_output(100).asnumpy()\n",
    "    output = np.transpose(output, (1, 0, 2, 3, 4))[:, :-1:10]\n",
    "\n",
    "    print('output shape is ', output.shape)\n",
    "    for i in range(0, 10, 2):\n",
    "        post_process(output[0, i], fig_path, is_u=True, num=i)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddb93ab1",
   "metadata": {},
   "source": [
    "## Model Training\n",
    "\n",
    "With **MindSpore version >= 2.0.0**, we can use the functional programming for training neural networks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8d79bc1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "use_ascend: False\n",
      "shape of uv is  (3001, 2, 48, 48, 48)\n",
      "shape of ic is  (1, 2, 48, 48, 48)\n",
      "shape of init_state_low is  (1, 2, 24, 24, 24)\n",
      "There are 10078 parameters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 train loss: 0.160835 epoch time: 5.545 s\n",
      "epoch: 2 train loss: 104.36749 epoch time: 0.010 s\n",
      "epoch: 3 train loss: 4.3207517 epoch time: 0.009 s\n",
      "epoch: 4 train loss: 8.491383 epoch time: 0.009 s\n",
      "epoch: 5 train loss: 23.683647 epoch time: 0.009 s\n",
      "epoch: 6 train loss: 23.857117 epoch time: 0.010 s\n",
      "epoch: 7 train loss: 16.037672 epoch time: 0.010 s\n",
      "epoch: 8 train loss: 8.406443 epoch time: 0.009 s\n",
      "epoch: 9 train loss: 3.527469 epoch time: 0.020 s\n",
      "epoch: 10 train loss: 1.0823832 epoch time: 0.009 s\n",
      "...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 9990 train loss: 8.7615306e-05 epoch time: 0.008 s\n",
      "epoch: 9991 train loss: 8.76504e-05 epoch time: 0.008 s\n",
      "epoch: 9992 train loss: 8.761823e-05 epoch time: 0.008 s\n",
      "epoch: 9993 train loss: 8.7546505e-05 epoch time: 0.008 s\n",
      "epoch: 9994 train loss: 8.7519744e-05 epoch time: 0.008 s\n",
      "epoch: 9995 train loss: 8.753734e-05 epoch time: 0.008 s\n",
      "epoch: 9996 train loss: 8.753101e-05 epoch time: 0.008 s\n",
      "epoch: 9997 train loss: 8.748294e-05 epoch time: 0.008 s\n",
      "epoch: 9998 train loss: 8.7443106e-05 epoch time: 0.008 s\n",
      "epoch: 9999 train loss: 8.743979e-05 epoch time: 0.008 s\n",
      "epoch: 10000 train loss: 8.744074e-05 epoch time: 0.008 s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 train loss: 61.754555 ic_loss: 8.7413886e-05 data_loss: 6.1754117  phy_loss: 2.6047118  valid_loss: 7.221066 truth_loss: 2.7125626 epoch time:  138.495 s\n",
      "best loss 61.754555 save model\n",
      "epoch: 2 train loss: 54.79151 ic_loss: 0.32984126 data_loss: 5.3142304  phy_loss: 52.50226  valid_loss: 6.812231 truth_loss: 2.7744124 epoch time:  1.342 s\n",
      "best loss 54.79151 save model\n",
      "epoch: 3 train loss: 46.904842 ic_loss: 0.12049961 data_loss: 4.6302347  phy_loss: 32.494545  valid_loss: 5.953037 truth_loss: 2.579268 epoch time:  1.262 s\n",
      "best loss 46.904842 save model\n",
      "epoch: 4 train loss: 40.674736 ic_loss: 0.031907484 data_loss: 4.05152  phy_loss: 11.360751  valid_loss: 5.08032 truth_loss: 2.3503494 epoch time:  1.233 s\n",
      "best loss 40.674736 save model\n",
      "epoch: 5 train loss: 36.910408 ic_loss: 0.10554239 data_loss: 3.6382694  phy_loss: 3.5776496  valid_loss: 4.4271708 truth_loss: 2.1671412 epoch time:  1.315 s\n",
      "best loss 36.910408 save model\n",
      "epoch: 6 train loss: 33.767193 ic_loss: 0.14396289 data_loss: 3.304738  phy_loss: 1.4308721  valid_loss: 3.954126 truth_loss: 2.0307255 epoch time:  1.322 s\n",
      "best loss 33.767193 save model\n",
      "epoch: 7 train loss: 30.495178 ic_loss: 0.09850004 data_loss: 3.0002677  phy_loss: 0.8241035  valid_loss: 3.586939 truth_loss: 1.9244627 epoch time:  1.178 s\n",
      "best loss 30.495178 save model\n",
      "epoch: 8 train loss: 27.448381 ic_loss: 0.03362463 data_loss: 2.728026  phy_loss: 0.6343211  valid_loss: 3.286183 truth_loss: 1.8369334 epoch time:  1.271 s\n",
      "best loss 27.448381 save model\n",
      "epoch: 9 train loss: 24.990078 ic_loss: 0.0024543565 data_loss: 2.4977806  phy_loss: 0.5740176  valid_loss: 3.0332325 truth_loss: 1.7619449 epoch time:  1.573 s\n",
      "best loss 24.990078 save model\n",
      "epoch: 10 train loss: 23.15583 ic_loss: 0.014634784 data_loss: 2.3082657  phy_loss: 0.5407104  valid_loss: 2.8156128 truth_loss: 1.6955423 epoch time:  1.351 s\n",
      "best loss 23.15583 save model\n",
      "...\n",
      "epoch: 1657 train loss: 0.092819184 ic_loss: 0.00065381714 data_loss: 0.00895501  phy_loss: 0.00069560105  valid_loss: 0.011931514 truth_loss: 0.16052744 epoch time:  1.174 s\n",
      "best loss 0.092819184 save model\n",
      "epoch: 1658 train loss: 0.09269943 ic_loss: 0.0006537079 data_loss: 0.008943089  phy_loss: 0.0006945693  valid_loss: 0.011916851 truth_loss: 0.1604548 epoch time:  1.296 s\n",
      "best loss 0.09269943 save model\n",
      "epoch: 1659 train loss: 0.092579775 ic_loss: 0.00065359805 data_loss: 0.008931179  phy_loss: 0.0006935386  valid_loss: 0.0119022 truth_loss: 0.16038223 epoch time:  1.426 s\n",
      "best loss 0.092579775 save model\n",
      "epoch: 1660 train loss: 0.09246021 ic_loss: 0.0006534874 data_loss: 0.008919277  phy_loss: 0.00069250836  valid_loss: 0.011887563 truth_loss: 0.16030973 epoch time:  1.389 s\n",
      "best loss 0.09246021 save model\n"
     ]
    }
   ],
   "source": [
    "train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a90b495f",
   "metadata": {},
   "source": [
    "## Model Evaluation and Visualization\n",
    "\n",
    "After training, all data points in the flow field can be inferred. And related results can be visualized."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bad8ea25",
   "metadata": {},
   "source": [
    "![](./images/result-percnn3d.jpg)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.14 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "vscode": {
   "interpreter": {
    "hash": "fd69f43f58546b570e94fd7eba7b65e6bcc7a5bbc4eab0408017d18902915d69"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}