损失函数概述
损失函数表
损失函数在微调交叉编码器(Cross Encoder)模型的性能中起着至关重要的作用。遗憾的是,没有“一刀切”的万能损失函数。理想情况下,此表应通过将损失函数与您的数据格式相匹配,来帮助您缩小选择范围。
注意
您通常可以将一种训练数据格式转换为另一种,从而使更多的损失函数适用于您的场景。例如,带有类别
标签的(句子_A, 句子_B) 对
可以通过采样具有相同或不同类别的句子,转换为(锚点, 正例, 负例) 三元组
。
此外,mine_hard_negatives()
可以轻松地将(锚点, 正例)
转换为:
使用
output_format="triplet"
得到(锚点, 正例, 负例) 三元组
,使用
output_format="n-tuple"
得到(锚点, 正例, 负例_1, …, 负例_n) 元组
,使用
output_format="labeled-pair"
得到(锚点, 段落, 标签) 带标签的对
,其中负例标签为0,正例为1,使用
output_format="labeled-list"
得到(锚点, [文档1, 文档2, ..., 文档N], [标签1, 标签2, ..., 标签N]) 三元组
,其中负例标签为0,正例为1。
输入 | 标签 | 模型输出标签数量 | 适用的损失函数 |
---|---|---|---|
(句子_A, 句子_B) 对 |
类别 |
num_classes (类别数量) |
CrossEntropyLoss (交叉熵损失) |
(锚点, 正例) 对 |
无 |
1 |
MultipleNegativesRankingLoss (多负例排序损失) CachedMultipleNegativesRankingLoss (缓存多负例排序损失) |
(锚点, 正例/负例) 对 |
正例为 1,负例为 0 |
1 |
BinaryCrossEntropyLoss (二元交叉熵损失) |
(句子_A, 句子_B) 对 |
0 到 1 之间的浮点相似度分数 |
1 |
BinaryCrossEntropyLoss (二元交叉熵损失) |
(锚点, 正例, 负例) 三元组 |
无 |
1 |
MultipleNegativesRankingLoss (多负例排序损失) CachedMultipleNegativesRankingLoss (缓存多负例排序损失) |
(锚点, 正例, 负例_1, ..., 负例_n) |
无 |
1 |
MultipleNegativesRankingLoss (多负例排序损失) CachedMultipleNegativesRankingLoss (缓存多负例排序损失) |
(查询, [文档1, 文档2, ..., 文档N]) |
[分数1, 分数2, ..., 分数N] |
1 |
蒸馏
这些损失函数专门设计用于将知识从一个模型蒸馏到另一个模型。例如,当微调一个小模型使其行为更像一个更大、更强的模型时,或者当微调一个模型使其成为多语言模型时。
文本 | 标签 | 适用的损失函数 |
---|---|---|
(句子_A, 句子_B) 对 |
相似度分数 |
MSELoss (均方误差损失) |
(查询, 段落_一, 段落_二) 三元组 |
gold_sim(查询, 段落_一) - gold_sim(查询, 段落_二) |
MarginMSELoss (边距均方误差损失) |
(查询, 正例, 负例_1, ..., 负例_n) |
[gold_sim(查询, 正例) - gold_sim(查询, 负例_i) for i in 1..n] |
MarginMSELoss (边距均方误差损失) |
(查询, 正例, 负例) |
[gold_sim(查询, 正例), gold_sim(查询, 负例)] |
MarginMSELoss (边距均方误差损失) |
(查询, 正例, 负例_1, ..., 负例_n) |
[gold_sim(查询, 正例), gold_sim(查询, 负例_i)...] |
MarginMSELoss (边距均方误差损失) |
常用损失函数
在实践中,并非所有损失函数的使用频率都相同。最常见的场景是:
带有
浮点相似度分数
或正例为1,负例为0
的(句子_A, 句子_B) 对
:BinaryCrossEntropyLoss
是一个传统选项,至今仍难以超越。无任何标签的
(锚点, 正例) 对
:结合mine_hard_negatives
当 `output_format=”labeled-list”` 时,
LambdaLoss
经常用于学习到排序(learning-to-rank)任务。当 `output_format=”labeled-pair”` 时,
BinaryCrossEntropyLoss
仍然是一个强有力的选项。
自定义损失函数
高级用户可以创建并使用自己的损失函数进行训练。自定义损失函数只有几个要求:
它们必须是
torch.nn.Module
的子类。它们的构造函数中必须有
model
作为第一个参数。它们必须实现一个
forward
方法,该方法接受inputs
和labels
。前者是批次中嵌套的文本列表,外层列表的每个元素代表训练数据集中的一列。您必须将这些文本组合成可以 1) 被分词 和 2) 被送入模型的文本对。后者是一个可选的(张量)列表,包含来自数据集中label
,labels
,score
, 或scores
列的标签。该方法必须返回一个单一的损失值,或者一个包含损失分量(分量名到损失值)的字典,这些分量将被相加得到最终的损失值。当返回一个字典时,除了总和损失外,每个分量也将被单独记录,使您能够监控损失的各个组成部分。
为了获得自动生成模型卡的完全支持,您可能还希望实现:
一个
get_config_dict
方法,返回一个包含损失参数的字典。一个
citation
属性,这样您的工作就会在所有使用该损失函数训练的模型中被引用。
可以考虑查看现有的损失函数,以了解损失函数的常用实现方式。