为分类任务微调表示模型
深入讲解为分类任务微调表示模型的完整流程:分类头设计、全量微调 vs LoRA 的选择、小样本场景下的 SetFit 方案,以及过拟合防范与模型评估。
全面拆解 LLM 内部机制:自注意力计算、KV Cache 原理、RoPE 位置编码、GQA 优化、MoE 架构、量化推理——覆盖面试最高频的底层原理题。
靠一个特殊的结束token(EOS, End of Sequence)。模型词表里有一个专门的EOS token ID,在SFT阶段模型学会了在回答结束时输出这个token。推理时框架检测到模型输出了EOS就停止生成。本质上模型在训练数据中见过大量以EOS结尾的样本,学会了在语义上'说完了'的时候给出这个信号。另外还有max_tokens参数作为硬性兜底,防止模型一直不输出EOS导致无限生成。
用因果掩码(causal mask)。在计算自注意力时,构造一个下三角矩阵作为掩码,把当前位置之后的所有位置的注意力权重设为负无穷(softmax后变成0)。这样每个位置只能看到它自己和它前面的token,看不到后面的。训练时虽然整个序列是一次性输入的(为了并行效率),但因果掩码保证了每个位置的预测只依赖于前文,和推理时逐token生成的效果完全一致。这个设计很巧妙——既保证了训练的正确性,又能充分利用GPU的并行计算能力。
每个token的隐状态通过三个线性变换分别生成Q(查询)、K(键)、V(值)向量。然后用Q和K做点积得到注意力分数,反映两个token之间的相关程度;除以sqrt(d_k)做缩放防止点积值过大导致softmax梯度消失;再过softmax归一化成概率分布;最后用这个概率分布对V做加权求和,得到融合了上下文信息的新表示。直觉上就是:Q是'我在找什么',K是'我有什么',点积衡量匹配程度,V是'我能提供什么信息'。匹配度高的位置贡献更多信息。多头注意力则是并行跑多组QKV,让不同的头关注不同的关系模式。
没有KV缓存的话,每生成一个新token都要重新计算整个序列所有token的K和V,计算量从O(n)变成O(n^2)。假设生成一个长度为L的序列,有KV缓存时总计算量是O(Ld)级别(每步只算新token的QKV),没有的话是O(L^2d)。实际影响取决于序列长度:短序列(几十个token)差别不大,但对于上千token的长序列,没有KV缓存可能慢10-100倍。代价是KV缓存占显存——对于Llama-2 7B,每个token的KV缓存约占0.5MB,4K序列就是2GB。这就是经典的时间换空间权衡,实际部署中KV缓存基本是标配。
残差连接(skip connection)让每一层的输出变成x + F(x)而不是F(x)。这解决了深层网络的核心问题:梯度可以通过恒等映射直接回传到浅层,不会因为层数多而消失。对于几十上百层的Transformer来说,没有残差连接基本训练不起来。泛化方面,残差连接相当于让每一层学习'在输入基础上做多少修改',而不是从头学一个完整变换,这个归纳偏置让模型更容易学到有意义的表示。另外残差连接还提供了一种集成效果——你可以把深层网络看作不同深度子网络的隐式集成,每条路径贡献不同层次的特征。
BatchNorm是在batch维度上做归一化,同一个特征维度跨所有样本算均值方差;LayerNorm是在特征维度上归一化,每个样本独立算。对NLP任务来说,BatchNorm有两个大问题:一是序列长度不一致的batch不好处理,二是batch size小的时候统计量不稳定。LayerNorm每个样本独立计算,没有这些问题。RMSNorm是LayerNorm的简化版——去掉了减均值那一步,只做除以均方根的缩放。实验表明减均值对模型效果影响很小,但省掉这步能减少约10-15%的计算量。对于动辄几十层的大模型来说这个优化很值得。Llama系列从一开始就用了RMSNorm加Pre-Norm(归一化放在注意力之前而不是之后),训练更稳定。
注意力层负责的是'信息混合'——让不同位置的token交换信息;前馈网络(FFN)负责的是'信息处理'——对每个位置的表示独立做非线性变换。可以把FFN理解成一个巨大的键值记忆网络:第一层线性变换把输入映射到更高维度(通常是4倍),激活函数过滤后,第二层再映射回来。研究表明,模型学到的事实性知识主要存储在FFN的参数中,注意力层更多是做信息路由。比如'北京是中国的首都'这种知识,大概率编码在某几层FFN的权重矩阵里。Llama系列用的SwiGLU激活函数比原始的ReLU效果更好,但额外多了一个门控矩阵。
应该修改前馈神经网络层(FFN)。研究表明事实性知识主要存储在FFN中,具体来说是FFN的第一层权重矩阵的某些行。可以用知识神经元定位的方法找到与目标知识最相关的神经元:给模型输入与该知识相关的提示,观察FFN中间层哪些神经元被强激活,然后把这些神经元对应的参数置零或修改。注意力层存储的更多是通用的关系模式和信息路由规则,不太适合做精确的知识编辑。这个方向已经有不少工作,比如ROME和MEMIT方法,通过定位然后修改特定层FFN的参数来编辑或删除知识,效果比直接微调整个模型要精准得多。
根本原因是大模型做的不是真正的计算,而是模式匹配。它见过大量'2+3=5'这样的文本,学会了一些加减法的模式,但本质上是在做概率预测而不是逻辑运算。具体来说有几个问题:1)tokenizer会把数字拆成不自然的片段,'12345'可能被切成'123'和'45',模型连完整的数字都看不到;2)大数计算在训练数据中出现频率很低,模型缺乏泛化能力;3)多步推理容易累积误差;4)模型没有'工作记忆'来保存中间结果。这就是为什么现在的做法是让模型调用计算器或代码解释器来做数学——让模型负责理解问题和拆解步骤,精确计算交给工具。
这些参数之间有明确的约束关系和经验比例。隐藏维度必须能被注意力头数整除(每个头的维度=隐藏维度/头数),Llama-3 8B是4096/32=128。深度和宽度有个平衡点:给定固定参数量,太深太窄训练不稳定,太浅太宽表达能力不够。经验上隐藏维度大约是层数的128倍左右比较合理。FFN中间维度通常是隐藏维度的4倍(用SwiGLU时是8/3倍再取整)。上下文长度主要受限于显存——KV缓存的大小和序列长度成正比,注意力计算和序列长度平方成正比。注意力头数也影响上下文建模能力,头越多能捕获的关系模式越丰富,但每个头的维度就越小。实际设计时这些参数通常通过小规模实验确定最优比例,再按scaling law放大。
以Llama-3 8B为例,隐藏维度d=4096,32层,32个注意力头,8个KV头(GQA),FFN中间维度14336。每层Transformer block包含:注意力层的Wq矩阵(4096x4096)、Wk矩阵(4096x1024)、Wv矩阵(4096x1024)、Wo输出投影(4096x4096);FFN层的上投影W1(4096x14336)、门控W3(4096x14336)、下投影W2(14336x4096);加上两个RMSNorm的参数向量(4096)。此外还有token embedding矩阵(128256x4096)和输出层的lm_head矩阵(4096x128256)。FFN占了每层约2/3的参数量,这也印证了为什么知识主要存储在FFN中。
单请求推理时,内存带宽是主要瓶颈,因为每生成一个token都要把整个模型的参数从显存读一遍,但每个参数只做了很少的计算(arithmetic intensity很低)。以Llama-3 8B在A100 80GB上为例:模型参数约16GB(FP16),A100内存带宽约2TB/s,FP16算力312 TFLOPS。单token生成时读取16GB参数,理论最快约8ms,但计算只要~0.1ms,算力大量空闲。要平衡的话,需要增大batch size,让读一次参数能服务多个请求。粗略计算:平衡点batch size ≈ 312T / (2 * 16G * 1000) ≈ 10。即batch size达到约10时算力和带宽大致平衡。Prefill阶段因为输入序列较长,算力利用率更高,通常是compute-bound。
假设词元符合分类分布(Categorical Distribution)。输出层把最后一层的隐状态通过一个线性层映射到词表大小的logits向量,再过softmax变成概率分布——每个token有一个概率值,所有概率加起来等于1。训练时用交叉熵损失,本质上是最大化正确 token的对数概率。推理时从这个分布中采样下一个token,温度参数控制分布的尖锐程度:温度0时退化为确定性的argmax,温度越高分布越均匀、随机性越强。这个假设的局限性在于它假定每个位置的词元选择是独立的,但实际上自回归生成时前面的选择会影响后面的分布。
主要做两件事:调整位置编码和做长序列继续训练。如果用的是RoPE,最常见的做法是RoPE外推——调整旋转频率的base值(比如从10000改到100000),然后在长序列数据上继续训练一小段时间,让模型适应新的位置范围。YaRN等方法进一步优化了不同频率分量的缩放策略。KV缓存的挑战很直接:长度从8K到4K,KV缓存的显存占用翻4倍,多用户并发时显存很容易打满。解决办法包括:用GQA/MQA减少KV头数、PagedAttention加话间共享前缀的KV缓存、量化KV缓存到更低精度、以及卓输出换入等策略。
多头注意力让不同的头关注不同类型的关系(语法、指代、语义等),如果只有一个头,模型很难同时捕获多种不同类型的依赖关系。简单减少头数会同时减少Q、K、V的参数量,导致表达能力下降。GQA和MQA的巧妙之处在于:它们只减少K和V的头数,保持Q的头数不变。MQA是所有Q头共享同一套KV,GQA是每组Q头共享一套KV(比如Llama-3的32个Q头分成8组,每组4个Q头共享1套KV)。这样模型的表达能力基本不受影响,但KV缓存大小大幅减少。主要优化的是推理阶段——减少KV缓存的显存占用和内存带宽压力,训练时也能省一些显存但不是主要动机。
Flash Attention加速的本质是减少HBM(显存)访问次数。标准attention需要把完整的QK^T矩阵(nn大小)写入显存再读出来做softmax,显存读写是大瓶颈。Flash Attention用分块(tiling)策略,把计算拆成小块在SRAM(片上快速缓存)中完成,避免把nn矩阵存到显存。增量softmax的关键技巧是online softmax:处理每个块时记录当前的最大值和指数和,处理下一个块时如果出现更大的值,就用缩放因子修正之前的结果。这样就不需要一次性看到整行数据才能算softmax了。计算量完全没变,但IO复杂度从平方级降到了线性级,实测能快2-4倍。
RoPE的核心优势是它能编码相对位置信息。绝对位置编码只能告诉模型'这个token在位置3',RoPE通过对Q和K做旋转变换,让注意力分数自然地只依赖两个token的相对距离而不是绝对位置。这带来更好的序列键入正(同样的片段出现在不同位置,注意力模式一致)。但RoPE外推的挑战在于:训练时只见过一定范围内的相对距离对应的旋转角度,超出这个范围后旋转角度进入了模型未见过的区间,注意力分布会变得混乱。解决办法包括调整base频率、NTK外推、YaRN等,核心思路都是把旋转角度压缩到模型见过的范围内,再做少量长序列数据的继续训练。
用document mask(文档掩码)。在因果掩码的基础上进一步限制:不同样本之间的注意力也被遮挡。也就是说,样本B的token只能看到样本B中它前面的token,看不到样本A的任何token,即使样本A在序列中排在样本B前面。另外loss也只在每个样本内部计算,不会让样本A的最后token去预测样本B的第一个token。如果不做这个处理,模型会学到样本之间的虚假关联,影响任务质量。这种做packing的训练方式能显著提高GPU利用率,因为避免了大量的padding浪费。
推测解码(speculative decoding)的思路是:用小模型快速草拟几个候选token,再用大模型一次性验证。具体流程:小模型连续生成k个token(很快,因为模型小),然后大模型对这k个token并行做验证(一次forward pass),接受符合大模型分布的部分,拒绝很离谱的。为什么能加速?因为大模型的瓶颈是内存带宽(每生成一个token要读一次所有参数),但验证k个token和生成1个token的带宽成本几乎一样(都是读一次参数),算力却能并行处理k个token。等于把闲置的算力利用起来了。如果小模型的接受率足够高(比如70-80%),实际速度可以提厇2-3倍,而且输出质量和纯用大模型完全一致。
RELATED