mindspore.ops.fused_infer_attention_score
- mindspore.ops.fused_infer_attention_score(query, key, value, *, pse_shift=None, atten_mask=None, actual_seq_lengths=None, actual_seq_lengths_kv=None, dequant_scale1=None, quant_scale1=None, dequant_scale2=None, quant_scale2=None, quant_offset2=None, antiquant_scale=None, antiquant_offset=None, key_antiquant_scale=None, key_antiquant_offset=None, value_antiquant_scale=None, value_antiquant_offset=None, block_table=None, query_padding_size=None, kv_padding_size=None, key_shared_prefix=None, value_shared_prefix=None, actual_shared_prefix_len=None, num_heads=1, scale=1.0, pre_tokens=2147483647, next_tokens=2147483647, input_layout='BSH', num_key_value_heads=0, sparse_mode=0, inner_precise=1, block_size=0, antiquant_mode=0, key_antiquant_mode=0, value_antiquant_mode=0, softmax_lse_flag=False)[source]
This is a FlashAttention function designed for both incremental and full inference scenarios. It supports full inference scenarios (PromptFlashAttention) as well as incremental inference scenarios (IncreFlashAttention). When the S dimension of the query tensor (Q_S) equals 1, it enters the IncreFlashAttention branch; otherwise, it enters the PromptFlashAttention branch.
\[Attention(Q,K,V) = Softmax(\frac{QK^{T}}{\sqrt{d}})V\]Warning
This is an experimental API that is subject to change or deletion.
For Ascend, only the Atlas A2 training series products and Atlas 800I A2 inference products are currently supported.
Note
The data layout formats of query, key and value can be interpreted from multiple dimensions, as shown below:
B, Batch size. Represents the batch size of the input samples.
S, Sequence length. Represents the sequence length of the input samples. S1 represents the sequence length of the query, and S2 represents the sequence length of the key/value.
H, Head size. Represents the size of the hidden layer.
N, Head nums. Represents the number of attention heads.
D, Head dims. Represents the smallest unit size of the hidden layer, satisfying \(D = H / N\).
- Parameters
query (Tensor) – The query input of the attention structure, with data type of float16, bfloat16 or int8. Input tensor of shape \((B, S, H)\), \((B, N, S, D)\), or \((B, S, N, D)\).
key (Union[Tensor, tuple[Tensor], list[Tensor]]) – The key input of the attention structure, with data type of float16, bfloat16 or int8. Input tensor of shape \((B, S, H)\), \((B, N, S, D)\), or \((B, S, N, D)\).
value (Union[Tensor, tuple[Tensor], list[Tensor]]) – The value input of the attention structure, with data type of float16, bfloat16 or int8. Input tensor of shape \((B, S, H)\), \((B, N, S, D)\), or \((B, S, N, D)\).
- Keyword Arguments
pse_shift (Tensor, optional) –
The padding mask tensor with data type of float16 or bfloat16. Default:
None
.When Q_S is not 1, if pse_shift is of type float16, the query must be of type float16 or int8. If pse_shift is of type bfloat16, the query must also be of type bfloat16. The input shape must be either \((B, N, Q\_S, KV\_S)\) or \((1, N, Q\_S, KV\_S)\), where Q_S corresponds to the S dimension of the query shape, and KV_S corresponds to the S dimension of the key and value shapes. For scenarios where the KV_S of pse_shift is not 32-aligned, it is recommended to pad it to 32 bytes to improve performance. The padding values for the extra portions are not restricted.
When Q_S is 1, if pse_shift is of type float16, the query must also be of type float16. If pse_shift is of type bfloat16, the query must be of type bfloat16. The input shape must be \((B, N, 1, KV\_S)\) or \((1, N, 1, KV\_S)\), where KV_S corresponds to the S dimension of the key/value shapes. For scenarios where the KV_S of pse_shift is not 32-aligned, it is recommended to pad it to 32 bytes to improve performance. The padding values for the extra portions are not restricted.
atten_mask (Tensor, optional) –
The attention mask tensor for the result of query*key with data type of int8, uint8 or bool. For each element, 0 indicates retention and 1 indicates discard. Default:
None
.When Q_S is not 1, the recommended input shapes are Q_S,KV_S; B,Q_S,KV_S; 1,Q_S,KV_S; B,1,Q_S,KV_S or 1,1,Q_S,KV_S.
When Q_S is 1, the recommended input shapes are B,KV_S; B,1,KV_S or B,1,1,KV_S.
actual_seq_lengths (Union[tuple[int], list[int], Tensor], optional) – Describe actual sequence length of the query with data type of int64. If this parameter is not specified, it can be set to None, indicating that it matches the S dimension of the query shape. Constraint: The effective sequence length for each batch in this parameter should not exceed the corresponding batch's sequence length in the query. When Q_S is 1, this parameter is ignored. Default:
None
.actual_seq_lengths_kv (Union[tuple[int], list[int], Tensor], optional) – Describe actual sequence length of the key and value with data type of int64. If this parameter is not specified, it can be set to None, indicating that it matches the S dimension of the key and value shape. Constraint: The effective sequence length for each batch in this parameter should not exceed the corresponding batch's sequence length in the key and value. Default:
None
.dequant_scale1 (Tensor, optional) – Quantization factors for inverse quantization after BMM1 with data type of uint64. Supports per-tensor mode. If not used, set it to None. Default:
None
.quant_scale1 (Tensor, optional) – Quantization factors for quantization before BMM2 with data type of float32. Supports per-tensor mode. If not used, set it to None. Default:
None
.dequant_scale2 (Tensor, optional) – Quantization factors for inverse quantization after BMM2 with data type of uint64. Supports per-tensor mode. If not used, set it to None. Default:
None
.quant_scale2 (Tensor, optional) – Quantization factors for output quantization with data type of float32, bfloat16. Supports per-tensor and per-channel modes. If not used, set it to None. Default:
None
.quant_offset2 (Tensor, optional) –
Quantization offset for output quantization with data type of float32, bfloat16. Supports per-tensor and per-channel modes. If not used, set it to None. Default:
None
.For scenarios where the input is int8 and the output is int8: the parameters dequant_scale1, quant_scale1, dequant_scale2, and quant_scale2 must all be provided. The parameter quant_offset2 is optional and defaults to 0 if not specified.
When the output is int8 and quant_scale2 and quant_offset2 are per-channel, left padding, Ring Attention, or D-axis misalignment (not 32-aligned) scenarios are not supported.
When the output is int8, scenarios with sparse_mode as band and pre_tokens/next_tokens being negative are not supported.
When the output is int8, if quant_offset2 is not None and empty tensor, and the sparse_mode, pre_tokens, and next_tokens meet the following conditions, certain rows of the matrix may not participate in calculations, leading to errors. This scenario will be intercepted (solution: if this scenario should not be intercepted, quantization should be performed outside the FIA interface, not enabled inside the FIA interface):
sparse_mode = 0, if atten_mask is a not None and each batch's actual_seq_lengths - actual_seq_lengths_kv - pre_tokens > 0 or next_tokens < 0, it will meet the interception condition.
sparse_mode = 1 or 2, no interception condition will occur.
sparse_mode = 3, if each batch's actual_seq_lengths - actual_seq_lengths_kv < 0, it will meet the interception condition.
sparse_mode = 4, if pre_tokens < 0 or each batch's next_tokens + actual_seq_lengths - actual_seq_lengths_kv < 0, it will meet the interception condition.
For scenarios where the input is int8 and the output is float16: the parameters dequant_scale1, quant_scale1, and dequant_scale2 must all be provided.
For scenarios where the input is entirely float16 or bfloat16 and the output is int8: the parameter quant_scale2 must be provided. The parameter quant_offset2 is optional and defaults to 0 if not specified.
The parameters quant_scale2 and quant_offset2 support both per-tensor and per-channel modes and two data types: float32 and bfloat16. If quant_offset2 is provided, its type and shape must match those of quant_scale2. When the input is bfloat16, both float32 and bfloat16 are supported; otherwise, only float32 is supported. For per-channel mode: When the output layout is BSH, the product of all dimensions in quant_scale2 must equal H. For other layouts, the product must equal N * D. When the output layout is BSH, it is recommended to set the shape of quant_scale2 as \((1, 1, H)\) or \((H)\). When the output layout is BNSD, it is recommended to set the shape as \((1, N, 1, D)\) or \((N, D)\). When the output layout is BSND, it is recommended to set the shape as \((1, 1, N, D)\) or \((N, D)\).
antiquant_scale (Tensor, optional) – Inverse quantization factors with data type of float16, float32 or bfloat16. Only support float16 when Q_S > 1. Supports per-tensor, per-channel and per-token modes. Default:
None
.antiquant_offset (Tensor, optional) –
Inverse quantization offset with data type of float16, float32 or bfloat16. Only support float16 when Q_S > 1. Supports per-tensor, per-channel and per-token modes. Default:
None
.Constraints for antiquant_scale and antiquant_offset parameters:
Supports three modes: per-channel, per-tensor, and per-token:
Per-channel mode: The shape of both parameters in the BNSD layout is \((2, N, 1, D)\), the shape in the BSND layout is \((2, N, D)\), and the shape in the BSH layout is \((2, H)\), where 2 corresponds to the key and value, and N represents num_key_value_heads. The parameter data type is the same as the query data type, and antiquant_mode should be set to 0.
Per-tensor mode: The shape of both parameters is \((2)\), the data type is the same as the query data type, and antiquant_mode should be set to 0.
Per-token mode: The shape of both parameters is \((2, B, S)\), the data type is fixed to float32, and antiquant_mode should be set to 1.
Supports both symmetric and asymmetric quantization:
Asymmetric quantization mode: Both antiquant_scale and antiquant_offset must be provided.
Symmetric quantization mode: antiquant_offset can be empty (
None
). If antiquant_offset is empty, symmetric quantization is performed. If antiquant_offset is provided, asymmetric quantization is performed.
key_antiquant_scale (Tensor, optional) – Inverse quantization factors for the key, with data type of float16, float32 or bfloat16, when the KV fake quantization parameters are separated. Supports per-tensor, per-channel and per-token modes. Default:
None
. Invalid when Q_S > 1.key_antiquant_offset (Tensor, optional) – Inverse quantization offset for the key, with data type of float16, float32 or bfloat16, when the KV fake quantization parameters are separated. Supports per-tensor, per-channel and per-token modes. Default:
None
. Invalid when Q_S > 1.value_antiquant_scale (Tensor, optional) – Inverse quantization factors for the value, with data type of float16, float32 or bfloat16, when the KV fake quantization parameters are separated. Supports per-tensor, per-channel and per-token modes. Default:
None
. Invalid when Q_S > 1.value_antiquant_offset (Tensor, optional) – Inverse quantization offset for the value, with data type of float16, float32 or bfloat16, when the KV fake quantization parameters are separated. Supports per-tensor, per-channel and per-token modes. Default:
None
. Invalid when Q_S > 1.block_table (Tensor, optional) – Block mapping table in KV cache for PageAttention, with data type of int32. If not used, set it to None. Default:
None
. Invalid when Q_S > 1.query_padding_size (Tensor, optional) – The query padding size with data type of int64. Indicates whether the data in each batch of the query is right-aligned, and how many elements are right-aligned. Default:
None
. Invalid when Q_S is 1.kv_padding_size (Tensor, optional) – The key and value padding size with data type of int64. Indicates whether the data in each batch of the key and value is right-aligned, and how many elements are right-aligned. Default:
None
. Invalid when Q_S is 1.key_shared_prefix (Tensor, optional) – Shared prefix of the key. This is a reserved parameter and is not yet enabled. Default:
None
.value_shared_prefix (Tensor, optional) – Shared prefix of the value. This is a reserved parameter and is not yet enabled. Default:
None
.actual_shared_prefix_len (Union[tuple[int], list[int], Tensor], optional) – Describe the actual length of shared prefix. This is a reserved parameter and is not yet enabled. Default:
None
.num_heads (int, optional) – The number of heads in the query, equal to N when input_layout is BNSD. Default:
1
.scale (double, optional) – The scale value indicating the scale coefficient, which serves as the scalar value for the Muls in the calculation. Generally, the value is \(1.0 / \sqrt{d}\). Default:
1.0
.pre_tokens (int, optional) – Parameter for sparse computation, represents how many tokens are counted forward. Default:
2147483647
. Invalid when Q_S is 1.next_tokens (int, optional) – Parameter for sparse computation, represents how many tokens are counted backward. Default:
2147483647
. Invalid when Q_S is 1.input_layout (str, optional) – Specifies the layout of input query, key and value. BSH, BNSD, BSND or BNSD_BSND is supported. When the layout is BNSD_BSND, it means the input is in the BNSD format and the output is in the BSND format, this is only supported when Q_S > 1. Default:
BSH
.num_key_value_heads (int, optional) – Head numbers of key/value which are used in GQA (Grouped-Query Attention) scenario. Default:
0
. A value of 0 means it is equal to the number of key/value heads. The num_heads must be divisible by num_key_value_heads, and the ratio of num_heads to num_key_value_heads must not be greater than 64. When the layout is BNSD, the num_key_value_heads must also equals to the N dimension of the key/value shapes, otherwise, an execution error will occur.sparse_mode (int, optional) –
Indicates sparse mode. Default
0
. Invalid when Q_S is 1.0: Indicates the defaultMask mode. If atten_mask is not passed, the mask operation is not performed, and pre_tokens and next_tokens(internally assigned as INT_MAX) are ignored. If passed in, the complete atten_mask matrix (S1 * S2) also must be passed in, indicating that the part between pre_tokens and next_tokens needs to be calculated.
1: Represents allMask. The complete atten_mask matrix (S1 * S2) is required.
2: Represents the mask in leftUpCausal mode. The optimized atten_mask matrix (2048*2048) is required.
3: Represents the mask in rightDownCausal mode, corresponding to the lower triangular scenario divided by the right vertex. The optimized atten_mask matrix (2048*2048) is required.
4: Represents the mask in band mode, that is, the part between counting pre_tokens and next_tokens. The optimized atten_mask matrix (2048*2048) is required.
5: Represents the prefix scenario, not implemented yet.
6: Represents the global scenario, not implemented yet.
7: Represents the dilated scenario, not implemented yet.
8: Represents the block_local scenario, not implemented yet.
inner_precise (int, optional) –
There are four modes: 0, 1, 2, and 3, represented by 2 bits: bit 0 (bit0) represents the choice for high precision or high performance, and bit 1 (bit1) indicates whether row-wise invalidity correction is applied.
0: Enable high-precise mode, without row-wise invalidity correction.
1: High-performance mode, without row-wise invalidity correction.
2: Enable high-precise mode, with row-wise invalidity correction.
3: High-performance mode, with row-wise invalidity correction.
When Q_S > 1, if sparse_mode is 0 or 1 and a user-defined mask is provided, it is recommended to enable row-wise invalidity correction. Only support 0 and 1 when Q_S is 1. Default:
1
.High-precise and high-performance are only effective for float16 inputs; Row invalidity correction is effective for float16, bfloat16, and int8 inputs. Currently, 0 and 1 are reserved configuration values. If there is a situation where an entire row in the "mask portion involved in computation" is all 1s, precision may degrade. In such cases, you can try setting this parameter to 2 or 3 to enable row invalidity correction for improved precision. However, this configuration will result in decreased performance. If the function can detect the presence of invalid row scenarios, e.g. in cases where sparse_mode is 3 and S_q > S_kv, it will automatically enable row invalidity computation.
block_size (int, optional) – Maximum number of tokens per block in the KV cache block for PageAttention. Default:
0
. Invalid when Q_S > 1.antiquant_mode (int, optional) – Fake-quantization mode, 0: per-channel (per-channel includes per-tensor), 1: per-token. The per-channel and per-tensor modes can be distinguished by the dimension of the input shape. When the dimension is 1, it runs in per-tensor mode; otherwise, it runs in per-channel mode. Default:
0
. Invalid when Q_S > 1.key_antiquant_mode (int, optional) – Fake-quantization mode for the key. 0: per-channel (per-channel includes per-tensor), 1: per-token. Default:
0
. Invalid when Q_S > 1.value_antiquant_mode (int, optional) – Fake-quantization mode for the value. 0: per-channel (per-channel includes per-tensor), 1: per-token. Default:
0
. Invalid when Q_S > 1.softmax_lse_flag (bool, optional) – Whether to output softmax_lse. Default:
False
.
- Returns
attention_out (Tensor), the attention score with data type of float16, bfloat16 or int8. When the input_layout is BNSD_BSND, the shape is \((B, S, N, D)\). In all other cases, the shape is consistent with the input query shape.
softmax_lse (Tensor), the softmax_lse with data type of float32, obtained by taking the lse (log, sum and exp) of the result of query*key. Specifically, the Ring Attention algorithm first takes the max of the result of query*key, obtaining softmax_max. The result of query*key is then subtracted by softmax_max, followed by taking exp, and then the sum is computed to obtain softmax_sum. Finally, the log of softmax_sum is taken, and softmax_max is added to obtain softmax_lse. The softmax_lse is only calculated when softmax_lse_flag is True, and the shape would be \((B, N, Q\_S, 1)\). If softmax_lse_flag is False, then a tensor with shape \((1)\) filled with zeros would be returned. In graph mode with JitConfig set to O2, please ensure that the softmax_lse_flag is enabled before using softmax_lse; otherwise, an exception will occur.
- Constraints:
Full Inference Scenario (Q_S > 1):
Query, key, and value inputs functional usage restrictions:
The B axis supports values less than or equal to 65535. If the input type includes int8, or if the input type is float16 or bfloat16 and the D axis is not 16-aligned, the B axis is only supported up to 128.
The N axis supports values less than or equal to 256, and the D axis supports values less than or equal to 512.
The S axis supports values less than or equal to 20,971,520 (20M). In some long sequence scenarios, if the computation load is too large, it may cause a timeout in the PFA operator (AICore error type with errorStr: "timeout or trap error"). In this case, it is recommended to perform an S split. Note: The computational load is affected by B, S, N, D, etc.; the larger the values, the greater the computational load. Typical long sequence timeout scenarios (where the product of B, S, N, and D is large) include, but are not limited to:
B=1, Q_N=20, Q_S=2097152, D=256, KV_N=1, KV_S=2097152;
B=1, Q_N=2, Q_S=20971520, D=256, KV_N=2, KV_S=20971520;
B=20, Q_N=1, Q_S=2097152, D=256, KV_N=1, KV_S=2097152;
B=1, Q_N=10, Q_S=2097152, D=512, KV_N=1, KV_S=2097152.
When the query, key, value, or attention_out type includes int8, the D axis must be 32-aligned. If all types are float16 or bfloat16, the D axis must be 16-aligned.
The sparse_mode parameter currently only supports values 0, 1, 2, 3, and 4. Using any other values will result in an error.
When sparse_mode = 0, if the atten_mask is None, or if the atten_mask is provided in the left padding scenario, the input parameters pre_tokens and next_tokens are ignored.
When sparse_mode = 2, 3, or 4, the shape of the atten_mask must be S,S or 1,S,S or 1,1,S,S, where S must be fixed at 2048, and the user must ensure the atten_mask is a lower triangular matrix. If no atten_mask is provided or if the shape is incorrect, an error will occur.
In sparse_mode = 1, 2, 3 scenarios, the pre_tokens and next_tokens inputs are ignored and assigned according to the relevant rules.
The KV cache de-quantization only supports queries of type float16, where int8 keys and values are de-quantized to float16. The data range of the input key/value and the antiquant_scale must have a product within the range of (-1, 1). High-performance mode can guarantee precision; otherwise, high-precision mode should be enabled to ensure accuracy.
Query left padding scenario:
In the query left padding scenario, the formula for calculating the starting point of the query transport is: Q_S - query_padding_size - actual_seq_lengths. The formula for the ending point of the query transport is: Q_S - query_padding_size. The query transport starting point must not be less than 0, and the ending point must not exceed Q_S; otherwise, the results will be incorrect.
If the kv_padding_size in the query left padding scenario is less than 0, it will be set to 0.
The query left padding scenario must be enabled together with the actual_seq_lengths parameter, otherwise, the default is the query right padding scenario.
The query left padding scenario does not support PageAttention and cannot be enabled together with the block_table parameter.
KV left padding scenario:
In the KV left padding scenario, the formula for calculating the starting point of the key and value transport is: KV_S - kv_padding_size - actual_seq_lengths_kv. The formula for the ending point of the key and value transport is: KV_S - kv_padding_size. The key and value transport starting point must not be less than 0, and the ending point must not exceed KV_S; otherwise, the results will be incorrect.
If the kv_padding_size in the KV left padding scenario is less than 0, it will be set to 0.
The KV left padding scenario must be enabled together with the actual_seq_lengths_kv parameter, otherwise, the default is the KV right padding scenario.
The KV left padding scenario does not support PageAttention and cannot be enabled together with the block_table parameter.
pse_shift functional usage restrictions:
This function is supported when the query data type is float16, bfloat16, or int8.
If the query data type is float16 and pse_shift is enabled, it will force high-precision mode, inheriting the limitations of high-precision mode.
Q_S must be greater than or equal to the length of the query S, and KV_S must be greater than or equal to the length of the key S.
KV fake quantization parameter separation is not currently supported.
Incremental Inference Scenario (Q_S is 1):
Query, key, and value inputs functional usage restrictions:
The B axis supports values less than or equal to 65,536.
The N axis supports values less than or equal to 256.
The D axis supports values less than or equal to 512.
Scenarios where the input types of query, key, and value are all int8 are not supported.
Page attention scenario:
The necessary condition to enable page attention is that the block_table exists and is valid. The key and value are arranged in contiguous memory according to the indices in the block_table. The key and value dtypes supported are float16, bfloat16, and int8. In this scenario, the input_layout parameter for key and value is invalid.
block_size is a user-defined parameter, and its value will affect the performance of page attention. When enabling page attention, a non-zero value for block_size must be provided, and the maximum value for block_size is 512.
If the input types of key and value are float16 or bfloat16, they must be 16-aligned. If the input types are int8, they must be 32-aligned, with 128 being recommended. In general, page attention can increase throughput but may lead to a performance decrease.
In the page attention enabled scenario, when the KV cache layout is (blocknum, block_size, H) and num_key_value_heads * D exceeds 64K, an error will be reported due to hardware instruction constraints. This can be resolved by enabling GQA (reducing num_key_value_heads) or adjusting the KV cache layout to (blocknum, num_key_value_heads, block_size, D).
The product of all dimensions of the shape of the key and value tensors in the page attention scenario must not exceed the representable range of int32.
In the page attention enabled scenario, the input S must be greater than or equal to max_block_num_per_seq * block_size.
Enabling attention mask (e.g., mask shape = (B, 1, 1, S))
Enabling pse_shift (e.g., pse_shift shape = (B, N, 1, S))
Enabling fake quantization in per-token mode (e.g., antiquant_scale and antiquant_offset shapes = (2, B, S)) are also supported.
KV left padding scenario:
In the KV left padding scenario, the formula for calculating the starting point of the KV cache transport is: KV_S - kv_padding_size - actual_seq_lengths. The formula for the endpoint of the KV cache transport is: KV_S - kv_padding_size. If the starting point or endpoint of the KV cache is less than 0, the returned data result will be all zeros.
If kv_padding_size is less than 0 in the KV left padding scenario, it will be set to 0.
The KV left padding scenario must be enabled together with the actual_seq_lengths parameter, otherwise, it defaults to the KV right padding scenario.
The KV left padding scenario must be enabled together with the atten_mask parameter, and the atten_mask must be correctly applied to hide invalid data. Otherwise, accuracy issues may arise.
pse_shift functional usage restrictions:
The data type of pse_shift must match the data type of the query.
Only the D axis alignment is supported, meaning the D axis must be divisible by 16.
KV fake quantization parameter separation:
key_antiquant_mode and value_antiquant_mode must be consistent.
key_antiquant_scale and value_antiquant_scale must either both be empty or both non-empty.
key_antiquant_offset and value_antiquant_offset must either both be empty or both non-empty.
When both key_antiquant_scale and value_antiquant_scale are non-empty, their shapes must be consistent.
When both key_antiquant_offset and value_antiquant_offset are non-empty, their shapes must be consistent.
- Supported Platforms:
Ascend
Examples
>>> from mindspore import ops >>> from mindspore import Tensor >>> import numpy as np >>> B, N, S, D = 1, 8, 1024, 128 >>> query = Tensor(np.random.rand(B, N, S, D).astype(np.float16)) >>> key = Tensor(np.random.rand(B, N, S, D).astype(np.float16)) >>> value = Tensor(np.random.rand(B, N, S, D).astype(np.float16)) >>> out = ops.fused_infer_attention_score(query, key, value, num_heads=N, input_layout='BNSD') >>> print(out[0].shape) (1, 8, 1024, 128)