mindspore.ops.prompt_flash_attention

View Source On Gitee
mindspore.ops.prompt_flash_attention(query, key, value, attn_mask=None, actual_seq_lengths=None, actual_seq_lengths_kv=None, pse_shift=None, deq_scale1=None, quant_scale1=None, deq_scale2=None, quant_scale2=None, quant_offset2=None, num_heads=1, scale_value=1.0, pre_tokens=2147483647, next_tokens=0, input_layout='BSH', num_key_value_heads=0, sparse_mode=0, inner_precise=1)[source]

The interface for fully inference.

  • B: Batch size

  • N: Num of attention heads

  • S: Sequence length

  • D: Head dim

  • H: Hidden layer size

Self attention constructs an attention model based on the relationship between input samples themselves. The principle is to assume that there is an input sample sequence x of length n, and each element of x is a d dimensional vector, which can be viewed as a token embedding. This sequence can be transformed through 3 weight matrices to obtain 3 matrices with dimensions of n×d.

The self attention calculation formula is defined as:

Attention(Q,K,V)=Softmax(QKTd)V

where the product of Q and KT represents the attention of input x. To avoid the value becoming too large, it is usually scaled by dividing it by the square root of d and perform softmax normalization on each row, yields a matrix of n×d after multiplying V.

Warning

  • Support dtype of float16 for attn_mask will be deprecated in the future.

  • When sparse_mode is 2, 3 or 4, the shape of attn_mask must be (2048,2048) / (B,1,2048,2048) / (1,1,2048,2048).

Note

  • Maximum Support for each axis

    • Supports B-axis values less than or equal to 65536 (64k). When the input type includes int8 with D-axis not aligned to 32, or the input type is float16 or bfloat16 with D-axis not aligned to 16, the B-axis supports up to 128 only.

    • Supports N-axis values less than or equal to 256.

    • Supports S-axis values less than or equal to 20971520 (20M).

    • Supports D-axis values less than or equal to 512.

  • Quantization

    • int8 Input, int8 Output: Parameters deq_scale1, quant_scale1, deq_scale2, and quant_scale2 must all be provided. quant_offset2 is optional (default is 0 if not provided).

    • int8 Input, float16 Output: Parameters deq_scale1, quant_scale1, and deq_scale2 must all be provided. If quant_offset2 or quant_scale2 is provided (i.e., not null), it will result in an error.

    • float16 or bfloat16 Input, int8 Output: Parameter quant_scale2 must be provided. quant_offset2 is optional (default is 0 if not provided). If deq_scale1, quant_scale1, or deq_scale2 is provided (i.e., not null), it will result in an error.

    • int8 Output:

      • quant_scale2 and quant_offset2 in per-channel format do not support scenarios with left padding, Ring Attention, or non-32-byte aligned D-axis.

      • In GE mode: quant_scale2 and quant_offset2 in per-tensor format do not support scenarios with non-32-byte aligned D-axis.

      • Does not support sparse as band and pre_tokens/next_tokens being negative.

    • quant_scale2 and quant_offset2 can be bfloat16 only when query is bfloat16.

  • Other Usage Caveats:

    • N of parameter query must be equal to num_heads. N of parameter key and parameter value must be equal to num_key_value_heads.

    • num_heads must be divisible by num_key_value_heads and num_heads divided by num_key_value_heads can not be greater than 64.

    • When query dtype is bfloat16, D axis should align with 16.

    • Each element of actual_seq_lengths must not exceed q_S and element of actual_seq_lengths_kv must not exceed kv_S.

Warning

Only support on Atlas A2 training series.

Parameters
  • query (Tensor) – The query tensor with data type of int8, float16 or bfloat16. The shape is (B,qS,qH) / (B, q_N, q_S, q_D).

  • key (Tensor) – The key tensor with the same dtype as query. The shape is (B,kvS,kvH) / (B, kv_N, kv_S, kv_D).

  • value (Tensor) – The value tensor with the same dtype as query. The shape is (B,kvS,kvH) / (B, kv_N, kv_S, kv_D).

  • attn_mask (Tensor, optional) – For each element, 0/False indicates retention and 1/True indicates discard. If sparse_mode is 0 or 1: the shape is (qS,kvS) / (B,qS,kvS) / (1,qS,kvS) / (B,1,qS,kvS) / (1,1,qS,kvS). If sparse_mode is 2, 3 or 4, the shape is (2048,2048) / (1,2048,2048) / (1,1,2048,2048). Default: None.

  • actual_seq_lengths (Union[Tensor, tuple[int], list[int]], optional) – Describe actual sequence length of each batch of query with data type of int64. The shape is (B,) and every element should be positive integer. Default: None.

  • actual_seq_lengths_kv (Union[Tensor, tuple[int], list[int]], optional) – Describe actual sequence length of each batch of key or value with data type of int64. The shape is (B,) and every element should be positive integer. Default: None.

  • pse_shift (Tensor, optional) –

    The position encoding tensor with data type of float16 or bfloat16. Input tensor of shape (B,N,qS,kvS) / (1,N,qS,kvS). Default: None.

    • q_S must be greater than or equal to the query's S length, and kv_S must be greater than or equal to the key's S length.'

    • If pse_shift has dtype float16, query should have dtype float16 or int8, in which case high precision mode is enabled automatically.

    • If pse_shift has dtype bfloat16, query should have dtype bfloat16.

  • deq_scale1 (Tensor, optional) – Quantitative parametor, the tensor with data type of uint64 or float32. Input Tensor of shape (1,). Default: None.

  • quant_scale1 (Tensor, optional) – Quantitative parametor, the tensor with data type of float32. Input Tensor of shape (1,). Default: None.

  • deq_scale2 (Tensor, optional) – Quantitative parametor, input Tensor of shape (1,) and it has the same dtype as deq_scale1. Default: None.

  • quant_scale2 (Tensor, optional) – Quantitative parametor, the tensor with data type of float32 or bfloat16. The suggested shape is (1,) / (1,1,qH) / (qH,) when output layout is BSH, (1,) / (1,qN,1,D) / (qN,D)whenlayoutisBNSD.Default:.

  • quant_offset2 (Tensor, optional) – Quantitative parametor, the tensor with data type of float32 or bfloat16. It has the same dtype and shape as quant_scale2. Default: None.

  • num_heads (int, optional) – The number of heads. It is an integer in range [0, 256]. Default: 1.

  • scale_value (double, optional) – The scale value indicating the scale coefficient, which is used as the scalar of Muls in the calculation. Default: 1.0.

  • pre_tokens (int, optional) – For sparse cumputing, indicating the number of previous tokens Attention needs to associated with. Default: 2147483647.

  • next_tokens (int, optional) – For sparse cumputing, indicating the number of next tokens Attention needs to associated with. Default: 0.

  • input_layout (str, optional) – the data layout of the input qkv, support (BSH) and (BNSD). Default BSH.

  • num_key_value_heads (int, optional) – An int indicates head numbers of key/value which are used in GQA algorithm. The value 0 indicates if the key and value have the same head nums, use num_heads. It it is specified(not 0), it must be a factor of num_heads and it must be equal to kv_n. Default: 0.

  • sparse_mode (int, optional) –

    An int specifies sparse mode, can be int from {0, 1, 2, 3, 4}. Default: 0.

    • sparseMode = 0: If attn_mask is a null pointer, pre_tokens and next_tokens inputs are ignored (internally set to INT_MAX).

    • sparseMode = 2, 3, 4: attn_mask shape must be (S,S) or (1,S,S) or (1,1,S,S), with S fixed at 2048. User must ensure that attn_mask is lower triangular. If not provided or incorrect shape, it will result in an error.

    • sparseMode = 1, 2, 3: Ignores pre_tokens, next_tokens inputs and sets values according to specific rules.

    • sparseMode = 4: pre_tokens and next_tokens must be non-negative.

  • inner_precise (int, optional) – An int number from {0, 1} indicates computing mode. 0 for high precision mode for float16 dtype. 1 for high performance mode. Default: 1.

Returns

attention_out (Tensor) - Output tensor, has the same shape as query of (B,qS,qH) / (B,qN,qS,qD). Output dtype is determined by multiple factors, please refer to Note above for details.

Raises
  • TypeError – Dtype of query is not int8, float16 or bfloat16.

  • TypeErrorquery, key and value don't have the same dtype.

  • TypeError – Dtype of attn_mask is not bool, int8 or uint8.

  • TypeError – Dtype of pse_shift is not bfloat16 or float16.

  • TypeErrorscale_value is not a double number.

  • TypeErrorinput_layout is not a string.

  • TypeErrornum_key_value_heads is not an int.

  • TypeErrorsparse_mode is not an int.

  • TypeErrorsparse_inner_precisemode is not an int.

  • TypeErrorquant_scale1 is not Tensor of type float32.

  • TypeErrordeq_scale1 is not Tensor of type uint64 or float32.

  • TypeErrorquant_scale2 is not Tensor of type float32.

  • TypeErrordeq_scale2 is not Tensor of type uint64 or float32.

  • TypeErrorquant_offset2 is not Tensor of type float32.

  • ValueErrorinput_layout is a string but of (BSH) or (BNSD).

  • RuntimeErrornum_heads is not divisible by num_key_value_heads.

  • RuntimeErrornum_heads is not greater than 0.

  • RuntimeErrornum_key_value_heads is not greater than or equal to 0.

  • RuntimeError – kv_n is not equal to num_key_value_heads.

  • RuntimeErrorattn_mask shape is not valid.

  • RuntimeErrorsparse_mode is specified but is not 0, 1, 2, 3 or 4.

  • RuntimeErrorquery dtype is bfloat16 and D axis is not aligned with 16.

  • RuntimeErrorinput_layout is BSH and kv_h is not divisible by num_key_value_heads.

  • RuntimeError – D-axis of query, key and value is not the same.

  • RuntimeError – In post quant per-channel scenario, D-axis is not 32 Byte aligned.

Supported Platforms:

Ascend

Examples

>>> from mindspore import Tensor, ops
>>> import numpy as np
>>> B = 1
>>> N = 16
>>> S = 256
>>> D = 16
>>> query = Tensor(np.ones((B, N, S, D), dtype=np.float16))
>>> key = Tensor(np.ones((B, N, S, D), dtype=np.float16))
>>> value = Tensor(np.ones((B, N, S, D), dtype=np.float16))
>>> out = ops.prompt_flash_attention(query, key, value, None, None, None, None, None, None, None, None,
...                                  None, N, input_layout='BNSD')
>>> print(out.shape)
(1, 16, 256, 16)