mindspore.ops.EmbeddingLookup

class mindspore.ops.EmbeddingLookup(*args, **kwargs)[source]

Returns a slice of input tensor based on the specified indices.

This Primitive has the similar functionality as GatherV2 operating on axis = 0, but has one more inputs: offset.

Inputs:
  • input_params (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\). This represents a Tensor slice, instead of the entire Tensor. Currently, the dimension is restricted to be 2.

  • input_indices (Tensor) - The shape of tensor is \((y_1, y_2, ..., y_S)\). Specifies the indices of elements of the original Tensor. Values can be out of range of input_params, and the exceeding part will be filled with 0 in the output. Values does not support negative and the result is undefined if values are negative. The data type should be int32 or int64.

  • offset (int) - Specifies the offset value of this input_params slice. Thus the real indices are equal to input_indices minus offset.

Outputs:

Tensor, the shape of tensor is \((z_1, z_2, ..., z_N)\). The data type is the same with input_params.

Raises
  • TypeError – If dtype of input_indices is not int.

  • ValueError – If length of shape of input_params is greater than 2.

Supported Platforms:

Ascend CPU GPU

Examples

>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
>>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
>>> offset = 4
>>> output = ops.EmbeddingLookup()(input_params, input_indices, offset)
>>> print(output)
[[[10. 11.]
  [ 0.  0.]]
 [[ 0.  0.]
  [10. 11.]]]