Transformer之后:新架构探索
AI 导读
Transformer之后:新架构探索 Mamba、RWKV、Hyena、xLSTM:突破二次复杂度瓶颈的下一代序列模型 引言...
Transformer之后:新架构探索
Mamba、RWKV、Hyena、xLSTM:突破二次复杂度瓶颈的下一代序列模型
引言
Transformer自2017年问世以来统治了几乎所有序列建模任务。然而其核心机制——自注意力(Self-Attention)——的O(n^2)时间和空间复杂度成为处理超长序列的根本瓶颈。在128K甚至1M上下文长度的需求驱动下,一系列新架构正在挑战Transformer的统治地位。它们的共同目标是:在保持建模质量的前提下,实现O(n)或近似O(n)的复杂度。
Transformer的瓶颈
复杂度分析
Self-Attention复杂度
输入序列长度: n
隐藏维度: d
计算 Q, K, V: O(n × d²) -- 线性于n
注意力分数 QK^T: O(n² × d) -- 二次于n ← 瓶颈
Softmax + 加权: O(n² × d) -- 二次于n ← 瓶颈
输出投影: O(n × d²) -- 线性于n
KV Cache (推理):
存储: O(n × d) per layer per head
总KV Cache: O(n × d × L × H)
70B模型, 128K上下文: ~40GB KV Cache
问题本质:
n=1K → QK^T: 1M 次乘法 (可接受)
n=32K → QK^T: 1G 次乘法 (昂贵)
n=128K → QK^T: 16G 次乘法 (极昂贵)
n=1M → QK^T: 1T 次乘法 (不可行)
Mamba(状态空间模型)
核心原理
Mamba是基于结构化状态空间模型(Structured State Space Model, S4)的改进。其核心思想是将序列建模视为一个连续时间系统的离散化,通过状态空间实现线性复杂度的序列处理。
状态空间模型(SSM)数学框架
连续时间:
h'(t) = A h(t) + B x(t) -- 状态转移
y(t) = C h(t) + D x(t) -- 输出映射
离散化(Zero-Order Hold):
A_bar = exp(Δ A)
B_bar = (Δ A)^{-1} (exp(Δ A) - I) Δ B
h_k = A_bar h_{k-1} + B_bar x_k -- 递推(推理)
y_k = C h_k + D x_k
关键创新(Mamba的Selective Mechanism):
B, C, Δ 均为输入依赖(input-dependent)
→ 模型可以选择性地记住或遗忘信息
→ 等价于"数据驱动的门控"
Mamba架构实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelectiveSSM(nn.Module):
"""Simplified Mamba selective state space model block."""
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4,
expand: int = 2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_inner = d_model * expand
# Input projection
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
# 1D convolution
self.conv1d = nn.Conv1d(
self.d_inner, self.d_inner, d_conv,
padding=d_conv - 1, groups=self.d_inner
)
# SSM parameters (input-dependent)
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
# Fixed parameter A (initialized with HiPPO)
A = torch.arange(1, d_state + 1).float()
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [batch, seq_len, d_model]
Returns:
output: [batch, seq_len, d_model]
"""
batch, seq_len, _ = x.shape
# Input projection -> (z, x)
xz = self.in_proj(x) # [B, L, 2*d_inner]
x_branch, z = xz.chunk(2, dim=-1) # Each [B, L, d_inner]
# 1D causal convolution
x_branch = x_branch.transpose(1, 2) # [B, d_inner, L]
x_branch = self.conv1d(x_branch)[:, :, :seq_len]
x_branch = x_branch.transpose(1, 2) # [B, L, d_inner]
x_branch = F.silu(x_branch)
# Input-dependent SSM parameters
x_dbl = self.x_proj(x_branch) # [B, L, 2*d_state+1]
B = x_dbl[..., :self.d_state] # [B, L, N]
C = x_dbl[..., self.d_state:2*self.d_state] # [B, L, N]
delta = F.softplus(x_dbl[..., -1:]) # [B, L, 1]
# Discretize (simplified)
A = -torch.exp(self.A_log) # [N]
A_bar = torch.exp(delta * A) # [B, L, N]
B_bar = delta * B # [B, L, N]
# Selective scan (sequential for clarity; real impl uses parallel scan)
h = torch.zeros(batch, self.d_inner, self.d_state, device=x.device)
outputs = []
for t in range(seq_len):
h = A_bar[:, t].unsqueeze(1) * h + B_bar[:, t].unsqueeze(1) * x_branch[:, t].unsqueeze(-1)
y_t = (h * C[:, t].unsqueeze(1)).sum(-1) # [B, d_inner]
outputs.append(y_t)
y = torch.stack(outputs, dim=1) # [B, L, d_inner]
y = y + x_branch * self.D # Skip connection
# Gate and output
y = y * F.silu(z)
return self.out_proj(y)
Mamba的优势与局限
| 维度 | Mamba | Transformer |
|---|---|---|
| 训练复杂度 | O(n) | O(n^2) |
| 推理复杂度 | O(1) per step | O(n) per step (KV cache) |
| 长序列能力 | 天然支持 | 需要位置编码扩展 |
| 并行训练 | 需要parallel scan | 天然并行 |
| In-context learning | 较弱 | 强 |
| 精确检索 | 弱(信息压缩到固定状态) | 强(可回看所有token) |
RWKV:线性Transformer
架构特点
RWKV结合了RNN的高效推理和Transformer的并行训练,其核心是将注意力机制替换为线性递推:
RWKV核心机制:Time-Mixing与Channel-Mixing
Time-Mixing (替代Self-Attention):
r_t = W_r · (mu_r ⊙ x_t + (1-mu_r) ⊙ x_{t-1})
k_t = W_k · (mu_k ⊙ x_t + (1-mu_k) ⊙ x_{t-1})
v_t = W_v · (mu_v ⊙ x_t + (1-mu_v) ⊙ x_{t-1})
wkv_t = (sum_{i=1}^{t-1} e^{-(t-1-i)w+k_i} v_i + e^{u+k_t} v_t)
/ (sum_{i=1}^{t-1} e^{-(t-1-i)w+k_i} + e^{u+k_t})
o_t = W_o · (sigmoid(r_t) ⊙ wkv_t)
Channel-Mixing (替代FFN):
r_t = W_r · (mu_r ⊙ x_t + (1-mu_r) ⊙ x_{t-1})
k_t = W_k · (mu_k ⊙ x_t + (1-mu_k) ⊙ x_{t-1})
o_t = sigmoid(r_t) ⊙ (W_v · max(k_t, 0)²)
关键特性:
训练: 可展开为并行计算 (类似Transformer)
推理: 递推形式 (O(1) per step, 无KV Cache)
Hyena:隐式长卷积
Hyena用参数化的长卷积替代注意力机制,其核心是通过可学习的卷积核实现全局信息混合:
class HyenaOperator(nn.Module):
"""Simplified Hyena operator using implicit long convolution."""
def __init__(self, d_model: int, max_len: int = 8192, order: int = 2):
super().__init__()
self.order = order
self.d_model = d_model
# Short convolution for local patterns
self.short_conv = nn.Conv1d(d_model, d_model * (order + 1),
kernel_size=3, padding=1, groups=d_model)
# Implicit parametrization of long convolution
self.filter_fn = nn.Sequential(
nn.Linear(1, 64),
nn.SiLU(),
nn.Linear(64, 64),
nn.SiLU(),
nn.Linear(64, d_model),
)
# Position encoding for filter generation
t = torch.linspace(0, 1, max_len).unsqueeze(-1)
self.register_buffer("t", t)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: [batch, seq_len, d_model]"""
batch, seq_len, _ = x.shape
# Generate projections via short convolution
x_conv = self.short_conv(x.transpose(1, 2)) # [B, D*(order+1), L]
splits = x_conv.chunk(self.order + 1, dim=1)
v = splits[0].transpose(1, 2) # [B, L, D]
# Generate long convolution filter
h = self.filter_fn(self.t[:seq_len]) # [L, D]
# Apply Hyena recurrence
y = v
for i in range(self.order):
x_i = splits[i + 1].transpose(1, 2) # [B, L, D]
# Element-wise gating
y = y * x_i
# Long convolution via FFT
y = self._fft_conv(y, h)
return y
def _fft_conv(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
"""Causal convolution via FFT for O(n log n) complexity."""
L = x.shape[1]
# Pad to avoid circular convolution artifacts
x_padded = F.pad(x.transpose(1, 2), (L, 0))
h_padded = F.pad(h.T, (0, L))
X = torch.fft.rfft(x_padded, dim=-1)
H = torch.fft.rfft(h_padded, dim=-1)
Y = X * H
y = torch.fft.irfft(Y, dim=-1)[..., :L]
return y.transpose(1, 2)
xLSTM:LSTM的现代化重生
关键创新
xLSTM通过两个关键改进让传统LSTM重新具备竞争力:
xLSTM架构
sLSTM (Scalar LSTM with exponential gating):
└── 指数门控 (exp gating) 替代sigmoid
└── 更大的遗忘能力范围
└── 标量记忆单元
mLSTM (Matrix LSTM with matrix memory):
└── 矩阵记忆单元 (d_model × d_model)
└── 类似于线性注意力的存储机制
└── 可并行训练
xLSTM = Residual Blocks of [sLSTM | mLSTM]
复杂度对比:
sLSTM: O(n) 时间, O(d) 空间 (per step)
mLSTM: O(n) 时间, O(d²) 空间 (per step)
Transformer: O(n²) 时间, O(n×d) 空间 (KV cache)
混合架构:取长补短
Jamba / Mamba-2 / Zamba
最新趋势是将Attention和SSM混合使用,在保持线性复杂度优势的同时弥补SSM在精确检索上的短板:
混合架构模式
模式A: 交替层(Jamba风格)
Layer 1: Mamba Block
Layer 2: Mamba Block
Layer 3: Attention Block ← 每N层插入一个
Layer 4: Mamba Block
Layer 5: Mamba Block
Layer 6: Attention Block ← 每N层插入一个
...
Attention占比: 1/N (通常N=3-6)
模式B: 并行分支
Input → [Mamba Branch] → ┐
→ [Attention Branch] → ┤→ Merge → Output
两个分支处理不同类型的依赖
模式C: 级联(Mamba-Attention-Mamba)
Input → Mamba(local) → Attention(global) → Mamba(refine) → Output
架构对比总结
| 架构 | 训练复杂度 | 推理复杂度 | 长序列 | ICL | 精确检索 | 成熟度 |
|---|---|---|---|---|---|---|
| Transformer | O(n^2) | O(n) per step | 受限 | 强 | 强 | 生产级 |
| Mamba/SSM | O(n) | O(1) per step | 天然 | 中 | 弱 | 早期生产 |
| RWKV | O(n) | O(1) per step | 天然 | 中 | 弱 | 早期生产 |
| Hyena | O(n log n) | O(n) | 好 | 中 | 中 | 研究级 |
| xLSTM | O(n) | O(1) per step | 天然 | 中 | 中 | 研究级 |
| 混合(Jamba) | O(n) ~ O(n^2) | 接近O(1) | 好 | 强 | 强 | 早期生产 |
选型建议
| 场景 | 推荐架构 | 理由 |
|---|---|---|
| 通用NLP/对话 | Transformer | 成熟度最高,生态最好 |
| 超长文档处理 | Mamba/混合 | 线性复杂度,128K+无压力 |
| 流式音频/信号 | Mamba/RWKV | O(1)推理,实时处理 |
| 端侧部署 | RWKV/Mamba | 无KV Cache,内存友好 |
| 需要精确召回 | Transformer/混合 | 注意力机制的核心优势 |
| 研究探索 | Mamba-2/xLSTM | 最新架构,潜力最大 |
结论
Transformer的二次复杂度瓶颈催生了一系列创新架构。Mamba以选择性状态空间模型实现了线性复杂度和强大的建模能力,RWKV以线性注意力变体实现了RNN的高效推理和Transformer的并行训练,Hyena以隐式长卷积提供了另一种全局信息混合方案,xLSTM则证明了经典架构通过现代化改造仍有竞争力。然而,这些架构在in-context learning和精确信息检索方面尚未完全匹敌Transformer,这也是为什么混合架构(如Jamba)正在成为最务实的方向——在关键层保留注意力机制,其余层使用线性复杂度的替代方案。
Maurice | [email protected]
深度加工(NotebookLM 生成)
基于本文内容生成的 PPT 大纲、博客摘要、短视频脚本与 Deep Dive 播客,用于多场景复用
PPT 大纲(5-8 张幻灯片) 点击展开
Transformer之后:新架构探索 — ppt
幻灯片 1:引言——Transformer的瓶颈与新架构目标
- 尽管Transformer自2017年起统治了几乎所有序列建模任务,但其自注意力机制的 O(n²) 时间和空间复杂度成为了处理超长序列的根本瓶颈 [1]。
- 处理长上下文时(如128K甚至1M),Transformer的 KV Cache 占用极其庞大的内存,例如70B模型在128K上下文下推理需要约40GB缓存 [1]。
- 随着序列长度增长,注意力分数的乘法运算量呈指数级爆发,导致百万级长度的上下文计算几乎不可行 [1]。
- 为了应对这一挑战,一系列新架构的共同目标是:在保持建模质量的前提下,实现线性 O(n) 或近似的计算复杂度 [1]。
幻灯片 2:Mamba——具有选择性机制的状态空间模型
- Mamba 的核心思想是将序列建模视为连续时间系统的离散化,通过改进的结构化状态空间模型(SSM)实现线性复杂度的处理 [1]。
- 其最大的创新在于“选择性机制(Selective Mechanism)”,利用依赖于输入的参数,使模型能够像数据驱动的门控一样,选择性地记住或遗忘信息 [1]。
- Mamba 具备 O(n) 的训练复杂度和单步 O(1) 的推理复杂度,天然支持长序列且推理时无需 KV Cache [1, 2]。
- 其主要局限在于将信息压缩到固定状态中,因此在精确检索(回忆所有token)和上下文学习(In-context learning)能力上弱于 Transformer [2]。
幻灯片 3:RWKV 与 Hyena——线性Transformer与隐式长卷积
- RWKV(线性Transformer): 结合了 RNN 的高效推理(单步 O(1),无 KV Cache)与 Transformer 的并行训练优势 [2]。
- RWKV 核心通过 Time-Mixing 和 Channel-Mixing 机制,将传统的自注意力替换为线性递推计算 [2]。
- Hyena(隐式长卷积): 采用可学习的参数化长卷积核来替代注意力机制,从而实现全局信息的混合 [2]。
- Hyena 利用快速傅里叶变换(FFT)进行因果卷积计算,成功将时间复杂度控制在 O(n log n) 的高效水平 [2, 3]。
幻灯片 4:xLSTM——经典 LSTM 架构的现代化重生
- xLSTM 证明了经典的 LSTM 架构通过现代化改造,依然能在下一代序列模型竞争中占据一席之地 [3]。
- 核心创新之一(sLSTM): 引入指数门控(exponential gating)替代传统的 Sigmoid 门,获得更大的遗忘能力范围,并保留标量记忆单元 [3]。
- 核心创新之二(mLSTM): 引入矩阵记忆单元,构建了类似于线性注意力的存储机制,并且支持完全的并行训练 [3]。
- xLSTM 的整体时间复杂度保持在 O(n),推理复杂度为单步 O(1),成功克服了传统RNN难以并行训练和容量不足的问题 [3]。
幻灯片 5:混合架构(如 Jamba)——取长补短的务实方向
- 由于纯状态空间模型等新架构在上下文学习和精确信息检索上尚未完全匹敌 Transformer,混合架构成为了目前最务实的方向 [3]。
- 最新趋势(如 Jamba、Mamba-2、Zamba)将 Attention 和 SSM 混合使用,既保留了线性复杂度的优势,又弥补了精确检索的短板 [3]。
- 常见的混合模式包括交替层设计(例如每 N 层 SSM 插入一层 Attention)、并行分支处理以及级联结构 [3]。
- 混合架构在保持接近 O(1) 的高效推理复杂度的同时,具备了强大的长序列处理、强大的 ICL 和精确召回能力 [3]。
幻灯片 6:新架构选型建议与场景应用
- 通用NLP、对话与精确召回: 推荐使用技术成熟度最高、生态最好的 Transformer,或者保留了注意力机制优势的混合架构 [3]。
- 超长文档处理: 推荐采用 Mamba 或混合架构,凭借其线性复杂度特性,可以无压力处理 128K 及以上的超长序列 [3]。
- 流式音频/信号与端侧部署: 首选 Mamba 或 RWKV,因为它们具备单步 O(1) 的推理能力且无需极占内存的 KV Cache,对硬件更友好 [3]。
- 前沿研究探索: 最新推出的 Mamba-2 和 xLSTM 架构展现出了极其巨大的发展潜力,是目前研究级场景的核心探索方向 [3]。
博客摘要 + 核心看点 点击展开
Transformer之后:新架构探索 — summary
这是一份为您定制的 SEO 友好博客摘要及核心看点:
博客摘要(约 150 字)
Transformer的O(n²)计算复杂度已成为大模型处理超长上下文的根本瓶颈[1]。本文深度解析突破该瓶颈的下一代序列模型架构,全面剖析Mamba(状态空间模型)、RWKV、Hyena及xLSTM的核心技术与原理[1-3]。文章揭示了它们如何实现O(n)的线性计算复杂度,对比了各自在推理效率与长序列处理上的优劣,并探讨了兼顾效率与精确检索的混合架构(如Jamba)的最新发展趋势[3]。这是AI从业者与开发者探索前沿大模型架构的必读指南。
3 条核心看点(每条 < 40 字)
- 突破二次复杂度瓶颈:全面解析Mamba、RWKV等实现线性复杂度的新架构原理[1, 2]。
- 核心创新机制对比:详细剖析Mamba状态空间、RWKV线性递推及xLSTM的优劣与适用场景[1-3]。
- 混合架构成最新趋势:解读融合Attention与SSM的模型如何完美兼顾推理效率与精准检索[3]。
60 秒短视频脚本 点击展开
Transformer之后:新架构探索 — video
这是一段为您定制的 60 秒短视频脚本:
【钩子开场】
谁能打破Transformer算力瓶颈?
【核心解说】
**画面一:**自注意力存在二次复杂度瓶颈[1],处理超长序列极其昂贵[1]。
**画面二:**Mamba等架构实现线性复杂度[1],推理极快但精确检索较弱[2]。
**画面三:**混合架构成为新趋势[3],融合注意力机制兼顾速度与精确检索[3]。
【收束】
下一代大模型之战,才刚刚开始!
课后巩固
与本文内容匹配的闪卡与测验,帮助巩固所学知识
延伸阅读
根据本文主题,为你推荐相关的学习资料