mindspore.ops.flash_attention_score
- mindspore.ops.flash_attention_score(query, key, value, head_num, real_shift=None, drop_mask=None, padding_mask=None, attn_mask=None, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, keep_prob=1.0, scalar_value=1.0, pre_tokens=2147483647, next_tokens=2147483647, inner_precise=0, input_layout='BSH', sparse_mode=0)[source]
Implement self-attention calculations in training scenarios.
B: Batch size. Value range 1 to 2k.
S1: Sequence length of query. Value range 1 to 512k.
S2: Sequence length of key and value. Value range 1 to 512k.
N1: Num heads of query. Value range 1 to 256.
N2: Num heads of key and value, and N2 must be a factor of N1.
D: Head size. The value ranges is a multiple of 16, with the max value of 512.
H1: Hidden size of query, which equals to N1 * D.
H2: Hidden size of key and value, which equals to N2 * D.
The self attention calculation formula is defined as:
\[\begin{split}\begin{array}{ll} \\ \text { attention_out }=\operatorname{Dropout}\left(\operatorname{Softmax}\left(\text { Mask(scale } *\left(\text { query } * \mathrm{key}^{\top}\right)+\text { pse }\right)\text {, atten_mask), keep_prob) } *\right. \text { value } \end{array}\end{split}\]Warning
This is an experimental API that is subject to change or deletion.
Only support on Atlas A2 training series.
- Parameters
query (Tensor) – The query tensor. Input tensor of shape \((B, S1, H1)\), \((B, N1, S1, D)\), \((S1, B, H1)\), \((B, S1, N1, D)\) or \((T1, N1, D)\). The supported dtype is float16 and bfloat16.
key (Tensor) – The key tensor with the same dtype as query. Supported shape: \((B, S2, H2)\), \((B, N2, S2, D)\), \((S2, B, H2)\), \((B, S2, N2, D)\) or \((T2, N2, D)\).
value (Tensor) – The value tensor with the same dtype and shape as key.
head_num (int) – The head num of query, equal to N1.
real_shift (Tensor, optional) –
The position embedding code which is also known as pse, it has the same dtype as query. Default:
None
. If S is greater than 1024 and the mask of the lower triangle is used, only the inverse 1024 lines of the lower triangle is used for memory optimization. Input tensor of shape \((B, N1, S1, S2)\), \((1, N1, S1, S2)\), \((B, N1, 1024, S2)\), \((1, N1, 1024, S2)\).ALiBi scenario: real_shift must meet the ALiBi rule, and sparse_mode is 2 or 3 for the lower triangle. In this scenario, real_shift is \((B, N1, 1024, S2)\), \((1, N1, 1024, S2)\).
Non-ALiBi scenario: real_shift is \((B, N1, S1, S2)\), \((1, N1, S1, S2)\).
input_layout is TND: shape should be \((B, N1, 1024, S2)\) and \((1, N1, 1024, S2)\).
drop_mask (Tensor, optional) – The dropout mask tensor of uint8. Input tensor of shape \((B, N1, S1, S2 // 8) or None\). S2 is a multiple of 8 when not None. Default:
None
.padding_mask (Tensor, optional) – Reserved parameter. Not implemented yet. Default:
None
.attn_mask (Tensor, optional) –
The attention mask tensor of bool or uint8. For each element, 0/False indicates retention and 1/True indicates discard. Input tensor of shape \((B, N1, S1, S2)\), \((B, 1, S1, S2)\), \((S1, S2)\) or \((2048, 2048)\). Default:
None
.In compression scenario, sparse_mode is 2, 3, or 4, attn_mask must be \((2048, 2048)\).
When sparse_mode is 5, attn_mask should be \((B, N1, S1, S2)\), \((B, 1, S1, S2)\).
When sparse_mode is 0 and 1, attn_mask should be \((B, N1, S1, S2)\), \((B, 1, S1, S2)\), \((S1, S2)\).
prefix (Union[Tensor, tuple[int], list[int]], optional) – N value of each Batch in the prefix sparse calculation scenario. Input tensor of shape \((B,)\). B max value 32. Not none only when sparse_mode is 5. Default:
None
. If S1 > S2, N ranges from 0 to S2. If S1 <= S2, N ranges from S2 - S1 to S2.actual_seq_qlen (Union[Tensor, tuple[int], list[int]], optional) – Size of query corresponding to each batch, array with increasing values and the last value equal to T1. Default:
None
.actual_seq_kvlen (Union[Tensor, tuple[int], list[int]], optional) – Size of key and value corresponding to each batch, array with increasing values and the last value equal to T2. Default:
None
.keep_prob (double, optional) – The keep probability of dropout. Value range is (0.0, 1.0]. When keep_prob is 1.0, drop_mask should be None. Default:
1.0
.scale_value (double, optional) – The scale factor of score. Generally, the value is 1.0 / (D ** 0.5). Default:
1.0
.pre_tokens (int, optional) – Parameter for sparse computation, represents how many tokens are counted forward. When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default:
2147483647
.next_tokens (int, optional) –
Parameter for sparse computation, represents how many tokens are counted backward. When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default:
2147483647
. The value of pre_tokens corresponds to S1, and the value of next_tokens corresponds to S2. They define the valid area on the attn_mask matrix. It must ensure that the band is not empty. The following values are not allowed:pre_tokens < 0 and next_tokens < 0.
(pre_tokens < 0 and next_tokens >= 0) and (next_tokens < abs(pre_tokens) or abs(pre_tokens) >= S2).
(pre_tokens >= 0 and next_tokens < 0) and (abs(next_tokens) > pre_tokens or abs(next_tokens) >= S1).
inner_precise (int, optional) – The parameter is reserved and not implemented yet. Default:
0
.input_layout (str, optional) –
Specifies the layout of input query, key and value. The value can be "BSH", "BNSD", "SBH", "BSND" or "TND". "TND" is an experimental format. Default:
"BSH"
. When input_layout is "TND", the following restrictions must be met. Assume there are two lists that represent the length of the input sequence: list_seq_q and list_seq_k. Each value in the list indicates the length of the sequence in the batch. For example, list_seq_q = [4, 2, 6], list_seq_k = [10, 3, 9]. The element of list indicate S. T1 is sum(list_seq_q) = 12, T2 is sum(list_seq_k) = 22. max_seqlen_q = max(list_seq_q), max_seqlen_k = max(list_seq_k). qk_pointer = sum(list_seq_q * list_seq_k), which is the sum of the element multiplication.The lengths of two lists must be the same, and size of list is batch. batch is less than or equal to 1024.
When input_layout is "TND", actual_seq_qlen and actual_seq_kvlen must be not none. Otherwise, they are none.
The actual_seq_qlen and actual_seq_kvlen are the cumulative sum of sequence of key/value, so they must be non-decreasing.
If real_shift is not none, list_seq_q and list_seq_k must be same. The maximum value of list_seq_q and list_seq_k is greater than 1024. real_shift should be \((B, N1, 1024, S2)\) and \((1, N1, 1024, S2)\), and S2 is equal to max_seqlen_k.
attn_mask must be a lower trianglar matrix, so sparse_mode should be 2 or 3. The shape of attn_mask should be \((2048, 2048)\).
The shape of drop_mask is \((qk\_pointer * N1 // 8,)\).
prefix is none.
next_tokens is 0, and pre_tokens is not less than max_seqlen_q.
When sparse_mode is 3, S1 of each batch should be less than or equal to S2.
0 should not exist in list_seq_k.
sparse_mode (int, optional) –
Indicates sparse mode. Default:
0
.0: Indicates the defaultMask mode. If attn_mask is not passed, the mask operation is not performed, next_tokens and pre_tokens (internally assigned as INT_MAX) are ignored. If passed in, the full attn_mask matrix (S1 * S2) needs to be passed in, indicating that the part between next_tokens and pre_tokens needs to be calculated.
1: Represents allMask, that is, passing in the complete attn_mask matrix.
2: Representing the leftUpCausal mode corresponds to the lower triangle scenario divided by the left vertex, and the optimized attn_mask matrix (2048*2048) is required.
3: Representing the rightDownCausal model corresponds to the lower triangle scene divided by the lower right vertex, and the optimized attn_mask matrix (2048*2048) is required.
4: Represents the band scenario, that is, the part between counting next_tokens and pre_tokens, and the optimized attn_mask matrix (2048*2048) is required.
5: Represents the prefix scenario, that is, on the basis of rightDownCasual, a matrix with length S1 and width N is added to the left side. The value of N is obtained by the new input prefix, and the N value of each Batch axis is different, not implemented yet.
6: Represents the global scenario, not implemented yet.
7: Represents the dilated scenario, not implemented yet.
8: Represents the block_local scenario, not implemented yet.
- Returns
attention_out (Tensor) - The output of attention, it has the same shape and dtype as query.
- Raises
TypeError – Dtype of query is not float16 or bfloat16.
TypeError – query, key and value don't have the same dtype.
TypeError – Dtype of attn_mask is not bool or uint8.
TypeError – Dtype of real_shift has a different dtype as query.
TypeError – scale_value or keep_prob is not a double number.
TypeError – input_layout is not a string.
TypeError – num_key_value_heads is not an int.
TypeError – sparse_mode is not an int.
TypeError – real_shift is not Tensor type.
TypeError – drop_mask is not Tensor type.
TypeError – padding_mask is not Tensor type.
TypeError – attn_mask is not Tensor type.
ValueError – input_layout is a string but not valid.
RuntimeError – head_num is not divisible by N2.
RuntimeError – head_num is not greater than 0.
RuntimeError – attn_mask shape is not valid.
RuntimeError – The specified value of sparse_mode is invalid.
RuntimeError – D-axis of query, key and value is not the same.
- Supported Platforms:
Ascend
Examples
>>> import mindspore >>> import mindspore.common.dtype as mstype >>> import numpy as np >>> from mindspore import ops, Tensor >>> query = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16) >>> key = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16) >>> value = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16) >>> head_num = 4 >>> output = ops.flash_attention_score(query, key, value, head_num) >>> print(output.shape) (2, 4, 64)