
输入 X: [B, L, d_model]
Q/K/V 权重:[d_model, d_model] (合头写法,拆开后每头是 [d_model, d_k])
多头时:先全量 linear 得 [B, L, d_model],再 view/reshape 成 [B, L, num_heads, d_k],再 permute 成 [B, num_heads, L, d_k]
先用简单的 Self-Attention 捋一遍数据流动的过程:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, embed_dim, d_k):
super().__init__()
self.embed_dim = embed_dim
self.d_k = d_k
self.W_Q = nn.Linear(embed_dim, d_k)
self.W_K = nn.Linear(embed_dim, d_k)
self.W_V = nn.Linear(embed_dim, d_k)
def forward(self, x):
# x: [batch_size, seq_len, embed_dim]
Q = self.W_Q(x) # [B, L, D]
K = self.W_K(x) # [B, L, D]
V = self.W_V(x) # [B, L, D]
# Attention scores: [B, L, L]
score = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(score, dim=-1) # [B, L, L]
att_output = torch.matmul(attn_weights, V)
att_output


