mindspore.nn.MultiFieldEmbeddingLookup

class mindspore.nn.MultiFieldEmbeddingLookup(vocab_size, embedding_size, field_size, param_init='normal', target='CPU', slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM')[source]

Returns a slice of input tensor based on the specified indices and the field ids. This operation supports looking up embeddings using multi hot and one hot fields simultaneously.

Note

When ‘target’ is set to ‘CPU’, this module will use P.EmbeddingLookup().add_prim_attr(‘primitive_target’, ‘CPU’) which specified ‘offset = 0’ to lookup table. When ‘target’ is set to ‘DEVICE’, this module will use P.Gather() which specified ‘axis = 0’ to lookup table. The vectors with the same field_ids will be combined by the ‘operator’, such as ‘SUM’, ‘MAX’ and ‘MEAN’. Ensure the input_values of the padded id is zero, so that they can be ignored. The final output will be zeros if the sum of absolute weight of the field is zero. This class only supports [‘table_row_slice’, ‘batch_slice’ and ‘table_column_slice’]. For the operation ‘MAX’ on device Ascend, there is a constrain where batch_size * (seq_length + field_size) < 3500.

Parameters
  • vocab_size (int) – The size of the dictionary of embeddings.

  • embedding_size (int) – The size of each embedding vector.

  • field_size (int) – The field size of the final outputs.

  • param_init (Union[Tensor, str, Initializer, numbers.Number]) – Initializer for the embedding_table. Refer to class initializer for the values of string when a string is specified. Default: ‘normal’.

  • target (str) – Specifies the target where the op is executed. The value must in [‘DEVICE’, ‘CPU’]. Default: ‘CPU’.

  • slice_mode (str) – The slicing way in semi_auto_parallel/auto_parallel. The value must get through nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.

  • feature_num_list (tuple) – The accompaniment array in field slice mode. This is unused currently.

  • max_norm (Union[float, None]) – A maximum clipping value. The data type must be float16, float32 or None. Default: None

  • sparse (bool) – Using sparse mode. When ‘target’ is set to ‘CPU’, ‘sparse’ has to be true. Default: True.

  • operator (str) – The pooling method for the features in one field. Support ‘SUM, ‘MEAN’ and ‘MAX’

Inputs:
  • input_indices (Tensor) - The shape of tensor is \((batch\_size, seq\_length)\). Specifies the indices of elements of the original Tensor. Input_indices must be a 2d tensor in this interface. Type is Int32, Int64.

  • input_values (Tensor) - The shape of tensor is \((batch\_size, seq\_length)\). Specifies the weights of elements of the input_indices. The lookout vector will multiply with the input_values. Type is Float32.

  • field_ids (Tensor) - The shape of tensor is \((batch\_size, seq\_length)\). Specifies the field id of elements of the input_indices. Type is Int32.

Outputs:

Tensor, the shape of tensor is \((batch\_size, field\_size, embedding\_size)\). Type is Float32.

Raises
  • TypeError – If vocab_size or embedding_size or field_size is not an int.

  • TypeError – If sparse is not a bool or feature_num_list is not a tuple.

  • ValueError – If vocab_size or embedding_size or field_size is less than 1.

  • ValueError – If target is neither ‘CPU’ nor ‘DEVICE’.

  • ValueError – If slice_mode is not one of ‘batch_slice’, ‘field_slice’, ‘table_row_slice’, ‘table_column_slice’.

  • ValueError – If sparse is False and target is ‘CPU’.

  • ValueError – If slice_mode is ‘field_slice’ and feature_num_list is None.

  • ValueError – If operator is not one of ‘SUM’, ‘MAX’, ‘MEAN’.

Supported Platforms:

Ascend GPU

Examples

>>> input_indices = Tensor([[2, 4, 6, 0, 0], [1, 3, 5, 0, 0]], mindspore.int32)
>>> input_values = Tensor([[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]], mindspore.float32)
>>> field_ids = Tensor([[0, 1, 1, 0, 0], [0, 0, 1, 0, 0]], mindspore.int32)
>>> net = nn.MultiFieldEmbeddingLookup(10, 2, field_size=2, operator='SUM', target='DEVICE')
>>> out = net(input_indices, input_values, field_ids)
>>> print(out.shape)
(2, 2, 2)