模型蒸馏
此页面包含用于 SparseEncoder 模型的知识蒸馏示例。知识蒸馏对于训练最强大的稀疏模型至关重要,因为最有效的稀疏编码器部分或完全通过强大的教师模型进行蒸馏训练。
知识蒸馏允许我们将来自更大、计算成本更高的模型(教师模型)的知识压缩到更小、更高效的稀疏模型(学生模型)中。这种方法可以利用更大的模型结果,包括非稀疏模型,如交叉编码器和密集双编码器,将知识压缩到我们的小型稀疏模型中,同时保持大部分原始性能。
MarginMSE
训练代码:train_splade_msmarco_margin_mse.py
SparseMarginMSELoss 基于 Hofstätter 等人 的论文。与使用 SparseMultipleNegativesRankingLoss 训练时一样,我们可以使用三元组:(query, passage1, passage2)。但是,与 MultipleNegativesRankingLoss 不同,我们使用 passage1 和 passage2 的相似度分数,因此它们不必严格为正/负,两者都可以与给定查询相关或不相关。
蒸馏过程通过将知识从强大的教师模型(如交叉编码器集成)转移到我们高效的稀疏编码器学生模型来工作。我们使用教师模型计算 (query, passage1) 和 (query, passage2) 的 交叉编码器 分数。我们在 msmarco-hard-negatives 数据集 中提供了 1.6 亿对此类对的分数,该数据集包含来自 BERT 集成交叉编码器的预计算分数。然后我们计算距离:CE_distance = CEScore(query, passage1) - CEScore(query, passage2)。
对于我们的 SparseEncoder(此处为 Splade 模型)学生训练,我们将 query、passage1 和 passage2 编码为嵌入,然后测量 (query, passage1) 和 (query, passage2) 之间的点积。同样,我们测量距离:SE_distance = DotScore(query, passage1) - DotScore(query, passage2)。
知识转移通过确保 Splade 模型预测的距离与教师交叉编码器预测的距离匹配来实现,即,我们优化 CE_distance 和 SE_distance 之间的均方误差 (MSE)。这使得稀疏模型能够学习更大教师模型的复杂排名行为。