PINNs for Point Source Poisson
Problem description
This example demonstrates how to use the PINNs method to solve the Poisson equation with a point source in two dimensions. The equation is defined by
where \((x_{src}, y_{src})\) is the coordinate corresponding to the point source position. he point source can be represented mathematically using the Dirac \(\delta\) function
When the solution domain is \(\Omega=[0,\pi]^2\), the analytical solution of this equation is
The corresponding paper for this case is: Xiang Huang, Hongsheng Liu, Beiji Shi, Zidong Wang, Kang Yang, Yang Li, Min Wang, Haotian Chu, Jing Zhou, Fan Yu, Bei Hua, Bin Dong, Lei Chen. “A Universal PINNs Method for Solving Partial Differential Equations with a Point Source”. Thirty-First International Joint Conference on Artificial Intelligence (IJCAI 2022), Vienna, Austria, Jul, 2022, Pages 3839-3846.
Method
MindSpore Flow solves the problem as follows:
Creating the dataset.
Creating the neural network.
PINNs’ loss.
Creating the optimizer.
Model training.
Model inference and visualization.
[1]:
import time
from mindspore import context, nn, ops, jit
from mindflow import load_yaml_config
from src.dataset import create_train_dataset, create_test_dataset
from src.poisson import Poisson
from src.utils import calculate_l2_error, visual
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
# Load config
file_cfg = "poisson_cfg.yaml"
config = load_yaml_config(file_cfg)
Creating the dataset
In this example, random sampling is performed in the solution domain, boundaries, and point source region (a rectangular area centered on the point source position) to generate the training dataset. See src/dataset.py for the implementation.
[2]:
# Create the dataset
ds_train = create_train_dataset(config)
Creating the neural network
This example uses a multiscale neural network combined with the sin activation function.
[3]:
from mindflow.cell import MultiScaleFCSequential
# Create the model
model = MultiScaleFCSequential(config['model']['in_channels'],
config['model']['out_channels'],
config['model']['layers'],
config['model']['neurons'],
residual=True,
act=config['model']['activation'],
num_scales=config['model']['num_scales'],
amp_factor=1.0,
scale_factor=2.0,
input_scale=[10., 10.],
)
PINNs’ loss
When using mindflow
to solve PDEs, we need to write a subclass of mindflow.PDEWithLloss
to define the loss function terms corresponding to the governing equation and boundary conditions (loss_pde
and loss_bc
, respectively). Since the point source region requires dense sampling points, we add an additional loss function term (loss_src
).
When the PINNs method uses the residual of the governing equation as a loss function term to constrain the neural network, the singularity of the Dirac delta function makes it impossible for neural network training to converge. Therefore, we use the probability density function of two-dimensional Laplace distribution to approximate the Dirac \(\delta\) function:
where \(\alpha\) is the kernel width. In theory, as long as the kernel width \(\alpha\) is small enough, the above probability density function can approximate the Dirac \(\delta\) function very well. However, in practice, the selection of kernel width \(\alpha\) has an important impact on the approximation effect. When \(\alpha\) is too large, the approximation error between probability density function \(\eta_{\alpha}(x, y)\) and Dirac \(\delta\) function will increase. But if \(\alpha\) is too small, the training process may not converge or the accuracy after convergence may be poor. Therefore, \(\alpha\) needs to be manually tuned. Here we determine it as \(\alpha=0.01\).
L2 loss is used for solution domain, boundaries and point source region. The MTLWeightedLoss
multi-objective loss function of mindflow
is used to combine these three loss function terms.
[4]:
import sympy
from mindspore import numpy as ms_np
from mindflow import PDEWithLoss, MTLWeightedLoss, sympy_to_mindspore
class Poisson(PDEWithLoss):
"""Define the loss of the Poisson equation."""
def __init__(self, model):
self.x, self.y = sympy.symbols("x y")
self.u = sympy.Function("u")(self.x, self.y)
self.in_vars = [self.x, self.y]
self.out_vars = [self.u,]
self.alpha = 0.01 # kernel width
super(Poisson, self).__init__(model, self.in_vars, self.out_vars)
self.bc_nodes = sympy_to_mindspore(self.bc(), self.in_vars, self.out_vars)
self.loss_fn = MTLWeightedLoss(num_losses=3)
def pde(self):
"""Define the gonvering equation."""
uu_xx = sympy.diff(self.u, (self.x, 2))
uu_yy = sympy.diff(self.u, (self.y, 2))
# Use Laplace probability density function to approximate the Dirac \delta function.
x_src = sympy.pi / 2
y_src = sympy.pi / 2
force_term = 0.25 / self.alpha**2 * sympy.exp(-(
sympy.Abs(self.x - x_src) + sympy.Abs(self.y - y_src)) / self.alpha)
poisson = uu_xx + uu_yy + force_term
equations = {"poisson": poisson}
return equations
def bc(self):
"""Define the boundary condition."""
bc_eq = self.u
equations = {"bc": bc_eq}
return equations
def get_loss(self, pde_data, bc_data, src_data):
"""Define the loss function."""
res_pde = self.parse_node(self.pde_nodes, inputs=pde_data)
res_bc = self.parse_node(self.bc_nodes, inputs=bc_data)
res_src = self.parse_node(self.pde_nodes, inputs=src_data)
loss_pde = ms_np.mean(ms_np.square(res_pde[0]))
loss_bc = ms_np.mean(ms_np.square(res_bc[0]))
loss_src = ms_np.mean(ms_np.square(res_src[0]))
return self.loss_fn((loss_pde, loss_bc, loss_src))
# Create the problem and optimizer
problem = Poisson(model)
Creating the optimizer
This example uses the Adam
optimizer and the learning rate decays to 1/10, 1/100, and 1/1000 of the initial learning rate when training reaches 40%, 60%, and 80%, respectively.
[5]:
n_epochs = 250
params = model.trainable_params() + problem.loss_fn.trainable_params()
steps_per_epoch = ds_train.get_dataset_size()
milestone = [int(steps_per_epoch * n_epochs * x) for x in [0.4, 0.6, 0.8]]
lr_init = config["optimizer"]["initial_lr"]
learning_rates = [lr_init * (0.1**x) for x in [0, 1, 2]]
lr_ = nn.piecewise_constant_lr(milestone, learning_rates)
optimizer = nn.Adam(params, learning_rate=lr_)
Model training
With MindSpore version >= 2.0.0, we can use the functional programming for training neural networks.
[6]:
def train():
grad_fn = ops.value_and_grad(problem.get_loss, None, optimizer.parameters, has_aux=False)
@jit
def train_step(pde_data, bc_data, src_data):
loss, grads = grad_fn(pde_data, bc_data, src_data)
if use_ascend:
loss = loss_scaler.unscale(loss)
is_finite = all_finite(grads)
if is_finite:
grads = loss_scaler.unscale(grads)
loss = ops.depend(loss, optimizer(grads))
loss_scaler.adjust(is_finite)
else:
loss = ops.depend(loss, optimizer(grads))
return loss
def train_epoch(model, dataset, i_epoch):
local_time_beg = time.time()
model.set_train()
for _, (pde_data, bc_data, src_data) in enumerate(dataset):
loss = train_step(pde_data, bc_data, src_data)
print(
f"epoch: {i_epoch} train loss: {float(loss):.8f}" +
f" epoch time: {time.time() - local_time_beg:.2f}s")
for i_epoch in range(1, 1 + n_epochs):
train_epoch(model, ds_train, i_epoch)
time_beg = time.time()
train()
print(f"End-to-End total time: {time.time() - time_beg:.1f} s")
Model inference and visualization
Calculate the relative L2 error and draw a comparison graph between the reference solution and the model prediction results.
[7]:
from src.utils import calculate_l2_error, visual
# Create the dataset
ds_test = create_test_dataset(config)
# Evaluate the model
calculate_l2_error(model, ds_test)
# Visual comparison of label and prediction
visual(model, ds_test)