mindspore.ops.prompt_flash_attention
- mindspore.ops.prompt_flash_attention(query, key, value, attn_mask=None, actual_seq_lengths=None, actual_seq_lengths_kv=None, pse_shift=None, deq_scale1=None, quant_scale1=None, deq_scale2=None, quant_scale2=None, quant_offset2=None, num_heads=1, scale_value=1.0, pre_tokens=2147483647, next_tokens=0, input_layout='BSH', num_key_value_heads=0, sparse_mode=0, inner_precise=1)[源代码]
全量推理场景接口。
B:batch维
N:注意力头数
S:序列长度
D:头维度
H:隐藏层大小
self-attention(自注意力)利用输入样本自身的关系构建了一种注意力模型。其原理是假设有一个长度为 \(n\) 的输入样本序列 \(x\) , \(x\) 的每个元素都是一个 \(d\) 维向量, 可以将每个 \(d\) 维向量看作一个token embedding,将这样一条序列经过3个权重矩阵变换得到3个维度为 \(n\times d\) 的矩阵。
self-attention的计算公式一般定义如下,
\[Attention(Q,K,V)=Softmax(\frac{QK^{T} }{\sqrt{d} } )V\]其中 \(Q\) 和 \(K^{T}\) 的乘积代表输入 \(x\) 的注意力,为避免该值变得过大,通常除以 \(d\) 的平方根进行缩放,并对每行进行softmax归一化,与 \(V\) 相乘后得到一个 \(n\times d\) 的矩阵。
警告
未来将不再支持 attn_mask 的float16数据类型。
当 sparse_mode 为2、3或4时, attn_mask 的shape必须为 \((2048, 2048)\) / \((B, 1, 2048, 2048)\) / \((1, 1, 2048, 2048)\) 。
说明
各轴的最大支持值:
支持B轴值小于等于65536(64k)。当输入类型包括D轴未对齐到32的int8或float16或bfloat16时,B轴支持最多为128。
支持N轴值小于等于256。
支持S轴值小于等于20971520(20M)。
支持D轴值小于等于512。
量化:
int8输入,int8输出:必须提供 deq_scale1、 quant_scale1、 deq_scale2 和 quant_scale2 参数。 quant_offset2 是可选的(如果未提供,默认值为0)。
int8输入,float16输出:必须提供 deq_scale1、 quant_scale1 和 deq_scale2 参数。如果提供 quant_offset2 或 quant_scale2,将导致错误。
float16或bfloat16输入,int8输出:必须提供 quant_scale2 参数。 quant_offset2 是可选的(如果未提供,默认值为0)。如果提供 deq_scale1、 quant_scale1 或 deq_scale2 参数,将导致错误。
int8输出:
per-channel格式的 quant_scale2 和 quant_offset2 不支持有左填充、Ring Attention或D轴未对齐到32字节的情况。
在GE模式中,per-tensor格式的 quant_scale2 和 quant_offset2 不支持D轴未对齐到32字节的情况。
不支持带有负值的带状稀疏和 pre_tokens/ next_tokens 。
quant_scale2 和 quant_offset2 只有在 query 为bfloat16时才可以是bfloat16。
其他使用注意事项:
query 参数的 N 必须等于 num_heads 。 key 和 value 参数的 N 必须等于 num_key_value_heads 。
num_heads 必须可以被 num_key_value_heads 整除,并且 num_heads 除以 num_key_value_heads 不得大于64。
当 query 数据类型为bfloat16时,D轴应与16对齐。
actual_seq_lengths 的每个元素不得超过q_S, actual_seq_lengths_kv 的每个元素不得超过kv_S。
警告
只支持 Atlas A2 训练系列产品。
- 参数:
query (Tensor) - 公式中的输入Q,数据类型可以是int8、float16或bfloat16。shape为 \((B, q_S, q_H)\) 或 \((B, q_N, q_S, q_D)\) 。
key (Tensor) - 公式中的输入K,数据类型与 query 相同。shape为 \((B, kv_S, kv_H)\) 或 \((B, kv_N, kv_S, kv_D)\) 。
value (Tensor) - 公式中的输入V,数据类型与 query 相同。shape为 \((B, kv_S, kv_H)\) 或 \((B, kv_N, kv_S, kv_D)\) 。
attn_mask (Tensor,可选) - 注意力掩码Tensor,数据类型为bool、int8、uint8或float16。每个元素中,0/False表示保留,1/True表示丢弃。如果 sparse_mode 为0或1,其shape可以是 \((q_S, kv_S)\)、 \((B, q_S, kv_S)\)、 \((1, q_S, kv_S)\)、 \((B, 1, q_S, kv_S)\) 或 \((1, 1, q_S, kv_S)\) 。如果 sparse_mode 为2、3或4,其shape应为 \((2048, 2048)\)、 \((1, 2048, 2048)\) 或 \((1, 1, 2048, 2048)\) 。默认值为
None
。actual_seq_lengths (Union[Tensor, tuple[int], list[int]],可选) - 描述 query 每个批次的实际序列长度,数据类型为int64。shape为 \((B,)\) ,每个元素应为正整数。默认值为
None
。actual_seq_lengths_kv (Union[Tensor, tuple[int], list[int]],可选) - 描述 key 或 value 每个批次的实际序列长度,数据类型为int64。shape为 \((B,)\) ,每个元素应为正整数。默认值为
None
。pse_shift (Tensor,可选) - 位置编码Tensor,数据类型为float16或bfloat16。输入Tensor shape为 \((B, N, q_S, kv_S)\) 或 \((1, N, q_S, kv_S)\) 。默认值为
None
。q_S必须大于等于query的S长度,kv_S必须大于等于key的S长度。
如果 pse_shift 的数据类型为float16, query 应为float16或int8,这种情况下会自动启用高精度模式。
如果 pse_shift 的数据类型为bfloat16, query 应为bfloat16。
deq_scale1 (Tensor,可选) - 量化参数,数据类型为uint64或float32。输入Tensor shape为 \((1,)\) 。默认值为
None
。quant_scale1 (Tensor,可选) - 量化参数,数据类型为float32。输入Tensor shape为 \((1,)\) 。默认值为
None
。deq_scale2 (Tensor,可选) - 量化参数,输入Tensor shape为 \((1,)\) 并与 deq_scale1 类型相同。默认值为
None
。quant_scale2 (Tensor,可选) - 量化参数,数据类型为float32或bfloat16。当输出layout为BSH时,建议shape为 \((1,)\)、 \((1, 1, q_H)\)、 \((q_H,)\) ;当layout为BNSD时,建议shape为 \((1,)\)、 \((1, q_N, 1, D)\)、 \((q_N, D)\) 。默认值为
None
。quant_offset2 (Tensor,可选) - 量化参数,数据类型为float32或bfloat16。数据类型和shape与 quant_scale2 相同。默认值为
None
。num_heads (int,可选) - 头的数量,范围为[0, 256]。默认值为
1
。scale_value (double,可选) - 表示缩放系数的值,用作计算中的乘数标量。默认值为
1.0
。pre_tokens (int,可选) - 用于稀疏计算,表示注意力需要关联的前序列元素个数。默认值为
2147483647
。next_tokens (int,可选) - 用于稀疏计算,表示注意力需要关联的后序列元素个数。默认值为
0
。input_layout (str,可选) - 输入qkv的数据layout,支持 BSH 和 BNSD 。默认值为
BSH
。num_key_value_heads (int,可选) - 一个整数,表示在GQA算法中 key 和 value 的头数量。0表示 key 和 value 具有与 query 相同的头数。如果指定(非0),其必须是 num_heads 的因子,并且等于kv_n。默认值为
0
。sparse_mode (int,可选) - 一个整数,指定稀疏模式,可以是 {0, 1, 2, 3, 4} 中的值。默认值为
0
。sparseMode = 0:如果 attn_mask 为空指针,则忽略 pre_tokens 和 next_tokens 输入(内部设置为INT_MAX)。
sparseMode = 2, 3, 4: attn_mask shape必须为 \((S, S)\)、 \((1, S, S)\) 或 \((1, 1, S, S)\) ,S固定为2048。用户必须确保 attn_mask 为下三角。shape不正确或未提供将导致错误。
sparseMode = 1, 2, 3:忽略 pre_tokens 和 next_tokens 输入,并根据特定规则设置值。
sparseMode = 4: pre_tokens 和 next_tokens 必须是非负的。
inner_precise (int,可选) - 一个 {0, 1} 中的整数,指定计算模式。
0
为高精度模式(适用于float16 数据类型),1
为高性能模式。默认值为1
。
- 返回:
attention_out (Tensor) - 输出Tensor,与 query 的shape相同: (B, q_S, q_H) 或 (B, q_N, q_S, q_D)。输出数据类型由多种因素决定,请参阅上面的Note部分获取详细信息。
- 异常:
TypeError - query 的数据类型不是int8、float16或bfloat16。
TypeError - query、 key 和 value 的数据类型不同。
TypeError - attn_mask 的数据类型不是bool、int8或uint8。
TypeError - pse_shift 的数据类型不是bfloat16或float16。
TypeError - scale_value 不是double类型。
TypeError - input_layout 不是字符串。
TypeError - num_key_value_heads 不是整数。
TypeError - sparse_mode 不是整数。
TypeError - inner_precise 不是整数。
TypeError - quant_scale1 不是float32类型的Tensor。
TypeError - deq_scale1 不是uint64或float32类型的Tensor。
TypeError - quant_scale2 不是float32类型的Tensor。
TypeError - deq_scale2 不是uint64或float32类型的Tensor。
TypeError - quant_offset2 不是float32类型的Tensor。
ValueError - input_layout 是字符串但不是BSH或BNSD。
RuntimeError - num_heads 不能被 num_key_value_heads 整除。
RuntimeError - num_heads 小于等于 0。
RuntimeError - num_key_value_heads 小于等于0。
RuntimeError - kv_n不等于 num_key_value_heads。
RuntimeError - attn_mask 的shape不合法。
RuntimeError - sparse_mode 被指定的值不是0、1、2、3或4。
RuntimeError - query 的数据类型为bfloat16并且D轴未对齐到16。
RuntimeError - 输入layout为BSH并且kv_h不能被 num_key_value_heads 整除。
RuntimeError - query、 key 和 value 的D轴不相同。
RuntimeError - 后量化per-channel情况下,D轴未对齐到32字节。
- 支持平台:
Ascend
样例:
>>> from mindspore import Tensor, ops >>> import numpy as np >>> B = 1 >>> N = 16 >>> S = 256 >>> D = 16 >>> query = Tensor(np.ones((B, N, S, D), dtype=np.float16)) >>> key = Tensor(np.ones((B, N, S, D), dtype=np.float16)) >>> value = Tensor(np.ones((B, N, S, D), dtype=np.float16)) >>> out = ops.prompt_flash_attention(query, key, value, None, None, None, None, None, None, None, None, ... None, N, input_layout='BNSD') >>> print(out.shape) (1, 16, 256, 16)