损失函数概述
损失函数表
损失函数在您微调的交叉编码器模型的性能中起着关键作用。遗憾的是,没有“一刀切”的损失函数。理想情况下,此表应通过将损失函数与您的数据格式匹配来帮助缩小您的选择范围。
注意
您通常可以将一种训练数据格式转换为另一种,从而使更多损失函数适用于您的场景。例如,通过抽样具有相同或不同类别的句子,可以将带有class标签的(sentence_A, sentence_B) pairs转换为(anchor, positive, negative) triplets。
此外,mine_hard_negatives() 可以轻松地将 (anchor, positive) 转换为
(anchor, positive, negative) 三元组,使用output_format="triplet",(anchor, positive, negative_1, …, negative_n) n元组,使用output_format="n-tuple"。(anchor, passage, label) 标记对,其中负样本的标签为0,正样本的标签为1,使用output_format="labeled-pair",(anchor, [doc1, doc2, ..., docN], [label1, label2, ..., labelN]) 三元组,其中负样本的标签为0,正样本的标签为1,使用output_format="labeled-list"
以及通过设置 output_scores=True 来使用相似性分数而非二值化标签的格式。
| 输入 | 标签 | 模型输出标签的数量 | 适用的损失函数 |
|---|---|---|---|
(句子_A, 句子_B) 对 |
类别 |
num_classes |
CrossEntropyLoss |
(锚点, 正例) 对 |
无 |
1 |
MultipleNegativesRankingLossCachedMultipleNegativesRankingLoss |
(锚点, 正例/负例) 对 |
正例为 1,负例为 0 |
1 |
BinaryCrossEntropyLoss |
(句子_A, 句子_B) 对 |
0 到 1 之间的浮点相似度分数 |
1 |
BinaryCrossEntropyLoss |
(锚点, 正例, 负例) 三元组 |
无 |
1 |
MultipleNegativesRankingLossCachedMultipleNegativesRankingLoss |
(锚点, 正例, 负例_1, ..., 负例_n) |
无 |
1 |
MultipleNegativesRankingLossCachedMultipleNegativesRankingLoss |
(query, [doc1, doc2, ..., docN]) |
[score1, score2, ..., scoreN] |
1 |
蒸馏
这些损失函数专门设计用于将知识从一个模型蒸馏到另一个模型。例如,当微调一个小型模型使其行为更像一个更大更强的模型时,或者当微调一个模型使其成为多语言模型时。
| 文本 | 标签 | 适用的损失函数 |
|---|---|---|
(句子_A, 句子_B) 对 |
相似性分数 |
MSELoss |
(查询, 段落_一, 段落_二) 三元组 |
gold_sim(查询, 段落_一) - gold_sim(查询, 段落_二) |
MarginMSELoss |
(查询, 正例, 负例_1, ..., 负例_n) |
[gold_sim(查询, 正例) - gold_sim(查询, 负例_i) for i in 1..n] |
MarginMSELoss |
(查询, 正例, 负例) |
[gold_sim(查询, 正例), gold_sim(查询, 负例)] |
MarginMSELoss |
(查询, 正例, 负例_1, ..., 负例_n) |
[gold_sim(查询, 正例), gold_sim(查询, 负例_i)...] |
MarginMSELoss |
常用损失函数
在实践中,并非所有损失函数的使用频率都相同。最常见的场景是:
(sentence_A, sentence_B) 对具有浮点 相似性 分数或1 如果 正样本, 0 如果 负样本:BinaryCrossEntropyLoss是一个传统的选项,它仍然很难被超越。(anchor, positive) 对不带任何标签:与mine_hard_negatives结合使用如果使用
output_format=”labeled-list”,则LambdaLoss经常用于学习排名任务。如果使用
output_format=”labeled-pair”,则BinaryCrossEntropyLoss仍然是一个强有力的选择。
自定义损失函数
高级用户可以创建并使用自己的损失函数进行训练。自定义损失函数只有几个要求:
它们必须是
torch.nn.Module的子类。它们的构造函数中必须有
model作为第一个参数。它们必须实现一个接受
inputs和labels的forward方法。前者是批处理中嵌套的文本列表,外部列表中的每个元素代表训练数据集中的一列。您必须将这些文本组合成可以 1) 分词和 2) 输入到模型的对。后者是数据集中label、labels、score或scores列中的可选(列表形式的)标签张量。该方法必须返回一个单一的损失值或一个损失组件字典(组件名称到损失值),这些组件将被求和以产生最终的损失值。当返回字典时,除了总和损失外,各个组件将单独记录,允许您监控损失的各个组件。
为了获得自动生成模型卡的完全支持,您可能还希望实现:
一个
get_config_dict方法,返回一个包含损失参数的字典。一个
citation属性,这样您的工作就会在所有使用该损失函数训练的模型中被引用。
可以考虑查看现有的损失函数,以了解损失函数的常用实现方式。