mindflow.data.equation 源代码

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#pylint: disable=W0223
#pylint: disable=W0221
"""
Sampling data of equation domain.
"""
from __future__ import absolute_import
import numpy as np

from mindspore import log as logger

from .data_base import Data
from ..geometry import Geometry, SamplingConfig
from ..utils.check_func import check_param_type

_SPACE = ""


[文档]class Equation(Data): """ Sampling data of equation domain. Args: geometry (Geometry): specifies geometry information of equation domain. Raises: TypeError: if geometry is not instance of class Geometry. ValueError: if sampling_config of geometry is None. KeyError: if sampling_config.domain of geometry is None. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindflow.geometry import generate_sampling_config, Geometry >>> from mindflow.data import Equation >>> geometry_config = dict({'domain' : dict({'random_sampling' : True, 'size' : 100, 'sampler' : 'uniform',})}) >>> sampling_config = generate_sampling_config(geometry_config) >>> geom = Geometry("geom", 1, 0.0, 1.0, sampling_config=sampling_config) >>> boundary = Equation(geometry=geom) """ def __init__(self, geometry): check_param_type(geometry, "geometry", data_type=Geometry) check_param_type(geometry.sampling_config, _SPACE.join(("sampling_config of geometry:", geometry.name)), data_type=SamplingConfig) self.sampling_config = geometry.sampling_config.domain if not self.sampling_config: raise ValueError("domain info of geometry: {} should not be None".format(geometry.name)) self.geometry = geometry self.data = None self.data_size = None self.batch_size = 1 self.shuffle = False self.batched_data_size = None self.columns_list = None self._domain_index = None self._domain_index_num = 0 self._random_merge = self.sampling_config.random_merge name = geometry.name + "_domain" columns_list = [geometry.name + "_domain_points"] constraint_type = "Equation" super(Equation, self).__init__(name, columns_list, constraint_type) def _get_sampling_data(self): sample_data = self.geometry.sampling(geom_type="domain") return sample_data, self.geometry.columns_dict["domain"] def _initialization(self, batch_size=1, shuffle=False): """initialization: sampling and set attrs.""" data, self.columns_list = self._get_sampling_data() if not isinstance(data, tuple): data = (data,) self.data = data self.data_size = len(self.data[0]) self.batch_size = batch_size if batch_size > self.data_size: raise ValueError("If prebatch data, batch_size: {} should not be larger than data size: {}.".format( batch_size, self.data_size )) self.batched_data_size = self.data_size // batch_size self.shuffle = shuffle self._domain_index = np.arange(self.data_size) logger.info("Get domain dataset: {}, columns: {}, size: {}, batched_size: {}, shuffle: {}".format( self.name, self.columns_list, self.data_size, self.batched_data_size, self.shuffle)) return data def _get_index_when_sample_iter(self, index): if self._domain_index_num == self.batched_data_size: self.data = self._initialization(self.batch_size, self.shuffle) self._domain_index_num = 0 index = self._domain_index_num self._domain_index_num += 1 return index def _get_index_when_sample_all(self, index): data_size = self.__len__() if (self._random_merge or self.shuffle) and index % data_size == 0: self._domain_index = np.random.permutation(self.data_size) index = index % data_size if index >= data_size else index return index def __getitem__(self, index): if not self.data: self._initialization() if self.sampling_config.random_sampling: index = self._get_index_when_sample_iter(index) else: index = self._get_index_when_sample_all(index) col_data = None for i in range(len(self.columns_list)): temp_data = self.data[i][self._domain_index[index]] if self.batch_size == 1 else \ self.data[i][self._domain_index[index * self.batch_size : (index + 1) * self.batch_size]] col_data = (temp_data,) if col_data is None else col_data + (temp_data,) return col_data def __len__(self): if not self.data: self._initialization() return self.batched_data_size