mindspore.ops.speed_fusion_attention

View Source On Gitee
mindspore.ops.speed_fusion_attention(query, key, value, head_num, input_layout, *, pse=None, padding_mask=None, atten_mask=None, scale=1.0, keep_prob=1.0, pre_tokens=2147483647, next_tokens=2147483647, inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False, pse_type=1, q_start_idx=None, kv_start_idx=None)[source]

The interface is used for self-attention fusion computing. If pse_type is 1 , calculation formula is:

\[attention\_out = Dropout(Softmax(Mask(scale * (pse + query * key^{T}), atten\_mask)), keep\_prob) * value\]

If pse_type is other valid value, calculation formula is:

\[attention\_out = Dropout(Softmax(Mask(scale * (query * key^{T}) + pse, atten\_mask)), keep\_prob) * value\]
  • B: Batch size. Value range 1 to 2k.

  • S1: Sequence length of query. Value range 1 to 512k.

  • S2: Sequence length of key and value. Value range 1 to 512k.

  • N1: Num heads of query. Value range 1 to 256.

  • N2: Num heads of key and value, and N2 must be a factor of N1.

  • D: Head size. The value ranges is a multiple of 16, with the max value of 512.

  • H1: Hidden size of query, which equals to N1 * D.

  • H2: Hidden size of key and value, which equals to N2 * D.

Warning

  • This is an experimental API that is subject to change or deletion.

  • Only support on Atlas A2 training series.

Note

This interface is not supported in graph mode (mode=mindspore.GRAPH_MODE).

Parameters
  • query (Tensor) – The query tensor. Input tensor of shape \((B, S1, H1)\), \((B, N1, S1, D)\), \((S1, B, H1)\), \((B, S1, N1, D)\) or \((T1, N1, D)\).

  • key (Tensor) – The key tensor. Input tensor of shape \((B, S2, H2)\), \((B, N2, S2, D)\), \((S2, B, H2)\), \((B, S2, N2, D)\) or \((T2, N2, D)\).

  • value (Tensor) – The value tensor. Input tensor of shape \((B, S2, H2)\), \((B, N2, S2, D)\), \((S2, B, H2)\), \((B, S2, N2, D)\) or \((T2, N2, D)\). The key and value should have the same shape.

  • head_num (int) – The head num of query, equal to N1.

  • input_layout (str) –

    Specifies the layout of input query, key and value. The value can be "BSH" , "BNSD" , "SBH" , "BSND" or "TND" . "TND" is an experimental format. When input_layout is "TND" , the following restrictions must be met. There are two lists that represent the length of the input sequence: list_seq_q and list_seq_k. Each value in the list indicates the length of the sequence in the batch. For example, list_seq_q = [4, 2, 6], list_seq_k = [10, 3, 9]. The element of list indicate S. T1 is sum(list_seq_q) = 12, T2 is sum(list_seq_k) = 22. max_seqlen_q = max(list_seq_q), max_seqlen_k = max(list_seq_k). qk_pointer = sum(list_seq_q * list_seq_k), which is the sum of the element multiplication.

    • The lengths of two lists are the same, and size of list is batch. batch is less than or equal to 1024.

    • When input_layout is "TND" , actual_seq_qlen and actual_seq_kvlen must be not None . Otherwise, they are None .

    • The actual_seq_qlen and actual_seq_kvlen are the cumulative sum of sequence of key/value, so they must be non-decreasing.

    • If pse is not None , list_seq_q and list_seq_k must be same. The maximum value of list_seq_q and list_seq_k is greater than 1024. pse should be \((B, N1, 1024, S2)\) and \((1, N1, 1024, S2)\), and S2 is equal to max_seqlen_k.

    • atten_mask must be a lower trianglar matrix, so sparse_mode should be 2 or 3. The shape of atten_mask should be \((2048, 2048)\).

    • Prefix is None .

    • next_tokens is 0, and pre_tokens is not less than max_seqlen_q.

    • When sparse_mode is 3, S1 of each batch should be less than or equal to S2.

    • 0 should not exist in list_seq_k.

Keyword Arguments
  • pse (Tensor, optional) –

    The position embedding code, dtype is same as query. Default: None . If S is greater than 1024 and the mask of the lower triangle is used, enter only the inverse 1024 lines of the lower triangle for memory optimization. Input tensor of shape \((B, N1, S1, S2)\), \((1, N1, S1, S2)\), \((B, N1, 1024, S2)\), \((1, N1, 1024, S2)\).

    • ALiBi scenario: pse must meet the ALiBi rule, and sparse_mode is 2 or 3 for the lower triangle. In this scenario, pse is \((B, N1, 1024, S2)\), \((1, N1, 1024, S2)\).

    • Non-ALiBi scenario: pse is \((B, N1, S1, S2)\), \((1, N1, S1, S2)\).

    • The shape of pse should be \((B, N1, 1024, S2)\) and \((1, N1, 1024, S2)\) when input_layout is "TND" .

    • If pse_type is 2 or 3, dtype of pse must be float32, and shape of pse should be \((B, N1)\) or \((N1,)\).

  • padding_mask (Tensor, optional) – Reserved parameter. Not implemented yet. Default: None .

  • atten_mask (Tensor, optional) –

    The attention mask tensor. For each element, 0/False indicates retention and 1/True indicates discard. Input tensor of shape \((B, N1, S1, S2)\), \((B, 1, S1, S2)\), \((S1, S2)\) or \((2048, 2048)\). Default: None .

    • In compression scenario, sparse_mode is 2, 3, or 4, atten_mask must be \((2048, 2048)\).

    • When sparse_mode is 5, atten_mask must be \((B, N1, S1, S2)\), \((B, 1, S1, S2)\).

    • When sparse_mode is 0 and 1, atten_mask should be \((B, N1, S1, S2)\), \((B, 1, S1, S2)\), \((S1, S2)\).

  • scale (float, optional) – The scale factor of score. Generally, the value is 1.0 / (D ** 0.5). Default: 1.0 .

  • keep_prob (float, optional) – The keep probability of dropout. Value range is (0.0, 1.0]. Default: 1.0 .

  • pre_tokens (int, optional) – Parameter for sparse computation, represents how many tokens are counted forward. When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647 .

  • next_tokens (int, optional) –

    Parameter for sparse computation, represents how many tokens are counted backward. When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647 . The value of pre_tokens corresponds to S1, and the value of next_tokens corresponds to S2. They define the valid area on the atten_mask matrix. It must ensure that the band is not empty. The following values are not allowed:

    • pre_tokens < 0 and next_tokens < 0.

    • (pre_tokens < 0 and next_tokens >= 0) and (next_tokens < abs(pre_tokens) or abs(pre_tokens) >= S2).

    • (pre_tokens >= 0 and next_tokens < 0) and (abs(next_tokens) > pre_tokens or abs(next_tokens) >= S1).

  • inner_precise (int, optional) – The parameter is reserved and not implemented yet. Default: 0 .

  • prefix (Union[tuple[int], list[int]], optional) – N value of each Batch in the prefix sparse calculation scenario. Input tensor of shape \((B,)\). B max value 32. Not none only when sparse_mode is 5. If S1 > S2, N ranges from 0 to S2. If S1 <= S2, N ranges from S2 - S1 to S2. Default: None .

  • actual_seq_qlen (Union[tuple[int], list[int]], optional) – Size of query corresponding to each batch, array with increasing values and the last value equal to T1. Default: None .

  • actual_seq_kvlen (Union[tuple[int], list[int]], optional) – Size of key and value corresponding to each batch, array with increasing values and the last value equal to T2. Default: None .

  • sparse_mode (int, optional) –

    Indicates sparse mode. Default 0 .

    • 0: Indicates the defaultMask mode. If atten_mask is not passed, the mask operation is not performed, and preTokens and nextTokens(internally assigned as INT_MAX) are ignored. If passed in, the full atten_mask matrix (S1 * S2) needs to be passed in, indicating that the part between preTokens and nextTokens needs to be calculated.

    • 1: Represents allMask, that is, passing in the complete atten_mask matrix.

    • 2: Representing the leftUpCausal mode corresponds to the lower triangle scenario divided by the left vertex, and the optimized atten_mask matrix (2048*2048) is required.

    • 3: Representing the rightDownCausal model corresponds to the lower triangle scene divided by the lower right vertex, and the optimized atten_mask matrix (2048*2048) is required.

    • 4: Represents the band scenario, that is, the part between counting preTokens and nextTokens, and the optimized atten_mask matrix (2048*2048) is required.

    • 5: Represents the prefix scenario, that is, on the basis of rightDownCasual, a matrix with length S1 and width N is added to the left side. The value of N is obtained by the new input prefix, and the N value of each Batch axis is different. Currently not enabled.

    • 6: Represents the global scenario. Currently not enabled.

    • 7: Represents the dilated scenario. Currently not enabled.

    • 8: Represents the block_local scenario. Currently not enabled.

  • gen_mask_parallel (bool, optional) – Debug parameter, a switch to control dropout_gen_mask execution method. If True , dropout_gen_mask is executed in parallel. If False , execution is serial. Not implemented yet. Default: True .

  • sync (bool, optional) – Debug parameter, a switch to control dropout_gen_mask execution method. If True , dropout_gen_mask is executed synchronously. If False , execution is asynchronous. Not implemented yet. Default: False .

  • pse_type (int, optional) –

    Indicates how to use pse. Default 1 .

    • 0: pse is passed from outside, and the calculation process is to first mul scale and then add pse.

    • 1: pse is passed from outside, and the calculation process is to add pse first and then mul scale.

    • 2: pse is generated internally and generates standard alibi position information. The internally generated alibi matrix 0 line is aligned with the upper left corner of \(query * key^{T}\).

    • 3: pse is generated internally, and the generated alibi position information is based on the standard and then the square root of sqrt is done. The internally generated alibi matrix 0 line is aligned with the upper left corner of \(query * key^{T}\).

  • q_start_idx (Union[tuple[int], list[int]], optional) – Int array with length 1. Default: None . When pse_type is configured as 2 or 3 , it indicates the number of cells that the internally generated alibi code is offset in the S1 direction. A positive number indicates that 0 moves diagonally upward.

  • kv_start_idx (Union[tuple[int], list[int]], optional) – Int array with length 1. Default: None . When pse_type is configured as 2 or 3 , it indicates the number of cells that the internally generated alibi code is offset in the S2 direction. A positive number indicates that 0 moves diagonally upward.

Returns

A tuple of tensors containing attention_out, softmax_max, softmax_sum, softmax_out, seed, offset and numels .

  • attention_out is the output of attention, it's shape, and data type are the same as the query.

  • softmax_max is the max intermediate result calculated by Softmax, used for grad calculation.

  • softmax_sum is the sum intermediate result calculated by Softmax, used for grad calculation.

  • softmax_out is a reserved parameter.

  • seed is generated seed, used for Dropout.

  • offset is generated offset, used for Dropout.

  • numels is the length of generated dropout_mask.

Raises
  • TypeErrorquery, key and value don't have the same dtype.

  • TypeError – Dtype of atten_mask is not bool or uint8.

  • TypeErrorscale or keep_prob is not a float number.

  • TypeErrorinput_layout is not a string.

  • TypeErrorhead_num is not an int.

  • TypeErrorsparse_mode is not an int.

  • TypeErrorpse is not Tensor type.

  • TypeErrorpadding_mask is not Tensor type.

  • TypeErroratten_mask is not Tensor type.

  • TypeErrorpse_type is not an int.

  • ValueErrorinput_layout is a string but not valid.

  • ValueError – The specified value of sparse_mode is invalid.

  • ValueError – The specified value of pse_type is invalid.

Supported Platforms:

Ascend

Examples

>>> import mindspore
>>> import mindspore.common.dtype as mstype
>>> import numpy as np
>>> from mindspore import ops, Tensor
>>> query = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
>>> key = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
>>> value = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
>>> head_num = 4
>>> input_layout = "BSH"
>>> output = ops.speed_fusion_attention(query, key, value, head_num, input_layout)
>>> print(output[0].shape)
(2, 4, 64)