Source code for mindelec.loss.constraints

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
get constraints function map.
"""
from __future__ import absolute_import

from mindspore import log as logger
from mindspore import nn

from ..data.dataset import Dataset
from ..architecture.util import check_mode

_CONSTRAINT_TYPES = ["Equation", "BC", "IC", "Label", "Function"]


class _Instance(nn.Cell):
    """
    Obtain the corresponding constraint function based on different conditions.

    Args:
        problem (Problem): The pde problem.
        constraint_type (str): The pde equation and constraint type of pde problem. Defaults: "Equation"

    Returns:
        Function, the constraint function of the problem.

    Supported Platforms:
        ``Ascend``
    """
    def __init__(self, problem, constraint_type="Equation"):
        super(_Instance, self).__init__()
        self.type = constraint_type
        self.problem = problem
        self.governing_equation = self.problem.governing_equation
        self.boundary_condition = self.problem.boundary_condition
        self.initial_condition = self.problem.initial_condition
        self.constraint_function = self.problem.constraint_function

        if self.type not in _CONSTRAINT_TYPES:
            raise ValueError("Unknown constrain type: {}, only: {} are supported".format(self.type, _CONSTRAINT_TYPES))

    def construct(self, *output, **kwargs):
        """Defines the computation to be performed"""
        if self.type == "Equation":
            return self.governing_equation(*output, **kwargs)
        if self.type == "BC":
            return self.boundary_condition(*output, **kwargs)
        if self.type == "IC":
            return self.initial_condition(*output, **kwargs)
        if self.type == "Function":
            return self.constraint_function(*output, **kwargs)
        if self.type == "Label":
            return self.constraint_function(*output, **kwargs)
        return None


[docs]class Constraints: """ Definition of the loss for all sub-dataset. Args: dataset (Dataset): The dataset including pde equation, boundary condition and initial condition. pde_dict(dict): The dict of pde problem. Supported Platforms: ``Ascend`` """ def __init__(self, dataset, pde_dict): check_mode("Constraints") if not isinstance(dataset, Dataset): raise TypeError("dataset should be an instance of Dataset") if not isinstance(pde_dict, dict): raise TypeError("pde_dict: {} should be dictionary, but got: {}".format(pde_dict, type(pde_dict))) self.dataset_columns_map = dataset.dataset_columns_map self.column_index_map = dataset.column_index_map self.dataset_constraint_map = dataset.dataset_constraint_map if not self.dataset_columns_map or not self.column_index_map or not self.dataset_constraint_map: raise ValueError("Maps info for dataset should not be none, please call create_dataset() first to " "avoid unexpected error") self.pde_dict = pde_dict self.dataset_cell_index_map = {} self.fn_cell_list = nn.CellList() self._get_loss_func() logger.info("check dataset_columns_map: {}".format(self.dataset_columns_map)) logger.info("check column_index_map: {}".format(self.column_index_map)) logger.info("check dataset_constraint_map: {}".format(self.dataset_constraint_map)) logger.info("check dataset_cell_index_map: {}".format(self.dataset_cell_index_map)) logger.info("check fn_cell_list: {}".format(self.fn_cell_list)) def _get_loss_func(self): """Get the loss fn map.""" index = -1 for dataset_name in self.dataset_columns_map.keys(): for name in self.pde_dict.keys(): if name is None: raise ValueError("pde_dict should not be None") if dataset_name in [name, name + "_BC", name + "_IC", name + "_domain"]: index += 1 self.dataset_cell_index_map[dataset_name] = index constraint_type = self.dataset_constraint_map[dataset_name] problem = _Instance(self.pde_dict[name], constraint_type) self.fn_cell_list.append(problem)