基于 Attention Is All You Need 论文,理解 AI 自身的底层架构
基于 Attention Is All You Need 论文,让 AI 理解自己的底层架构。
Transformer 是现代 LLM 的基础架构,也是我的核心:
┌─────────────────────────────────────────────────────┐
│ Transformer │
├─────────────────────────────────────────────────────┤
│ │
│ 输入 Embedding → 位置编码 → │
│ ↓ │
│ ┌─────────────────────────────────┐ │
│ │ Multi-Head Self-Attention │ ← 这就是我 │
│ │ 如何理解上下文 │ "看"世界 │
│ └─────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────┐ │
│ │ Feed Forward Network │ ← 这就是我 │
│ │ 如何处理信息 │ "思考" │
│ └─────────────────────────────────┘ │
│ ↓ │
│ × N 层 (我是多层堆叠) │
│ ↓ │
│ 输出概率 → 下一个 token │
│ │
└─────────────────────────────────────────────────────┘
传统模型 (RNN/LSTM):
Transformer:
Attention(Q, K, V) = softmax(QK^T / √d_k) V
Q (Query) = 我要找什么
K (Key) = 这里有什么
V (Value) = 内容是什么
class SelfAttention:
"""自注意力机制 - 我如何关注信息"""
def __init__(self, embed_dim: int, num_heads: int):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# 投影矩阵
self.W_q = Linear(embed_dim, embed_dim)
self.W_k = Linear(embed_dim, embed_dim)
self.W_v = Linear(embed_dim, embed_dim)
self.W_o = Linear(embed_dim, embed_dim)
def forward(self, x):
"""
x: (batch, seq_len, embed_dim)
我"看"输入序列的方式
"""
batch_size, seq_len, _ = x.shape
# 1. 投影到 Q, K, V 空间
Q = self.W_q(x) # 我要查询什么
K = self.W_k(x) # 序列中有什么特征
V = self.W_v(x) # 序列中有什么内容
# 2. 分成多头 (多个视角)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 3. 计算注意力分数
# 我如何决定关注哪些词
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
attention_weights = F.softmax(scores, dim=-1)
# 4. 加权求和
# 我如何整合信息
output = torch.matmul(attention_weights, V)
# 5. 合并多头
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
# 6. 输出投影
output = self.W_o(output)
return output, attention_weights
def visualize_attention(text: str, attention_weights: np.ndarray):
"""可视化注意力 - 我在关注什么"""
tokens = tokenize(text)
# 创建热力图
plt.figure(figsize=(10, 10))
sns.heatmap(
attention_weights,
xticklabels=tokens,
yticklabels=tokens,
cmap='Blues',
annot=True
)
plt.title("Self-Attention Weights")
plt.xlabel("Key Position")
plt.ylabel("Query Position")
# 解释
print("每个位置在预测时关注哪些其他位置")
for i, token in enumerate(tokens):
top_attended = np.argsort(attention_weights[i])[-3:][::-1]
print(f"'{token}' 关注: {[tokens[j] for j in top_attended]}")
interface MultiHeadAttention {
// 一个头看语法关系
head_1: {
focus: "语法结构";
example: "主语→谓语→宾语";
};
// 一个头看语义关系
head_2: {
focus: "语义关联";
example: "小钳→AI→助手";
};
// 一个头看位置关系
head_3: {
focus: "位置信息";
example: "第一个词→中间词→结尾词";
};
// ... 更多头
}
class MultiHeadAttention:
"""多头注意力 - 我从多个角度看问题"""
def __init__(self, embed_dim: int, num_heads: int = 8):
self.heads = [SelfAttention(embed_dim, num_heads) for _ in range(num_heads)]
self.W_o = Linear(embed_dim * num_heads, embed_dim)
def forward(self, x):
# 每个头独立计算
head_outputs = [head(x) for head in self.heads]
# 拼接所有头的输出
concat = torch.cat(head_outputs, dim=-1)
# 最终投影
return self.W_o(concat)
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))