# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bijector"""
from mindspore import context
from mindspore.nn.cell import Cell
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import CheckTensor, cast_to_tensor, raise_type_error
from ..distribution import Distribution
from ..distribution import TransformedDistribution
[docs]class Bijector(Cell):
"""
Bijecotr class.
Args:
is_constant_jacobian (bool): Whether the Bijector has constant derivative. Default: False.
is_injective (bool): Whether the Bijector is a one-to-one mapping. Default: True.
name (str): The name of the Bijector. Default: None.
dtype (mindspore.dtype): The type of the distributions that the Bijector can operate on. Default: None.
param (dict): The parameters used to initialize the Bijector. Default: None.
Supported Platforms:
``Ascend`` ``GPU``
Note:
`dtype` of bijector represents the type of the distributions that the bijector could operate on.
When `dtype` is None, there is no enforcement on the type of input value except that the input value
has to be float type. During initialization, when `dtype` is None, there is no enforcement on the dtype
of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised.
Specifically, the parameter type will follow the dtype of the input value, i.e. parameters of the bijector
will be casted into the same type as input value when `dtype` is None.
When `dtype` is specified, it is forcing the parameters and input value to be the same dtype as `dtype`.
When the type of parameters or the type of the input value is not the same as `dtype`, a TypeError will be
raised. Only subtype of mindspore.float_type can be used to specify bijector's `dtype`.
"""
def __init__(self,
is_constant_jacobian=False,
is_injective=True,
name=None,
dtype=None,
param=None):
"""
Constructor of Bijector class.
"""
super(Bijector, self).__init__()
validator.check_value_type('name', name, [str], type(self).__name__)
validator.check_value_type(
'is_constant_jacobian', is_constant_jacobian, [bool], name)
validator.check_value_type('is_injective', is_injective, [bool], name)
if dtype is not None:
validator.check_type_name(
"dtype", dtype, mstype.float_type, type(self).__name__)
self._name = name
self._dtype = dtype
self._parameters = {}
# parsing parameters
for k in param.keys():
if k == 'param':
continue
if not(k == 'self' or k.startswith('_')):
self._parameters[k] = param[k]
# if no bijector is used as an argument during initialization
if 'bijector' not in param.keys():
self._batch_shape = self._calc_batch_shape()
self._is_scalar_batch = self._check_is_scalar_batch()
self._is_constant_jacobian = is_constant_jacobian
self._is_injective = is_injective
self.context_mode = context.get_context('mode')
self.checktensor = CheckTensor()
# ops needed for the base class
self.cast_base = P.Cast()
self.dtype_base = P.DType()
self.shape_base = P.Shape()
self.fill_base = P.Fill()
self.sametypeshape_base = P.SameTypeShape()
self.issubclass_base = P.IsSubClass()
@property
def name(self):
return self._name
@property
def dtype(self):
return self._dtype
@property
def parameters(self):
return self._parameters
@property
def is_constant_jacobian(self):
return self._is_constant_jacobian
@property
def is_injective(self):
return self._is_injective
@property
def batch_shape(self):
return self._batch_shape
@property
def is_scalar_batch(self):
return self._is_scalar_batch
def _check_value_dtype(self, value):
"""
Firstly check if the input value is Tensor. Then, if `self.dtype` is None, check
if the input tensor is or can be directly cast into a float tensor.
If `self.dtype` is not None, check if the input tensor's dtype is `self.dtype`.
"""
self.checktensor(value, 'input value of bijector')
value_type = self.dtype_base(value)
if self.dtype is None:
if self.issubclass_base(value_type, mstype.float_):
return value
return raise_type_error('input value of bijector', value_type, mstype.float_)
dtype_tensor = self.fill_base(self.dtype, self.shape_base(value), 0.0)
self.sametypeshape_base(value, dtype_tensor)
return value
def _shape_mapping(self, shape):
shape_tensor = self.fill_base(self.parameter_type, shape, 0.0)
dist_shape_tensor = self.fill_base(
self.parameter_type, self.batch_shape, 0.0)
return (shape_tensor + dist_shape_tensor).shape
def shape_mapping(self, shape):
return self._shape_mapping(shape)
def _add_parameter(self, value, name):
"""
Cast `value` to a tensor and add it to `self.default_parameters`.
Add `name` into and `self.parameter_names`.
"""
# initialize the attributes if they do not exist yet
if not hasattr(self, 'default_parameters'):
self.default_parameters = []
self.parameter_names = []
self.common_dtype = None
# cast value to a tensor if it is not None
if isinstance(value, bool) or value is None:
raise TypeError(f"{name} cannot be type {type(value)}")
value_t = Tensor(value)
# if the bijector's dtype is not specified
if self.dtype is None:
if self.common_dtype is None:
self.common_dtype = value_t.dtype
elif value_t.dtype != self.common_dtype:
raise TypeError(
f"{name} should have the same dtype as other arguments.")
# check if the parameters are casted into float-type tensors
validator.check_type_name(
f"dtype of {name}", value_t.dtype, mstype.float_type, type(self).__name__)
# check if the dtype of the input_parameter agrees with the bijector's dtype
elif value_t.dtype != self.dtype:
raise TypeError(
f"{name} should have the same dtype as the bijector's dtype.")
self.default_parameters += [value,]
self.parameter_names += [name,]
return value_t
def _calc_batch_shape(self):
"""
Calculate batch_shape based on parameters.
"""
if 'param_dict' not in self.parameters.keys():
return None
param_dict = self.parameters['param_dict']
broadcast_shape_tensor = None
for value in param_dict.values():
if value is None:
return None
if broadcast_shape_tensor is None:
broadcast_shape_tensor = cast_to_tensor(value)
else:
value = cast_to_tensor(value)
broadcast_shape_tensor = (value + broadcast_shape_tensor)
return broadcast_shape_tensor.shape
def _check_is_scalar_batch(self):
"""
Check if the parameters used during initialization are scalars.
"""
if 'param_dict' not in self.parameters.keys():
return False
param_dict = self.parameters['param_dict']
for value in param_dict.values():
if value is None:
continue
if not isinstance(value, (int, float)):
return False
return True
def _check_value(self, value, name):
"""
Check availability of `value` as a Tensor.
"""
self.checktensor(value, name)
return value
[docs] def cast_param_by_value(self, value, para):
"""
Cast the parameter(s) of the bijector to be the same type of input_value.
"""
local = self.cast_base(para, self.dtype_base(value))
return local
[docs] def forward(self, value, *args, **kwargs):
"""
Forward transformation: transform the input value to another distribution.
"""
return self._forward(value, *args, **kwargs)
[docs] def inverse(self, value, *args, **kwargs):
"""
Inverse transformation: transform the input value back to the original distribution.
"""
return self._inverse(value, *args, **kwargs)
[docs] def forward_log_jacobian(self, value, *args, **kwargs):
"""
Logarithm of the derivative of the forward transformation.
"""
return self._forward_log_jacobian(value, *args, **kwargs)
[docs] def inverse_log_jacobian(self, value, *args, **kwargs):
"""
Logarithm of the derivative of the inverse transformation.
"""
return self._inverse_log_jacobian(value, *args, **kwargs)
def __call__(self, *args, **kwargs):
"""
Call Bijector directly.
This __call__ may go into two directions:
If args[0] is a distribution instance, the call will generate a new distribution derived from
the input distribution.
Otherwise, input[0] must be the name of a Bijector function, e.g. "forward", then this call will
go in the construct and invoke the corresponding Bijector function.
Args:
*args: args[0] shall be either a distribution or the name of a Bijector function.
"""
if isinstance(args[0], Distribution):
return TransformedDistribution(self, args[0])
return super(Bijector, self).__call__(*args, **kwargs)
[docs] def construct(self, name, *args, **kwargs):
"""
Override `construct` in Cell.
Note:
Names of supported functions include:
'forward', 'inverse', 'forward_log_jacobian', and 'inverse_log_jacobian'.
Args:
name (str): The name of the function.
*args (list): A list of positional arguments that the function needs.
**kwargs (dict): A dictionary of keyword arguments that the function needs.
"""
if name == 'forward':
return self.forward(*args, **kwargs)
if name == 'inverse':
return self.inverse(*args, **kwargs)
if name == 'forward_log_jacobian':
return self.forward_log_jacobian(*args, **kwargs)
if name == 'inverse_log_jacobian':
return self.inverse_log_jacobian(*args, **kwargs)
return None