mindspore.ops.prompt_flash_attention
- mindspore.ops.prompt_flash_attention(query, key, value, attn_mask=None, actual_seq_lengths=None, actual_seq_lengths_kv=None, pse_shift=None, deq_scale1=None, quant_scale1=None, deq_scale2=None, quant_scale2=None, quant_offset2=None, num_heads=1, scale_value=1.0, pre_tokens=2147483647, next_tokens=0, input_layout='BSH', num_key_value_heads=0, sparse_mode=0, inner_precise=1)[source]
The interface for fully inference.
B – Batch size
N – Num heads
S – Sequence length
D – Head dim
H – Hidden size
Self attention constructs an attention model based on the relationship between input samples themselves. The principle is to assume that there is an input sample sequence \(x\) of length \(n\), and each element of \(x\) is a \(d\) dimensional vector, which can be viewed as a token embedding. This sequence can be transformed through 3 weight matrices to obtain 3 matrices with dimensions of \(n\times d\).
The self attention calculation formula is defined as:
\[Attention(Q,K,V)=Softmax(\frac{QK^{T} }{\sqrt{d} } )V\]where the product of \(Q\) and \(K^{T}\) represents the attention of input \(x\). To avoid the value becoming too large, it is usually scaled by dividing it by the square root of \(d\) and perform softmax normalization on each row, yields a matrix of \(n\times d\) after multiplying \(V\).
Warning
Support dtype of float16 for attn_mask will be deprecated in the future.
When sparse_mode is 2, 3 or 4, the shape of attn_mask must be \((2048, 2048)\) / \((B, 1, 2048, 2048)\) / \((1, 1, 2048, 2048)\).
Note
Maximum Support for each axis
Supports B-axis values less than or equal to 65536 (64k). When the input type includes int8 with D-axis not aligned to 32, or the input type is float16 or bfloat16 with D-axis not aligned to 16, the B-axis supports up to 128 only.
Supports N-axis values less than or equal to 256.
Supports S-axis values less than or equal to 20971520 (20M).
Supports D-axis values less than or equal to 512.
Quantization
int8 Input, int8 Output: Parameters deq_scale1, quant_scale1, deq_scale2, and quant_scale2 must all be provided. quant_offset2 is optional (default is 0 if not provided).
int8 Input, float16 Output: Parameters deq_scale1, quant_scale1, and deq_scale2 must all be provided. If quant_offset2 or quant_scale2 is provided (i.e., not null), it will result in an error.
float16 or bfloat16 Input, int8 Output: Parameter quant_scale2 must be provided. quant_offset2 is optional (default is 0 if not provided). If deq_scale1, quant_scale1, or deq_scale2 is provided (i.e., not null), it will result in an error.
int8 Output:
quant_scale2 and quant_offset2 in per-channel format do not support scenarios with left padding, Ring Attention, or non-32-byte aligned D-axis.
In GE mode: quant_scale2 and quant_offset2 in per-tensor format and input dtype is bfloat16 do not support scenarios with non-32-byte aligned D-axis when input_layout is "BSH".
Does not support sparse as band and pre_tokens/next_tokens being negative.
quant_scale2 and quant_offset2 can be bfloat16 only when query is bfloat16.
Other Usage Caveats:
\(N\) of parameter query must be equal to num_heads. \(N\) of parameter key and parameter value must be equal to num_key_value_heads.
num_heads must be divisible by num_key_value_heads and num_heads divided by `num_key_value_heads can not be greater than 64.
When query dtype is bfloat16, D axis should align with 16.
Each element of actual_seq_lengths must not exceed q_S and element of actual_seq_lengths_kv must not exceed kv_S.
- Parameters
query (Tensor) – The query tensor with data type of int8, float16 or bfloat16. The shape is \((B, q_S, q_H)\) / (B, q_N, q_S, q_D).
key (Tensor) – The key tensor with the same dtype as query. The shape is \((B, kv_S, kv_H)\) / (B, kv_N, kv_S, kv_D).
value (Tensor) – The value tensor with the same dtype as query. The shape is \((B, kv_S, kv_H)\) / (B, kv_N, kv_S, kv_D).
attn_mask (Tensor, optional) – For each element, 0/False indicates retention and 1/True indicates discard. If sparse_mode is 0 or 1: the shape is \((q_S, kv_S)\) / \((B, q_S, kv_S)\) / \((1, q_S, kv_S)\) / \((B, 1, q_S, kv_S)\) / \((1, 1, q_S, kv_S)\). If sparse_mode is 2, 3 or 4, the shape is \((2048, 2048)\) / \((1, 2048, 2048)\) / \((1, 1, 2048, 2048)\). Default:
None
.actual_seq_lengths (Union[Tensor, tuple[int], list[int]], optional) – Describe actual sequence length of each batch of query with data type of int64. The shape is \((B, )\) and every element should be positive integer. Default:
None
.actual_seq_lengths_kv (Union[Tensor, tuple[int], list[int]], optional) – Describe actual sequence length of each batch of key or value with data type of int64. The shape is \((B, )\) and every element should be positive integer. Default:
None
.pse_shift (Tensor, optional) –
The position encoding tensor with data type of float16 or bfloat16. Input tensor of shape \((B, N, q_S, kv_S)\) / \((1, N, q_S, kv_S)\). Default:
None
.q_S must be greater than or equal to the query's S length, and kv_S must be greater than or equal to the key's S length.'
If pse_shift has dtype float16, query should have dtype float16 or int8, in which case high precision mode is enabled automatically.
If pse_shift has dtype bfloat16, query should have dtype bfloat16.
deq_scale1 (Tensor, optional) – Quantitative parametor, the tensor with data type of uint64 or float32. Input Tensor of shape \((1,)\). Default:
None
.quant_scale1 (Tensor, optional) – Quantitative parametor, the tensor with data type of float32. Input Tensor of shape \((1,)\). Default:
None
.deq_scale2 (Tensor, optional) – Quantitative parametor, input Tensor of shape \((1,)\) and it has the same dtype as deq_scale1. Default:
None
.quant_scale2 (Tensor, optional) – Quantitative parametor, the tensor with data type of float32 or bfloat16. The suggested shape is \((1,)\) / \((1, 1, q_H)\) / \((q_H, )\) when output layout is BSH, \((1,)\) / \((1, q_N, 1, D)\) / \((q_N, D) when layout is BNSD. Default: \).
quant_offset2 (Tensor, optional) – Quantitative parametor, the tensor with data type of float32 or bfloat16. It has the same dtype and shape as quant_scale2. Default:
None
.num_heads (int, optional) – The number of heads. It is an integer in range [0, 256]. Default:
1
.scale_value (double, optional) – The scale value indicating the scale coefficient, which is used as the scalar of Muls in the calculation. Default:
1.0
.pre_tokens (int, optional) – For sparse cumputing, indicating the number of previous tokens Attention needs to associated with. Default:
2147483647
.next_tokens (int, optional) – For sparse cumputing, indicating the number of next tokens Attention needs to associated with. Default:
0
.input_layout (str, optional) – the data layout of the input qkv, support (BSH) and (BNSD). Default
BSH
.num_key_value_heads (int, optional) – An int indicates head numbers of
key
/value
which are used in GQA algorithm. The value 0 indicates if the key and value have the same head nums, use num_heads. It it is specified(not 0), it must be a factor of num_heads and it must be equal to kv_n. Default:0
.sparse_mode (int, optional) –
An int specifies sparse mode, can be int from {0, 1, 2, 3, 4}. Default:
0
.sparseMode = 0: If attn_mask is a null pointer, pre_tokens and next_tokens inputs are ignored (internally set to INT_MAX).
sparseMode = 2, 3, 4: attn_mask shape must be \((S, S)\) or \((1, S, S)\) or \((1, 1, S, S)\), with S fixed at 2048. User must ensure that attn_mask is lower triangular. If not provided or incorrect shape, it will result in an error.
sparseMode = 1, 2, 3: Ignores pre_tokens, next_tokens inputs and sets values according to specific rules.
sparseMode = 4: pre_tokens and next_tokens must be non-negative.
inner_precise (int, optional) – An int number from {0, 1} indicates computing mode.
0
for high precision mode for float16 dtype.1
for high performance mode. Default:1
.
- Returns
attention_out (Tensor) - Output tensor, has the same shape as` query` of \((B, q_S, q_H)\) / \((B, q_N, q_S, q_D)\). Output dtype is determined by multiple factors, please refer to Note above for details.
- Raises
TypeError – Dtype of query is not int8, float16 or bfloat16.
TypeError – query, key and value don't have the same dtype.
TypeError – Dtype of attn_mask is not bool, int8 or uint8.
TypeError – Dtype of pse_shift is not bfloat16 or float16.
TypeError – scale_value is not a double number.
TypeError – input_layout is not a string.
TypeError – num_key_value_heads is not an int.
TypeError – sparse_mode is not an int.
TypeError – sparse_inner_precisemode is not an int.
TypeError – quant_scale1 is not Tensor of type float32.
TypeError – deq_scale1 is not Tensor of type uint64 or float32.
TypeError – quant_scale2 is not Tensor of type float32.
TypeError – deq_scale2 is not Tensor of type uint64 or float32.
TypeError – quant_offset2 is not Tensor of type float32.
ValueError – input_layout is a string but of (BSH) or (BNSD).
RuntimeError – num_heads is not divisible by num_key_value_heads.
RuntimeError – num_heads is not greater than 0.
RuntimeError – num_key_value_heads is not greater than or equal to 0.
RuntimeError – kv_n is not equal to num_key_value_heads.
RuntimeError – attn_mask shape is not valid.
RuntimeError – sparse_mode is specified but is not 0, 1, 2, 3 or 4.
RuntimeError – query dtype is bfloat16 and D axis is not aligned with 16.
RuntimeError – input_layout is BSH and kv_h is not divisible by num_key_value_heads.
RuntimeError – D-axis of query, key and value is not the same.
RuntimeError – In post quant per-channel scenario, D-axis is not 32 Byte aligned.
- Supported Platforms:
Ascend
Examples
>>> from mindspore.ops.function.nn_func import prompt_flash_attention >>> from mindspore import Tensor >>> import numpy as np >>> B = 1 >>> N = 16 >>> S = 256 >>> D = 16 >>> query = Tensor(np.ones((B, N, S, D), dtype=np.float16)) >>> key = Tensor(np.ones((B, N, S, D), dtype=np.float16)) >>> value = Tensor(np.ones((B, N, S, D), dtype=np.float16)) >>> out = prompt_flash_attention(query, key, value, None, None, None, None, None, None, None, None, None, N, input_layout='BNSD') >>> print(out.shape) (1, 16, 256, 16)