本文的重点放在GPT式模型的数据流向和损失计算, 以及其在SFT情景下的处理

有关 GPT 结构, 诸如 decoder-only, positional encoding等会在有关Transformer的文章中介绍

  1. 模型结构与序列长度的关系
    首先,tokens的长度并不会从定义上影响Transformer网络的结构。Transformer的计算本质是对Embedding向量进行处理,因此无论输入的序列长度(N)是多少,Transformer的结构(层数、头数、隐藏维度等)并不发生变化。

    不过,序列长度会影响计算量:

    • 在自注意力(Self-Attention)机制中,需要计算长度为N的序列中每个token与其他N-1个token之间的关联权重,复杂度为$O(N^2)$,因此序列越长,计算越耗时。
    • 如果使用的是绝对位置编码,那么位置编码的定义方式可能受到序列长度的影响(例如编码范围、位置映射方式等)。相对位置编码就不受此直接限制。
  2. 数据流向(Data Flow)
    以一个简单的案例为例:假设输入数据的形状为 (batch_size=1, sequence_length=10),模型的embedding维度为768,词汇表大小为50,000。

    • 首先,原始输入 tokens 会通过 Embedding Layer 映射为 (1, 10, 768) 的向量表示。
    • 然后,这些嵌入向量经过多层Transformer网络处理,每一层都会保持相同的形状 (1, 10, 768)
    • 在最后,模型通过一个 lm_head(线性层)将隐藏状态投影到词汇表维度 (1, 10, 50,000),对应每个序列位置的下一token预测分布(logits)。
      换言之,输出中对于每个输入的token位置,都有一个对应的词汇表大小的向量,代表对下一个token的概率估计。
  3. 损失计算(Loss Calculation)
    在自监督预训练中,我们对输入序列进行自回归预测。例如,对于长度为10的输入序列,模型要用第1个token预测第2个token、第2个token预测第3个token,依次类推,直到用第9个token预测第10个token。

    • 最后一个token没有下一个token可预测,因此它的标签(label)位置设为-100(在HuggingFace中,-100用于表示忽略该位置的loss)。
    • 对前9个位置的预测结果和真实下一个token进行交叉熵计算,将这些交叉熵值相加(或取平均),得到该样本的loss。

      对应的label序列举例:

    • 输入: [input1, input2, input3, input4]

    • Label: [input2, input3, input4, -100]
      前三个位置计算预测误差,最后一个位置不计入loss。

指令微调(SFT)情境下的损失计算与训练方式

在指令微调(SFT)中,我们通常将问题(Question)和答案(Answer)组合成一个连续序列作为输入:
[Q1, Q2, Q3, A1, A2, A3]

在训练中,模型依然是以自回归的方式进行”next token”预测,只是我们通过设置labels中的问题部分为-100来忽略损失。这样就确保:

  • 对问题区间 (Q1, Q2, Q3) 不计算loss,因为这些部分不需要模型从自身预测出来,而是被当做给定的上下文。
  • 对回答区间 (A1, A2, A3) 进行正常的next token预测和loss计算。

因此Label会是:
[-100, -100, -100, A1, A2, A3]
这里的 A1, A2, A3 是模型需要预测的目标token。最后一个token之后如果没有下一个token,则同理最终的末尾位置应该设置为-100,从而不参与loss计算。

为什么不用另一种训练方式?
从理论上,人们可能设想:为什么不让模型先完整看到问题部分,然后在训练中直接从问题结束的地方开始进行next token预测?理论上是可行的,但有以下实务考量:

  1. 模型结构与训练范式保持一致
    GPT类模型本身是作为自回归语言模型设计的,其预训练目标一直是给定前面所有token预测下一个token。无论是pretrain还是SFT,都是用相同的自回归建模方法,即:输入一个完整的序列,然后对序列中每个位置的下一个token进行预测。
    这种统一的训练范式极大简化了实现和训练代码,避免为SFT设置专门的一套”只在特定位置开始预测”的逻辑。

  2. 上下文一致性和无缝对接
    GPT模型在推理时是将问题与回答作为一个连续的token流来处理。SFT的目标也是在这样的场景下优化模型的行为,使模型更好地在一个连续的token序列上下文中进行预测。
    如果在训练阶段为了”简化”而人为将问题部分与回答预测分开处理(即先全量看到问题,然后再单独从answer部分开始预测),那么训练和推理时的上下文处理方式会不一致,可能对模型的泛化产生影响。

  3. 避免对框架进行大改
    多数已有的SFT框架是建立在已有的语言模型训练流程之上,只是通过将不参与loss计算的部分token的label置为-100来忽略这些位置的损失。这样无需对模型的前向计算或训练流程进行特殊调整,只需要调整label或loss mask。
    如果要在训练时将问题部分和回答部分”分段处理”,就需要对训练流程进行特殊改造,比如在forward时对输入序列分区,然后从问题结束后的位置开始计算loss。虽然这并非不可实现,但会带来额外的工程复杂度。

  4. 模型能力与训练目标统一
    SFT的最终目的是让模型更好地对指令(问题)做出回答(回答序列)。即使在训练时对问题部分不计算loss,模型仍在这种自回归范式下学习如何从”上下文包含指令”的状态,流畅地生成后续答案。这与推理时的场景严格对应:推理时你给定整个问题序列,然后让模型自回归生成答案序列。
    换句话说,这种训练和推理方式是对称和直观的。

  5. 简化与适配性
    在工业和研究环境中,为了保持训练和推理流程尽量简单和统一,人们更倾向用同一套自回归范式,只通过mask机制在loss层面完成控制,而不在模型前向计算逻辑上作出过多特化。这也提高了代码的可维护性和通用性。

总而言之,虽然理论上可以在SFT时采用 “直接从问题结束位置开始预测” 的特化训练策略,但实践中更常采用的问题+答案连续序列加上问题部分label为-100的方式。这种方式与预训练范式及推理步骤高度一致,代码改动少,维护简单,因此被广泛应用。