损失函数概述

警告

要训练 SparseEncoder,你需要 SpladeLossCSRLoss,具体取决于架构。这些是包装器损失,它们在主损失函数之上添加稀疏性正则化,主损失函数必须作为参数提供。唯一可以独立使用的损失是 SparseMSELoss,因为它执行嵌入级别的蒸馏,通过直接复制教师模型的稀疏嵌入来确保稀疏性。

稀疏特定损失函数

SPLADE 损失

SpladeLoss 为 SPLADE (Sparse Lexical and Expansion) 模型实现了一个专门的损失函数。它将一个主损失函数与正则化项相结合以控制效率。

  • 支持下面提到的所有损失作为主损失,但主要有三种损失类型:SparseMultipleNegativesRankingLossSparseMarginMSELossSparseDistillKLDivLoss

  • 默认使用 FlopsLoss 进行正则化以控制稀疏性,但也支持自定义正则化器。

  • 通过对查询和文档表示进行正则化,平衡了有效性(通过主损失)和效率。

  • 允许通过 query_regularizerdocument_regularizer 参数为查询和文档使用不同的正则化器,从而可以对不同类型的输入进行细粒度的稀疏模式控制。

  • 通过 query_regularizer_thresholddocument_regularizer_threshold 参数支持对查询和文档使用不同的阈值,允许每种输入类型具有不同的稀疏性严格程度。

CSR 损失

如果你正在使用 SparseAutoEncoder 模块,那么你必须使用 CSRLoss (Contrastive Sparse Representation Loss,对比稀疏表示损失)。它结合了两个部分:

  • 一个重构损失 CSRReconstructionLoss,确保稀疏表示能够忠实地重构原始嵌入。

  • 一个主损失,在论文中是使用 SparseMultipleNegativesRankingLoss 的对比学习部分,确保语义上相似的句子有相似的表示。但理论上,像 SpladeLoss 一样,可以使用下面提到的所有损失作为主损失。

损失函数表

损失函数对微调模型的性能起着至关重要的作用。遗憾的是,没有“一刀切”的损失函数。理想情况下,这张表格应该通过将损失函数与你的数据格式相匹配,来帮助你缩小选择范围。

注意

你通常可以将一种训练数据格式转换为另一种,从而使更多的损失函数适用于你的场景。例如,带有 class 标签的 (sentence_A, sentence_B) 配对可以通过采样具有相同或不同类别的句子,转换为 (anchor, positive, negative) 三元组

注意

SentenceTransformer > 损失概览 中,这里出现的带有 Sparse 前缀的损失函数与其密集版本是相同的。该前缀仅用于指示哪些损失可以用作主损失来训练 SparseEncoder

输入 标签 适用的损失函数
(锚点, 正例) 对 SparseMultipleNegativesRankingLoss
(句子_A, 句子_B) 对 0 到 1 之间的浮点相似度分数 SparseCoSENTLoss
SparseAnglELoss
SparseCosineSimilarityLoss (稀疏余弦相似度损失)
(锚点, 正例, 负例) 三元组 SparseMultipleNegativesRankingLoss
SparseTripletLoss (稀疏三元组损失)
(锚点, 正例, 负例_1, ..., 负例_n) SparseMultipleNegativesRankingLoss

蒸馏

这些损失函数专门设计用于将知识从一个模型蒸馏到另一个模型。这在训练稀疏嵌入模型时相当常用。

文本 标签 适用的损失函数
句子 模型句子嵌入 SparseMSELoss (稀疏均方误差损失)
(sentence_1, sentence_2, ..., sentence_N) (句子1, 句子2, ..., 句子N) 模型句子嵌入 SparseMSELoss (稀疏均方误差损失)
(query, passage_one, passage_two) (查询, 段落一, 段落二) gold_sim(查询, 段落_一) - gold_sim(查询, 段落_二) SparseMarginMSELoss
(查询, 正例, 负例_1, ..., 负例_n) [gold_sim(查询, 正例) - gold_sim(查询, 负例_i) for i in 1..n] SparseMarginMSELoss
(查询, 正例, 负例) [gold_sim(查询, 正例), gold_sim(查询, 负例)] SparseDistillKLDivLoss
SparseMarginMSELoss
(查询, 正例, 负例_1, ..., 负例_n) [gold_sim(查询, 正例), gold_sim(查询, 负例_i)...] SparseDistillKLDivLoss
SparseMarginMSELoss

常用损失函数

在实践中,并非所有损失函数的使用频率都相同。最常见的场景是:

  • 没有任何标签的 (anchor, positive) 配对SparseMultipleNegativesRankingLoss (又名 InfoNCE 或批内负采样损失) 通常用于训练性能顶尖的嵌入模型。这种数据通常获取成本较低,并且模型通常性能非常好。在这里,对于我们的稀疏检索任务,这种格式与 SpladeLossCSRLoss 配合得很好,两者通常都使用 InfoNCE 作为其底层损失函数。

  • (query, positive, negative_1, ..., negative_n) 格式:这种具有多个负样本的结构在使用配置了 SparseMarginMSELossSpladeLoss 时特别有效,尤其是在教师模型提供相似度分数的知识蒸馏场景中。最强的模型是用蒸馏损失训练的,如 SparseDistillKLDivLossSparseMarginMSELoss

自定义损失函数

高级用户可以创建并使用自己的损失函数进行训练。自定义损失函数只有几个要求:

  • 它们必须是 torch.nn.Module 的子类。

  • 它们的构造函数中必须有 model 作为第一个参数。

  • 它们必须实现一个 forward 方法,该方法接受 sentence_featureslabels。前者是分词后批次的列表,每个元素对应一列。这些分词后的批次可以直接输入到正在训练的 model 中以生成嵌入。后者是可选的标签张量。该方法必须返回单个损失值或一个损失分量字典(分量名称到损失值的映射),这些分量将被相加以产生最终的损失值。当返回字典时,除了总损失外,各个分量也会被单独记录,从而使你能够监控损失的各个组成部分。

为了获得自动生成模型卡的完全支持,您可能还希望实现:

  • 一个 get_config_dict 方法,返回一个包含损失参数的字典。

  • 一个 citation 属性,这样您的工作就会在所有使用该损失函数训练的模型中被引用。

可以考虑查看现有的损失函数,以了解损失函数的常用实现方式。