DPR-模型

开放域问答的密集段落检索(Dense Passage Retrieval for Open-Domain Question Answering) 一文中,Karpukhin 等人基于 Google 的自然问答(Natural Questions)数据集 训练了模型。

  • facebook-dpr-ctx_encoder-single-nq-base

  • facebook-dpr-question_encoder-single-nq-base

他们还结合自然问答(Natural Questions)、TriviaQA、WebQuestions 和 CuratedTREC 数据集训练了模型。

  • facebook-dpr-ctx_encoder-multiset-base

  • facebook-dpr-question_encoder-multiset-base

其中一个模型用于编码段落,另一个模型用于编码问题/查询。

用法

要编码段落,您需要提供标题(例如维基百科文章标题)和文本段落。这两部分必须用 [SEP] 标记分隔。对于编码段落,我们使用 ctx_encoder

查询则使用 question_encoder 进行编码。

from sentence_transformers import SentenceTransformer, util

passage_encoder = SentenceTransformer("facebook-dpr-ctx_encoder-single-nq-base")

passages = [
    "London [SEP] London is the capital and largest city of England and the United Kingdom.",
    "Paris [SEP] Paris is the capital and most populous city of France.",
    "Berlin [SEP] Berlin is the capital and largest city of Germany by both area and population.",
]

passage_embeddings = passage_encoder.encode(passages)

query_encoder = SentenceTransformer("facebook-dpr-question_encoder-single-nq-base")
query = "What is the capital of England?"
query_embedding = query_encoder.encode(query)

# Important: You must use dot-product, not cosine_similarity
scores = util.dot_score(query_embedding, passage_embeddings)
print("Scores:", scores)

重要提示: 当您使用这些模型时,必须使用点积(例如,在 util.dot_score 中实现的)进行计算,而不是余弦相似度。