mindspore.nn.probability.distribution.Distribution

View Source On Gitee
class mindspore.nn.probability.distribution.Distribution(seed, dtype, name, param)[source]

Base class for all mathematical distributions.

Parameters
  • seed (int) – The seed is used in sampling. 0 is used if it is None.

  • dtype (mindspore.dtype) – The type of the event samples.

  • name (str) – The name of the distribution.

  • param (dict) – The parameters used to initialize the distribution.

Note

Derived class must override operations such as _mean, _prob, and _log_prob. Required arguments, such as value for _prob, must be passed in through args or kwargs. dist_spec_args which specifies a new distribution are optional.

dist_spec_args is unique for each type of distribution. For example, mean and sd are the dist_spec_args for a Normal distribution, while rate is the dist_spec_args for an Exponential distribution.

For all functions, passing in dist_spec_args, is optional. Function calls with the additional dist_spec_args passed in will evaluate the result with a new distribution specified by the dist_spec_args. However, it will not change the original distribution.

Supported Platforms:

Ascend GPU

cdf(value, *args, **kwargs)[source]

Evaluate the cumulative distribution function(cdf) at given value.

Parameters
  • value (Tensor) – value to be evaluated.

  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

A distribution can be optionally passed to the function by passing its dist_spec_args through args or kwargs.

Returns

Tensor, the cdf of the distribution.

construct(name, *args, **kwargs)[source]

Override construct in Cell.

Note

Names of supported functions include: ‘prob’, ‘log_prob’, ‘cdf’, ‘log_cdf’, ‘survival_function’, ‘log_survival’, ‘var’, ‘sd’, ‘mode’, ‘mean’, ‘entropy’, ‘kl_loss’, ‘cross_entropy’, ‘sample’, ‘get_dist_args’, and ‘get_dist_type’.

Parameters
  • name (str) – The name of the function.

  • *args (list) – A list of positional arguments that the function needs.

  • **kwargs (dict) – A dictionary of keyword arguments that the function needs.

Returns

Tensor, the value of corresponding computation method.

cross_entropy(dist, *args, **kwargs)[source]

Evaluate the cross_entropy between distribution a and b.

Parameters
  • dist (str) – type of the distribution.

  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

dist_spec_args of distribution b must be passed to the function through args or kwargs. Passing in dist_spec_args of distribution a is optional.

Returns

Tensor, the cross_entropy of two distributions.

entropy(*args, **kwargs)[source]

Evaluate the entropy.

Parameters
  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

A distribution can be optionally passed to the function by passing its dist_spec_args through args or kwargs.

Returns

Tensor, the entropy of the distribution.

get_dist_args(*args, **kwargs)[source]

Check the availability and validity of default parameters and dist_spec_args.

Parameters
  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

dist_spec_args must be passed in through list or dictionary. The order of dist_spec_args should follow the initialization order of default parameters through _add_parameter. If some dist_spec_args is None, the corresponding default parameter is returned.

Returns

list[Tensor], the list of parameters.

get_dist_type()[source]

Return the type of the distribution.

Returns

string, the name of distribution.

kl_loss(dist, *args, **kwargs)[source]

Evaluate the KL divergence, i.e. KL(a||b).

Parameters
  • dist (str) – type of the distribution.

  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

dist_spec_args of distribution b must be passed to the function through args or kwargs. Passing in dist_spec_args of distribution a is optional.

Returns

Tensor, the kl loss function of the distribution.

log_cdf(value, *args, **kwargs)[source]

Evaluate the log the cumulative distribution function(cdf) at given value.

Parameters
  • value (Tensor) – value to be evaluated.

  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

A distribution can be optionally passed to the function by passing its dist_spec_args through args or kwargs.

Returns

Tensor, the log cdf of the distribution.

log_prob(value, *args, **kwargs)[source]

Evaluate the log probability(pdf or pmf) at the given value.

Parameters
  • value (Tensor) – value to be evaluated.

  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

A distribution can be optionally passed to the function by passing its dist_spec_args through args or kwargs.

Returns

Tensor, the value of log probability.

log_survival(value, *args, **kwargs)[source]

Evaluate the log survival function at given value.

Parameters
  • value (Tensor) – value to be evaluated.

  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

A distribution can be optionally passed to the function by passing its dist_spec_args through args or kwargs.

Returns

Tensor, the log survival function of the distribution.

mean(*args, **kwargs)[source]

Evaluate the mean.

Parameters
  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

A distribution can be optionally passed to the function by passing its dist_spec_args through args or kwargs.

Returns

Tensor, the mean of the distribution.

mode(*args, **kwargs)[source]

Evaluate the mode.

Parameters
  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

A distribution can be optionally passed to the function by passing its dist_spec_args through args or kwargs.

Returns

Tensor, the mode of the distribution.

prob(value, *args, **kwargs)[source]

Evaluate the probability (pdf or pmf) at given value. For a discrete distribution, it is a probability mass function, while for a continuous distribution, it is probability density function.

Parameters
  • value (Tensor) – value to be evaluated.

  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

A distribution can be optionally passed to the function by passing its dist_spec_args through args or kwargs.

Returns

Tensor, the value of probability.

sample(*args, **kwargs)[source]

Sampling function.

Parameters
  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

A distribution can be optionally passed to the function by passing its dist_spec_args through args or kwargs.

Returns

Tensor, the sample generated from the distribution.

sd(*args, **kwargs)[source]

Evaluate the standard deviation.

Parameters
  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

A distribution can be optionally passed to the function by passing its dist_spec_args through args or kwargs.

Returns

Tensor, the standard deviation of the distribution.

survival_function(value, *args, **kwargs)[source]

Evaluate the survival function at given value.

Parameters
  • value (Tensor) – value to be evaluated.

  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

A distribution can be optionally passed to the function by passing its dist_spec_args through args or kwargs.

Returns

Tensor, the survival function of the distribution.

var(*args, **kwargs)[source]

Evaluate the variance.

Parameters
  • *args (list) – the list of positional arguments forwarded to subclasses.

  • **kwargs (dict) – the dictionary of keyword arguments forwarded to subclasses.

Note

A distribution can be optionally passed to the function by passing its dist_spec_args through args or kwargs.

Returns

Tensor, the variance of the distribution.