Quora 重复问题
此文件夹包含一个脚本,演示如何训练用于**信息检索**的稀疏编码器。作为一个简单的例子,我们将使用 Quora 重复问题数据集。它包含超过 500,000 个句子,以及超过 400,000 对问题是否重复的成对标注。
在此数据集上训练的模型可用于挖掘重复问题,即,给定大量句子(在此例中为问题),使用稀疏向量相似性识别所有重复对。
训练
选择正确的损失函数对于微调有用的稀疏编码器模型至关重要。对于给定任务,SparseMultipleNegativesRankingLoss
损失函数是一个很好的起点。
有关完整示例,请参阅**training_splade_quora.py**,该脚本利用此损失函数在此数据集上训练 splade 模型。
SparseMultipleNegativesRankingLoss
特别适用于稀疏编码器进行信息检索/语义搜索。一个很好的优点是它只需要正样本对,即我们只需要重复问题对的示例。
使用此损失函数很简单,无需调整任何超参数
from datasets import load_dataset
from sentence_transformers import losses
# Assume 'model' is your SparseEncoder model
full_dataset = load_dataset("sentence-transformers/quora-duplicates", "triplet", split="train").select(
range(100000)
)
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
# => Dataset({
# features: ['anchor', 'positive', 'negative],
# num_rows: 99000
# })
loss = losses.SpladeLoss(
model=model,
loss=losses.SparseMultipleNegativesRankingLoss(model=model),
query_regularizer_weight=query_regularizer_weight, # Weight for query loss
document_regularizer_weight=document_regularizer_weight, # Weight for document loss
)
注意
增加批次大小通常会产生更好的结果,因为任务变得更难。从 100 个问题中识别出正确的重复问题比从仅 10 个问题中识别更困难。因此,建议将训练批次大小设置得尽可能大。对于稀疏模型,批次大小也可能受到梯度累积所需内存的限制,因为稀疏表示在反向传播期间是密集的。
注意
SparseMultipleNegativesRankingLoss
仅在 (a_i, b_j) (j != i) 确实是一个负向的、非重复的问题对时才有效。在少数情况下,这个假设是错误的。但在大多数情况下,如果我们抽取两个随机问题,它们不是重复的。如果您的数据集不满足此属性,SparseMultipleNegativesRankingLoss
可能无法很好地工作。