损失函数概述

损失函数表

损失函数在微调模型的性能中扮演着至关重要的角色。遗憾的是,并不存在“一刀切”的万能损失函数。理想情况下,下表应能通过将损失函数与您的数据格式相匹配,帮助您缩小损失函数的选择范围。

注意

您通常可以将一种训练数据格式转换为另一种,从而使更多的损失函数适用于您的场景。例如,带有 类别 标签的 (句子_A, 句子_B) 可以通过对相同或不同类别的句子进行采样,转换为 (锚点, 正例, 负例) 三元组

输入 标签 适用的损失函数
单个句子 类别 BatchAllTripletLoss (批处理全三元组损失)
BatchHardSoftMarginTripletLoss (批处理硬软间隔三元组损失)
BatchHardTripletLoss (批处理硬三元组损失)
BatchSemiHardTripletLoss (批处理半硬三元组损失)
单个句子 ContrastiveTensionLoss (对比张力损失)
DenoisingAutoEncoderLoss (去噪自编码器损失)
(锚点, 锚点) 对 ContrastiveTensionLossInBatchNegatives (批处理负例对比张力损失)
(损坏的句子, 原始句子) 对 DenoisingAutoEncoderLoss (去噪自编码器损失)
(句子_A, 句子_B) 对 类别 SoftmaxLoss (Softmax 损失)
(锚点, 正例) 对 MultipleNegativesRankingLoss (多负例排序损失)
CachedMultipleNegativesRankingLoss (缓存多负例排序损失)
MultipleNegativesSymmetricRankingLoss (多负例对称排序损失)
CachedMultipleNegativesSymmetricRankingLoss (缓存多负例对称排序损失)
MegaBatchMarginLoss (大批量间隔损失)
GISTEmbedLoss
CachedGISTEmbedLoss (缓存 GISTEmbed 损失)
(锚点, 正例/负例) 对 正例为 1,负例为 0 ContrastiveLoss (对比损失)
OnlineContrastiveLoss (在线对比损失)
(句子_A, 句子_B) 对 0 到 1 之间的浮点相似度分数 CoSENTLoss
AnglELoss
CosineSimilarityLoss (余弦相似度损失)
(锚点, 正例, 负例) 三元组 MultipleNegativesRankingLoss (多负例排序损失)
CachedMultipleNegativesRankingLoss (缓存多负例排序损失)
TripletLoss (三元组损失)
CachedGISTEmbedLoss (缓存 GISTEmbed 损失)
GISTEmbedLoss
(锚点, 正例, 负例_1, ..., 负例_n) MultipleNegativesRankingLoss (多负例排序损失)
CachedMultipleNegativesRankingLoss (缓存多负例排序损失)
CachedGISTEmbedLoss (缓存 GISTEmbed 损失)

损失修改器

这些损失函数可以被视为损失修改器:它们在标准损失函数的基础上工作,但以不同的方式应用这些损失函数,试图为训练好的嵌入模型注入有用的特性。

例如,使用 MatryoshkaLoss 训练的模型产生的嵌入,其大小可以被截断而不会有显著的性能损失;而使用 AdaptiveLayerLoss 训练的模型,在移除模型层以加快推理速度后,仍然表现良好。

文本 标签 适用的损失函数
任何 任何 MatryoshkaLoss (套娃损失)
AdaptiveLayerLoss (自适应层损失)
Matryoshka2dLoss (二维套娃损失)

蒸馏

这些损失函数是专门为将知识从一个模型蒸馏到另一个模型而设计的。例如,在微调一个小模型以使其行为更像一个更大更强的模型时,或者在微调一个模型使其成为多语言模型时。

文本 标签 适用的损失函数
句子 模型句子嵌入 MSELoss (均方误差损失)
(句子_1, 句子_2, ..., 句子_N) 模型句子嵌入 MSELoss (均方误差损失)
(查询, 段落一, 段落二) gold_sim(查询, 段落_一) - gold_sim(查询, 段落_二) MarginMSELoss (间隔均方误差损失)
(查询, 正例, 负例_1, ..., 负例_n) [gold_sim(查询, 正例) - gold_sim(查询, 负例_i) for i in 1..n] MarginMSELoss (间隔均方误差损失)
(查询, 正例, 负例) [gold_sim(查询, 正例), gold_sim(查询, 负例)] DistillKLDivLoss (蒸馏 KL 散度损失)
MarginMSELoss (间隔均方误差损失)
(查询, 正例, 负例_1, ..., 负例_n) [gold_sim(查询, 正例), gold_sim(查询, 负例_i)...] DistillKLDivLoss (蒸馏 KL 散度损失)
MarginMSELoss (间隔均方误差损失)

常用损失函数

在实践中,并非所有损失函数的使用频率都相同。最常见的场景是:

  • 没有任何标签的 (锚点, 正例) MultipleNegativesRankingLoss(又名 InfoNCE 或批处理内负例损失)通常用于训练性能顶尖的嵌入模型。这种数据通常获取成本较低,而且模型通常性能非常出色。CachedMultipleNegativesRankingLoss 通常用于增加批处理大小,从而获得更优的性能。

  • 带有浮点 相似度 分数(句子_A, 句子_B) :传统上多使用CosineSimilarityLoss,但最近CoSENTLossAnglELoss作为其直接替代品,性能更优。

自定义损失函数

高级用户可以创建并使用自己的损失函数进行训练。自定义损失函数只有几个要求:

  • 它们必须是 torch.nn.Module 的子类。

  • 它们的构造函数中必须有 model 作为第一个参数。

  • 它们必须实现一个 forward 方法,该方法接受 sentence_featureslabels。前者是分词批次的列表,每个元素对应一列。这些分词批次可以直接输入到正在训练的模型中以生成嵌入。后者是一个可选的标签张量。该方法必须返回一个单一的损失值或一个损失分量的字典(分量名称到损失值的映射),这些分量将被相加以产生最终的损失值。当返回字典时,除了总和损失外,各个分量也会被单独记录,以便您监控损失的各个组成部分。

为了获得自动生成模型卡的完全支持,您可能还希望实现:

  • 一个 get_config_dict 方法,返回一个包含损失参数的字典。

  • 一个 citation 属性,这样您的工作就会在所有使用该损失函数训练的模型中被引用。

可以考虑查看现有的损失函数,以了解损失函数的常用实现方式。