2D & 3D Poisson
This notebook requires MindSpore version >= 2.0.0 to support new APIs including: mindspore.jit, mindspore.jit_class, mindspore.jacrev.
Problem Description
This example demonstrates how to use PINNs to solve 2D and 3D Poisson equations in different geometries. The 2D equation is defined by
and the 3D equation is given by
It is easy to verify that the following functions satisfy the Poisson equation in two and three dimensions, respectively.
If we set the Dirichlet boundary condition using these functions, they become the solution to the problems. In this example, we solve the 2D equation in rectangle, disk, triangle and pentagon. In the 3D case, we solve the problem in a tetrahedron, cylinder and cone.
Method
MindSpore Flow solves the problem as follows:
Dataset Construction.
Model Construction.
Poisson.
Optimizer.
Model Training.
Model Evaluation.
The poisson_cfg.yaml
file can be downloaded at applications/physics_driven/poisson/point_source/poisson_cfg.yaml, and the src
package can be downloaded at applications/physics_driven/poisson/point_source/src.
[1]:
import time
import mindspore as ms
from mindspore import nn, ops, jit
from mindflow import load_yaml_config
from src.model import create_model
from src.lr_scheduler import OneCycleLR
from src.dataset import create_dataset
ms.set_context(mode=ms.GRAPH_MODE, save_graphs=False, device_target="GPU")
# Load config
file_cfg = "poisson_cfg.yaml"
config = load_yaml_config(file_cfg)
Dataset Construction
This example creates the dataset by ramdom sampling in the domain and on the boundaries (See src/dataset.py
for the implementation). Set geom_name
to select geometry, which can be ‘triangle’, ‘pentagon’, ‘tetrahedron’, ‘cylinder’, and ‘cone’.
[2]:
geom_name = "triangle"
ds_train, n_dim = create_dataset(geom_name, config)
Model Construction
This example adopts a MLP with 3 hidden layers with the following features:
Using \(f(x) = x \exp(-x^2/(2e))\) as the activation function.
Using weight normalization for the last layer.
Employing the
HeUniform
frommindspore
for the initialization of all weights.
See src/model.py
for the details.
[3]:
model = create_model(**config['model'][f'{n_dim}d'])
Poisson
When applying mindflow
to solving PDEs, we need to write a subclass of mindflow.PDEWithLloss
to define the governing equation, boundary condition and loss function. In this example, we adopt the L2 loss in the domain and on the boundaries. Both losses are combined using the multitarget loss defined by MTLWeightedLossCell
.
[4]:
import sympy
from mindspore import numpy as ms_np
from mindflow import PDEWithLoss, MTLWeightedLossCell, sympy_to_mindspore
class Poisson(PDEWithLoss):
"""Define the loss of the Poisson equation."""
def __init__(self, model, n_dim):
if n_dim == 2:
var_str = 'x y'
elif n_dim == 3:
var_str = 'x y z'
else:
raise ValueError("`n_dim` can only be 2 or 3.")
self.in_vars = sympy.symbols(var_str)
self.out_vars = (sympy.Function('u')(*self.in_vars),)
super(Poisson, self).__init__(model, self.in_vars, self.out_vars)
self.bc_nodes = sympy_to_mindspore(self.bc(n_dim), self.in_vars, self.out_vars)
self.loss_fn = MTLWeightedLossCell(num_losses=2)
def pde(self):
"""Define the gonvering equation."""
poisson = 0
src_term = 1
sym_u = self.out_vars[0]
for var in self.in_vars:
poisson += sympy.diff(sym_u, (var, 2))
src_term *= sympy.sin(4*sympy.pi*var)
poisson += src_term
equations = {"poisson": poisson}
return equations
def bc(self, n_dim):
"""Define the boundary condition."""
bc_term = 1
for var in self.in_vars:
bc_term *= sympy.sin(4*sympy.pi*var)
bc_term *= 1/(16*n_dim*sympy.pi*sympy.pi)
bc_eq = self.out_vars[0] - bc_term
equations = {"bc": bc_eq}
return equations
def get_loss(self, pde_data, bc_data):
"""Define the loss function."""
res_pde = self.parse_node(self.pde_nodes, inputs=pde_data)
res_bc = self.parse_node(self.bc_nodes, inputs=bc_data)
loss_pde = ms_np.mean(ms_np.square(res_pde[0]))
loss_bc = ms_np.mean(ms_np.square(res_bc[0]))
return self.loss_fn((loss_pde, loss_bc))
# Create the problem
problem = Poisson(model, n_dim)
poisson: sin(4*pi*x)*sin(4*pi*y) + Derivative(u(x, y), (x, 2)) + Derivative(u(x, y), (y, 2))
Item numbers of current derivative formula nodes: 3
bc: u(x, y) - sin(4*pi*x)*sin(4*pi*y)/(32*pi**2)
Item numbers of current derivative formula nodes: 2
Optimizer
This example applies the Adam optimizer, and adopts the dynamic learning rate proposed by Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates. See src/lr_scheduler.py
for the implementation of the dynamic learning rate.
[5]:
n_epochs = 50
params = model.trainable_params() + problem.loss_fn.trainable_params()
steps_per_epoch = config['data']['domain']['size']//config['batch_size']
learning_rate = OneCycleLR(total_steps=steps_per_epoch*n_epochs, **config['optimizer'])
opt = nn.Adam(params, learning_rate=learning_rate)
Model Training
With MindSpore version >= 2.0.0, we can use the functional programming for training neural networks.
[6]:
def train():
# Create
grad_fn = ms.value_and_grad(problem.get_loss, None, opt.parameters, has_aux=False)
@jit
def train_step(pde_data, bc_data):
loss, grads = grad_fn(pde_data, bc_data)
loss = ops.depend(loss, opt(grads))
return loss
def train_epoch(model, dataset, i_epoch):
n_step = dataset.get_dataset_size()
model.set_train()
for i_step, (pde_data, bc_data) in enumerate(dataset):
local_time_beg = time.time()
loss = train_step(pde_data, bc_data)
if i_step%50 == 0 or i_step + 1 == n_step:
print("\repoch: {}, loss: {:>f}, time elapsed: {:.1f}ms [{}/{}]".format(
i_epoch, float(loss), (time.time() - local_time_beg)*1000, i_step + 1, n_step))
for i_epoch in range(n_epochs):
train_epoch(model, ds_train, i_epoch)
[7]:
time_beg = time.time()
train()
print("End-to-End total time: {} s".format(time.time() - time_beg))
epoch: 0, loss: 1.527029, time elapsed: 12050.1ms [1/200]
epoch: 0, loss: 1.468655, time elapsed: 52.4ms [51/200]
epoch: 0, loss: 1.442717, time elapsed: 52.3ms [101/200]
epoch: 0, loss: 1.430150, time elapsed: 52.4ms [151/200]
epoch: 0, loss: 1.420228, time elapsed: 53.4ms [200/200]
epoch: 1, loss: 1.419910, time elapsed: 53.0ms [1/200]
epoch: 1, loss: 1.407040, time elapsed: 52.5ms [51/200]
epoch: 1, loss: 1.386505, time elapsed: 52.4ms [101/200]
epoch: 1, loss: 1.362307, time elapsed: 52.4ms [151/200]
epoch: 1, loss: 1.349054, time elapsed: 52.5ms [200/200]
epoch: 2, loss: 1.349143, time elapsed: 53.7ms [1/200]
epoch: 2, loss: 1.336657, time elapsed: 52.7ms [51/200]
epoch: 2, loss: 1.323158, time elapsed: 52.6ms [101/200]
epoch: 2, loss: 1.307419, time elapsed: 52.9ms [151/200]
epoch: 2, loss: 1.289993, time elapsed: 52.7ms [200/200]
epoch: 3, loss: 1.289594, time elapsed: 53.5ms [1/200]
epoch: 3, loss: 1.270476, time elapsed: 52.4ms [51/200]
epoch: 3, loss: 1.246817, time elapsed: 52.6ms [101/200]
epoch: 3, loss: 1.222093, time elapsed: 52.6ms [151/200]
epoch: 3, loss: 1.194862, time elapsed: 52.3ms [200/200]
epoch: 4, loss: 1.194533, time elapsed: 52.5ms [1/200]
epoch: 4, loss: 1.164445, time elapsed: 52.6ms [51/200]
epoch: 4, loss: 1.134136, time elapsed: 52.5ms [101/200]
epoch: 4, loss: 1.100014, time elapsed: 52.6ms [151/200]
epoch: 4, loss: 1.064941, time elapsed: 52.4ms [200/200]
...
epoch: 45, loss: 0.001281, time elapsed: 53.0ms [1/200]
epoch: 45, loss: 0.001264, time elapsed: 52.6ms [51/200]
epoch: 45, loss: 0.001263, time elapsed: 52.5ms [101/200]
epoch: 45, loss: 0.001236, time elapsed: 52.6ms [151/200]
epoch: 45, loss: 0.001237, time elapsed: 52.5ms [200/200]
epoch: 46, loss: 0.001218, time elapsed: 52.7ms [1/200]
epoch: 46, loss: 0.001209, time elapsed: 52.6ms [51/200]
epoch: 46, loss: 0.001191, time elapsed: 52.6ms [101/200]
epoch: 46, loss: 0.001202, time elapsed: 52.7ms [151/200]
epoch: 46, loss: 0.001182, time elapsed: 52.9ms [200/200]
epoch: 47, loss: 0.001174, time elapsed: 53.0ms [1/200]
epoch: 47, loss: 0.001186, time elapsed: 52.7ms [51/200]
epoch: 47, loss: 0.001182, time elapsed: 52.6ms [101/200]
epoch: 47, loss: 0.001169, time elapsed: 52.8ms [151/200]
epoch: 47, loss: 0.001172, time elapsed: 52.7ms [200/200]
epoch: 48, loss: 0.001165, time elapsed: 52.7ms [1/200]
epoch: 48, loss: 0.001168, time elapsed: 52.6ms [51/200]
epoch: 48, loss: 0.001148, time elapsed: 52.5ms [101/200]
epoch: 48, loss: 0.001159, time elapsed: 52.7ms [151/200]
epoch: 48, loss: 0.001171, time elapsed: 52.5ms [200/200]
epoch: 49, loss: 0.001156, time elapsed: 52.7ms [1/200]
epoch: 49, loss: 0.001155, time elapsed: 52.6ms [51/200]
epoch: 49, loss: 0.001148, time elapsed: 52.6ms [101/200]
epoch: 49, loss: 0.001159, time elapsed: 52.9ms [151/200]
epoch: 49, loss: 0.001153, time elapsed: 52.6ms [200/200]
End-to-End total time: 584.182409286499 s
Model Evaluation
We can use the following function to calculate the relative L2 error.
[8]:
from eval import calculate_l2_error
n_samps = 5000 # Number of test samples
ds_test, _ = create_dataset(geom_name, config, n_samps)
calculate_l2_error(model, ds_test, n_dim)
Relative L2 error (domain): 0.0310
Relative L2 error (bc): 0.0833