损失函数概述

损失函数表

损失函数在您微调的交叉编码器模型的性能中起着关键作用。遗憾的是,没有“一刀切”的损失函数。理想情况下,此表应通过将损失函数与您的数据格式匹配来帮助缩小您的选择范围。

注意

您通常可以将一种训练数据格式转换为另一种,从而使更多损失函数适用于您的场景。例如,通过抽样具有相同或不同类别的句子,可以将带有class标签的(sentence_A, sentence_B) pairs转换为(anchor, positive, negative) triplets

此外,mine_hard_negatives() 可以轻松地将 (anchor, positive) 转换为

  • (anchor, positive, negative) 三元组,使用 output_format="triplet"

  • (anchor, positive, negative_1, …, negative_n) n元组,使用 output_format="n-tuple"

  • (anchor, passage, label) 标记对,其中负样本的标签为0,正样本的标签为1,使用 output_format="labeled-pair"

  • (anchor, [doc1, doc2, ..., docN], [label1, label2, ..., labelN]) 三元组,其中负样本的标签为0,正样本的标签为1,使用 output_format="labeled-list"

以及通过设置 output_scores=True 来使用相似性分数而非二值化标签的格式。

输入 标签 模型输出标签的数量 适用的损失函数
(句子_A, 句子_B) 对 类别 num_classes CrossEntropyLoss
(锚点, 正例) 对 1 MultipleNegativesRankingLoss
CachedMultipleNegativesRankingLoss
(锚点, 正例/负例) 对 正例为 1,负例为 0 1 BinaryCrossEntropyLoss
(句子_A, 句子_B) 对 0 到 1 之间的浮点相似度分数 1 BinaryCrossEntropyLoss
(锚点, 正例, 负例) 三元组 1 MultipleNegativesRankingLoss
CachedMultipleNegativesRankingLoss
(锚点, 正例, 负例_1, ..., 负例_n) 1 MultipleNegativesRankingLoss
CachedMultipleNegativesRankingLoss
(query, [doc1, doc2, ..., docN]) [score1, score2, ..., scoreN] 1
  1. LambdaLoss
  2. PListMLELoss
  3. ListNetLoss
  4. RankNetLoss
  5. ListMLELoss

蒸馏

这些损失函数专门设计用于将知识从一个模型蒸馏到另一个模型。例如,当微调一个小型模型使其行为更像一个更大更强的模型时,或者当微调一个模型使其成为多语言模型时。

文本 标签 适用的损失函数
(句子_A, 句子_B) 对 相似性分数 MSELoss
(查询, 段落_一, 段落_二) 三元组 gold_sim(查询, 段落_一) - gold_sim(查询, 段落_二) MarginMSELoss
(查询, 正例, 负例_1, ..., 负例_n) [gold_sim(查询, 正例) - gold_sim(查询, 负例_i) for i in 1..n] MarginMSELoss
(查询, 正例, 负例) [gold_sim(查询, 正例), gold_sim(查询, 负例)] MarginMSELoss
(查询, 正例, 负例_1, ..., 负例_n) [gold_sim(查询, 正例), gold_sim(查询, 负例_i)...] MarginMSELoss

常用损失函数

在实践中,并非所有损失函数的使用频率都相同。最常见的场景是:

  • (sentence_A, sentence_B) 具有 浮点 相似性 分数1 如果 正样本, 0 如果 负样本BinaryCrossEntropyLoss 是一个传统的选项,它仍然很难被超越。

  • (anchor, positive) 不带任何标签:与 mine_hard_negatives 结合使用

    • 如果使用 output_format=”labeled-list”,则 LambdaLoss 经常用于学习排名任务。

    • 如果使用 output_format=”labeled-pair”,则 BinaryCrossEntropyLoss 仍然是一个强有力的选择。

自定义损失函数

高级用户可以创建并使用自己的损失函数进行训练。自定义损失函数只有几个要求:

  • 它们必须是torch.nn.Module的子类。

  • 它们的构造函数中必须有 model 作为第一个参数。

  • 它们必须实现一个接受 inputslabelsforward 方法。前者是批处理中嵌套的文本列表,外部列表中的每个元素代表训练数据集中的一列。您必须将这些文本组合成可以 1) 分词和 2) 输入到模型的对。后者是数据集中 labellabelsscorescores 列中的可选(列表形式的)标签张量。该方法必须返回一个单一的损失值或一个损失组件字典(组件名称到损失值),这些组件将被求和以产生最终的损失值。当返回字典时,除了总和损失外,各个组件将单独记录,允许您监控损失的各个组件。

为了获得自动生成模型卡的完全支持,您可能还希望实现:

  • 一个 get_config_dict 方法,返回一个包含损失参数的字典。

  • 一个 citation 属性,这样您的工作就会在所有使用该损失函数训练的模型中被引用。

可以考虑查看现有的损失函数,以了解损失函数的常用实现方式。