Quora 重复问题

此文件夹包含演示如何训练 SentenceTransformers 用于信息检索的脚本。作为一个简单的例子,我们将使用 Quora 重复问题数据集。它包含超过 500,000 个句子,以及超过 400,000 个成对注释,指示两个问题是否重复。

在此数据集上训练的模型可用于挖掘重复问题,即,给定大量句子(在本例中为问题),识别所有重复的句子对。由于 CrossEncoder 模型仅适用于文本对,因此最好在使用 SentenceTransformer 模型进行初始过滤后部署它们。有关如何使用 sentence transformers 挖掘数十万句子中的重复问题/释义的示例,请参阅 Sentence Transformer > 用法 > 释义挖掘

在初始过滤之后,可以使用 CrossEncoder 模型将例如前 100 个候选重新排序为例如前 10 个。 因为 CrossEncoder 可以跨句子对应用注意力机制,所以该模型可以给出比 SentenceTransformer 更好的分数。

要在 Quora 重复问题数据集上训练 CrossEncoder,请参阅以下示例文件

您还可以训练和使用 SentenceTransformer 模型来完成此任务。 有关更多详细信息,请参阅 Sentence Transformer > 训练示例 > Quora 重复问题

训练

选择正确的损失函数对于微调有用的模型至关重要。对于只有一个输出类(即仅输出一个分数)的任何 CrossEncoder 模型,BinaryCrossEntropyLoss 仍然是一个非常可靠的损失函数。

CrossEncoder architecture

对于每个问题对,我们将问题 A 和问题 B 传递到基于 BERT 的模型中,之后分类器头部将来自基于 BERT 的模型的中间表示转换为相似度分数。使用此损失函数,我们应用 torch.nn.BCEWithLogitsLoss,它接受 logits(也称为输出、原始预测)和黄金相似度分数(如果重复则为 1,如果不重复则为 0)来计算损失,表示模型表现如何。然后最小化此损失以提高模型的性能。

推理

您可以像这样使用任何 预训练的用于重复问题检测的 CrossEncoder 模型 执行推理

from sentence_transformers import CrossEncoder

model = CrossEncoder('cross-encoder/quora-distilroberta-base')
scores = model.predict([
    ('What do apples consist of?', 'What are in Apple devices?'),
    ('How do I get good at programming?', 'How to become a good programmer?')
])
print(scores)
# [0.00056, 0.97536]