采样过程源码解析:从 logits 到 token 的采样策略
在自然语言处理(NLP)领域,尤其是生成式模型中,采样过程是将模型输出的 logits 转换为实际可读的 token 的关键步骤。这一过程不仅决定了生成文本的多样性,还影响着模型输出的质量和实用性。本文将深入解析采样过程的源码实现,探讨从 logits 到 token 的多种采样策略,帮助读者更好地理解这一核心环节。
1. Logits 的生成与理解
在深度学习模型中,尤其是基于 Transformer 的架构,模型的最后一层通常会输出一个形状为 (batch_size, sequence_length, vocab_size) 的张量,其中 vocab_size 是词汇表的大小。这个张量中的每个值,我们称之为 logits,代表了模型对每个位置上每个可能 token 的预测分数。
Logits 本身并不直接代表概率,它们需要通过 softmax 函数进行归一化,转换为概率分布。然而,在采样过程中,我们并不总是直接使用 softmax 后的概率,而是基于 logits 应用各种采样策略,以平衡生成文本的准确性和多样性。
2. 采样策略概览
采样策略的选择直接影响生成文本的质量。常见的采样策略包括贪婪采样(Greedy Sampling)、随机采样(Random Sampling)、温度采样(Temperature Sampling)、Top-k 采样(Top-k Sampling)和 Top-p(Nucleus)采样(Top-p Sampling)。下面,我们将逐一解析这些策略的源码实现。
2.1 贪婪采样
贪婪采样是最简单的采样策略,它选择每个位置上概率最高的 token 作为输出。虽然这种方法能保证生成文本的确定性,但往往缺乏多样性,容易陷入重复或模式化的输出。
import torch
import torch.nn.functional as F
def greedy_sample(logits):
# 应用 softmax 获取概率分布
probs = F.softmax(logits, dim=-1)
# 选择每个位置上概率最高的 token
_, sampled_tokens = torch.max(probs, dim=-1)
return sampled_tokens
2.2 随机采样
随机采样,也称为多项式采样,根据 softmax 后的概率分布随机选择 token。这种方法能增加输出的多样性,但也可能导致生成不连贯或无意义的文本。
def random_sample(logits):
probs = F.softmax(logits, dim=-1)
# 根据概率分布随机采样
sampled_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
return sampled_tokens
2.3 温度采样
温度采样通过调整 softmax 函数的'温度'参数来控制输出的多样性。温度参数 T 越大,概率分布越平滑,采样结果越多样;T 越小,概率分布越尖锐,采样结果越接近贪婪采样。
():
adjusted_logits = logits / temperature
probs = F.softmax(adjusted_logits, dim=-)
sampled_tokens = torch.multinomial(probs, num_samples=).squeeze(-)
sampled_tokens

