Document feedback

Question document fragment

When a question document fragment contains a formula, it is displayed as a space.

Submission type
issue

It's a little complicated...

I'd like to ask someone.

PR

Just a small problem.

I can fix it online!

Please select the submission type

Problem type
Specifications and Common Mistakes

- Specifications and Common Mistakes:

- Misspellings or punctuation mistakes,incorrect formulas, abnormal display.

- Incorrect links, empty cells, or wrong formats.

- Chinese characters in English context.

- Minor inconsistencies between the UI and descriptions.

- Low writing fluency that does not affect understanding.

- Incorrect version numbers, including software package names and version numbers on the UI.

Usability

- Usability:

- Incorrect or missing key steps.

- Missing main function descriptions, keyword explanation, necessary prerequisites, or precautions.

- Ambiguous descriptions, unclear reference, or contradictory context.

- Unclear logic, such as missing classifications, items, and steps.

Correctness

- Correctness:

- Technical principles, function descriptions, supported platforms, parameter types, or exceptions inconsistent with that of software implementation.

- Incorrect schematic or architecture diagrams.

- Incorrect commands or command parameters.

- Incorrect code.

- Commands inconsistent with the functions.

- Wrong screenshots.

- Sample code running error, or running results inconsistent with the expectation.

Risk Warnings

- Risk Warnings:

- Lack of risk warnings for operations that may damage the system or important data.

Content Compliance

- Content Compliance:

- Contents that may violate applicable laws and regulations or geo-cultural context-sensitive words and expressions.

- Copyright infringement.

Please select the type of question

Problem description

Describe the bug so that we can quickly locate the problem.

Neural Representation Method for Boltzmann Equation

DownloadNotebookDownloadCodeView Source On Gitee

Problem Description

This case demonstrates how to solve the Boltzmann equation in 1+3 dimensions using neural networks. The Boltzmann equation is an equation for the gas density distribution function f, defined as

f(x,v,t)t+vxf(x,v,t)=Q[f](x,v,t),tR+,xR3,vR3,

where Q[f] is the collision term operator. With f, we can obtain the macroscopic physical variables of the gas

ρ(x,t)=fdv,m(x,t)ρu(x,t)=fvfdv,E(x,t)32ρT(x,t)+12ρ|u|2=12|v|2fdv,

For simplified models such as the BGK model, the collision term operator is defined as

QBGK[f]=1τ(Mf).

where

M=ρ2πT3exp(|vu|22T).

In this case, the Boltzmann-BGK equation with periodic boundary conditions in 1-dimensional physical space and 3-dimensional microscopic velocity space will be studied. Its initial value is

ρ(x)=1+0.5sin(2πx),u(x)=0,T(x)=1+0.5sin(2πx+0.2),

Method

MindSpore Flow solves the problem as follows:

  1. Creating the dataset.

  2. Creating the neural network.

  3. BoltzmannBGK

  4. Creating the optimizer.

  5. Model training.

  6. Model evaluation.

[1]:
import time
import numpy as np

import mindspore as ms
from mindspore import ops, nn
import mindspore.numpy as mnp

ms.set_context(mode=ms.context.GRAPH_MODE, device_target="GPU")
ms.set_seed(0)
[2]:
from mindflow.utils import load_yaml_config

from src.boltzmann import BoltzmannBGK, BGKKernel
from src.utils import get_vdis, visual, mesh_nd
from src.cells import SplitNet, MultiRes, Maxwellian, MtlLoss, JacFwd, RhoUTheta, PrimNorm
from src.dataset import Wave1DDataset
[3]:
config = load_yaml_config("WaveD1V3.yaml")

Dataset Construction

The computational region of our selection is [0.5,0.5]×[0,0.1] and we select 100 points in the initial value, 100 points in the boundary and 700 points in the interior of the region.

[4]:
class Wave1DDataset(nn.Cell):
    """dataset for 1D wave problem"""

    def __init__(self, config):
        super().__init__()
        self.config = config
        xmax = config["xtmesh"]["xmax"]
        xmin = config["xtmesh"]["xmin"]
        nv = config["vmesh"]["nv"]
        vmin = config["vmesh"]["vmin"]
        vmax = config["vmesh"]["vmax"]
        v, _ = mesh_nd(vmin, vmax, nv)
        self.xmax = xmax
        self.xmin = xmin
        self.vdis = ms.Tensor(v.astype(np.float32))
        self.maxwellian = Maxwellian(self.vdis)
        self.iv_points = self.config["dataset"]["iv_points"]
        self.bv_points = self.config["dataset"]["bv_points"]
        self.in_points = self.config["dataset"]["in_points"]
        self.uniform = ops.UniformReal(seed=0)

    def construct(self):
        # Initial value points
        iv_x = self.uniform((self.iv_points, 1)) * \
            (self.xmax - self.xmin) + self.xmin
        iv_t = mnp.zeros_like(iv_x)

        # boundary value points
        bv_x1 = -0.5 * mnp.ones(self.bv_points)[..., None]
        bv_t1 = self.uniform((self.bv_points, 1)) * 0.1
        bv_x2 = 0.5 * mnp.ones(self.bv_points)[..., None]
        bv_t2 = bv_t1

        # inner points
        in_x = self.uniform((self.in_points, 1)) - 0.5
        in_t = self.uniform((self.in_points, 1)) * 0.1

        return {
            "in": ops.concat([in_x, in_t], axis=-1),
            "iv": ops.concat([iv_x, iv_t], axis=-1),
            "bv1": ops.concat([bv_x1, bv_t1], axis=-1),
            "bv2": ops.concat([bv_x2, bv_t2], axis=-1),
        }


dataset = Wave1DDataset(config)

Model Construction

This case uses a neural network structure with 6 layers and 80 neurons per layer. Our network consists of two parts, one part of the network is in the form of Maxwellian equilibrium state and the other part is in the form of discrete velocity distribution, such a structure helps in training.

[5]:
class SplitNet(nn.Cell):
    """the network combined the maxwellian and non-maxwellian"""

    def __init__(self, in_channel, layers, neurons, vdis, alpha=0.01):
        super().__init__()
        self.net_eq = MultiRes(in_channel, 5, layers, neurons)
        self.net_neq = MultiRes(in_channel, vdis.shape[0], layers, neurons)
        self.maxwellian = Maxwellian(vdis)
        self.alpha = alpha

    def construct(self, xt):
        www = self.net_eq(xt)
        rho, u, theta = www[..., 0:1], www[..., 1:4], www[..., 4:5]
        rho = ops.exp(-rho)
        theta = ops.exp(-theta)
        x1 = self.maxwellian(rho, u, theta)
        x2 = self.net_neq(xt)
        y = x1 * (x1 + self.alpha * x2)
        return y


vdis, _ = get_vdis(config["vmesh"])
model = SplitNet(2, config["model"]["layers"],
                 config["model"]["neurons"], vdis)

BoltzmannBGK

[6]:
class BoltzmannBGK(nn.Cell):
    """The Boltzmann BGK model"""

    def __init__(self, net, kn, vconfig, iv_weight=100, bv_weight=100, pde_weight=10):
        super().__init__()
        self.net = net
        self.kn = kn

        vdis, wdis = get_vdis(vconfig)

        self.vdis = vdis
        loss_num = 3 * (vdis.shape[0] + 1 + 2 * vdis.shape[-1])
        self.mtl = MtlLoss(loss_num)
        self.jac = JacFwd(self.net)
        self.iv_weight = iv_weight
        self.bv_weight = bv_weight
        self.pde_weight = pde_weight
        self.maxwellian_nd = Maxwellian(vdis)
        self.rho_u_theta = RhoUTheta(vdis, wdis)
        self.criterion_norm = lambda x: ops.square(x).mean(axis=0)
        self.criterion = lambda x, y: ops.square(x - y).mean(axis=0)
        self.prim_norm = PrimNorm(vdis, wdis)
        self.collision = BGKKernel(
            vconfig["vmin"], vconfig["vmax"], vconfig["nv"])

    def governing_equation(self, inputs):
        f, fxft = self.jac(inputs)
        fx, ft = fxft[0], fxft[1]
        pde = ft + self.vdis[..., 0] * fx - self.collision(f, self.kn)
        return pde

    def boundary_condition(self, bv_points1, bv_points2):
        fl = self.net(bv_points1)
        fr = self.net(bv_points2)
        return fl - fr

    def initial_condition(self, inputs):
        iv_pred = self.net(inputs)
        iv_x = inputs[..., 0:1]
        rho_l = ops.sin(2 * np.pi * iv_x) * 0.5 + 1
        u_l = ops.zeros((iv_x.shape[0], 3), ms.float32)
        theta_l = ops.sin(2 * np.pi * iv_x + 0.2) * 0.5 + 1
        iv_truth = self.maxwellian_nd(rho_l, u_l, theta_l)
        return iv_pred - iv_truth

    def loss_fn(self, inputs):
        """the loss function"""
        return self.criterion_norm(inputs), self.prim_norm(inputs)

    def construct(self, domain_points, iv_points, bv_points1, bv_points2):
        """combined all loss function"""
        pde = self.governing_equation(domain_points)
        iv = self.initial_condition(iv_points)
        bv = self.boundary_condition(bv_points1, bv_points2)

        loss_pde = self.pde_weight * self.criterion_norm(pde)
        loss_pde2 = self.pde_weight * self.prim_norm(pde)

        loss_bv = self.bv_weight * self.criterion_norm(bv)
        loss_bv2 = self.bv_weight * self.prim_norm(bv)

        loss_iv = self.iv_weight * self.criterion_norm(iv)
        loss_iv2 = self.iv_weight * self.prim_norm(iv)

        loss_sum = self.mtl(
            ops.concat(
                [loss_iv, loss_iv2, loss_bv, loss_bv2, loss_pde, loss_pde2], axis=-1
            )
        )
        return loss_sum, (loss_iv, loss_iv2, loss_bv, loss_bv2, loss_pde, loss_pde2)


problem = BoltzmannBGK(model, config["kn"], config["vmesh"])

Optimizer

[7]:
cosine_decay_lr = nn.CosineDecayLR(
    config["optim"]["lr_scheduler"]["min_lr"],
    config["optim"]["lr_scheduler"]["max_lr"],
    config["optim"]["Adam_steps"],
)
optim = nn.Adam(params=problem.trainable_params(),
                learning_rate=cosine_decay_lr)

Training

[8]:
grad_fn = ops.value_and_grad(problem, None, optim.parameters, has_aux=True)


@ms.jit
def train_step(*inputs):
    loss, grads = grad_fn(*inputs)
    optim(grads)
    return loss


start_time = time.time()
for i in range(1, config["optim"]["Adam_steps"] + 1):
    time_beg = time.time()
    ds = dataset()
    loss, _ = train_step(*ds)
    if i % 500 == 0:
        e_sum = loss.mean().asnumpy().item()
        print(
            f"epoch: {i} loss: {e_sum:.3e} epoch time: {(time.time() - time_beg) * 1000 :.3f} ms"
        )
print("End-to-End total time: {} s".format(time.time() - start_time))
ms.save_checkpoint(problem, f"./model.ckpt")
epoch: 500 loss: 2.396e-01 epoch time: 194.953 ms
epoch: 1000 loss: 3.136e-02 epoch time: 193.463 ms
epoch: 1500 loss: 1.583e-03 epoch time: 191.280 ms
epoch: 2000 loss: 2.064e-04 epoch time: 191.099 ms
epoch: 2500 loss: 1.434e-04 epoch time: 190.477 ms
epoch: 3000 loss: 1.589e-04 epoch time: 190.489 ms
epoch: 3500 loss: 9.694e-05 epoch time: 190.476 ms
epoch: 4000 loss: 8.251e-05 epoch time: 191.893 ms
epoch: 4500 loss: 7.238e-05 epoch time: 190.835 ms
epoch: 5000 loss: 5.705e-05 epoch time: 190.611 ms
epoch: 5500 loss: 4.932e-05 epoch time: 190.530 ms
epoch: 6000 loss: 4.321e-05 epoch time: 190.756 ms
epoch: 6500 loss: 4.205e-05 epoch time: 191.470 ms
epoch: 7000 loss: 3.941e-05 epoch time: 190.781 ms
epoch: 7500 loss: 3.328e-05 epoch time: 190.543 ms
epoch: 8000 loss: 3.113e-05 epoch time: 190.786 ms
epoch: 8500 loss: 2.995e-05 epoch time: 190.864 ms
epoch: 9000 loss: 2.875e-05 epoch time: 190.796 ms
epoch: 9500 loss: 2.806e-05 epoch time: 188.351 ms
epoch: 10000 loss: 2.814e-05 epoch time: 189.244 ms
End-to-End total time: 2151.8691852092743 s

Visualization

[9]:
fig = visual(problem, config["visual_resolution"], "result.png")
fig.show()
../_images/physics_driven_boltzmannD1V3_16_0.png