Solve Navier-Stokes equation based on 3D Spectral Neural Operator

DownloadNotebookDownloadCodeView Source On Gitee

Overview

Computational fluid dynamics is one of the most important techniques in the field of fluid mechanics in the 21st century. The flow analysis, prediction and control can be realized by solving the governing equations of fluid mechanics by numerical method. Traditional finite element method (FEM) and finite difference method (FDM) are inefficient because of the complex simulation process (physical modeling, meshing, numerical discretization, iterative solution, etc.) and high computing costs. Therefore, it is necessary to improve the efficiency of fluid simulation with AI.

Machine learning methods provide a new paradigm for scientific computing by providing a fast solver similar to traditional methods. Classical neural networks learn mappings between finite dimensional spaces and can only learn solutions related to a specific discretization. Different from traditional neural networks, Fourier Neural Operator (FNO) is a new deep learning architecture that can learn mappings between infinite-dimensional function spaces. It directly learns mappings from arbitrary function parameters to solutions to solve a class of partial differential equations. Therefore, it has a stronger generalization capability. More information can be found in the paper, Fourier Neural Operator for Parametric Partial Differential Equations.

Spectral Neural Operator (SNO) is the FNO-like architecture using polynomial transformation to spectral space (Chebyshev, Legendre, etc.) instead of Fourier. SNO is characterized by smaller systematic bias caused by aliasing errors compared to FNO. One of the most important benefits of the architecture is that SNO supports multiple choice of basis, so it is possible to find a set of polynomials which is the most convenient for representation, e.g., respect symmetries of the problem or fit well into the solution interval. Besides, the neural operators, based on orthogonal polynomials, may be competitive with other spectral operators in case when the input function is defined on unstructured grid.

More information can be found in the paper, “Spectral Neural Operators”. arXiv preprint arXiv:2205.10573 (2022).

This tutorial describes how to solve the Navier-Stokes equation using 3D Spectral neural operator.

Problem Description

We aim to learn the operator mapping the vorticity at the first 10 time steps to the full trajectory [10, T]:

w|(0,1)2×[0,10]w|(0,1)2×[10,T]

Technology Path

MindFlow solves the problem as follows:

  1. Training Dataset Construction.

  2. Model Construction.

  3. Optimizer and Loss Function.

  4. Model Training.

Spectral Neural Operator

The following figure shows the architecture of the Spectral Neural Operator, which consists of encoder, multiple spectral convolution layers (linear transformation in space of coefficients in polynomial basis) and decoder. To compute forward and inverse polynomial transformation matrices for spectral convolutions, the input should be interpolated at the respective Gauss quadrature nodes (Chebyshev grid, etc.). The interpolated input is lifted to a higher dimension channel space by a convolutional Encoder layer. The result comes to the input of a sequence of spectral (SNO) layers, each of which applies a linear convolution to its truncated spectral representation. The output of SNO layers is projected back to the target dimension by a convolutional Decoder, and finally interpolated back to the original nodes.

The spectral (SNO) layer performs the following operations: applies the polynomial transformation A to spectral space (Chebyshev, Legendre, etc.); a linear convolution L on the lower polynomial modes and filters out the higher modes; then applies the inverse conversion S=A1 (back to the physical space). Then a linear convolution W of input is added, and nonlinear activation is applied.

Spectral Neural Operator model structure

[1]:
import os
import time
import numpy as np

from mindspore import nn, ops, jit, data_sink, context, Tensor
from mindspore.common import set_seed
from mindspore import dtype as mstype

The following src pacakage can be downloaded in applications/data_driven/navier_stokes/sno3d/src.

[2]:
from mindflow import get_warmup_cosine_annealing_lr, load_yaml_config
from mindflow.utils import print_log
from mindflow.cell import SNO3D, get_poly_transform

from src import calculate_l2_error, UnitGaussianNormalizer, create_training_dataset, load_interp_data, visual

set_seed(0)
np.random.seed(0)
[3]:
# set context for training: using graph mode for high performance training with GPU acceleration
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', device_id=0)
use_ascend = context.get_context(attr_key='device_target') == "Ascend"
config = load_yaml_config('./configs/sno3d.yaml')
data_params = config["data"]
model_params = config["model"]
optimizer_params = config["optimizer"]
summary_params = config["summary"]

grid_size = data_params["resolution"]
input_timestep = model_params["input_timestep"]
output_timestep = model_params["extrapolations"]

Training Dataset Construction

Download the training and test dataset: data_driven/navier_stokes/dataset .

In this case, training data sets and test data sets are generated according to Zongyi Li’s data set in Fourier Neural Operator for Parametric Partial Differential Equations . The settings are as follows:

The initial condition w0(x) is generated according to periodic boundary conditions:

w0μ,μ=N(0,73/2(Δ+49I)2.5)

The forcing function is defined as:

f(x)=0.1(sin(2π(x1+x2))+cos(2π(x1+x2)))

We use a time-step of 1e-4 for the Crank–Nicolson scheme in the data-generated process where we record the solution every t = 1 time units. All data are generated on a 256 × 256 grid and are downsampled to 64 × 64. In this case, the viscosity coefficient ν=1e4, the number of samples in the training set is 1000, and the number of samples in the test set is 200.

[4]:
load_interp_data(data_params, 'train')
train_loader = create_training_dataset(data_params, shuffle=True)
test_data = load_interp_data(data_params, 'test')
test_a = Tensor(test_data['a'], mstype.float32)
test_u = Tensor(test_data['u'], mstype.float32)

test_u_unif = np.load(os.path.join(data_params['root_dir'], 'test/test_u.npy'))

train_a = Tensor(np.load(os.path.join(
        data_params["root_dir"], "train/train_a_interp.npy")), mstype.float32)
train_u = Tensor(np.load(os.path.join(
        data_params["root_dir"], "train/train_u_interp.npy")), mstype.float32)
train a, u:  (1000, 10, 64, 64) (1000, 40, 64, 64)
test a, u:  (200, 10, 64, 64) (200, 40, 64, 64)

Model Construction

The network is composed of 1 encoding layer, multiple spectral layers and decoding block:

  • The encoding convolution corresponds to the SNO3D.encoder in the case, and maps the input data x to the high dimension;

  • A sequence of SNO layers corresponds to the SNO3D.sno_kernel in the case. Input matrices of polynomial transformations(forward and inverse conversions for each of three variables) are used to realize the transition between space-time domain and frequency domain;

  • The decoding layer corresponds to SNO3D.decoder and consists of two convolutions.The decoder is used to obtain the final prediction.

[5]:
n_modes = model_params['modes']
poly_type = data_params['poly_type']
transform_data = get_poly_transform(grid_size, n_modes, poly_type)
transform = Tensor(transform_data["analysis"], mstype.float32)
inv_transform = Tensor(transform_data["synthesis"], mstype.float32)

transform_t_axis = get_poly_transform(output_timestep, n_modes, poly_type)
transform_t = Tensor(transform_t_axis["analysis"], mstype.float32)
inv_transform_t = Tensor(transform_t_axis["synthesis"], mstype.float32)

transforms = [[transform, inv_transform]] * 2 + [[transform_t, inv_transform_t]]
[6]:
if use_ascend:
    compute_type = mstype.float16
else:
    compute_type = mstype.float32

# prepare model
model = SNO3D(in_channels=model_params['in_channels'],
              out_channels=model_params['out_channels'],
              hidden_channels=model_params['hidden_channels'],
              num_sno_layers=model_params['sno_layers'],
              transforms=transforms,
              kernel_size=model_params['kernel_size'],
              compute_dtype=compute_type)

model_params_list = []
for k, v in model_params.items():
    model_params_list.append(f"{k}-{v}")
model_name = "_".join(model_params_list)

total = 0
for param in model.get_parameters():
    print_log(param.shape)
    total += param.size
print_log(f"Total Parameters:{total}")
(64, 10, 1, 1, 1)
(64, 64, 7, 7, 7)
(64, 64, 1, 1, 1)
(64, 64, 7, 7, 7)
(64, 64, 1, 1, 1)
(64, 64, 7, 7, 7)
(64, 64, 1, 1, 1)
(64, 64, 7, 7, 7)
(64, 64, 1, 1, 1)
(64, 64, 7, 7, 7)
(64, 64, 1, 1, 1)
(64, 64, 1, 1, 1)
(1, 64, 1, 1, 1)
Total Parameters:7049920

Optimizer and Loss Function

[ ]:

lr = get_warmup_cosine_annealing_lr(lr_init=optimizer_params["learning_rate"], last_epoch=optimizer_params["epochs"], steps_per_epoch=train_loader.get_dataset_size(), warmup_epochs=optimizer_params["warmup_epochs"]) steps_per_epoch = train_loader.get_dataset_size() optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=Tensor(lr), eps=float(optimizer_params['eps']), weight_decay=optimizer_params['weight_decay']) loss_fn = nn.RMSELoss() #LpLoss() a_normalizer = UnitGaussianNormalizer(train_a) u_normalizer = UnitGaussianNormalizer(train_u)

Training Function

With MindSpore>= 2.0.0, you can train neural networks using functional programming paradigms, and single-step training functions are decorated with jit. The data_sink function is used to transfer the step-by-step training function and training dataset.

[8]:
def forward_fn(data, label):
    bs = data.shape[0]
    data = a_normalizer.encode(data)
    data = data.reshape(bs, input_timestep, grid_size, grid_size, 1).repeat(output_timestep, axis=-1)
    logits = model(data).reshape(bs, output_timestep, grid_size, grid_size)

    logits = u_normalizer.decode(logits)
    loss = loss_fn(logits.reshape(bs, -1), label.reshape(bs, -1))
    if use_ascend:
        loss = loss_scaler.scale(loss)
    return loss

grad_fn = ops.value_and_grad(
    forward_fn, None, optimizer.parameters, has_aux=False)

from mindspore.amp import DynamicLossScaler, auto_mixed_precision, all_finite
if use_ascend:
    loss_scaler = DynamicLossScaler(1024, 2, 100)
    auto_mixed_precision(model, 'O2')
else:
    loss_scaler = None

@jit
def train_step(data, label):
    loss, grads = grad_fn(data, label)
    if use_ascend:
        loss = loss_scaler.unscale(loss)
        is_finite = all_finite(grads)
        if is_finite:
            grads = loss_scaler.unscale(grads)
            loss = ops.depend(loss, optimizer(grads))
        loss_scaler.adjust(is_finite)
    else:
        loss = ops.depend(loss, optimizer(grads))
    return loss

sink_process = data_sink(train_step, train_loader, sink_size=200)

Model Training

With MindSpore version >= 2.0.0, we can use the functional programming for training neural networks.

[11]:
summary_dir = os.path.join(summary_params["root_dir"], model_name)
ckpt_dir = os.path.join(summary_dir, summary_params["ckpt_dir"])
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)
[9]:
def train():
    for epoch in range(1, 1 + optimizer_params["epochs"]):
        local_time_beg = time.time()
        model.set_train(True)
        cur_loss = sink_process()
        local_time_end = time.time()
        epoch_seconds = local_time_end - local_time_beg
        step_seconds = (epoch_seconds/200)
        print_log(
            f"epoch: {epoch} train loss: {cur_loss} epoch time: {epoch_seconds:.3f}s step time: {step_seconds:5.3f}ms")

        if epoch % summary_params['test_interval'] == 0:
            model.set_train(False)
            calculate_l2_error(model, test_a, test_u, data_params, a_normalizer, u_normalizer)
[11]:
train()
epoch: 1 train loss: 0.90118235 epoch time: 44.718s step time: 0.224s
epoch: 2 train loss: 0.91254395 epoch time: 40.240s step time: 0.201s
epoch: 3 train loss: 0.9374327 epoch time: 40.302s step time: 0.202s
epoch: 4 train loss: 0.85217404 epoch time: 40.482s step time: 0.202s
epoch: 5 train loss: 0.6309165 epoch time: 40.590s step time: 0.203s
epoch: 6 train loss: 0.4290015 epoch time: 40.576s step time: 0.203s
epoch: 7 train loss: 0.34428337 epoch time: 40.536s step time: 0.203s
epoch: 8 train loss: 0.34126174 epoch time: 40.564s step time: 0.203s
epoch: 9 train loss: 0.27420813 epoch time: 40.571s step time: 0.203s
epoch: 10 train loss: 0.2711888 epoch time: 40.554s step time: 0.203s
================================Start Evaluation================================
Error on Gauss grid: 0.28268588, on regular grid: 0.2779018060781111
predict total time: 26.014892578125 s
=================================End Evaluation=================================
epoch: 11 train loss: 0.2603902 epoch time: 40.542s step time: 0.203s
epoch: 12 train loss: 0.24578454 epoch time: 40.570s step time: 0.203s
epoch: 13 train loss: 0.23497193 epoch time: 40.543s step time: 0.203s
epoch: 14 train loss: 0.210803 epoch time: 40.543s step time: 0.203s
epoch: 15 train loss: 0.24416743 epoch time: 40.507s step time: 0.203s
epoch: 16 train loss: 0.2085956 epoch time: 40.520s step time: 0.203s
epoch: 17 train loss: 0.22456339 epoch time: 40.507s step time: 0.203s
epoch: 18 train loss: 0.20356481 epoch time: 40.494s step time: 0.202s
epoch: 19 train loss: 0.1977826 epoch time: 40.486s step time: 0.202s
epoch: 20 train loss: 0.21421571 epoch time: 40.487s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.21850872, on regular grid: 0.21519988102613533
predict total time: 24.13991665840149 s
=================================End Evaluation=================================
epoch: 21 train loss: 0.19105345 epoch time: 40.517s step time: 0.203s
epoch: 22 train loss: 0.1900783 epoch time: 40.514s step time: 0.203s
epoch: 23 train loss: 0.19938461 epoch time: 40.525s step time: 0.203s
epoch: 24 train loss: 0.17807631 epoch time: 40.475s step time: 0.202s
epoch: 25 train loss: 0.23215973 epoch time: 40.487s step time: 0.202s
epoch: 26 train loss: 0.16794981 epoch time: 40.480s step time: 0.202s
epoch: 27 train loss: 0.17212906 epoch time: 40.480s step time: 0.202s
epoch: 28 train loss: 0.18129097 epoch time: 40.507s step time: 0.203s
epoch: 29 train loss: 0.17482412 epoch time: 40.477s step time: 0.202s
epoch: 30 train loss: 0.16695607 epoch time: 40.476s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.21386118, on regular grid: 0.2106647958068735
predict total time: 24.241203784942627 s
=================================End Evaluation=================================
epoch: 31 train loss: 0.18177168 epoch time: 40.473s step time: 0.202s
epoch: 32 train loss: 0.21024966 epoch time: 40.516s step time: 0.203s
epoch: 33 train loss: 0.17173253 epoch time: 40.502s step time: 0.203s
epoch: 34 train loss: 0.16217099 epoch time: 40.476s step time: 0.202s
epoch: 35 train loss: 0.16301228 epoch time: 40.499s step time: 0.202s
epoch: 36 train loss: 0.18293448 epoch time: 40.498s step time: 0.202s
epoch: 37 train loss: 0.18147346 epoch time: 40.441s step time: 0.202s
epoch: 38 train loss: 0.16941778 epoch time: 40.447s step time: 0.202s
epoch: 39 train loss: 0.16393727 epoch time: 40.508s step time: 0.203s
epoch: 40 train loss: 0.14487892 epoch time: 40.456s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.17054155, on regular grid: 0.1669973841447345
predict total time: 24.262221574783325 s
=================================End Evaluation=================================
epoch: 41 train loss: 0.15602723 epoch time: 40.482s step time: 0.202s
epoch: 42 train loss: 0.1580032 epoch time: 40.502s step time: 0.203s
epoch: 43 train loss: 0.14684558 epoch time: 40.464s step time: 0.202s
epoch: 44 train loss: 0.1525133 epoch time: 40.450s step time: 0.202s
epoch: 45 train loss: 0.15542132 epoch time: 40.483s step time: 0.202s
epoch: 46 train loss: 0.14850396 epoch time: 40.461s step time: 0.202s
epoch: 47 train loss: 0.15148017 epoch time: 40.470s step time: 0.202s
epoch: 48 train loss: 0.1460498 epoch time: 40.457s step time: 0.202s
epoch: 49 train loss: 0.14232638 epoch time: 40.450s step time: 0.202s
epoch: 50 train loss: 0.14340377 epoch time: 40.448s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.15915471, on regular grid: 0.15595411313723867
predict total time: 21.78568935394287 s
=================================End Evaluation=================================
epoch: 51 train loss: 0.14372692 epoch time: 40.469s step time: 0.202s
epoch: 52 train loss: 0.14164849 epoch time: 40.487s step time: 0.202s
epoch: 53 train loss: 0.14629523 epoch time: 40.512s step time: 0.203s
epoch: 54 train loss: 0.1396117 epoch time: 40.464s step time: 0.202s
epoch: 55 train loss: 0.13634394 epoch time: 40.459s step time: 0.202s
epoch: 56 train loss: 0.13366798 epoch time: 40.463s step time: 0.202s
epoch: 57 train loss: 0.13632345 epoch time: 40.457s step time: 0.202s
epoch: 58 train loss: 0.13450852 epoch time: 40.474s step time: 0.202s
epoch: 59 train loss: 0.12455033 epoch time: 40.435s step time: 0.202s
epoch: 60 train loss: 0.1306016 epoch time: 40.483s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.15231597, on regular grid: 0.1492942070119114
predict total time: 23.996315956115723 s
=================================End Evaluation=================================
epoch: 61 train loss: 0.13203391 epoch time: 40.448s step time: 0.202s
epoch: 62 train loss: 0.13594246 epoch time: 40.470s step time: 0.202s
epoch: 63 train loss: 0.13565734 epoch time: 40.466s step time: 0.202s
epoch: 64 train loss: 0.12305962 epoch time: 40.435s step time: 0.202s
epoch: 65 train loss: 0.13006279 epoch time: 40.452s step time: 0.202s
epoch: 66 train loss: 0.12222704 epoch time: 40.474s step time: 0.202s
epoch: 67 train loss: 0.123683415 epoch time: 40.440s step time: 0.202s
epoch: 68 train loss: 0.120612934 epoch time: 40.453s step time: 0.202s
epoch: 69 train loss: 0.115140736 epoch time: 40.462s step time: 0.202s
epoch: 70 train loss: 0.12193731 epoch time: 40.420s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.14335541, on regular grid: 0.14026955797649515
predict total time: 23.51957106590271 s
=================================End Evaluation=================================
epoch: 71 train loss: 0.12505008 epoch time: 40.453s step time: 0.202s
epoch: 72 train loss: 0.12200938 epoch time: 40.484s step time: 0.202s
epoch: 73 train loss: 0.11936474 epoch time: 40.440s step time: 0.202s
epoch: 74 train loss: 0.12116067 epoch time: 40.480s step time: 0.202s
epoch: 75 train loss: 0.11600651 epoch time: 40.430s step time: 0.202s
epoch: 76 train loss: 0.11403544 epoch time: 40.447s step time: 0.202s
epoch: 77 train loss: 0.117489025 epoch time: 40.445s step time: 0.202s
epoch: 78 train loss: 0.10970513 epoch time: 40.479s step time: 0.202s
epoch: 79 train loss: 0.10635782 epoch time: 40.448s step time: 0.202s
epoch: 80 train loss: 0.11854948 epoch time: 40.459s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.13526691, on regular grid: 0.1325411629391487
predict total time: 22.508509397506714 s
=================================End Evaluation=================================
epoch: 81 train loss: 0.11022251 epoch time: 40.444s step time: 0.202s
epoch: 82 train loss: 0.108954415 epoch time: 40.509s step time: 0.203s
epoch: 83 train loss: 0.113180526 epoch time: 40.459s step time: 0.202s
epoch: 84 train loss: 0.106218904 epoch time: 40.431s step time: 0.202s
epoch: 85 train loss: 0.10933072 epoch time: 40.429s step time: 0.202s
epoch: 86 train loss: 0.10805362 epoch time: 40.442s step time: 0.202s
epoch: 87 train loss: 0.10749279 epoch time: 40.423s step time: 0.202s
epoch: 88 train loss: 0.112811126 epoch time: 40.471s step time: 0.202s
epoch: 89 train loss: 0.1098047 epoch time: 40.443s step time: 0.202s
epoch: 90 train loss: 0.110777 epoch time: 40.446s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.13124052, on regular grid: 0.1286052267835557
predict total time: 23.993195056915283 s
=================================End Evaluation=================================
epoch: 91 train loss: 0.10114018 epoch time: 40.435s step time: 0.202s
epoch: 92 train loss: 0.10804214 epoch time: 40.477s step time: 0.202s
epoch: 93 train loss: 0.103131406 epoch time: 40.461s step time: 0.202s
epoch: 94 train loss: 0.1079015 epoch time: 40.453s step time: 0.202s
epoch: 95 train loss: 0.10340427 epoch time: 40.445s step time: 0.202s
epoch: 96 train loss: 0.10799302 epoch time: 40.426s step time: 0.202s
epoch: 97 train loss: 0.1010814 epoch time: 40.448s step time: 0.202s
epoch: 98 train loss: 0.10470774 epoch time: 40.441s step time: 0.202s
epoch: 99 train loss: 0.105584204 epoch time: 40.422s step time: 0.202s
epoch: 100 train loss: 0.10661688 epoch time: 40.453s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.12790823, on regular grid: 0.12531590522074978
predict total time: 24.04036521911621 s
=================================End Evaluation=================================
epoch: 101 train loss: 0.09820426 epoch time: 40.432s step time: 0.202s
epoch: 102 train loss: 0.10459576 epoch time: 40.495s step time: 0.202s
epoch: 103 train loss: 0.100737445 epoch time: 40.475s step time: 0.202s
epoch: 104 train loss: 0.104481824 epoch time: 40.439s step time: 0.202s
epoch: 105 train loss: 0.10380473 epoch time: 40.432s step time: 0.202s
epoch: 106 train loss: 0.10476779 epoch time: 40.475s step time: 0.202s
epoch: 107 train loss: 0.1067871 epoch time: 40.444s step time: 0.202s
epoch: 108 train loss: 0.10912545 epoch time: 40.452s step time: 0.202s
epoch: 109 train loss: 0.09430095 epoch time: 40.442s step time: 0.202s
epoch: 110 train loss: 0.100769304 epoch time: 40.430s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.12614353, on regular grid: 0.12363198251396713
predict total time: 23.740464687347412 s
=================================End Evaluation=================================
epoch: 111 train loss: 0.12305746 epoch time: 40.437s step time: 0.202s
epoch: 112 train loss: 0.10381921 epoch time: 40.470s step time: 0.202s
epoch: 113 train loss: 0.107983574 epoch time: 40.429s step time: 0.202s
epoch: 114 train loss: 0.10639244 epoch time: 40.435s step time: 0.202s
epoch: 115 train loss: 0.098030716 epoch time: 40.430s step time: 0.202s
epoch: 116 train loss: 0.104712404 epoch time: 40.422s step time: 0.202s
epoch: 117 train loss: 0.10629137 epoch time: 40.435s step time: 0.202s
epoch: 118 train loss: 0.107867606 epoch time: 40.446s step time: 0.202s
epoch: 119 train loss: 0.11190843 epoch time: 40.422s step time: 0.202s
epoch: 120 train loss: 0.10280066 epoch time: 40.433s step time: 0.202s
================================Start Evaluation================================
Error on Gauss grid: 0.12565085, on regular grid: 0.12316060496613422
predict total time: 24.39193344116211 s
=================================End Evaluation=================================
[ ]:
visual(model, test_a, data_params, a_normalizer, u_normalizer)
[13]:
from IPython.display import Image, display
with open('images/input.gif', 'rb') as f:
    display(Image(data=f.read(), format='png', embed=True))
../_images/data_driven_navier_stokes_SNO3D_24_0.png
[4]:
with open('images/result.gif', 'rb') as f:
    display(Image(data=f.read(), format='png', embed=True))
../_images/data_driven_navier_stokes_SNO3D_25_0.png