mindspore.ops.incre_flash_attention
- mindspore.ops.incre_flash_attention(query, key, value, attn_mask=None, actual_seq_lengths=None, pse_shift=None, dequant_scale1=None, quant_scale1=None, dequant_scale2=None, quant_scale2=None, quant_offset2=None, antiquant_scale=None, antiquant_offset=None, block_table=None, num_heads=1, input_layout='BSH', scale_value=1.0, num_key_value_heads=0, block_size=0, inner_precise=1, kv_padding_size=None)[source]
B – Batch size
N – Num heads
kvN – Num key value heads
S – Sequence length
D – Head dim
H – Hidden size
kvH – Hidden size of key value
where \(H=N\times D\), \(kvH=kvN\times D\)
Self attention constructs an attention model based on the relationship between input samples themselves. The principle is to assume that there is a length of the input sample sequence \(x\) of \(n\), and each element of \(x\) is a \(d\) dimensional vector, which can be viewed as a token embedding. This sequence can be transformed through 3 weight matrices to obtain 3 matrices with dimensions of \(n\times d\). The self attention calculation formula is defined as:
\[Attention(Q,K,V)=Softmax(\frac{QK^{T} }{\sqrt{d} } )V\]where the product of \(Q\) and \(K^{T}\) represents the attention of input \(x\). To avoid the value becoming too large, it is usually scaled by dividing it by the square root of \(d\) and perform softmax normalization on each row, yields a matrix of \(n\times d\) after multiplying \(V\).
Warning
This is an experimental API that is subject to change or deletion.
Note
If there is no input parameter and no default value, None needs to be passed.
The shape of the tensor corresponding to the key and value parameters needs to be completely consistent.
\(N\) of parameter query is equal with num_heads. \(N\) of parameter key and parameter value is equal with num_key_value_heads. num_heads is a multiple of num_key_value_heads.
Quantization
When the data type of query, key, and value is float16 and the data type of output is int8, the input parameter quant_scale2 is required and quant_offset2 is optional.
When antiquant_scale exists, key and value need to be passed by int8. antiquant_offset is optional.
The data type of antiquant_scale and antiquant_offset should be consistency with that of query.
pse_shift
The pse_shift data type needs to be consistent with the query data type, and only supports D-axis alignment, which means that the D-axis can be divided by 16.
Page attention:
The necessary condition for enabling page attention is that the block_table exists, and the key and value are arranged in a contiguous memory according to the index in the block_table. The support for key and value dtypes is float16/bfloat16/int8.
In the enabling scenario of page attention, 16 alignment is required when input types of key and value are float16/bfloat16, and 32 alignment is required when input types of key and value are int8. It is recommended to use 128.
The maximum max_block_num_per_seq currently supported by blocktable is 16k, and exceeding 16k will result in interception and error messages; If you encounter \(S\) being too large and causing max_block_num_per_seq to exceed 16k, you can increase the block_size to solve the problem.
The multiplication of all dimensions of the shape of the parameters key and value in the page attention scenario cannot exceed the representation range of int32.
When performing per-channel post quantization, page attention cannot be enabled simultaneously.
kv_padding_size:
The calculation formula for the starting point of KV cache transfer is \(S-kv\_padding\_size-actual\_seq\_lengths\). The calculation formula for the transfer endpoint of KV cache is \(S-kv\_padding\_size\). When the starting or ending point of the KV cache transfer is less than 0, the returned data result is all 0.
When kv_padding_size is less than 0, it will be set to 0.
kv_padding_size needs to be enabled together with the actual_seq_lengths parameter, otherwise it is considered as the KV right padding scene.
It needs to be enabled together with the atten_mask parameter and ensure that the meaning of atten_mask is correct, that is, it can correctly hide invalid data. Otherwise, it will introduce accuracy issues.
kv_padding_size does not support page attention scenarios
- Parameters
query (Tensor) – The query tensor with data type of float16 or bfloat16. The shape is \((B, 1, H)\) / \((B, N, 1, D)\).
key (TensorList) – The key tensor with data type of float16 or bfloat16 or int8. The shape is \((B, S, kvH)\) / \((B, kvN, S, D)\).
value (TensorList) – The value tensor with data type of float16 or bfloat16 or int8. The shape is \((B, S, kvH)\) / \((B, kvN, S, D)\).
attn_mask (Tensor, optional) – The attention mask tensor with data type of bool or int8 or uint8. The shape is \((B, S)\) / \((B, 1, S)\) / \((B, 1, 1, S)\). Default:
None
.actual_seq_lengths (Union[Tensor, tuple[int], list[int]], optional) – Describe actual sequence length of each input with data type of int32 or int64. The shape is \((B, )\). Default:
None
.pse_shift (Tensor, optional) – The position encoding tensor with data type of float16 or bfloat16. Input tensor of shape \((1, N, 1, S)\) / \((B, N, 1, S)\). Default:
None
.dequant_scale1 (Tensor, optional) – Quantitative parametor, the tensor with data type of uint64 or float32. It is disable now. Default:
None
.quant_scale1 (Tensor, optional) – Quantitative parametor, the tensor with data type of float32. It is disable now. Default:
None
.dequant_scale2 (Tensor, optional) – Quantitative parametor, the tensor with data type of uint64 or float32. It is disable now. Default:
None
.quant_scale2 (Tensor, optional) – Post Quantitative parametor, the tensor with data type of float32. The shape is \((1,)\). Default:
None
.quant_offset2 (Tensor, optional) – Post Quantitative parametor, the tensor with data type of float32. The shape is \((1,)\). Default:
None
.antiquant_scale (Tensor, optional) –
- Pseudo Quantitative parametor, the tensor with data type of float16 or
bfloat16. The shape is \((2, kvN, 1, D)\) when input_layout is 'BNSD' or \((2, kvH)\) when
input_layout is 'BSH'. Default:
None
.antiquant_offset (Tensor, optional) –
- Pseudo Quantitative parametor, the tensor with data type of float16 or
bfloat16. The shape is \((2, kvN, 1, D)\) when input_layout is 'BNSD' or \((2, kvH)\) when
input_layout is 'BSH'. Default:
None
.block_table (Tensor, optional) – The tensor with data type of int32. The shape is \((B, max\_block\_num\_per\_seq)\), where \(max\_block\_num\_per\_seq = ceil(\frac{max(actual\_seq\_length)}{block\_size} )\). Default:
None
.num_heads (int) – The number of heads.
input_layout (str) – The data layout of the input qkv, support 'BSH' and 'BNSD'. Default
'BSH'
.scale_value (double) – The scale value indicating the scale coefficient, which is used as the scalar of Muls in the calculation. Default:
1.0
.num_key_value_heads (int) – Head numbers of key/value which are used in GQA algorithm. The value 0 indicates if the key and value have the same head nums, use numHeads. Default:
0
.block_size (int) – The maximum number of tokens stored in each block of KV in page attention. Default:
0
.inner_precise (int) – Default:
1
.kv_padding_size (Tensor, optional) – The tensor with data type of int64. The range of values is \(0\le kv\_padding\_size \le S-max(actual\_seq\_length)\). The shape is \(()\) or \((1,)\). Default:
None
.
- Returns
attention_out (Tensor), the shape is \((B, 1, H)\) / \((B, N, 1, D)\).
- Supported Platforms:
Ascend
Examples
>>> from mindspore import ops >>> from mindspore.common import Tensor >>> from mindspore.common import dtype as mstype >>> import numpy as np >>> B, N, S, D, kvN = 1, 4, 10, 128, 1 >>> query = Tensor(np.random.randn(B, 1, N * D), mstype.float16) >>> key = [Tensor(np.random.randn(B, S, kvN * D), mstype.float16)] >>> value = [Tensor(np.random.randn(B, S, kvN * D), mstype.float16)] >>> ifa_ms = ops.functional.incre_flash_attention >>> attn_out = ifa_ms(query, key, value, num_heads=N, num_key_value_heads=kvN) >>> attn_out Tensor(shape=[1, 1, 512], dtype=Float16, value= [[[ 1.6104e+00, 7.3438e-01, 1.0684e+00 ... -8.7891e-01, 1.7695e+00, 1.0264e+00]]])