mindflow.pde

init

class mindflow.pde.Burgers(model, loss_fn='mse')[source]

Base class for Burgers 1-D problem based on PDEWithLoss.

Parameters
  • model (mindspore.nn.Cell) – Network for training.

  • loss_fn (str) – Define the loss function. Default: mse.

Supported Platforms:

Ascend GPU

Examples

>>> from mindflow.pde import Burgers
>>> from mindspore import nn, ops
>>> class Net(nn.Cell):
...     def __init__(self, cin=2, cout=1, hidden=10):
...         super().__init__()
...         self.fc1 = nn.Dense(cin, hidden)
...         self.fc2 = nn.Dense(hidden, hidden)
...         self.fcout = nn.Dense(hidden, cout)
...         self.act = ops.Tanh()
...
...     def construct(self, x):
...         x = self.act(self.fc1(x))
...         x = self.act(self.fc2(x))
...         x = self.fcout(x)
...         return x
>>> model = Net()
>>> problem = Burgers(model)
>>> print(problem.pde())
burgers: u(x, t)Derivative(u(x, t), x) + Derivative(u(x, t), t) - 0.00318309897556901Derivative(u(x, t), (x, 2))
    Item numbers of current derivative formula nodes: 3
{'burgers': u(x, t)Derivative(u(x, t), x) + Derivative(u(x, t), t) - 0.00318309897556901Derivative(u(x, t),
(x, 2))}
pde()[source]

Define Burgers 1-D governing equations based on sympy, abstract method.

Returns

dict, user defined sympy symbolic equations.

class mindflow.pde.NavierStokes(model, re=100.0, loss_fn='mse')[source]

2D NavierStokes equation problem based on PDEWithLoss.

Parameters
  • model (mindspore.nn.Cell) – network for training.

  • re (float) – reynolds number is the ratio of inertia force to viscous force of a fluid. It is a dimensionless quantity. Default: 100.0.

  • loss_fn (str) – Define the loss function. Default: mse.

Supported Platforms:

Ascend GPU

Examples

>>> from mindflow.pde import NavierStokes
>>> from mindspore import nn, ops
>>> class Net(nn.Cell):
...     def __init__(self, cin=3, cout=3, hidden=10):
...         super().__init__()
...         self.fc1 = nn.Dense(cin, hidden)
...         self.fc2 = nn.Dense(hidden, hidden)
...         self.fcout = nn.Dense(hidden, cout)
...         self.act = ops.Tanh()
...
...     def construct(self, x):
...         x = self.act(self.fc1(x))
...         x = self.act(self.fc2(x))
...         x = self.fcout(x)
...         return x
>>> model = Net()
>>> problem = NavierStokes(model)
>>> print(problem.pde())
momentum_x: u(x, y, t)Derivative(u(x, y, t), x) + v(x, y, t)Derivative(u(x, y, t), y) +
Derivative(p(x, y, t), x) + Derivative(u(x, y, t), t) - 0.00999999977648258Derivative(u(x, y, t), (x, 2)) -
0.00999999977648258Derivative(u(x, y, t), (y, 2))
    Item numbers of current derivative formula nodes: 6
momentum_y: u(x, y, t)Derivative(v(x, y, t), x) + v(x, y, t)Derivative(v(x, y, t), y) +
Derivative(p(x, y, t), y) + Derivative(v(x, y, t), t) - 0.00999999977648258Derivative(v(x, y, t), (x, 2)) -
0.00999999977648258Derivative(v(x, y, t), (y, 2))
    Item numbers of current derivative formula nodes: 6
continuty: Derivative(u(x, y, t), x) + Derivative(v(x, y, t), y)
    Item numbers of current derivative formula nodes: 2
{'momentum_x': u(x, y, t)Derivative(u(x, y, t), x) + v(x, y, t)Derivative(u(x, y, t), y) +
Derivative(p(x, y, t), x) + Derivative(u(x, y, t), t) - 0.00999999977648258Derivative(u(x, y, t), (x, 2)) -
0.00999999977648258Derivative(u(x, y, t), (y, 2)),
'momentum_y': u(x, y, t)Derivative(v(x, y, t), x) + v(x, y, t)Derivative(v(x, y, t), y) +
Derivative(p(x, y, t), y) + Derivative(v(x, y, t), t) - 0.00999999977648258Derivative(v(x, y, t), (x, 2)) -
0.00999999977648258Derivative(v(x, y, t), (y, 2)),
'continuty': Derivative(u(x, y, t), x) + Derivative(v(x, y, t), y)}
pde()[source]

Define governing equations based on sympy, abstract method.

Returns

dict, user defined sympy symbolic equations.

class mindflow.pde.PDEWithLoss(model, in_vars, out_vars)[source]

Base class of user-defined pde problems. All user-defined problems to set constraint on each dataset should be inherited from this class. It is utilized to establish the mapping between each sub-dataset and used-defined loss functions. The loss will be calculated automatically by the constraint type of each sub-dataset. Corresponding member functions must be out_channels by user based on the constraint type in order to obtain the target label output. For example, for dataset1 the constraint type is “pde”, so the member function “pde” must be overridden to tell that how to get the pde residual. The data(e.g. inputs) used to solve the residuals is passed to the parse_node, and the residuals of each equation can be automatically calculated.

Parameters
  • model (mindspore.nn.Cell) – Network for training.

  • in_vars (List[sympy.core.Symbol]) – Input variables of the model, represented by the sympy symbol.

  • out_vars (List[sympy.core.Function]) – Output variables of the model, represented by the sympy function.

Note

  • The member function, “pde”, must be overridden to define the symbolic derivative equqtions based on sympy.

  • The member function, “get_loss”, must be overridden to caluate the loss of symbolic derivative equqtions.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> from mindflow.pde import PDEWithLoss, sympy_to_mindspore
>>> from mindspore import nn, ops, Tensor
>>> from mindspore import dtype as mstype
>>> from sympy import symbols, Function, diff
>>> class Net(nn.Cell):
...     def __init__(self, cin=2, cout=1, hidden=10):
...         super().__init__()
...         self.fc1 = nn.Dense(cin, hidden)
...         self.fc2 = nn.Dense(hidden, hidden)
...         self.fcout = nn.Dense(hidden, cout)
...         self.act = ops.Tanh()
...
...     def construct(self, x):
...         x = self.act(self.fc1(x))
...         x = self.act(self.fc2(x))
...         x = self.fcout(x)
...         return x
>>> model = Net()
>>> class MyProblem(PDEWithLoss):
...     def __init__(self, model, loss_fn=nn.MSELoss()):
...         self.x, self.y = symbols('x t')
...         self.u = Function('u')(self.x, self.y)
...         self.in_vars = [self.x, self.y]
...         self.out_vars = [self.u]
...         super(MyProblem, self).__init__(model, in_vars=self.in_vars, out_vars=self.out_vars)
...         self.loss_fn = loss_fn
...         self.bc_nodes = sympy_to_mindspore(self.bc(), self.in_vars, self.out_vars)
...
...     def pde(self):
...         my_eq = diff(self.u, (self.x, 2)) + diff(self.u, (self.y, 2)) - 4.0
...         equations = {"my_eq": my_eq}
...         return equations
...
...     def bc(self):
...         bc_eq = diff(self.u, (self.x, 1)) + diff(self.u, (self.y, 1)) - 2.0
...         equations = {"bc_eq": bc_eq}
...         return equations
...
...     def get_loss(self, pde_data, bc_data):
...         pde_res = self.parse_node(self.pde_nodes, inputs=pde_data)
...         pde_loss = self.loss_fn(pde_res[0], Tensor(np.array([0.0]), mstype.float32))
...         bc_res = self.parse_node(self.bc_nodes, inputs=bc_data)
...         bc_loss = self.loss_fn(bc_res[0], Tensor(np.array([0.0]), mstype.float32))
...         return pde_loss + bc_loss
>>> problem = MyProblem(model)
>>> print(problem.pde())
>>> print(problem.bc())
my_eq: Derivative(u(x, t), (t, 2)) + Derivative(u(x, t), (x, 2)) - 4.0
    Item numbers of current derivative formula nodes: 3
bc_eq: Derivative(u(x, t), t) + Derivative(u(x, t), x) - 2.0
    Item numbers of current derivative formula nodes: 3
{'my_eq': Derivative(u(x, t), (t, 2)) + Derivative(u(x, t), (x, 2)) - 4.0}
{'bc_eq': Derivative(u(x, t), t) + Derivative(u(x, t), x) - 2.0}
get_loss()[source]

Compute all loss from user-defined derivative equations. This function must be overridden.

parse_node(formula_nodes, inputs=None, norm=None)[source]

Calculate the results for each formula node.

Parameters
  • formula_nodes (list[FormulaNode]) – List of expressions node can by identified by mindspore.

  • inputs (Tensor) – The input data of network. Default: None.

  • norm (Tensor) – The normal of the surface at a point P is a vector perpendicular to the tangent plane of the point. Default: None.

Returns

List(Tensor), the results of the partial differential equations.

pde()[source]

Governing equation based on sympy, abstract method. This function must be overridden, if the corresponding constraint is governing equation.

class mindflow.pde.Poisson(model, loss_fn='mse')[source]

Base class for Poisson 2-D problem based on PDEWithLoss.

Parameters
  • model (mindspore.nn.Cell) – network for training.

  • loss_fn (str) – Define the loss function. Default: mse.

Supported Platforms:

Ascend GPU

Examples

>>> from mindflow.pde import Poisson
>>> from mindspore import nn, ops
>>> class Net(nn.Cell):
...     def __init__(self, cin=2, cout=1, hidden=10):
...         super().__init__()
...         self.fc1 = nn.Dense(cin, hidden)
...         self.fc2 = nn.Dense(hidden, hidden)
...         self.fcout = nn.Dense(hidden, cout)
...         self.act = ops.Tanh()
...
...     def construct(self, x):
...         x = self.act(self.fc1(x))
...         x = self.act(self.fc2(x))
...         x = self.fcout(x)
...         return x
>>> model = Net()
>>> problem = Poisson(model)
>>> print(problem.pde())
poisson: Derivative(u(x, y), (x, 2)) + Derivative(u(x, y), (y, 2)) + 1.0
    Item numbers of current derivative formula nodes: 3
{'poisson': Derivative(u(x, y), (x, 2)) + Derivative(u(x, y), (y, 2)) + 1.0}
pde()[source]

Define Poisson 2-D governing equations based on sympy, abstract method.

Returns

dict, user defined sympy symbolic equations.

mindflow.pde.sympy_to_mindspore(equations, in_vars, out_vars)[source]

The sympy expression to create an identifier for mindspore.

Parameters
  • equations (dict) – the item in equations contains the key defined by user and the value is sympy expression.

  • in_vars (list[sympy.core.Symbol]) – list of all input variable symbols, consistent with the dimension of the input data.

  • out_vars (list[sympy.core.Function]) – list of all output variable symbols, consistent with the dimension of the output data.

Returns

List([FormulaNode]), list of expressions node can be identified by mindspore.

Supported Platforms:

Ascend GPU

Examples

>>> from mindflow.pde import sympy_to_mindspore
>>> from sympy import symbols, Function, diff
>>> x, y = symbols('x, y')
>>> u = Function('u')(x, y)
>>> in_vars = [x, y]
>>> out_vars = [u]
>>> eq1 = x + y
>>> eq2 = diff(u, (x, 1)) + diff(u, (y, 1))
>>> equations = {"eq1": eq1, "eq2": eq2}
>>> res = sympy_to_mindspore(equations, in_vars, out_vars)
>>> print(len(res))
eq1: x + y
    Item numbers of current derivative formula nodes: 2
eq2: Derivative(u(x, y), x) + Derivative(u(x, y), y)
    Item numbers of current derivative formula nodes: 2
2