损失函数概述
警告
要训练 SparseEncoder
,你需要 SpladeLoss
或 CSRLoss
,具体取决于架构。这些是包装器损失,它们在主损失函数之上添加稀疏性正则化,主损失函数必须作为参数提供。唯一可以独立使用的损失是 SparseMSELoss
,因为它执行嵌入级别的蒸馏,通过直接复制教师模型的稀疏嵌入来确保稀疏性。
稀疏特定损失函数
SPLADE 损失
SpladeLoss
为 SPLADE (Sparse Lexical and Expansion) 模型实现了一个专门的损失函数。它将一个主损失函数与正则化项相结合以控制效率。
支持下面提到的所有损失作为主损失,但主要有三种损失类型:
SparseMultipleNegativesRankingLoss
、SparseMarginMSELoss
和SparseDistillKLDivLoss
。默认使用
FlopsLoss
进行正则化以控制稀疏性,但也支持自定义正则化器。通过对查询和文档表示进行正则化,平衡了有效性(通过主损失)和效率。
允许通过
query_regularizer
和document_regularizer
参数为查询和文档使用不同的正则化器,从而可以对不同类型的输入进行细粒度的稀疏模式控制。通过
query_regularizer_threshold
和document_regularizer_threshold
参数支持对查询和文档使用不同的阈值,允许每种输入类型具有不同的稀疏性严格程度。
CSR 损失
如果你正在使用 SparseAutoEncoder
模块,那么你必须使用 CSRLoss
(Contrastive Sparse Representation Loss,对比稀疏表示损失)。它结合了两个部分:
一个重构损失
CSRReconstructionLoss
,确保稀疏表示能够忠实地重构原始嵌入。一个主损失,在论文中是使用
SparseMultipleNegativesRankingLoss
的对比学习部分,确保语义上相似的句子有相似的表示。但理论上,像SpladeLoss
一样,可以使用下面提到的所有损失作为主损失。
损失函数表
损失函数对微调模型的性能起着至关重要的作用。遗憾的是,没有“一刀切”的损失函数。理想情况下,这张表格应该通过将损失函数与你的数据格式相匹配,来帮助你缩小选择范围。
注意
你通常可以将一种训练数据格式转换为另一种,从而使更多的损失函数适用于你的场景。例如,带有
class
标签的(sentence_A, sentence_B) 配对
可以通过采样具有相同或不同类别的句子,转换为(anchor, positive, negative) 三元组
。
注意
在 SentenceTransformer > 损失概览 中,这里出现的带有 Sparse
前缀的损失函数与其密集版本是相同的。该前缀仅用于指示哪些损失可以用作主损失来训练 SparseEncoder
。
输入 | 标签 | 适用的损失函数 |
---|---|---|
(锚点, 正例) 对 |
无 |
SparseMultipleNegativesRankingLoss |
(句子_A, 句子_B) 对 |
0 到 1 之间的浮点相似度分数 |
SparseCoSENTLoss SparseAnglELoss SparseCosineSimilarityLoss (稀疏余弦相似度损失) |
(锚点, 正例, 负例) 三元组 |
无 |
SparseMultipleNegativesRankingLoss SparseTripletLoss (稀疏三元组损失) |
(锚点, 正例, 负例_1, ..., 负例_n) |
无 |
SparseMultipleNegativesRankingLoss |
蒸馏
这些损失函数专门设计用于将知识从一个模型蒸馏到另一个模型。这在训练稀疏嵌入模型时相当常用。
文本 | 标签 | 适用的损失函数 |
---|---|---|
句子 |
模型句子嵌入 |
SparseMSELoss (稀疏均方误差损失) |
(sentence_1, sentence_2, ..., sentence_N) (句子1, 句子2, ..., 句子N) |
模型句子嵌入 |
SparseMSELoss (稀疏均方误差损失) |
(query, passage_one, passage_two) (查询, 段落一, 段落二) |
gold_sim(查询, 段落_一) - gold_sim(查询, 段落_二) |
SparseMarginMSELoss |
(查询, 正例, 负例_1, ..., 负例_n) |
[gold_sim(查询, 正例) - gold_sim(查询, 负例_i) for i in 1..n] |
SparseMarginMSELoss |
(查询, 正例, 负例) |
[gold_sim(查询, 正例), gold_sim(查询, 负例)] |
SparseDistillKLDivLoss SparseMarginMSELoss |
(查询, 正例, 负例_1, ..., 负例_n) |
[gold_sim(查询, 正例), gold_sim(查询, 负例_i)...] |
SparseDistillKLDivLoss SparseMarginMSELoss |
常用损失函数
在实践中,并非所有损失函数的使用频率都相同。最常见的场景是:
没有任何标签的
(anchor, positive) 配对
:SparseMultipleNegativesRankingLoss
(又名 InfoNCE 或批内负采样损失) 通常用于训练性能顶尖的嵌入模型。这种数据通常获取成本较低,并且模型通常性能非常好。在这里,对于我们的稀疏检索任务,这种格式与SpladeLoss
或CSRLoss
配合得很好,两者通常都使用 InfoNCE 作为其底层损失函数。(query, positive, negative_1, ..., negative_n)
格式:这种具有多个负样本的结构在使用配置了SparseMarginMSELoss
的SpladeLoss
时特别有效,尤其是在教师模型提供相似度分数的知识蒸馏场景中。最强的模型是用蒸馏损失训练的,如SparseDistillKLDivLoss
或SparseMarginMSELoss
。
自定义损失函数
高级用户可以创建并使用自己的损失函数进行训练。自定义损失函数只有几个要求:
它们必须是
torch.nn.Module
的子类。它们的构造函数中必须有
model
作为第一个参数。它们必须实现一个
forward
方法,该方法接受sentence_features
和labels
。前者是分词后批次的列表,每个元素对应一列。这些分词后的批次可以直接输入到正在训练的model
中以生成嵌入。后者是可选的标签张量。该方法必须返回单个损失值或一个损失分量字典(分量名称到损失值的映射),这些分量将被相加以产生最终的损失值。当返回字典时,除了总损失外,各个分量也会被单独记录,从而使你能够监控损失的各个组成部分。
为了获得自动生成模型卡的完全支持,您可能还希望实现:
一个
get_config_dict
方法,返回一个包含损失参数的字典。一个
citation
属性,这样您的工作就会在所有使用该损失函数训练的模型中被引用。
可以考虑查看现有的损失函数,以了解损失函数的常用实现方式。