模型蒸馏

本页包含 SparseEncoder 模型知识蒸馏的示例。知识蒸馏对于训练最强的稀疏模型至关重要,因为最有效的稀疏编码器部分或完全通过强大教师模型的蒸馏进行训练。

知识蒸馏允许我们将知识从更大、计算成本更高的模型(教师模型)压缩到更小、更高效的稀疏模型(学生模型)中。这种方法可以利用大型模型的结果,包括 Cross-Encoders 和密集双编码器等非稀疏模型,将知识压缩到我们的小型稀疏模型中,同时保持大部分原始性能。

MarginMSE

训练代码:train_splade_msmarco_margin_mse.py

SparseMarginMSELoss 基于 Hofstätter et al. 的论文。与使用 SparseMultipleNegativesRankingLoss 进行训练时类似,我们可以使用三元组:(query, passage1, passage2)。然而,与 MultipleNegativesRankingLoss 不同的是,我们使用 passage1passage2 的相似度得分,因此它们不一定是严格的正/负,两者都可以与给定查询相关或不相关。

蒸馏过程通过将知识从强大的教师模型(如 Cross-Encoder 集成)转移到我们高效的稀疏编码器学生模型。我们使用教师模型计算 (query, passage1)(query, passage2)Cross-Encoder 分数。我们的 msmarco-hard-negatives 数据集中提供了 1.6 亿对这样的分数,其中包含来自 BERT 集成 Cross-Encoder 的预计算分数。然后我们计算距离:CE_distance = CEScore(query, passage1) - CEScore(query, passage2)

对于我们的 SparseEncoder(这里是 Splade 模型)学生训练,我们将 querypassage1passage2 编码为嵌入,然后测量 (query, passage1)(query, passage2) 之间的点积。同样,我们测量距离:SE_distance = DotScore(query, passage1) - DotScore(query, passage2)

知识转移通过确保 Splade 模型预测的距离与教师 Cross-Encoder 预测的距离匹配来实现,即我们优化 CE_distanceSE_distance 之间的均方误差 (MSE)。这使得稀疏模型能够学习更大教师模型的复杂排序行为。