mindspore.nn.probability.distribution.TransformedDistribution

class mindspore.nn.probability.distribution.TransformedDistribution(bijector, distribution, seed=None, name='transformed_distribution')[source]

Transformed Distribution. This class contains a bijector and a distribution and transforms the original distribution to a new distribution through the operation defined by the bijector.

Parameters
  • bijector (Bijector) – The transformation to perform.

  • distribution (Distribution) – The original distribution. Must has a float dtype.

  • seed (int) – The seed is used in sampling. The global seed is used if it is None. Default:None. If this seed is given when a TransformedDistribution object is initialized, the object’s sampling function will use this seed; elsewise, the underlying distribution’s seed will be used.

  • name (str) – The name of the transformed distribution. Default: ‘transformed_distribution’.

Supported Platforms:

Ascend GPU

Note

The arguments used to initialize the original distribution cannot be None. For example, mynormal = msd.Normal(dtype=mindspore.float32) cannot be used to initialized a TransformedDistribution since mean and sd are not specified. batch_shape is the batch_shape of the original distribution. broadcast_shape is the broadcast shape between the original distribution and bijector. is_scalar_batch is only true if both the original distribution and the bijector are scalar batches. default_parameters, parameter_names and parameter_type are set to be consistent with the original distribution. Derived class can overwrite default_parameters and parameter_names by calling reset_parameters followed by add_parameter.

Examples

>>> import mindspore
>>> import mindspore.context as context
>>> import mindspore.nn as nn
>>> import mindspore.nn.probability.distribution as msd
>>> import mindspore.nn.probability.bijector as msb
>>> from mindspore import Tensor
>>> context.set_context(mode=1)
>>>
>>> # To initialize a transformed distribution
>>> # using a Normal distribution as the base distribution,
>>> # and an Exp bijector as the bijector function.
>>> trans_dist = msd.TransformedDistribution(msb.Exp(), msd.Normal(0.0, 1.0))
>>>
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32)
>>> prob = trans_dist.prob(value)
>>> print(prob.shape)
(3,)
>>> sample = trans_dist.sample(shape=(2, 3))
>>> print(sample.shape)
(2, 3)