mindspore.ops.incre_flash_attention

mindspore.ops.incre_flash_attention(query, key, value, attn_mask=None, actual_seq_lengths=None, pse_shift=None, dequant_scale1=None, quant_scale1=None, dequant_scale2=None, quant_scale2=None, quant_offset2=None, antiquant_scale=None, antiquant_offset=None, block_table=None, num_heads=1, input_layout='BSH', scale_value=1.0, num_key_value_heads=0, block_size=0, inner_precise=1, kv_padding_size=None)[source]

The interface for incremental inference.

  • B: Batch size

  • N: Num of attention heads

  • kvN: Num of key / value heads

  • S: Sequence length

  • D: Head dim

  • H: Hidden layer size

  • kvH: Hidden size of key / value

where H=N×D, kvH=kvN×D.

Self attention constructs an attention model based on the relationship between input samples themselves. The principle is to assume that there is a length of the input sample sequence x of n, and each element of x is a d dimensional vector, which can be viewed as a token embedding. This sequence can be transformed through 3 weight matrices to obtain 3 matrices with dimensions of n×d. The self attention calculation formula is defined as:

Attention(Q,K,V)=Softmax(QKTd)V

where the product of Q and KT represents the attention of input x. To avoid the value becoming too large, it is usually scaled by dividing it by the square root of d and perform softmax normalization on each row, yields a matrix of n×d after multiplying V.

Note

  • If there is no input parameter and no default value, None needs to be passed.

  • The shape of the tensor corresponding to the key and value parameters needs to be completely consistent.

  • N of parameter query is equal with num_heads. N of parameter key and parameter value is equal with num_key_value_heads. num_heads is a multiple of num_key_value_heads.

  • Quantization

    • When the data type of query, key, and value is float16 and the data type of output is int8, the input parameter quant_scale2 is required and quant_offset2 is optional.

    • When antiquant_scale exists, key and value need to be passed by int8. antiquant_offset is optional.

    • The data type of antiquant_scale and antiquant_offset should be consistenct with that of query.

  • pse_shift

    • The pse_shift data type needs to be consistent with query, and only supports D-axis alignment, which means that the D-axis can be divided by 16.

  • Page attention:

    • The necessary condition for enabling page attention is that the block_table exists, and the key and value are arranged in a contiguous memory according to the index in the block_table. The support dtype for key and value is float16/bfloat16/int8.

    • In the enabling scenario of page attention, 16 alignment is required when input types of key and value are float16/bfloat16, and 32 alignment is required when input dtype of key and value is int8. It is recommended to use 128.

    • The maximum max_block_num_per_seq currently supported by blocktable is 16k, and exceeding 16k will result in interception and error messages; If you encounter S being too large and causing max_block_num_per_seq to exceed 16k, you can increase the block_size to solve the problem.

    • The multiplication of all dimensions of the shape of the parameters key and value in the page attention scenario cannot exceed the representation range of int32.

    • When performing per-channel post quantization, page attention cannot be enabled simultaneously.

  • kv_padding_size:

    • The calculation formula for the starting point of KV cache transfer is Skv_padding_sizeactual_seq_lengths. The calculation formula for the transfer endpoint of KV cache is Skv_padding_size. When the starting or ending point of the KV cache transfer is less than 0, the returned data result is all 0.

    • When kv_padding_size is less than 0, it will be set to 0.

    • kv_padding_size needs to be enabled together with the actual_seq_lengths parameter, otherwise it is considered as the KV right padding scene.

    • It needs to be enabled together with the atten_mask parameter and ensure that the meaning of atten_mask is correct, that is, it can correctly hide invalid data. Otherwise, it will introduce accuracy issues.

    • kv_padding_size does not support page attention scenarios.

Warning

Only support on Atlas A2 training series.

Parameters
  • query (Tensor) – The query tensor with data type of float16 or bfloat16. The shape is (B,1,H) / (B,N,1,D).

  • key (Union[tuple, list]) – The key tensor with data type of float16 or bfloat16 or int8. The shape is (B,S,kvH) / (B,kvN,S,D).

  • value (Union[tuple, list]) – The value tensor with data type of float16 or bfloat16 or int8. The shape is (B,S,kvH) / (B,kvN,S,D).

  • attn_mask (Tensor, optional) – The attention mask tensor with data type of bool or int8 or uint8. The shape is (B,S) / (B,1,S) / (B,1,1,S). Default: None.

  • actual_seq_lengths (Union[Tensor, tuple[int], list[int]], optional) – Describe actual sequence length of each input with data type of int64. The shape is (B,). Default: None.

  • pse_shift (Tensor, optional) – The position encoding tensor with data type of float16 or bfloat16. Input tensor of shape (1,N,1,S) / (B,N,1,S). Default: None.

  • dequant_scale1 (Tensor, optional) – Quantitative parametor, the tensor with data type of uint64 or float32. It is disable now. Default: None.

  • quant_scale1 (Tensor, optional) – Quantitative parametor, the tensor with data type of float32. It is disable now. Default: None.

  • dequant_scale2 (Tensor, optional) – Quantitative parametor, the tensor with data type of uint64 or float32. It is disable now. Default: None.

  • quant_scale2 (Tensor, optional) – Post Quantitative parametor, the tensor with data type of float32. The shape is (1,). Default: None.

  • quant_offset2 (Tensor, optional) – Post Quantitative parametor, the tensor with data type of float32. The shape is (1,). Default: None.

  • antiquant_scale (Tensor, optional) – Pseudo Quantitative parametor, the tensor with data type of float16 or bfloat16. The shape is (2,kvN,1,D) when input_layout is 'BNSD' or (2,kvH) when input_layout is 'BSH'. Default: None.

  • antiquant_offset (Tensor, optional) – Pseudo Quantitative parametor, the tensor with data type of float16 or bfloat16. The shape is (2,kvN,1,D) when input_layout is 'BNSD' or (2,kvH) when input_layout is 'BSH'. Default: None.

  • block_table (Tensor, optional) – The tensor with data type of int32. The shape is (B,max_block_num_per_seq), where max_block_num_per_seq=ceil(max(actual_seq_length)block_size). Default: None.

  • num_heads (int, optional) – The number of heads. Default: 1.

  • input_layout (str, optional) – The data layout of the input qkv, support 'BSH' and 'BNSD'. Default 'BSH'.

  • scale_value (double, optional) – The scale value indicating the scale coefficient, which is used as the scalar of Muls in the calculation. Default: 1.0.

  • num_key_value_heads (int, optional) – Head numbers of key/value which are used in GQA algorithm. The value 0 indicates if the key and value have the same head nums, use numHeads. Default: 0.

  • block_size (int, optional) – The maximum number of tokens stored in each block of KV in page attention. Default: 0.

  • inner_precise (int, optional) – An int number from {0, 1} indicates computing mode. 0 for high precision mode for float16 dtype. 1 for high performance mode. Default: 1.

  • kv_padding_size (Tensor, optional) – The tensor with data type of int64. The range of values is 0kv_padding_sizeSmax(actual_seq_length). The shape is () or (1,). Default: None.

Returns

attention_out (Tensor), the shape is (B,1,H) / (B,N,1,D).

Raises
  • TypeError – dtype of query is not float16 or bfloat16.

  • TypeErrorkey and value don't have the same dtype.

  • TypeError – dtype of attn_mask is not bool, int8 or uint8.

  • TypeError – dtype of pse_shift is not bfloat16 or float16.

  • TypeErrorscale_value is not a double number.

  • TypeErrorinput_layout is not a string.

  • TypeErrornum_key_value_heads or num_heads is not an int.

  • TypeErrorinner_precise is not an int.

  • TypeErrorquant_scale1 is not Tensor of type float32.

  • TypeErrorquant_scale2 is not Tensor of type float32.

  • TypeErrorquant_offset2 is not Tensor of type float32.

  • ValueError – size of actual_seq_lengths is not 1 or B.

  • ValueErrorinput_layout is a string but of (BSH) or (BNSD).

  • ValueErrornum_heads is not divisible by Q_H.

  • ValueErrornum_heads is not divisible by num_key_value_heads.

  • RuntimeErrornum_heads is not greater than 0.

  • RuntimeErrorattn_mask shape is not valid.

Supported Platforms:

Ascend

Examples

>>> from mindspore import ops
>>> from mindspore.common import Tensor
>>> from mindspore.common import dtype as mstype
>>> import numpy as np
>>> B, N, S, D, kvN = 1, 4, 10, 128, 1
>>> query = Tensor(np.random.randn(B, 1, N * D), mstype.float16)
>>> key = [Tensor(np.random.randn(B, S, kvN * D), mstype.float16)]
>>> value = [Tensor(np.random.randn(B, S, kvN * D), mstype.float16)]
>>> ifa_ms = ops.incre_flash_attention
>>> attn_out = ifa_ms(query, key, value, num_heads=N, num_key_value_heads=kvN)
>>> attn_out
Tensor(shape=[1, 1, 512], dtype=Float16, value=
[[[ 1.6104e+00,  7.3438e-01,  1.0684e+00 ... -8.7891e-01,  1.7695e+00,  1.0264e+00]]])