mindspore.mint.nn.functional.embedding
- mindspore.mint.nn.functional.embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False)[source]
Retrieve the word embeddings in weight using indices specified in input.
Warning
On Ascend, the behavior is unpredictable when the value of input is invalid.
- Parameters
input (Tensor) – The indices used to lookup in the weight. The data type must be mindspore.int32 or mindspore.int64, and the value should be in range [0, weight.shape[0]).
weight (Parameter) – The matrix where to lookup from. The shape must be 2D.
padding_idx (int, optional) – If the value is not None, the corresponding row of weight will not be updated in training. The value should be in range [-weight.shape[0], weight.shape[0]) if it's not
None
. DefaultNone
.max_norm (float, optional) – If not None, firstly get the p-norm result of the weight specified by input where p is specified by norm_type; if the result is larger then max_norm, update the weight with \(\frac{max\_norm}{result+1e^{-7}}\) in-place. Default
None
.norm_type (float, optional) – Indicates the value of p in p-norm. Default
2.0
.scale_grad_by_freq (bool, optional) – If
True
the gradients will be scaled by the inverse of frequency of the index in input. DefaultFalse
.
- Returns
Tensor, has the same data type as weight, the shape is \((*input.shape, weight.shape[1])\).
- Raises
ValueError – If padding_idx is out of valid range.
ValueError – If the shape of weight is invalid.
TypeError – weight is not a
mindspore.Parameter
.
- Supported Platforms:
Ascend
Examples
>>> import mindspore >>> import numpy as np >>> from mindspore import Tensor, Parameter, mint >>> input = Tensor([[1, 0, 1, 1], [0, 0, 1, 0]]) >>> weight = Parameter(np.random.randn(3, 3).astype(np.float32)) >>> output = mint.nn.functional.embedding(input, weight, max_norm=0.4) >>> print(output) [[[ 5.49015924e-02, 3.47811311e-01, -1.89771220e-01], [ 2.09307984e-01, -2.24846993e-02, 3.40124398e-01], [ 5.49015924e-02, 3.47811311e-01, -1.89771220e-01], [ 5.49015924e-02, 3.47811311e-01, -1.89771220e-01]], [[ 2.09307984e-01, -2.24846993e-02, 3.40124398e-01], [ 2.09307984e-01, -2.24846993e-02, 3.40124398e-01], [ 5.49015924e-02, 3.47811311e-01, -1.89771220e-01], [ 2.09307984e-01, -2.24846993e-02, 3.40124398e-01]]]