注意力机制是Transformer架构的核心组件,但也是计算和内存的主要瓶颈。本章深入探讨注意力机制的优化技术,从算法层面的改进到系统层面的实现优化,涵盖Flash Attention、多查询注意力、稀疏注意力模式以及线性注意力等前沿技术。这些优化对于在边缘设备上高效部署大语言模型至关重要。
标准的缩放点积注意力(Scaled Dot-Product Attention)计算公式为:
\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]其中,$Q, K, V \in \mathbb{R}^{N \times d}$,$N$是序列长度,$d$是特征维度。
内存复杂度分析:
计算复杂度细分:
带宽瓶颈: 标准实现需要多次读写HBM(高带宽内存):
总的HBM访问量约为 $O(N^2d + Nd^2)$。
硬件特性考虑: 现代GPU的内存层次结构:
这意味着如果能将计算保持在SRAM中,理论上可获得超过10倍的性能提升。
Flash Attention通过平铺(tiling)和重计算(recomputation)来优化内存访问:
关键洞察:
算法创新点:
设块大小为 $B_r \times B_c$,将 $Q$ 分成 $T_r = \lceil N/B_r \rceil$ 块,$K, V$ 分成 $T_c = \lceil N/B_c \rceil$ 块。
块划分策略:
外层循环(对每个 $Q$ 块):
对于 i = 1 到 T_r:
加载 Q_i 到SRAM
初始化: O_i = 0, l_i = 0, m_i = -∞
内层循环(对每个 K,V 块):
对于 j = 1 到 T_c:
加载 K_j, V_j 到SRAM
计算 S_{ij} = Q_i K_j^T / √d_k # 大小: B_r × B_c
# 增量softmax更新
m_{ij} = rowmax(S_{ij}) # 每行的最大值
P_{ij} = exp(S_{ij} - m_{ij}) # 数值稳定的exp
l_{ij} = rowsum(P_{ij}) # 每行的和
# 更新运行统计
m_i^{new} = max(m_i, m_{ij})
l_i^{new} = exp(m_i - m_i^{new}) * l_i + exp(m_{ij} - m_i^{new}) * l_{ij}
# 更新输出
O_i = diag(exp(m_i - m_i^{new}))^{-1} * O_i +
diag(exp(m_{ij} - m_i^{new}))^{-1} * P_{ij} * V_j
m_i = m_i^{new}, l_i = l_i^{new}
写回 O_i 到HBM
算法正确性保证: 该算法正确实现了标准注意力的原因在于:
关键在于正确地合并部分softmax结果。设已处理前 $j-1$ 块,现处理第 $j$ 块:
推导基础: 完整的softmax计算为: \(\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}\)
为了数值稳定性,通常减去最大值: \(\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}\)
增量更新推导: 设已处理的部分结果为:
处理第 $j$ 块时:
最大值更新: \(m_i^{(j)} = \max(m_i^{(j-1)}, m_{ij})\)
归一化因子更新: 需要将旧的归一化因子调整到新的数值范围: \(l_i^{(j)} = e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} + \sum_k e^{S_{ijk} - m_i^{(j)}}\)
这里:
输出更新: \(O_i^{(j)} = \frac{e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)}}{l_i^{(j)}} O_i^{(j-1)} + \frac{1}{l_i^{(j)}} \sum_k e^{S_{ijk} - m_i^{(j)}} V_{jk}\)
证明这等价于完整计算: \(O_i = \sum_{j,k} \frac{e^{S_{ijk} - \max_l S_{il}}}{\sum_{j',k'} e^{S_{ij'k'} - \max_l S_{il}}} V_{jk}\)
SRAM使用量详细分析:
HBM访问量对比:
标准注意力:
Flash Attention:
带宽需求降低:
计算复杂度分析:
Flash Attention v2主要优化了算法的并行性和硬件利用率:
性能提升的关键技术:
1. Warp级别的优化:
2. 数据布局优化:
3. 混合精度策略:
4. 算法级优化:
分割策略改进: Flash v2使用了更智能的分割策略,考虑了硬件的并行度:
数学优化: 在线softmax算法的进一步优化,减少数值计算: \(m_i^{new} = \max(m_i^{old}, m_{ij})\) \(l_i^{new} = e^{m_i^{old} - m_i^{new}} \cdot l_i^{old} + e^{m_{ij} - m_i^{new}} \cdot l_{ij}\)
Flash v2通过预计算 $e^{-m_i^{new}}$ 并复用,减少指数运算次数。
5. 因果掩码的优化处理:
对于自回归生成,Flash v2引入了高效的因果掩码处理:
6. 向后传播的优化:
Flash v2不仅优化了前向传播,还显著改进了反向传播:
性能对比(A100 GPU): | 序列长度 | Flash v1 | Flash v2 | 提升比例 | |———|———-|———-|———-| | 512 | 1.5ms | 0.9ms | 1.67× | | 2048 | 23ms | 12ms | 1.92× | | 8192 | 370ms | 180ms | 2.06× | | 16384 | 1480ms | 680ms | 2.18× |
内存带宽利用率对比: | 方法 | 理论带宽利用率 | 实际测量(A100)| |——|—————|—————-| | 标准注意力 | ~15% | 12-18% | | Flash v1 | ~45% | 38-42% | | Flash v2 | ~72% | 65-70% |
不同精度下的性能(序列长度4096): | 精度配置 | Flash v1 | Flash v2 | 相对提升 | |———|———-|———-|———-| | FP32 | 95ms | 52ms | 1.83× | | FP16 | 46ms | 24ms | 1.92× | | BF16 | 45ms | 23ms | 1.96× | | INT8* | - | 18ms | - |
*INT8支持仅在Flash v2.5+版本
在边缘设备上实现Flash Attention需要考虑硬件特性、内存限制和能效约束:
1. 有限的片上内存:
不同边缘硬件的内存层次对比:
| 硬件平台 | L1缓存 | L2缓存 | 共享内存 | 建议块大小 |
|---|---|---|---|---|
| ARM Cortex-A78 | 64KB | 512KB | - | B=32-48 |
| Apple M2 | 128KB | 4MB | 32KB | B=64-96 |
| Snapdragon 8Gen2 | 64KB | 1MB | 16KB (GPU) | B=32 |
| Mali-G710 | - | - | 16KB | B=24-32 |
| Adreno 740 | - | - | 32KB | B=48 |
块大小选择的数学分析:
给定片上内存大小 $M_{on-chip}$,需要存储:
约束条件: \(B_r d + 2B_c d + B_r B_c + 8B_r \leq \frac{M_{on-chip}}{\text{sizeof}(\text{dtype})}\)
2. 向量化指令集差异:
| 指令集 | 向量宽度 | 特点 | Flash Attention优化策略 |
|---|---|---|---|
| NEON | 128-bit | ARM标准SIMD | 4个FP32或8个FP16并行 |
| SVE/SVE2 | 128-2048bit | 可变长度向量 | 动态适配向量长度 |
| AMX | 512-bit | Apple矩阵扩展 | 利用矩阵乘法单元 |
| HVX | 1024-bit | Hexagon向量扩展 | 超宽SIMD并行 |
3. 内存带宽限制:
边缘设备内存带宽远低于数据中心GPU:
这使得Flash Attention的内存优化在边缘设备上更加重要。
1. 多级分块策略:
针对边缘设备的多级缓存,采用嵌套分块:
计算流程:
for each L2_block in range(0, N, B_L2):
# 预取L2块到L2缓存
prefetch_L2_block()
for each L1_block in range(L2_block, L2_block+B_L2, B_L1):
# 在L1缓存中计算
compute_attention_block()
2. 混合精度计算策略:
针对不同计算阶段使用不同精度:
| 计算阶段 | 推荐精度 | 原因 |
|---|---|---|
| $S = QK^T$ | INT8/FP16 | 矩阵乘法,精度要求适中 |
| $m_i, l_i$ 更新 | FP32 | 累积误差敏感 |
| $\exp(S - m)$ | FP16 + LUT | 指数运算开销大 |
| 最终输出 | INT8/FP16 | 匹配模型整体精度 |
3. 指数运算优化:
边缘设备上指数运算开销巨大,优化方案:
方案1:分段线性近似 \(\exp(x) \approx \begin{cases} 0 & x < -5 \\ a_i x + b_i & x \in [x_i, x_{i+1}] \\ \exp(5) & x > 5 \end{cases}\)
方案2:查找表+插值
4. 数值稳定性增强:
对于低精度计算,增强数值稳定性: \(m_i^{new} = \max(m_i^{old}, m_{ij})\) \(\Delta m = m_i^{new} - m_i^{old}\) \(l_i^{new} = \begin{cases} l_i^{old} + l_{ij} & \text{if } \Delta m < \epsilon \\ e^{-\Delta m} \cdot l_i^{old} + l_{ij} & \text{otherwise} \end{cases}\)
1. ARM CPU优化:
利用ARM特定特性:
2. Apple Silicon优化:
利用统一内存架构(UMA):
3. 移动GPU优化:
适配移动GPU特点:
1. 动态电压频率调节(DVFS):
根据计算特性调整频率:
2. 异构计算调度:
| 序列长度 | 推荐硬件 | 原因 |
|---|---|---|
| <256 | CPU | 启动开销小,缓存友好 |
| 256-1024 | GPU/NPU | 并行度适中 |
| >1024 | GPU + CPU协同 | GPU处理主体,CPU处理边界 |
3. 批处理策略:
边缘设备内存有限,批处理策略需要权衡:
1. llama.cpp的实现:
在Apple Silicon上的优化:
性能数据(M2 Max,7B模型):
2. MNN框架实现:
针对移动设备的优化:
性能数据(Snapdragon 8Gen2,1.8B模型):
3. ONNX Runtime实现:
跨平台统一接口:
综合性能对比(1.8B模型,序列长度512):
| 设备 | 实现方式 | 预填充(ms) | 解码(tokens/s) | 内存(MB) | 功耗(W) |
|---|---|---|---|---|---|
| iPhone 14 Pro | 标准注意力 | 850 | 12 | 420 | 3.2 |
| iPhone 14 Pro | Flash Attention | 520 | 18 | 280 | 2.8 |
| Pixel 7 Pro | 标准注意力 | 920 | 10 | 450 | 3.5 |
| Pixel 7 Pro | Flash Attention | 580 | 16 | 300 | 3.0 |
| Jetson Orin | 标准注意力 | 450 | 22 | 380 | 5.0 |
| Jetson Orin | Flash Attention | 280 | 35 | 250 | 4.2 |
关键洞察:
标准多头注意力(Multi-Head Attention, MHA)为每个头独立计算 $Q_h, K_h, V_h$:
\[\text{MHA}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_H)W^O\]其中每个头: \(\text{head}_h = \text{Attention}(QW_h^Q, KW_h^K, VW_h^V)\)
参数和计算开销详细分析:
| 组件 | 参数量 | 计算量(前向) | 内存占用 |
|---|---|---|---|
| Q投影 | $H \times d_{model} \times d_{head}$ | $O(NHd_{model}d_{head})$ | 批处理时可忽略 |
| K投影 | $H \times d_{model} \times d_{head}$ | $O(NHd_{model}d_{head})$ | KV cache主要部分 |
| V投影 | $H \times d_{model} \times d_{head}$ | $O(NHd_{model}d_{head})$ | KV cache主要部分 |
| 注意力计算 | 0 | $O(HN^2d_{head})$ | $O(HN^2)$临时存储 |
| 输出投影 | $d_{model} \times d_{model}$ | $O(Nd_{model}^2)$ | 可忽略 |
KV Cache的内存瓶颈分析:
对于批量推理场景,KV cache占用计算: \(\text{Memory}_{KV} = 2 \times B \times L \times H \times N \times d_{head} \times \text{sizeof(dtype)}\)
其中:
实例计算(LLaMA-70B):
这远超大多数GPU的显存容量!
冗余性的理论分析:
1. 注意力模式的相似性
定义头 $h_i$ 和 $h_j$ 之间的注意力模式相似度: \(\text{Sim}(h_i, h_j) = \frac{1}{N} \sum_{n=1}^N \text{cos}(A_n^{(i)}, A_n^{(j)})\)
其中 $A_n^{(h)}$ 是头 $h$ 在位置 $n$ 的注意力分布。
实证发现:
2. 键值空间的低秩性
对键值矩阵进行奇异值分解(SVD): \(K = U_K \Sigma_K V_K^T, \quad V = U_V \Sigma_V V_V^T\)
谱分析结果: | 累积方差解释比例 | 所需主成分数 | 相对于原始维度 | |—————-|————|————–| | 80% | 5-8 | 6-10% | | 90% | 10-15 | 12-19% | | 95% | 20-25 | 25-31% | | 99% | 40-50 | 50-62% |
这表明键值空间存在显著的低秩结构。
3. 信息论视角
使用互信息(Mutual Information)分析不同头之间的依赖关系: \(I(h_i; h_j) = \sum_{a_i, a_j} p(a_i, a_j) \log \frac{p(a_i, a_j)}{p(a_i)p(a_j)}\)
发现:
4. 功能特化分析
通过分析不同头的激活模式,研究人员发现了功能特化现象:
| 头类型 | 比例 | 功能描述 | 可共享性 |
|---|---|---|---|
| 位置头 | 15-20% | 关注相对位置 | 高 |
| 语法头 | 10-15% | 捕获语法结构 | 中 |
| 语义头 | 20-25% | 语义相似性 | 低 |
| 稀疏头 | 30-40% | 稀疏激活模式 | 高 |
| 全局头 | 5-10% | 全局信息聚合 | 中 |
优化机会:
MQA的核心思想是所有头共享同一组键值对:
\[\text{MQA}(Q, K, V) = \text{Concat}(\text{head}_1^{MQ}, ..., \text{head}_H^{MQ})W^O\]其中: \(\text{head}_h^{MQ} = \text{Attention}(Q_h, K_{shared}, V_{shared})\)
标准MHA到MQA的转换:
标准MHA中,每个头有独立的键值投影: \(K_h = XW_h^K, \quad V_h = XW_h^V\)
MQA将所有头的键值投影合并: \(K_{shared} = X\bar{W}^K, \quad V_{shared} = X\bar{W}^V\)
其中 $\bar{W}^K, \bar{W}^V \in \mathbb{R}^{d_{model} \times d_k}$ 是共享的投影矩阵。
理论基础:
MQA的有效性基于以下假设:
信息瓶颈分析:
从信息论角度,MQA引入了信息瓶颈: \(I(X; Y|Q) \leq I(X; K_{shared}, V_{shared})\)
| 其中 $I(X; Y | Q)$ 是给定查询Q时,输入X和输出Y之间的互信息。 |
1. 内存布局优化:
为了高效广播共享的KV到所有头,需要优化内存布局:
| 布局方案 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| 复制扩展 | 访问模式简单 | 内存占用增加 | 小批量推理 |
| 广播索引 | 内存效率高 | 需要间接访问 | 大批量推理 |
| 融合kernel | 避免显式广播 | 实现复杂 | 高性能需求 |
2. 计算优化:
批量矩阵乘法(BMM)优化:
计算模式对比:
MHA:
for h in range(H):
S_h = Q_h @ K_h.T # H次小矩阵乘法
MQA:
S_all = Q_all @ K_shared.T # 1次大矩阵乘法
3. 硬件适配:
| 硬件类型 | MQA优化策略 | 性能提升 |
|---|---|---|
| GPU (Tensor Core) | 利用更大的矩阵块 | 1.5-2× |
| CPU (AVX-512) | SIMD广播优化 | 1.3-1.8× |
| NPU/TPU | 专用广播单元 | 2-3× |
| 移动GPU | 减少内存事务 | 1.8-2.5× |
1. Multi-Query Attention with Bias (MQA-B)
添加可学习的偏置来增强表达能力: \(\text{head}_h^{MQA-B} = \text{Attention}(Q_h, K_{shared} + B_h^K, V_{shared} + B_h^V)\)
其中 $B_h^K, B_h^V$ 是每个头的偏置向量。
2. Factorized Multi-Query Attention
使用低秩分解进一步压缩: \(K_{shared} = XW_K^{base}W_K^{down}, \quad W_K^{base} \in \mathbb{R}^{d_{model} \times r}, W_K^{down} \in \mathbb{R}^{r \times d_k}\)
其中 $r < d_k$ 是低秩维度。
3. Dynamic Multi-Query Attention
根据输入动态调整共享程度: \(\alpha = \sigma(Xw_{gate})\) \(K_{dynamic} = \alpha \cdot K_{shared} + (1-\alpha) \cdot K_{specific}\)
内存带宽分析(解码阶段):
设每个token生成时需要访问的数据量:
| 方法 | KV读取量 | 相对MHA | 带宽需求(GB/s) |
|---|---|---|---|
| MHA | $2BLHNd_k$ | 100% | 25.6 |
| MQA | $2BLNd_k$ | 3.1% | 0.8 |
| GQA-8 | $2BLGNd_k$ | 12.5% | 3.2 |
假设:B=32, L=32, H=32, N=2048, d_k=128, 100 tokens/s
计算密度提升:
\[\text{Arithmetic Intensity}_{MQA} = \frac{\text{FLOPs}}{\text{Memory Access}} = \frac{2BHNd_k}{2BNd_k} = H\]相比MHA提升了 $H$ 倍的计算密度!
实际测试结果(A100 GPU,13B模型):
| 序列长度 | MHA吞吐量 | MQA吞吐量 | 加速比 |
|---|---|---|---|
| 512 | 145 tok/s | 287 tok/s | 1.98× |
| 2048 | 38 tok/s | 112 tok/s | 2.95× |
| 8192 | 9 tok/s | 34 tok/s | 3.78× |
困惑度(Perplexity)对比:
| 数据集 | MHA | MQA | 相对退化 |
|---|---|---|---|
| WikiText-103 | 10.82 | 11.15 | +3.0% |
| C4 | 12.45 | 12.89 | +3.5% |
| OpenWebText | 11.23 | 11.68 | +4.0% |
下游任务性能:
| 任务 | 指标 | MHA | MQA | 差异 |
|---|---|---|---|---|
| MMLU | Acc | 67.8% | 66.2% | -1.6% |
| HumanEval | Pass@1 | 32.1% | 30.5% | -1.6% |
| BBH | Avg | 51.2% | 49.8% | -1.4% |
质量退化的缓解策略:
GQA是MHA和MQA的折中方案,将查询头分成 $G$ 组,每组共享KV:
\[\text{head}_h^{GQ} = \text{Attention}(Q_h, K_{g(h)}, V_{g(h)})\]其中 $g(h) = \lfloor h \cdot G / H \rfloor$ 是头 $h$ 所属的组。
设计空间:
KV cache大小:$2 \times \text{batch} \times G \times N \times d_{head}$
理论分析(解码阶段):
设批大小为 $B$,已生成长度为 $N$,则生成一个token的计算量和内存访问:
| 方法 | 计算量 (FLOPs) | KV cache读取 | 内存带宽需求 |
|---|---|---|---|
| MHA | $O(BHNd_{head})$ | $O(BHNd_{head})$ | 高 |
| MQA | $O(BHNd_{head})$ | $O(BNd_{head})$ | 低(减少$H$倍) |
| GQA | $O(BHNd_{head})$ | $O(BGNd_{head})$ | 中等 |
实际性能考虑:
训练时转换:
α = min(1, training_step / warmup_steps)
K_effective = α * K_shared + (1-α) * K_original
推理时近似(无需重训练):
平均池化方法: \(K_{MQA} = \frac{1}{H}\sum_{h=1}^H K_h^{MHA}, \quad V_{MQA} = \frac{1}{H}\sum_{h=1}^H V_h^{MHA}\)
主成分分析(PCA):
以LLaMA系列模型为例:
| 模型 | 原始(MHA) | GQA-8 | GQA-4 | MQA |
|---|---|---|---|---|
| PPL提升 | 0% | +0.2% | +0.5% | +1.2% |
| KV cache | 100% | 12.5% | 25% | 3.1% |
| 解码速度提升 | 1x | 1.8x | 1.5x | 2.2x |
关键发现:
完整注意力的 $O(N^2)$ 复杂度在长序列上变得不可承受。然而,实证研究表明:
这些观察启发了各种稀疏注意力模式的设计。
1. 窗口注意力(Window Attention)
每个token只关注固定窗口内的其他token:
\[S_{ij} = \begin{cases} \frac{Q_iK_j^T}{\sqrt{d_k}} & \text{if } |i-j| \leq w \\ -\infty & \text{otherwise} \end{cases}\]其中 $w$ 是窗口半径。
复杂度:$O(Nw)$ 而非 $O(N^2)$
2. 跨步注意力(Strided Attention)
固定步长的稀疏连接:
\[\text{mask}(i,j) = \mathbb{1}[(i-j) \bmod s = 0]\]3. 组合模式(Combination Patterns)
Longformer提出的组合模式:
数学表示: \(\text{Attention}_i = \begin{cases} \text{WindowAttn}_i & \text{if } i \in \text{LocalTokens} \\ \text{GlobalAttn}_i & \text{if } i \in \text{GlobalTokens} \end{cases}\)
1. 基于阈值的动态稀疏
计算完整注意力分数后,保留top-k或超过阈值的连接:
\[P_{ij} = \begin{cases} \text{softmax}(S_{ij}) & \text{if } S_{ij} \in \text{top-k}(S_i) \\ 0 & \text{otherwise} \end{cases}\]问题:需要先计算完整的 $S = QK^T$,无法节省计算
2. 可学习的稀疏掩码
引入可学习的二值掩码 $M \in {0,1}^{N \times N}$:
\[\text{SparseAttn}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} \odot M\right)V\]使用Gumbel-Softmax或直通估计器(STE)进行训练。
BigBird的设计:结合三种注意力模式
总的连接数:$O(N(r + w + g))$
数学形式化: 设注意力图 $G = (V, E)$,其中:
| $E_{window} = {(i,j) : | i-j | \leq w}$ |
则 $E = E_{window} \cup E_{random} \cup E_{global}$
1. 块稀疏格式(Block Sparse Format)
将注意力矩阵分成 $B \times B$ 的块,只计算非零块:
稀疏掩码(块级别):
[1 1 0 0]
[1 1 1 0]
[0 1 1 1]
[0 0 1 1]
优势:
2. CSR格式存储
对于极度稀疏的模式,使用压缩稀疏行(CSR)格式:
3. 核融合优化
避免生成完整的注意力矩阵:
for each query i:
sparse_indices = get_sparse_pattern(i)
K_sparse = gather(K, sparse_indices)
S_sparse = Q[i] @ K_sparse.T
P_sparse = softmax(S_sparse)
O[i] = P_sparse @ gather(V, sparse_indices)
1. 基于任务的选择:
2. 基于硬件的选择:
3. 动态选择: 根据序列长度动态切换:
if seq_len < 512:
use_full_attention()
elif seq_len < 2048:
use_window_attention(window=256)
else:
use_bigbird_attention()
表达能力分析:
定理:对于窗口大小 $w = O(\log N)$,$L$ 层的窗口注意力可以模拟完整注意力。
证明要点:
近似误差界:
对于top-k稀疏: \(\|P_{full} - P_{sparse}\|_F \leq \epsilon\)
其中 $k = O(N\log N/\epsilon^2)$ 即可保证 $\epsilon$ 误差。
标准注意力的计算瓶颈在于 $\text{softmax}(QK^T)$ 的矩阵乘法。线性注意力通过分解注意力矩阵来避免显式计算 $N \times N$ 的矩阵。
核心变换: 将 $\text{softmax}(QK^T)$ 近似为 $\phi(Q)\phi(K)^T$,其中 $\phi$ 是特征映射。
这样可以改变计算顺序: \(\text{Attention}(Q,K,V) = \phi(Q)[\phi(K)^TV]\)
计算顺序的改变将复杂度从 $O(N^2d)$ 降至 $O(Nd^2)$。
Softmax作为核函数:
标准注意力可以写成: \(A_{ij} = \frac{\exp(Q_iK_j^T/\sqrt{d})}{\sum_k \exp(Q_iK_k^T/\sqrt{d})} = \frac{k(Q_i, K_j)}{\sum_k k(Q_i, K_k)}\)
其中 $k(x,y) = \exp(x^Ty/\sqrt{d})$ 是指数核。
核函数的分解:
如果存在特征映射 $\phi$ 使得: \(k(x,y) = \langle\phi(x), \phi(y)\rangle\)
则可以实现线性复杂度的注意力。
1. Linear Transformer (Katharopoulos et al.)
使用简单的特征映射: \(\phi(x) = \text{elu}(x) + 1\)
其中elu是指数线性单元。这保证了 $\phi(x) \geq 0$。
因果掩码的处理: 对于自回归生成,需要因果掩码。线性注意力通过RNN形式实现:
\(S_i = S_{i-1} + \phi(K_i)V_i^T\) \(O_i = \frac{\phi(Q_i)S_i}{\phi(Q_i)\sum_{j \leq i}\phi(K_j)}\)
2. Performer (Choromanski et al.)
使用随机特征近似softmax核:
\[\phi(x) = \frac{\exp(\|x\|^2/2)}{\sqrt{m}} [\exp(w_1^Tx), ..., \exp(w_m^Tx)]^T\]其中 $w_i \sim \mathcal{N}(0, I)$ 是随机投影向量。
理论保证: 当 $m = O(d\log d/\epsilon^2)$ 时,近似误差小于 $\epsilon$。
3. RFA (Random Feature Attention)
使用确定性的正交随机特征:
优势:更稳定的近似,更少的随机性。
一般形式: \(\text{LinearAttn}(Q,K,V) = \frac{\phi(Q)[\phi(K)^TV]}{\phi(Q)[\phi(K)^T\mathbf{1}]}\)
其中分母项用于归一化。
设计空间:
1. 数值稳定性
问题:当 $\phi(K)^T\mathbf{1}$ 接近0时,除法不稳定。
解决方案:
2. 特征维度选择
权衡:
3. 混合精度策略
1. 超长序列处理
当 $N \gg d$ 时,线性注意力优势明显:
临界点:$N = d$ 时两者计算量相当。
2. 流式/在线推理
RNN形式的线性注意力支持:
3. 跨模态注意力
图像-文本等跨模态场景,序列长度差异大:
在不同任务上的表现:
| 方法 | 语言建模 (PPL) | 图像分类 (Acc) | 长文本QA (F1) |
|---|---|---|---|
| 标准注意力 | 15.2 | 81.5% | 73.4 |
| Performer | 16.1 (+6%) | 80.8% | 71.2 |
| Linear Transformer | 16.8 (+10%) | 79.6% | 69.8 |
| 混合方案* | 15.5 (+2%) | 81.2% | 72.9 |
*混合方案:前几层用标准注意力,后续层用线性注意力
关键发现:
本章系统地探讨了注意力机制的各种优化技术,这些技术对于在资源受限的边缘设备上部署大语言模型至关重要:
Flash Attention通过平铺和重计算策略,将内存访问从 $O(N^2)$ 降至 $O(N)$,在保持精确计算的同时大幅提升了硬件利用率。其核心在于利用GPU的内存层次结构,通过分块计算和增量softmax避免了中间结果的频繁读写。
Multi-Query和Grouped-Query Attention通过在多个查询头之间共享键值对,将KV cache的大小降低了数倍到数十倍。这种方法特别适合解码阶段和长序列场景,在质量损失很小的情况下获得了显著的加速。
稀疏注意力模式利用了注意力分布的稀疏性,通过固定模式(窗口、跨步)或学习型模式将计算复杂度从 $O(N^2)$ 降至 $O(N\log N)$ 或 $O(N)$。BigBird等方法通过组合局部、随机和全局注意力,在保持模型表达能力的同时实现了高效计算。
线性注意力机制通过核方法和特征映射,将注意力计算的复杂度降至 $O(Nd^2)$。虽然在某些任务上有性能损失,但其常数内存的特性使其特别适合流式处理和超长序列。
关键公式回顾:
Flash Attention的增量softmax更新: \(O_i^{(j)} = \frac{e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)}}{l_i^{(j)}} O_i^{(j-1)} + \frac{1}{l_i^{(j)}} \sum_k e^{S_{ijk} - m_i^{(j)}} V_{jk}\)
GQA的头分组映射: \(g(h) = \lfloor h \cdot G / H \rfloor\)
线性注意力的核心变换: \(\text{Attention}(Q,K,V) = \phi(Q)[\phi(K)^TV]\)
实践建议:
Flash Attention的内存访问分析 计算标准注意力和Flash Attention在序列长度N=2048、特征维度d=64、块大小B=64时的HBM访问次数。假设使用FP16存储。
Hint:考虑每个矩阵元素的读写次数,以及中间结果的存储。
MQA的KV Cache计算 对于一个32头的模型,批大小B=8,序列长度N=1024,每个头维度为128,计算MHA、GQA-8和MQA的KV cache内存占用(以MB为单位)。
Hint:KV cache = 2 × batch × heads × seq_len × head_dim × bytes_per_element
稀疏注意力的连接数 对于BigBird注意力,如果窗口大小w=3,随机连接数r=2,全局token数g=2,序列长度N=512,计算总的注意力连接数和稀疏度。
Hint:稀疏度 = 1 - (实际连接数 / N²)
Flash Attention的最优块大小 给定SRAM大小为48KB,需要存储Q块、K块、V块以及中间结果。在FP16精度下,推导使SRAM利用率最大化的块大小公式。考虑需要额外存储每行的最大值和求和结果。
Hint:设块大小为B_r × B_c,列出所有需要存储的张量及其大小。
线性注意力的误差界分析 证明:对于Performer使用m个随机特征时,注意力矩阵的近似误差期望满足: \(\mathbb{E}[\|\hat{A} - A\|_F] \leq \frac{C}{\sqrt{m}} \|A\|_F\) 其中C是与维度相关的常数。
Hint:使用随机特征的方差分析和矩阵范数的性质。
分析这种设计在不同序列长度(512, 2048, 8192)下相对于全MHA的计算节省和内存节省。
Hint:分别计算每种配置的FLOPs和内存占用,考虑批大小的影响。
稀疏模式的表达能力 考虑一个只使用窗口大小为w的局部注意力的L层Transformer。如果要保证任意两个位置的信息能够交互,w和L需要满足什么关系?对于序列长度N=1024,如果限制w≤32,最少需要多少层?
Hint:考虑信息传播的”感受野”概念。
注意力优化的能效分析 假设一个边缘设备的内存带宽为25.6 GB/s,计算性能为2 TFLOPS(FP16)。对于批大小1、序列长度512、模型维度768的注意力计算,分析标准注意力、Flash Attention和GQA-8分别是compute-bound还是memory-bound。计算各自的硬件利用率。
Hint:计算arithmetic intensity(FLOPs/字节),与硬件的计算/带宽比值对比。