模型蒸馏
本页包含一个针对 SparseEncoder 模型的知识蒸馏示例。知识蒸馏对于训练最强大的稀疏模型至关重要,因为最有效的稀疏编码器部分或完全是通过强大的教师模型进行蒸馏训练的。
知识蒸馏允许我们将知识从更大、计算成本更高的模型(教师模型)压缩到更小、更高效的稀疏模型(学生模型)中。这种方法可以利用更大模型(包括像 Cross-Encoders 和密集双编码器这样的非稀疏模型)的结果,将知识压缩到我们的小型稀疏模型中,同时保留大部分原始性能。
MarginMSE
训练代码:train_splade_msmarco_margin_mse.py
SparseMarginMSELoss
基于 Hofstätter 等人的论文。与使用 SparseMultipleNegativesRankingLoss
进行训练时类似,我们可以使用三元组:(query, passage1, passage2)
。然而,与 MultipleNegativesRankingLoss
不同,我们为 passage1
和 passage2
使用相似度得分,因此它们不必是严格的正例/负例,两者对于给定的查询都可能相关或不相关。
蒸馏过程通过将知识从强大的教师模型(如 Cross-Encoder 集成模型)转移到我们高效的稀疏编码器学生模型中来实现。我们使用教师模型计算 (query, passage1)
和 (query, passage2)
的 Cross-Encoder 得分。我们在 msmarco-hard-negatives 数据集中为 1.6 亿个这样的数据对提供了得分,该数据集包含了由 BERT 集成 Cross-Encoder 预先计算的得分。然后我们计算距离:CE_distance = CEScore(query, passage1) - CEScore(query, passage2)
。
对于我们的 SparseEncoder(这里是一个 Splade 模型)学生模型训练,我们将 query
、passage1
和 passage2
编码为嵌入,然后测量 (query, passage1)
和 (query, passage2)
之间的点积。同样,我们测量距离:SE_distance = DotScore(query, passage1) - DotScore(query, passage2)
。
知识转移的实现方式是确保 Splade 模型预测的距离与教师 Cross-Encoder 预测的距离相匹配,即我们优化 CE_distance
和 SE_distance
之间的均方误差 (MSE)。这使得稀疏模型能够学习到体量大得多的教师模型的复杂排序行为。