Quora 重复问题
此文件夹包含一个脚本,演示如何训练用于信息检索的稀疏编码器。作为一个简单的例子,我们将使用Quora 重复问题数据集。它包含超过 500,000 个句子,其中有超过 400,000 对注释,说明两个问题是否重复。
在此数据集上训练的模型可用于挖掘重复问题,即,给定一大组句子(在本例中为问题),使用稀疏向量相似性识别所有重复对。
训练
选择正确的损失函数对于微调有用的稀疏编码器模型至关重要。对于给定任务,SparseMultipleNegativesRankingLoss 损失函数是一个很好的开始。
有关完整示例,请参阅 training_splade_quora.py,它利用该损失在此数据集上训练 splade 模型。
SparseMultipleNegativesRankingLoss 特别适用于稀疏编码器的信息检索/语义搜索。一个很好的优点是它只需要正向对,即我们只需要重复问题的示例。
使用此损失很容易,并且不需要调整任何超参数
from datasets import load_dataset
from sentence_transformers import losses
# Assume 'model' is your SparseEncoder model
full_dataset = load_dataset("sentence-transformers/quora-duplicates", "triplet", split="train").select(
range(100000)
)
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
# => Dataset({
# features: ['anchor', 'positive', 'negative],
# num_rows: 99000
# })
loss = losses.SpladeLoss(
model=model,
loss=losses.SparseMultipleNegativesRankingLoss(model=model),
query_regularizer_weight=query_regularizer_weight, # Weight for query loss
document_regularizer_weight=document_regularizer_weight, # Weight for document loss
)
注意
增加批次大小通常会产生更好的结果,因为任务会变得更难。从 100 个问题中识别出正确的重复问题比从 10 个问题中识别出来更困难。因此,建议将训练批次大小设置得尽可能大。对于稀疏模型,批次大小也可能受到梯度累积所需内存的限制,因为稀疏表示在反向传播过程中是密集的。
注意
SparseMultipleNegativesRankingLoss 仅在 (a_i, b_j) (j != i) 实际上是负的、非重复的问题对时才有效。在少数情况下,这个假设是错误的。但在大多数情况下,如果我们随机抽取两个问题,它们不是重复的。如果您的数据集不能满足此属性,SparseMultipleNegativesRankingLoss 可能无法很好地工作。