# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Defines parameter operators with functional form."""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
from mindspore.common import dtype as mstype
from mindspore.common.seed import _get_graph_seed
from mindspore.common.tensor import Tensor
from mindspore.ops.operations.random_ops import RandomShuffle, RandomChoiceWithMask
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.ops._utils import get_broadcast_shape
[文档]def random_gamma(shape, alpha, seed=0, seed2=0):
r"""
Outputs random values from the Gamma distribution(s) described by alpha.
Args:
shape (Tensor): The shape of random tensor to be generated.
Must be one of the following types: int32, int64. 1-D integer tensor.
alpha (Tensor): The alpha α distribution parameter.
A Tensor. Must be one of the following types: half, float32, float64.
seed (int): Seed is used as entropy source for the random number engines to generate
pseudo-random numbers, must be non-negative. Default: None, which will be treated as 0.
seed2 (int): Seed2 is used as entropy source for the random number engines to generate
pseudo-random numbers, must be non-negative. Default: None, which will be treated as 0.
Returns:
Tensor. The shape should be equal to the concat shape between the input `shape` and the broadcast
of `alpha`.
The dtype is the same type as alpha.
Raises:
TypeError: If `shape` is not a Tensor.
TypeError: If `alpha` is not a Tensor.
TypeError: If `seed` is not an int.
TypeError: If dtype of `alpha` is not half, float32 or float64.
Supported Platforms:
``CPU``
Examples:
>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, ops
>>> shape = Tensor(np.array([7, 5]), mindspore.int32)
>>> alpha = Tensor(np.array([0.5, 1.5]), mindspore.float32)
>>> output = ops.random_gamma(shape, alpha, seed=5)
>>> result = output.shape
>>> print(result)
(7, 5, 2)
"""
alpha_type = P.DType()(alpha)
beta = Tensor(np.array([1.0]), alpha_type)
alpha_shape = P.Shape()(alpha)
beta_shape = P.Shape()(beta)
broadcast_shape = get_broadcast_shape(alpha_shape, beta_shape, "random_gamma", arg_name1="alpha", arg_name2="beta")
broadcast_shape_t = tuple(broadcast_shape)
broadcast_to = P.BroadcastTo(broadcast_shape_t)
alpha_broadcast = broadcast_to(alpha)
random_gamma_op = _get_cache_prim(P.RandomGamma)(seed=seed, seed2=seed2)
output = random_gamma_op(shape, alpha_broadcast)
return output
@constexpr(reuse_result=False)
def _get_seed(op_seed, kernel_name):
"Get the graph-level seed."
return _get_graph_seed(op_seed, kernel_name)
[文档]def standard_laplace(shape, seed=0, seed2=0):
r"""
Generates random numbers according to the Laplace random number distribution (mean=0, lambda=1).
It is defined as:
.. math::
\text{f}(x) = \frac{1}{2}\exp(-|x|),
Args:
shape (Union[tuple, Tensor]): The shape of random tensor to be generated. Only constant value is allowed
when the input type is tuple. And the operator supports dynamic shape only when the input type is Tensor.
seed (int): Random seed. Default: 0.
seed2 (int): Random seed2. Default: 0.
Returns:
Tensor. The shape that the input 'shape' denotes. The dtype is float32.
Raises:
TypeError: If seed or seed2 is not an int.
TypeError: If shape is neither a tuple nor a Tensor.
ValueError: If seed or seed2 is not a non-negative int.
ValueError: If shape is a tuple containing non-positive items.
ValueError: If shape is a Tensor, and the rank of the Tensor is not equal to 1.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import ops
>>> shape = (4, 4)
>>> output = ops.standard_laplace(shape)
>>> result = output.shape
>>> print(result)
(4, 4)
"""
standard_laplace_op = _get_cache_prim(P.StandardLaplace)(seed=seed, seed2=seed2)
return standard_laplace_op(shape)
[文档]def random_categorical(logits, num_sample, seed=0, dtype=mstype.int64):
r"""
Generates random samples from a given categorical distribution tensor.
Args:
logits (Tensor): The input tensor. 2-D Tensor with shape :math:`(batch\_size, num\_classes)`.
num_sample (int): Number of sample to be drawn. Only constant values is allowed.
seed (int): Random seed. Only constant values is allowed. Default: 0.
dtype (mindspore.dtype): The type of output. Its value must be one of mindspore.int16,
mindspore.int32 and mindspore.int64. Default: mindspore.int64.
Returns:
Tensor, The output Tensor with shape :math:`(batch\_size, num\_samples)`.
Raises:
TypeError: If `dtype` is not one of the following: mindspore.int16, mindspore.int32, mindspore.int64.
TypeError: If `logits` is not a Tensor.
TypeError: If neither `num_sample` nor `seed` is an int.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import ops
>>> from mindspore import Tensor
>>> import mindspore.common.dtype as mstype
>>> import numpy as np
>>> logits = Tensor(np.random.random((10, 5)).astype(np.float32), mstype.float32)
>>> net = ops.random_categorical(logits, 8)
>>> result = net.shape
>>> print(result)
(10, 8)
"""
random_categorical_ = P.RandomCategorical(dtype)
return random_categorical_(logits, num_sample, seed)
def multinomial_with_replacement(x, seed, offset, numsamples, replacement=False):
r"""
Returns a tensor where each row contains `numsamples` elements sampled from the multinomial distribution.
Note:
The rows of input do not need to sum to one (in which case we use the values as weights),
but must be non-negative, finite and have a non-zero sum.
Args:
x (Tensor): the input tensor containing the cumsum of probabilities, must be 1 or 2
dimensions. Must be one of the following types: float16, float32, float64.
seed (int): If seed is set to be -1, and offset is set to be 0, the random number
generator is seeded by a random seed. Otherwise, it is seeded by the given seed.
offset (int): To avoid seed collision.
numsamples (int): the number of samples to draw.
replacement (bool): Whether to draw with replacement or not. Defaults to false.
Returns:
Tensor with the same rows as `x`, each row has numsamples sampled indices.
Raises:
TypeError: If `x` is not a Tensor whose dtype is float16, float32, float64.
TypeError: If `numsamples` is not an int.
TypeError: If `replacement` is not a bool.
ValueError: If `x` rank is not 1 or 2.
ValueError: If the value of `numsamples` must larger than x_shape[-1], when `replacement` is false.
ValueError: If the sum of one row of `x` less than 0.
ValueError: If one of the element of each row of `x` less than 0.
ValueError: If `numsamples` equal or less than 0.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> x = Tensor([[0., 9., 4., 0.]], mstype.float32)
>>> output = multinomialwithreplacement(x, 2, 5, 2, True)
>>> print(output)
[[1 1]]
"""
multinomial_with_replacement_ = _get_cache_prim(P.MultinomialWithReplacement) \
(numsamples=numsamples, replacement=replacement)
return multinomial_with_replacement_(x, seed, offset)
[文档]def standard_normal(shape, seed=0, seed2=0):
r"""
Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
Returns the tensor with the given shape, the random numbers in it drawn from normal distributions
whose mean is 0 and standard deviation is 1.
.. math::
f(x)=\frac{1}{\sqrt{2 \pi}} e^{\left(-\frac{x^{2}}{2}\right)}
Args:
shape (Union[tuple, Tensor]): The shape of random tensor to be generated. Only constant value is allowed
when the input type is tuple. And the operator supports dynamic shape only when the input type is Tensor.
seed (int): Random seed, must be non-negative. Default: 0.
seed2 (int): Random seed2, must be non-negative. A second seed to avoid seed collision. Default: 0.
Returns:
Tensor. The shape that the input 'shape' denotes. The dtype is float32.
Raises:
TypeError: If `seed` or `seed2` is not an int.
TypeError: If `shape` is neither a tuple nor a Tensor.
ValueError: If `seed` or `seed2` is not a non-negative int.
ValueError: If `shape` is a tuple containing non-positive items.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import ops
>>> shape = (4, 4)
>>> output = ops.standard_normal(shape)
>>> result = output.shape
>>> print(result)
(4, 4)
"""
standard_normal_op = _get_cache_prim(P.StandardNormal)(seed=seed, seed2=seed2)
return standard_normal_op(shape)
[文档]def random_poisson(shape, rate, seed=None, dtype=mstype.float32):
r"""
Generates random number Tensor with shape `shape` according to a Poisson distribution with mean `rate`.
.. math::
\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}
Args:
shape (Tensor): The shape of random tensor to be sampled from each poisson distribution, 1-D `Tensor` whose
dtype is mindspore.dtype.int32 or mindspore.dtype.int64.
rate (Tensor): The μ parameter the distribution is constructed with. It represents the mean of the distribution
and also the variance of the distribution. It should be a `Tensor` whose dtype is mindspore.dtype.int64,
mindspore.dtype.int32, mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16.
seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers
and must be non-negative. Default: None, which will be treated as 0.
dtype (mindspore.dtype): The data type of output: mindspore.dtype.int64, mindspore.dtype.int32,
mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16. Default: mindspore.dtype.float32.
Returns:
A Tensor whose shape is `mindspore.concat(['shape', mindspore.shape('rate')], axis=0)` and data type is equal to
argument `dtype`.
Raises:
TypeError: If `shape` is not a Tensor.
TypeError: If datatype of `shape` is not mindspore.dtype.int64 nor mindspore.dtype.int32.
ValueError: If shape of `shape` is not 1-D.
TypeError: If `rate` is not a Tensor nor a scalar.
TypeError: If datatype of `rate` is not in [mindspore.dtype.int64, mindspore.dtype.int32,
mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16].
TypeError: If `seed` is not a non-negtive int.
TypeError: If `dtype` is not in [mindspore.dtype.int64, mindspore.dtype.int32, mindspore.dtype.float64,
mindspore.dtype.float32 nor mindspore.dtype.float16].
ValueError: If any element of input `shape` tensor is not positive.
Supported Platforms:
``CPU``
Examples:
>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> # case 1: 1-D shape, 2-D rate, float64 output
>>> shape = Tensor(np.array([2, 2]), mindspore.int64)
>>> rate = Tensor(np.array([[5.0, 10.0], [5.0, 1.0]]), mindspore.float32)
>>> output = ops.random_poisson(shape, rate, seed=5, dtype=mindspore.float64)
>>> print(output.shape, output.dtype)
(2, 2, 2, 2) float64
>>> # case 2: 1-D shape, scalar rate, int64 output
>>> shape = Tensor(np.array([2, 2]), mindspore.int64)
>>> rate = Tensor(5.0, mindspore.float64)
>>> output = ops.random_poisson(shape, rate, seed=5, dtype=mindspore.int64)
>>> print(output.shape, output.dtype)
(2, 2) Int64
"""
seed1, seed2 = _get_seed(seed, "random_poisson")
prim_random_poisson = P.random_ops.RandomPoisson(seed1, seed2, dtype)
value = prim_random_poisson(shape, rate)
return value
[文档]def shuffle(x, seed=None):
r"""
Randomly shuffles a Tensor along its first dimension.
Args:
x (Tensor): The Tensor need be shuffled.
seed (int): Random seed used for random number generation, must be non-negative. If `seed` is 0, which will be
replaced with a randomly generated value. Default: None, which will be treated as 0.
Returns:
Tensor. The shape and type are the same as the input `x`.
Raises:
TypeError: If data type of `seed` is not None or non-negative int.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([1, 2, 3, 4]), mstype.float32)
>>> output = ops.shuffle(x, seed=1)
>>> print(output.shape)
(4,)
"""
seed, seed2 = _get_seed(seed, "shuffle")
random_shuffle_ = _get_cache_prim(RandomShuffle)(seed=seed, seed2=seed2)
output = random_shuffle_(x)
return output
[文档]def choice_with_mask(input_x, count=256, seed=0, seed2=0):
"""
Generates a random sample as index tensor with a mask tensor from a given tensor.
The `input_x` must be a tensor whose rank is not less than 1. If its rank is greater than or equal to 2,
the first dimension specifies the number of samples.
The returned index tensor denotes the index of the nonzero
sample, the mask tensor denotes which elements in the index tensor are valid.
Args:
input_x (Tensor[bool]): The input tensor.
The input tensor rank must be greater than or equal to 1 and less than or equal to 5.
count (int): Number of items expected to get and the number must be greater than 0. Default: 256.
seed (int): Random seed. Default: 0.
seed2 (int): Random seed2. Default: 0.
Returns:
Two tensors, the first one is the index tensor and the other one is the mask tensor.
- **index** (Tensor) - The output shape is 2-D.
- **mask** (Tensor) - The output shape is 1-D.
Raises:
TypeError: If `count` is not an int.
TypeError: If neither `seed` nor `seed2` is an int.
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.ones(shape=[240000, 4]).astype(np.bool))
>>> output_y, output_mask = ops.choice_with_mask(input_x)
>>> result = output_y.shape
>>> print(result)
(256, 2)
>>> result = output_mask.shape
>>> print(result)
(256,)
"""
choice_with_mask_ = _get_cache_prim(RandomChoiceWithMask)(count=count, seed=seed, seed2=seed2)
output = choice_with_mask_(input_x)
return output
__all__ = [
'standard_laplace', 'random_categorical', 'uniform', 'standard_normal', 'random_gamma',
'uniform_candidate_sampler', 'random_poisson', 'log_uniform_candidate_sampler', 'shuffle', 'choice_with_mask'
]
__all__.sort()