损失函数概述

损失函数表

损失函数在微调交叉编码器(Cross Encoder)模型的性能中起着至关重要的作用。遗憾的是,没有“一刀切”的万能损失函数。理想情况下,此表应通过将损失函数与您的数据格式相匹配,来帮助您缩小选择范围。

注意

您通常可以将一种训练数据格式转换为另一种,从而使更多的损失函数适用于您的场景。例如,带有类别标签的(句子_A, 句子_B) 可以通过采样具有相同或不同类别的句子,转换为(锚点, 正例, 负例) 三元组

此外,mine_hard_negatives()可以轻松地将(锚点, 正例)转换为:

  • 使用 output_format="triplet" 得到 (锚点, 正例, 负例) 三元组

  • 使用 output_format="n-tuple" 得到 (锚点, 正例, 负例_1, …, 负例_n) 元组

  • 使用 output_format="labeled-pair" 得到 (锚点, 段落, 标签) 带标签的对,其中负例标签为0,正例为1,

  • 使用 output_format="labeled-list" 得到 (锚点, [文档1, 文档2, ..., 文档N], [标签1, 标签2, ..., 标签N]) 三元组,其中负例标签为0,正例为1。

输入 标签 模型输出标签数量 适用的损失函数
(句子_A, 句子_B) 对 类别 num_classes (类别数量) CrossEntropyLoss (交叉熵损失)
(锚点, 正例) 对 1 MultipleNegativesRankingLoss (多负例排序损失)
CachedMultipleNegativesRankingLoss (缓存多负例排序损失)
(锚点, 正例/负例) 对 正例为 1,负例为 0 1 BinaryCrossEntropyLoss (二元交叉熵损失)
(句子_A, 句子_B) 对 0 到 1 之间的浮点相似度分数 1 BinaryCrossEntropyLoss (二元交叉熵损失)
(锚点, 正例, 负例) 三元组 1 MultipleNegativesRankingLoss (多负例排序损失)
CachedMultipleNegativesRankingLoss (缓存多负例排序损失)
(锚点, 正例, 负例_1, ..., 负例_n) 1 MultipleNegativesRankingLoss (多负例排序损失)
CachedMultipleNegativesRankingLoss (缓存多负例排序损失)
(查询, [文档1, 文档2, ..., 文档N]) [分数1, 分数2, ..., 分数N] 1
  1. LambdaLoss
  2. PListMLELoss
  3. ListNetLoss
  4. RankNetLoss
  5. ListMLELoss

蒸馏

这些损失函数专门设计用于将知识从一个模型蒸馏到另一个模型。例如,当微调一个小模型使其行为更像一个更大、更强的模型时,或者当微调一个模型使其成为多语言模型时。

文本 标签 适用的损失函数
(句子_A, 句子_B) 对 相似度分数 MSELoss (均方误差损失)
(查询, 段落_一, 段落_二) 三元组 gold_sim(查询, 段落_一) - gold_sim(查询, 段落_二) MarginMSELoss (边距均方误差损失)
(查询, 正例, 负例_1, ..., 负例_n) [gold_sim(查询, 正例) - gold_sim(查询, 负例_i) for i in 1..n] MarginMSELoss (边距均方误差损失)
(查询, 正例, 负例) [gold_sim(查询, 正例), gold_sim(查询, 负例)] MarginMSELoss (边距均方误差损失)
(查询, 正例, 负例_1, ..., 负例_n) [gold_sim(查询, 正例), gold_sim(查询, 负例_i)...] MarginMSELoss (边距均方误差损失)

常用损失函数

在实践中,并非所有损失函数的使用频率都相同。最常见的场景是:

  • 带有浮点相似度分数正例为1,负例为0(句子_A, 句子_B) BinaryCrossEntropyLoss 是一个传统选项,至今仍难以超越。

  • 无任何标签的(锚点, 正例) :结合 mine_hard_negatives

    • 当 `output_format=”labeled-list”` 时,LambdaLoss 经常用于学习到排序(learning-to-rank)任务。

    • 当 `output_format=”labeled-pair”` 时,BinaryCrossEntropyLoss 仍然是一个强有力的选项。

自定义损失函数

高级用户可以创建并使用自己的损失函数进行训练。自定义损失函数只有几个要求:

  • 它们必须是 torch.nn.Module 的子类。

  • 它们的构造函数中必须有 model 作为第一个参数。

  • 它们必须实现一个 forward 方法,该方法接受 inputslabels。前者是批次中嵌套的文本列表,外层列表的每个元素代表训练数据集中的一列。您必须将这些文本组合成可以 1) 被分词 和 2) 被送入模型的文本对。后者是一个可选的(张量)列表,包含来自数据集中 label, labels, score, 或 scores 列的标签。该方法必须返回一个单一的损失值,或者一个包含损失分量(分量名到损失值)的字典,这些分量将被相加得到最终的损失值。当返回一个字典时,除了总和损失外,每个分量也将被单独记录,使您能够监控损失的各个组成部分。

为了获得自动生成模型卡的完全支持,您可能还希望实现:

  • 一个 get_config_dict 方法,返回一个包含损失参数的字典。

  • 一个 citation 属性,这样您的工作就会在所有使用该损失函数训练的模型中被引用。

可以考虑查看现有的损失函数,以了解损失函数的常用实现方式。