mindspore.ops.EmbeddingLookup
- class mindspore.ops.EmbeddingLookup[源代码]
根据指定的索引,返回输入Tensor的切片。
此算子在 axis = 0 上的运行与GatherV2的功能相似,只是多一个 offset 输入。
输入:
input_params (Tensor) - shape为 \((x_1, x_2, ..., x_R)\) 的Tensor。是一个Tensor切片。当前,只支持二维。
input_indices (Tensor) - shape为 \((y_1, y_2, ..., y_S)\) 的Tensor。指定输入Tensor元素的索引。当取值超出 input_params 在该维度的最大长度时,超出部分将返回0值。不支持负值,否则结果将未定义。其数据类型为int32或int64。
offset (int) - 指定 input_params 切片的偏移值。实际索引等于 input_indices 减去 offset 。
输出:
Tensor,shape为 \((z_1, z_2, ..., z_N)\) 的Tensor。数据类型与 input_params 相同。
异常:
TypeError - input_indices 的数据类型不是int。
ValueError - input_params 的shape长度大于2。
- 支持平台:
Ascend
CPU
GPU
样例:
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) >>> offset = 4 >>> output = ops.EmbeddingLookup()(input_params, input_indices, offset) >>> print(output) [[[10. 11.] [ 0. 0.]] [[ 0. 0.] [10. 11.]]]