Quora 重复问题
此文件夹包含演示如何为**信息检索**任务训练 SentenceTransformers 的脚本。作为一个简单的示例,我们将使用 Quora 重复问题数据集。它包含超过 50 万个句子和超过 40 万个成对的标注,用于判断两个问题是否重复。
在此数据集上训练的模型可用于挖掘重复问题,即,给定一个大型句子集合(在本例中是问题),识别所有重复的句子对。由于 CrossEncoder
模型只处理成对的文本,因此它们最好在使用 SentenceTransformer
模型进行初步筛选后部署。有关如何使用句子转换器在数十万个句子中挖掘重复问题/释义的示例,请参见 Sentence Transformer > 用法 > 释义挖掘。
在初步筛选后,可以使用 CrossEncoder
模型将前 100 个候选重新排序为前 10 个。因为 CrossEncoder
可以在句子对中的句子之间应用注意力机制,所以该模型可以比 SentenceTransformer
给出更好的分数。
要在 Quora 重复问题数据集上训练 CrossEncoder,请参阅以下示例文件:
training_quora_duplicate_questions.py:
此示例使用
BinaryCrossEntropyLoss
来训练 CrossEncoder 模型,使其为相同的问题给出高分,为不同的问题给出低分。
您也可以为此任务训练和使用 SentenceTransformer
模型。更多详情请参见 Sentence Transformer > 训练示例 > Quora 重复问题。
训练
选择正确的损失函数对于微调有用的模型至关重要。对于任何只有一个输出类别(即只输出一个分数)的 CrossEncoder
模型,BinaryCrossEntropyLoss
仍然是一个非常可靠的损失函数。

对于每个问题对,我们将问题 A 和问题 B 输入到基于 BERT 的模型中,然后一个分类器头将来自基于 BERT 模型的中介表示转换为相似度分数。使用此损失函数,我们应用 torch.nn.BCEWithLogitsLoss
,它接受 logits(也称为输出、原始预测)和黄金相似度分数(如果重复则为 1,不重复则为 0),以计算一个表示模型表现如何的损失。然后最小化此损失以提高模型的性能。
推理
您可以使用任何用于重复问题检测的预训练 CrossEncoder 模型进行推理,如下所示:
from sentence_transformers import CrossEncoder
model = CrossEncoder('cross-encoder/quora-distilroberta-base')
scores = model.predict([
('What do apples consist of?', 'What are in Apple devices?'),
('How do I get good at programming?', 'How to become a good programmer?')
])
print(scores)
# [0.00056, 0.97536]