{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# PeRCNN求解2D burgers方程\n",
    "\n",
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/mindflow/zh_cn/data_mechanism_fusion/mindspore_percnn2d.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/mindflow/zh_cn/data_mechanism_fusion/mindspore_percnn2d.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.2/docs/mindflow/docs/source_zh_cn/data_mechanism_fusion/mindspore_percnn2d.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 概述\n",
    "\n",
    "近日,华为与中国人民大学孙浩教授团队合作,基于昇腾AI基础软硬件平台与昇思\n",
    "MindSpore AI框架提出了一种[物理编码递归卷积神经网络(Physics-encoded Recurrent Convolutional Neural Network, PeRCNN)](https://www.nature.com/articles/s42256-023-00685-7)。相较于物理信息神经网络、ConvLSTM、PDE-NET等方法,模型泛化性和抗噪性明显提升,长期推理精度提升了\n",
    "10倍以上,在航空航天、船舶制造、气象预报等领域拥有广阔的应用前景,目前该成果已在 nature machine intelligence 上发表。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 问题描述\n",
    "\n",
    "伯格斯方程(Burgers' equation)是一个模拟冲击波的传播和反射的非线性偏微分方程,被广泛应用于流体力学,非线性声学,气体动力学等领域,它以约翰内斯·马丁斯汉堡(1895-1981)的名字命名。本案例基于PeRCNN方法,求解二维有粘性情况下的Burgers方程。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 控制方程\n",
    "\n",
    "在本研究中,Burgers方程的形式为:\n",
    "\n",
    "$$\n",
    "u_{t} = \\nu \\Delta u - (uu_{x} + vu_{y}).\n",
    "$$\n",
    "\n",
    "$$\n",
    "v_{t} = \\nu \\Delta v - (uv_{x} + vv_{y}).\n",
    "$$\n",
    "\n",
    "其中,\n",
    "$\\nu = 0.005$\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 技术路径\n",
    "\n",
    "MindSpore Flow求解该问题的具体流程如下:\n",
    "\n",
    "1. 优化器\n",
    "2. 构建模型\n",
    "3. 模型训练\n",
    "4. 模型推理及可视化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import context, jit, nn, ops, save_checkpoint, set_seed\n",
    "import mindspore.common.dtype as mstype\n",
    "from mindflow.utils import load_yaml_config\n",
    "from src import RecurrentCNNCell, RecurrentCNNCellBurgers, Trainer, UpScaler, post_process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "set_seed(123456)\n",
    "np.random.seed(123456)\n",
    "context.set_context(mode=context.GRAPH_MODE, device_target=\"GPU\", device_id=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load configuration yaml\n",
    "config = load_yaml_config('./configs/data_driven_percnn_burgers.yaml')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 优化器和单步训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def train_stage(trainer, stage, pattern, config, ckpt_dir, use_ascend):\n",
    "    \"\"\"train stage\"\"\"\n",
    "    if use_ascend:\n",
    "        from mindspore.amp import DynamicLossScaler, all_finite\n",
    "        loss_scaler = DynamicLossScaler(2**10, 2, 100)\n",
    "\n",
    "    if 'milestone_num' in config.keys():\n",
    "        milestone = list([(config['epochs']//config['milestone_num'])*(i + 1)\n",
    "                          for i in range(config['milestone_num'])])\n",
    "        learning_rate = config['learning_rate']\n",
    "        lr = float(config['learning_rate'])*np.array(list([config['gamma']\n",
    "                                                           ** i for i in range(config['milestone_num'])]))\n",
    "        learning_rate = nn.piecewise_constant_lr(milestone, list(lr))\n",
    "    else:\n",
    "        learning_rate = config['learning_rate']\n",
    "\n",
    "    if stage == 'pretrain':\n",
    "        params = trainer.upconv.trainable_params()\n",
    "    else:\n",
    "        params = trainer.upconv.trainable_params() + trainer.recurrent_cnn.trainable_params()\n",
    "\n",
    "    optimizer = nn.Adam(params, learning_rate=learning_rate)\n",
    "\n",
    "    def forward_fn():\n",
    "        if stage == 'pretrain':\n",
    "            loss = trainer.get_ic_loss()\n",
    "        else:\n",
    "            loss = trainer.get_loss()\n",
    "        if use_ascend:\n",
    "            loss = loss_scaler.scale(loss)\n",
    "        return loss\n",
    "\n",
    "    if stage == 'pretrain':\n",
    "        grad_fn = ops.value_and_grad(forward_fn, None, params, has_aux=False)\n",
    "    else:\n",
    "        grad_fn = ops.value_and_grad(forward_fn, None, params, has_aux=True)\n",
    "\n",
    "    @jit\n",
    "    def train_step():\n",
    "        loss, grads = grad_fn()\n",
    "        if use_ascend:\n",
    "            loss = loss_scaler.unscale(loss)\n",
    "            is_finite = all_finite(grads)\n",
    "            if is_finite:\n",
    "                grads = loss_scaler.unscale(grads)\n",
    "                loss = ops.depend(loss, optimizer(grads))\n",
    "            loss_scaler.adjust(is_finite)\n",
    "        else:\n",
    "            loss = ops.depend(loss, optimizer(grads))\n",
    "        return loss\n",
    "\n",
    "    best_loss = 100000\n",
    "    for epoch in range(1, 1 + config['epochs']):\n",
    "        time_beg = time.time()\n",
    "        trainer.upconv.set_train(True)\n",
    "        trainer.recurrent_cnn.set_train(True)\n",
    "        if stage == 'pretrain':\n",
    "            step_train_loss = train_step()\n",
    "            print_log(\n",
    "                f\"epoch: {epoch} train loss: {step_train_loss} \\\n",
    "                    epoch time: {(time.time() - time_beg)*1000 :5.3f}ms \\\n",
    "                    step time: {(time.time() - time_beg)*1000 :5.3f}ms\")\n",
    "        else:\n",
    "            step_train_loss, loss_data, loss_ic, loss_phy, loss_valid = train_step()\n",
    "            print_log(f\"epoch: {epoch} train loss: {step_train_loss} ic_loss: {loss_ic} data_loss: {loss_data}\"\n",
    "                      f\"val_loss: {loss_valid} phy_loss: {loss_phy}\"\n",
    "                      f\"epoch time: {(time.time() - time_beg)*1000 :5.3f}ms\"\n",
    "                      f\"step time: {(time.time() - time_beg)*1000 :5.3f}ms\")\n",
    "            if step_train_loss < best_loss:\n",
    "                best_loss = step_train_loss\n",
    "                print_log('best loss', best_loss, 'save model')\n",
    "                save_checkpoint(trainer.upconv, os.path.join(ckpt_dir, f\"{pattern}_{config['name']}_upconv.ckpt\"))\n",
    "                save_checkpoint(trainer.recurrent_cnn,\n",
    "                                os.path.join(ckpt_dir, f\"{pattern}_{config['name']}_recurrent_cnn.ckpt\"))\n",
    "    if pattern == 'physics_driven':\n",
    "        trainer.recurrent_cnn.show_coef()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 构建模型\n",
    "\n",
    "PeRCNN要构建两个网络,一个是做上采样的UpSclaer,一个是作为主体的recurrent CNN。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def train():\n",
    "    \"\"\"train\"\"\"\n",
    "    burgers_config = config\n",
    "\n",
    "    use_ascend = context.get_context(attr_key='device_target') == \"Ascend\"\n",
    "    print_log(f\"use_ascend: {use_ascend}\")\n",
    "\n",
    "    if use_ascend:\n",
    "        compute_dtype = mstype.float16\n",
    "    else:\n",
    "        compute_dtype = mstype.float32\n",
    "\n",
    "    data_config = burgers_config['data']\n",
    "    optimizer_config = burgers_config['optimizer']\n",
    "    model_config = burgers_config['model']\n",
    "    summary_config = burgers_config['summary']\n",
    "\n",
    "    upconv = UpScaler(in_channels=model_config['in_channels'],\n",
    "                      out_channels=model_config['out_channels'],\n",
    "                      hidden_channels=model_config['upscaler_hidden_channels'],\n",
    "                      kernel_size=model_config['kernel_size'],\n",
    "                      stride=model_config['stride'],\n",
    "                      has_bais=True)\n",
    "\n",
    "    if use_ascend:\n",
    "        from mindspore.amp import auto_mixed_precision\n",
    "        auto_mixed_precision(upconv, 'O1')\n",
    "\n",
    "    pattern = data_config['pattern']\n",
    "    if pattern == 'data_driven':\n",
    "        recurrent_cnn = RecurrentCNNCell(input_channels=model_config['in_channels'],\n",
    "                                         hidden_channels=model_config['rcnn_hidden_channels'],\n",
    "                                         kernel_size=model_config['kernel_size'],\n",
    "                                         compute_dtype=compute_dtype)\n",
    "    else:\n",
    "        recurrent_cnn = RecurrentCNNCellBurgers(kernel_size=model_config['kernel_size'],\n",
    "                                                init_coef=model_config['init_coef'],\n",
    "                                                compute_dtype=compute_dtype)\n",
    "\n",
    "    percnn_trainer = Trainer(upconv=upconv,\n",
    "                             recurrent_cnn=recurrent_cnn,\n",
    "                             timesteps_for_train=data_config['rollout_steps'],\n",
    "                             dx=data_config['dx'],\n",
    "                             dt=data_config['dy'],\n",
    "                             nu=data_config['nu'],\n",
    "                             data_path=os.path.join(data_config['root_dir'], data_config['file_name']),\n",
    "                             compute_dtype=compute_dtype)\n",
    "\n",
    "    ckpt_dir = os.path.join(summary_config[\"root_dir\"], summary_config['ckpt_dir'])\n",
    "    if not os.path.exists(ckpt_dir):\n",
    "        os.makedirs(ckpt_dir)\n",
    "\n",
    "    train_stage(percnn_trainer, 'pretrain', pattern, optimizer_config['pretrain'], ckpt_dir, use_ascend)\n",
    "    train_stage(percnn_trainer, 'finetune', pattern, optimizer_config['finetune'], ckpt_dir, use_ascend)\n",
    "    post_process(percnn_trainer, pattern)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型训练\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "use_ascend: False\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 train loss: 1.5724593 epoch time: 0.867 s\n",
      "epoch: 2 train loss: 1.5299724 epoch time: 0.002 s\n",
      "epoch: 3 train loss: 1.4901378 epoch time: 0.002 s\n",
      "epoch: 4 train loss: 1.449844 epoch time: 0.002 s\n",
      "epoch: 5 train loss: 1.4070688 epoch time: 0.002 s\n",
      "epoch: 6 train loss: 1.3605155 epoch time: 0.002 s\n",
      "epoch: 7 train loss: 1.3093143 epoch time: 0.002 s\n",
      "epoch: 8 train loss: 1.253143 epoch time: 0.002 s\n",
      "epoch: 9 train loss: 1.1923409 epoch time: 0.002 s\n",
      "epoch: 10 train loss: 1.1278089 epoch time: 0.002 s\n",
      "...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 5991 train loss: 0.00017400463 epoch time: 0.001 s\n",
      "epoch: 5992 train loss: 0.00017378097 epoch time: 0.001 s\n",
      "epoch: 5993 train loss: 0.00017361519 epoch time: 0.001 s\n",
      "epoch: 5994 train loss: 0.00017362367 epoch time: 0.001 s\n",
      "epoch: 5995 train loss: 0.00017370074 epoch time: 0.001 s\n",
      "epoch: 5996 train loss: 0.00017368408 epoch time: 0.001 s\n",
      "epoch: 5997 train loss: 0.00017355102 epoch time: 0.001 s\n",
      "epoch: 5998 train loss: 0.00017341717 epoch time: 0.001 s\n",
      "epoch: 5999 train loss: 0.0001733772 epoch time: 0.001 s\n",
      "epoch: 6000 train loss: 0.00017340294 epoch time: 0.001 s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 train loss: 0.0040010856 ic_loss: 0.00017339904 data_loss: 0.0036542874 val_loss: 0.034989584 phy_loss: 385.93723 epoch time:  14.898 s\n",
      "best loss 0.0040010856 save model\n",
      "epoch: 2 train loss: 0.023029208 ic_loss: 0.0069433 data_loss: 0.009142607 val_loss: 0.035638213 phy_loss: 416.80725 epoch time:  0.247 s\n",
      "epoch: 3 train loss: 0.09626201 ic_loss: 0.030940203 data_loss: 0.0343816 val_loss: 0.05810566 phy_loss: 221.00093 epoch time:  0.162 s\n",
      "epoch: 4 train loss: 0.01788263 ic_loss: 0.0053461124 data_loss: 0.0071904045 val_loss: 0.03353381 phy_loss: 301.05966 epoch time:  0.147 s\n",
      "epoch: 5 train loss: 0.029557336 ic_loss: 0.0091625415 data_loss: 0.011232254 val_loss: 0.038305752 phy_loss: 449.9107 epoch time:  0.152 s\n",
      "epoch: 6 train loss: 0.052337468 ic_loss: 0.016626468 data_loss: 0.019084534 val_loss: 0.046096146 phy_loss: 497.9761 epoch time:  0.214 s\n",
      "epoch: 7 train loss: 0.014262615 ic_loss: 0.004195284 data_loss: 0.005872047 val_loss: 0.03377932 phy_loss: 430.3675 epoch time:  0.151 s\n",
      "epoch: 8 train loss: 0.00919872 ic_loss: 0.0025033113 data_loss: 0.0041920976 val_loss: 0.031886213 phy_loss: 344.02713 epoch time:  0.181 s\n",
      "epoch: 9 train loss: 0.032457784 ic_loss: 0.010022995 data_loss: 0.012411795 val_loss: 0.039276786 phy_loss: 301.3161 epoch time:  0.168 s\n",
      "epoch: 10 train loss: 0.027750801 ic_loss: 0.008489873 data_loss: 0.010771056 val_loss: 0.037965972 phy_loss: 310.4488 epoch time:  0.159 s\n",
      "...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 14991 train loss: 0.0012423343 ic_loss: 0.00041630908 data_loss: 0.00040971604 val_loss: 0.03190168 phy_loss: 394.9725 epoch time:  0.163 s\n",
      "best loss 0.0012423343 save model\n",
      "epoch: 14992 train loss: 0.0012423296 ic_loss: 0.0004163079 data_loss: 0.00040971374 val_loss: 0.0319017 phy_loss: 394.97614 epoch time:  0.158 s\n",
      "best loss 0.0012423296 save model\n",
      "epoch: 14993 train loss: 0.0012423252 ic_loss: 0.00041630593 data_loss: 0.00040971336 val_loss: 0.031901684 phy_loss: 394.97284 epoch time:  0.196 s\n",
      "best loss 0.0012423252 save model\n",
      "epoch: 14994 train loss: 0.0012423208 ic_loss: 0.00041630483 data_loss: 0.00040971107 val_loss: 0.0319017 phy_loss: 394.97568 epoch time:  0.173 s\n",
      "best loss 0.0012423208 save model\n",
      "epoch: 14995 train loss: 0.0012423162 ic_loss: 0.0004163029 data_loss: 0.00040971037 val_loss: 0.031901684 phy_loss: 394.97305 epoch time:  0.194 s\n",
      "best loss 0.0012423162 save model\n",
      "epoch: 14996 train loss: 0.0012423118 ic_loss: 0.00041630171 data_loss: 0.00040970836 val_loss: 0.031901695 phy_loss: 394.9754 epoch time:  0.175 s\n",
      "best loss 0.0012423118 save model\n",
      "epoch: 14997 train loss: 0.0012423072 ic_loss: 0.0004162999 data_loss: 0.00040970749 val_loss: 0.031901684 phy_loss: 394.97308 epoch time:  0.164 s\n",
      "best loss 0.0012423072 save model\n",
      "epoch: 14998 train loss: 0.0012423028 ic_loss: 0.00041629872 data_loss: 0.00040970545 val_loss: 0.0319017 phy_loss: 394.97534 epoch time:  0.135 s\n",
      "best loss 0.0012423028 save model\n",
      "epoch: 14999 train loss: 0.0012422984 ic_loss: 0.00041629683 data_loss: 0.00040970472 val_loss: 0.031901687 phy_loss: 394.97314 epoch time:  0.135 s\n",
      "best loss 0.0012422984 save model\n",
      "epoch: 15000 train loss: 0.0012422939 ic_loss: 0.00041629552 data_loss: 0.00040970283 val_loss: 0.0319017 phy_loss: 394.97556 epoch time:  0.153 s\n",
      "best loss 0.0012422939 save model\n"
     ]
    }
   ],
   "source": [
    "train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 模型推理及可视化\n",
    "\n",
    "完成训练后,下图展示了预测结果和真实标签的对比情况。\n",
    "![](./images/results.gif)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.14"
  },
  "vscode": {
   "interpreter": {
    "hash": "fd69f43f58546b570e94fd7eba7b65e6bcc7a5bbc4eab0408017d18902915d69"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}