mindspore.nn.probability.bijector.bijector 源代码

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bijector"""
from mindspore import context
from mindspore.nn.cell import Cell
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore import _checkparam as validator
from ..distribution._utils.utils import CheckTensor, cast_to_tensor, raise_type_error
from ..distribution import Distribution
from ..distribution import TransformedDistribution


[文档]class Bijector(Cell): """ Bijecotr class. A bijector perform a mapping from one distribution to the other via some function. If :math:`X` is a random variable following the original distribution, and :math:`g(x)` is the mapping function, then :math:`Y = g(X)` is the random variable following the transformed distribution. Args: is_constant_jacobian (bool): Whether the Bijector has constant derivative. Default: ``False`` . is_injective (bool): Whether the Bijector is a one-to-one mapping. Default: ``True`` . name (str): The name of the Bijector. Default: ``None`` . dtype (mindspore.dtype): The type of the distributions that the Bijector can operate on. Default: ``None`` . param (dict): The parameters used to initialize the Bijector. Default: ``None`` . Note: - `dtype` of bijector represents the type of the distributions that the bijector could operate on. - When `dtype` is None, there is no enforcement on the type of input value except that the input value has to be float type. During initialization, when `dtype` is None, there is no enforcement on the dtype of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised. Specifically, the parameter type will follow the dtype of the input value. - Parameters of the bijector will be casted into the same type as input value when `dtype` is None. - When `dtype` is specified, it is forcing the parameters and input value to be the same dtype as `dtype`. When the type of parameters or the type of the input value is not the same as `dtype`, a TypeError will be raised. - Only subtype of mindspore.float_type can be used to specify bijector's `dtype`. Supported Platforms: ``Ascend`` ``GPU`` """ def __init__(self, is_constant_jacobian=False, is_injective=True, name=None, dtype=None, param=None): """ Constructor of Bijector class. """ super(Bijector, self).__init__() validator.check_value_type('name', name, [str], type(self).__name__) validator.check_value_type( 'is_constant_jacobian', is_constant_jacobian, [bool], name) validator.check_value_type('is_injective', is_injective, [bool], name) if dtype is not None: validator.check_type_name( "dtype", dtype, mstype.float_type, type(self).__name__) self._name = name self._dtype = dtype self._parameters = {} # parsing parameters for k in param.keys(): if k == 'param': continue if not(k == 'self' or k.startswith('_')): self._parameters[k] = param[k] # if no bijector is used as an argument during initialization if 'bijector' not in param.keys(): self._batch_shape = self._calc_batch_shape() self._is_scalar_batch = self._check_is_scalar_batch() self._is_constant_jacobian = is_constant_jacobian self._is_injective = is_injective self.context_mode = context.get_context('mode') self.checktensor = CheckTensor() # ops needed for the base class self.cast_base = P.Cast() self.dtype_base = P.DType() self.shape_base = P.Shape() self.sametypeshape_base = inner.SameTypeShape() self.issubclass_base = inner.IsSubClass() @property def name(self): return self._name @property def dtype(self): return self._dtype @property def parameters(self): return self._parameters @property def is_constant_jacobian(self): return self._is_constant_jacobian @property def is_injective(self): return self._is_injective @property def batch_shape(self): return self._batch_shape @property def is_scalar_batch(self): return self._is_scalar_batch def _check_value_dtype(self, value): """ Firstly check if the input value is Tensor. Then, if `self.dtype` is None, check if the input tensor is or can be directly cast into a float tensor. If `self.dtype` is not None, check if the input tensor's dtype is `self.dtype`. """ self.checktensor(value, 'input value of bijector') value_type = self.dtype_base(value) if self.dtype is None: if self.issubclass_base(value_type, mstype.float_): return value return raise_type_error('input value of bijector', value_type, mstype.float_) dtype_tensor = F.fill(self.dtype, self.shape_base(value), 0.0) self.sametypeshape_base(value, dtype_tensor) return value def _shape_mapping(self, shape): shape_tensor = F.fill(self.parameter_type, shape, 0.0) dist_shape_tensor = F.fill( self.parameter_type, self.batch_shape, 0.0) return (shape_tensor + dist_shape_tensor).shape def shape_mapping(self, shape): return self._shape_mapping(shape) def _add_parameter(self, value, name): """ Cast `value` to a tensor and add it to `self.default_parameters`. Add `name` into and `self.parameter_names`. """ # initialize the attributes if they do not exist yet if not hasattr(self, 'default_parameters'): self.default_parameters = [] self.parameter_names = [] self.common_dtype = None # cast value to a tensor if it is not None if isinstance(value, bool) or value is None: raise TypeError(f"{name} cannot be type {type(value)}") value_t = Tensor(value) # if the bijector's dtype is not specified if self.dtype is None: if self.common_dtype is None: self.common_dtype = value_t.dtype elif value_t.dtype != self.common_dtype: raise TypeError( f"{name} should have the same dtype as other arguments.") # check if the parameters are casted into float-type tensors validator.check_type_name( f"dtype of {name}", value_t.dtype, mstype.float_type, type(self).__name__) # check if the dtype of the input_parameter agrees with the bijector's dtype elif value_t.dtype != self.dtype: raise TypeError( f"{name} should have the same dtype as the bijector's dtype.") self.default_parameters += [value,] self.parameter_names += [name,] return value_t def _calc_batch_shape(self): """ Calculate batch_shape based on parameters. """ if 'param_dict' not in self.parameters: return None param_dict = self.parameters.get('param_dict') broadcast_shape_tensor = None for value in param_dict.values(): if value is None: return None if broadcast_shape_tensor is None: broadcast_shape_tensor = cast_to_tensor(value) else: value = cast_to_tensor(value) broadcast_shape_tensor = (value + broadcast_shape_tensor) return broadcast_shape_tensor.shape def _check_is_scalar_batch(self): """ Check if the parameters used during initialization are scalars. """ if 'param_dict' not in self.parameters.keys(): return False param_dict = self.parameters.get('param_dict') for value in param_dict.values(): if value is None: continue if not isinstance(value, (int, float)): return False return True def _check_value(self, value, name): """ Check availability of `value` as a Tensor. """ self.checktensor(value, name) return value
[文档] def cast_param_by_value(self, value, para): """ Converts the data type of `para` in the input to the same type as `value`. Typically used by subclasses of Bijector to convert data types of their own parameters. Args: value (Tensor): input value. para (Tensor): parameter(s) of the bijector. Returns: Tensor, the value of parameters after casting. """ local = self.cast_base(para, self.dtype_base(value)) return local
[文档] def forward(self, value, *args, **kwargs): """ Forward transformation: transform the input value to another distribution. Args: value (Tensor): the value of the input variables. *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. Returns: Tensor, the value of the transformed random variable. """ return self._forward(value, *args, **kwargs)
[文档] def inverse(self, value, *args, **kwargs): """ Inverse transformation: transform the input value back to the original distribution. Args: value (Tensor): the value of the transformed variables. *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. Returns: Tensor, the value of the input random variable. """ return self._inverse(value, *args, **kwargs)
[文档] def forward_log_jacobian(self, value, *args, **kwargs): """ Logarithm of the derivative of the forward transformation. Args: value (Tensor): the value of the input variables. *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. Returns: Tensor, outputs the value of a random variable after mapping. """ return self._forward_log_jacobian(value, *args, **kwargs)
[文档] def inverse_log_jacobian(self, value, *args, **kwargs): """ Logarithm of the derivative of the inverse transformation. Args: value (Tensor): the value of the transformed variables. *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. Returns: Tensor, the value of logarithm of the derivative of the inverse transformation. """ return self._inverse_log_jacobian(value, *args, **kwargs)
def __call__(self, *args, **kwargs): """ Call Bijector directly. This __call__ may go into two directions: If args[0] is a distribution instance, the call will generate a new distribution derived from the input distribution. Otherwise, input[0] must be the name of a Bijector function, e.g. "forward", then this call will go in the construct and invoke the corresponding Bijector function. Args: *args: args[0] shall be either a distribution or the name of a Bijector function. """ if isinstance(args[0], Distribution): return TransformedDistribution(self, args[0]) return super(Bijector, self).__call__(*args, **kwargs)
[文档] def construct(self, name, *args, **kwargs): """ Override `construct` in Cell. Note: Names of supported functions include: 'forward', 'inverse', 'forward_log_jacobian', and 'inverse_log_jacobian'. Args: name (str): The name of the function. *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. Returns: Tensor, the result of the function corresponding to name. """ if name == 'forward': return self.forward(*args, **kwargs) if name == 'inverse': return self.inverse(*args, **kwargs) if name == 'forward_log_jacobian': return self.forward_log_jacobian(*args, **kwargs) if name == 'inverse_log_jacobian': return self.inverse_log_jacobian(*args, **kwargs) raise Exception('Invalid name')