Document feedback

Question document fragment

When a question document fragment contains a formula, it is displayed as a space.

Submission type
issue

It's a little complicated...

I'd like to ask someone.

Please select the submission type

Problem type
Specifications and Common Mistakes

- Specifications and Common Mistakes:

- Misspellings or punctuation mistakes,incorrect formulas, abnormal display.

- Incorrect links, empty cells, or wrong formats.

- Chinese characters in English context.

- Minor inconsistencies between the UI and descriptions.

- Low writing fluency that does not affect understanding.

- Incorrect version numbers, including software package names and version numbers on the UI.

Usability

- Usability:

- Incorrect or missing key steps.

- Missing main function descriptions, keyword explanation, necessary prerequisites, or precautions.

- Ambiguous descriptions, unclear reference, or contradictory context.

- Unclear logic, such as missing classifications, items, and steps.

Correctness

- Correctness:

- Technical principles, function descriptions, supported platforms, parameter types, or exceptions inconsistent with that of software implementation.

- Incorrect schematic or architecture diagrams.

- Incorrect commands or command parameters.

- Incorrect code.

- Commands inconsistent with the functions.

- Wrong screenshots.

- Sample code running error, or running results inconsistent with the expectation.

Risk Warnings

- Risk Warnings:

- Lack of risk warnings for operations that may damage the system or important data.

Content Compliance

- Content Compliance:

- Contents that may violate applicable laws and regulations or geo-cultural context-sensitive words and expressions.

- Copyright infringement.

Please select the type of question

Problem description

Describe the bug so that we can quickly locate the problem.

FNO for 1D Burgers

DownloadNotebookDownloadCodeViewSource

Overview

Computational fluid dynamics is one of the most important techniques in the field of fluid mechanics in the 21st century. The flow analysis, prediction and control can be realized by solving the governing equations of fluid mechanics by numerical method. Traditional finite element method (FEM) and finite difference method (FDM) are inefficient because of the complex simulation process (physical modeling, meshing, numerical discretization, iterative solution, etc.) and high computing costs. Therefore, it is necessary to improve the efficiency of fluid simulation with AI.

Machine learning methods provide a new paradigm for scientific computing by providing a fast solver similar to traditional methods. Classical neural networks learn mappings between finite dimensional spaces and can only learn solutions related to a specific discretization. Different from traditional neural networks, Fourier Neural Operator (FNO) is a new deep learning architecture that can learn mappings between infinite-dimensional function spaces. It directly learns mappings from arbitrary function parameters to solutions to solve a class of partial differential equations. Therefore, it has a stronger generalization capability. More information can be found in the paper, Fourier Neural Operator for Parametric Partial Differential Equations.

This tutorial describes how to solve the 1-d Burgers’ equation using Fourier neural operator.

Burgers’ equation

The 1-d Burgers’ equation is a non-linear PDE with various applications including modeling the one dimensional flow of a viscous fluid. It takes the form

tu(x,t)+x(u2(x,t)/2)=νxxu(x,t),x(0,1),t(0,1]
u(x,0)=u0(x),x(0,1)

where u is the velocity field, u0 is the initial condition and ν is the viscosity coefficient.

Problem Description

We aim to learn the operator mapping the initial condition to the solution at time one:

u0u(,1)

Technology Path

MindSpore Flow solves the problem as follows:

  1. Training Dataset Construction.

  2. Model Construction.

  3. Optimizer and Loss Function.

  4. Model Training.

Fourier Neural Operator

The Fourier Neural Operator consists of the Lifting Layer, Fourier Layers, and the Decoding Layer.

Fourier Neural Operator model structure

Fourier layers: Start from input V. On top: apply the Fourier transform F; a linear transform R on the lower Fourier modes and filters out the higher modes; then apply the inverse Fourier transform F1. On the bottom: apply a local linear transform W. Finally, the Fourier Layer output vector is obtained through the activation function.

Fourier Layer structure

[1]:
import os
import time
import numpy as np
import mindspore as ms

from mindspore.amp import DynamicLossScaler, auto_mixed_precision, all_finite
from mindspore import nn, Tensor, set_seed, ops, data_sink, jit, save_checkpoint
from mindspore import dtype as mstype
from mindflow import FNO1D, RelativeRMSELoss, load_yaml_config, get_warmup_cosine_annealing_lr
from mindflow.pde import UnsteadyFlowWithLoss

The following src pacakage can be downloaded in applications/data_driven/burgers_fno/src.

[2]:
from src.dataset import create_training_dataset

set_seed(0)
np.random.seed(0)

ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU", device_id=5)
use_ascend = ms.get_context(attr_key='device_target') == "Ascend"

You can get parameters of model, data and optimizer from config.

[3]:
config = load_yaml_config('burgers1d.yaml')
data_params = config["data"]
model_params = config["model"]
optimizer_params = config["optimizer"]

Training Dataset Construction

Download the training and test dataset: data_driven/burgers/dataset .

In this case, training datasets and test datasets are generated according to Zongyi Li’s dataset in Fourier Neural Operator for Parametric Partial Differential Equations . The settings are as follows:

the initial condition u0(x) is generated according to periodic boundary conditions:

u0μ,μ=N(0,625(Δ+25I)2)

We set the viscosity to ν=0.1 and solve the equation using a split step method where the heat equation part is solved exactly in Fourier space then the non-linear part is advanced, again in Fourier space, using a very fine forward Euler method. The number of samples in the training set is 1000, and the number of samples in the test set is 200.

[4]:
# create training dataset
train_dataset = create_training_dataset(data_params, shuffle=True)

# create test dataset
test_input, test_label = np.load(os.path.join(data_params["path"], "test/inputs.npy")), \
                         np.load(os.path.join(data_params["path"], "test/label.npy"))
test_input = Tensor(np.expand_dims(test_input, -2), mstype.float32)
test_label = Tensor(np.expand_dims(test_label, -2), mstype.float32)
Data preparation finished
input_path:  (1000, 1024, 1)
label_path:  (1000, 1024)

Model Construction

The network is composed of 1 lifting layer, multiple Fourier layers and 1 decoding layer:

  • The Lifting layer corresponds to the FNO1D.fc0 in the case, and maps the output data x to the high dimension;

  • Multi-layer Fourier Layer corresponds to the FNO1D.fno_seq in the case. Discrete Fourier transform is used to realize the conversion between time domain and frequency domain;

  • The Decoding layer corresponds to FNO1D.fc1 and FNO1D.fc2 in the case to obtain the final predictive value.

The initialization of the model based on the network above, parameters can be modified in configuration file.

[5]:
model = FNO1D(in_channels=model_params["in_channels"],
              out_channels=model_params["out_channels"],
              resolution=model_params["resolution"],
              modes=model_params["modes"],
              channels=model_params["width"],
              depths=model_params["depth"])

model_params_list = []
for k, v in model_params.items():
    model_params_list.append(f"{k}:{v}")
model_name = "_".join(model_params_list)
print(model_name)
name:FNO1D_in_channels:1_out_channels:1_resolution:1024_modes:16_width:64_depth:4

Optimizer and Loss Function

[6]:
steps_per_epoch = train_dataset.get_dataset_size()
lr = get_warmup_cosine_annealing_lr(lr_init=optimizer_params["initial_lr"],
                                    last_epoch=optimizer_params["train_epochs"],
                                    steps_per_epoch=steps_per_epoch,
                                    warmup_epochs=1)
optimizer = nn.Adam(model.trainable_params(), learning_rate=Tensor(lr))

if use_ascend:
    loss_scaler = DynamicLossScaler(1024, 2, 100)
    auto_mixed_precision(model, 'O1')
else:
    loss_scaler = None

Model Training

With MindSpore version >= 2.0.0, we can use the functional programming for training neural networks. MindSpore Flow provide a training interface for unsteady problems UnsteadyFlowWithLoss for model training and evaluation.

[7]:
problem = UnsteadyFlowWithLoss(model, loss_fn=RelativeRMSELoss(), data_format="NHWTC")

summary_dir = os.path.join(config["summary_dir"], model_name)
print(summary_dir)

def forward_fn(data, label):
    loss = problem.get_loss(data, label)
    return loss

grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)

@jit
def train_step(data, label):
    loss, grads = grad_fn(data, label)
    if use_ascend:
        loss = loss_scaler.unscale(loss)
        if all_finite(grads):
            grads = loss_scaler.unscale(grads)
    loss = ops.depend(loss, optimizer(grads))
    return loss

sink_process = data_sink(train_step, train_dataset, 1)
summary_dir = os.path.join(config["summary_dir"], model_name)

for epoch in range(1, config["epochs"] + 1):
    model.set_train()
    local_time_beg = time.time()
    for _ in range(steps_per_epoch):
        cur_loss = sink_process()
    print("epoch: {}, time elapsed: {}ms, loss: {}".format(epoch, (time.time() - local_time_beg) * 1000, cur_loss.asnumpy()))

    if epoch % config['eval_interval'] == 0:
        model.set_train(False)
        print("================================Start Evaluation================================")
        rms_error = problem.get_loss(test_input, test_label)/test_input.shape[0]
        print("mean rms_error:", rms_error)
        print("=================================End Evaluation=================================")
        ckpt_dir = os.path.join(summary_dir, "ckpt")
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)
        save_checkpoint(model, os.path.join(ckpt_dir, model_params["name"] + '_epoch' + str(epoch)))
./summary/name:FNO1D_in_channels:1_out_channels:1_resolution:1024_modes:16_width:64_depth:4
epoch: 1, time elapsed: 21747.305870056152ms, loss: 2.167046070098877
epoch: 2, time elapsed: 5525.397539138794ms, loss: 0.5935954451560974
epoch: 3, time elapsed: 5459.984540939331ms, loss: 0.7349425554275513
epoch: 4, time elapsed: 4948.82869720459ms, loss: 0.6338694095611572
epoch: 5, time elapsed: 5571.3865756988525ms, loss: 0.3174982964992523
epoch: 6, time elapsed: 5712.041616439819ms, loss: 0.3099440038204193
epoch: 7, time elapsed: 5218.639135360718ms, loss: 0.3117891848087311
epoch: 8, time elapsed: 4819.460153579712ms, loss: 0.1810857653617859
epoch: 9, time elapsed: 4968.810081481934ms, loss: 0.1386510729789734
epoch: 10, time elapsed: 4849.36785697937ms, loss: 0.2102256715297699
================================Start Evaluation================================
mean rms_error: 0.027940063
=================================End Evaluation=================================
...
epoch: 91, time elapsed: 4398.104429244995ms, loss: 0.019643772393465042
epoch: 92, time elapsed: 5479.56109046936ms, loss: 0.0641067773103714
epoch: 93, time elapsed: 5549.5476722717285ms, loss: 0.02199840545654297
epoch: 94, time elapsed: 6238.730907440186ms, loss: 0.024467874318361282
epoch: 95, time elapsed: 5434.457778930664ms, loss: 0.025712188333272934
epoch: 96, time elapsed: 6481.106281280518ms, loss: 0.02247200347483158
epoch: 97, time elapsed: 6303.435325622559ms, loss: 0.026637140661478043
epoch: 98, time elapsed: 5162.56856918335ms, loss: 0.030040305107831955
epoch: 99, time elapsed: 5364.72225189209ms, loss: 0.02589748054742813
epoch: 100, time elapsed: 5902.378797531128ms, loss: 0.028599221259355545
================================Start Evaluation================================
mean rms_error: 0.0037017763
=================================End Evaluation=================================