损失函数概述
损失函数表
损失函数在微调模型的性能中起着关键作用。遗憾的是,没有一个“一刀切”的损失函数。理想情况下,此表应通过将损失函数与您的数据格式匹配来帮助缩小您的选择范围。
注意
您通常可以将一种训练数据格式转换为另一种,从而使更多损失函数适用于您的场景。例如,通过抽样具有相同或不同类别的句子,可以将带有class标签的(sentence_A, sentence_B) pairs转换为(anchor, positive, negative) triplets。
| 输入 | 标签 | 适用的损失函数 |
|---|---|---|
单个句子 |
类别 |
BatchAllTripletLossBatchHardSoftMarginTripletLossBatchHardTripletLossBatchSemiHardTripletLoss |
单个句子 |
无 |
ContrastiveTensionLossDenoisingAutoEncoderLoss |
(锚点, 锚点) 对 |
无 |
ContrastiveTensionLossInBatchNegatives |
(损坏的句子, 原始句子) 对 |
无 |
DenoisingAutoEncoderLoss |
(句子_A, 句子_B) 对 |
类别 |
SoftmaxLoss |
(锚点, 正例) 对 |
无 |
MultipleNegativesRankingLossCachedMultipleNegativesRankingLossMultipleNegativesSymmetricRankingLossCachedMultipleNegativesSymmetricRankingLossMegaBatchMarginLossGISTEmbedLossCachedGISTEmbedLoss |
(锚点, 正例/负例) 对 |
正例为 1,负例为 0 |
ContrastiveLossOnlineContrastiveLoss |
(句子_A, 句子_B) 对 |
0 到 1 之间的浮点相似度分数 |
CoSENTLossAnglELossCosineSimilarityLoss |
(锚点, 正例, 负例) 三元组 |
无 |
MultipleNegativesRankingLossCachedMultipleNegativesRankingLossTripletLossCachedGISTEmbedLossGISTEmbedLoss |
(锚点, 正例, 负例_1, ..., 负例_n) |
无 |
MultipleNegativesRankingLossCachedMultipleNegativesRankingLossCachedGISTEmbedLoss |
损失修饰符
这些损失函数可以被视为损失修饰符:它们在标准损失函数之上工作,但以不同的方式应用这些损失函数,以试图向训练好的嵌入模型灌输有用的属性。
例如,使用MatryoshkaLoss训练的模型会产生大小可以截断而性能没有明显损失的嵌入,而使用AdaptiveLayerLoss训练的模型在您移除模型层以加快推理速度时仍然表现良好。
| 文本 | 标签 | 适用的损失函数 |
|---|---|---|
任何 |
任何 |
MatryoshkaLossAdaptiveLayerLossMatryoshka2dLoss |
蒸馏
这些损失函数专门设计用于将知识从一个模型提炼到另一个模型。例如,当微调一个小模型使其表现更像一个更大更强的模型时,或者当微调一个模型使其成为多语言模型时。
| 文本 | 标签 | 适用的损失函数 |
|---|---|---|
句子 |
模型句子嵌入 |
MSELoss |
(句子_1, 句子_2, ..., 句子_N) |
模型句子嵌入 |
MSELoss |
(查询, 段落一, 段落二) |
gold_sim(查询, 段落_一) - gold_sim(查询, 段落_二) |
MarginMSELoss |
(查询, 正例, 负例_1, ..., 负例_n) |
[gold_sim(查询, 正例) - gold_sim(查询, 负例_i) for i in 1..n] |
MarginMSELoss |
(查询, 正例, 负例) |
[gold_sim(查询, 正例), gold_sim(查询, 负例)] |
DistillKLDivLossMarginMSELoss |
(查询, 正例, 负例_1, ..., 负例_n) |
[gold_sim(查询, 正例), gold_sim(查询, 负例_i)...] |
DistillKLDivLossMarginMSELoss |
常用损失函数
在实践中,并非所有损失函数的使用频率都相同。最常见的场景是:
不带任何标签的
(anchor, positive) pairs:MultipleNegativesRankingLoss(又名 InfoNCE 或批内负样本损失)常用于训练表现最佳的嵌入模型。这种数据通常相对容易获取,并且模型通常表现非常好。CachedMultipleNegativesRankingLoss通常用于增加批量大小,从而获得卓越的性能。带有
float similarity score的(sentence_A, sentence_B) pairs:CosineSimilarityLoss传统上使用很多,尽管最近CoSENTLoss和AnglELoss被用作具有卓越性能的直接替代品。
自定义损失函数
高级用户可以创建并使用自己的损失函数进行训练。自定义损失函数只有几个要求:
它们必须是
torch.nn.Module的子类。它们的构造函数中必须有
model作为第一个参数。它们必须实现一个接受
sentence_features和labels的forward方法。前者是一个标记化批次的列表,每列一个元素。这些标记化批次可以直接馈送到正在训练的model以生成嵌入。后者是一个可选的标签张量。该方法必须返回一个单一的损失值或一个损失组件字典(组件名称到损失值),这些组件将被求和以产生最终的损失值。当返回字典时,除了求和的损失之外,各个组件将单独记录,从而允许您监控损失的各个组件。
为了获得自动生成模型卡的完全支持,您可能还希望实现:
一个
get_config_dict方法,返回一个包含损失参数的字典。一个
citation属性,这样您的工作就会在所有使用该损失函数训练的模型中被引用。
可以考虑查看现有的损失函数,以了解损失函数的常用实现方式。