DPR-模型
在 开放域问答的密集段落检索(Dense Passage Retrieval for Open-Domain Question Answering) 一文中,Karpukhin 等人基于 Google 的自然问答(Natural Questions)数据集 训练了模型。
facebook-dpr-ctx_encoder-single-nq-base
facebook-dpr-question_encoder-single-nq-base
他们还结合自然问答(Natural Questions)、TriviaQA、WebQuestions 和 CuratedTREC 数据集训练了模型。
facebook-dpr-ctx_encoder-multiset-base
facebook-dpr-question_encoder-multiset-base
其中一个模型用于编码段落,另一个模型用于编码问题/查询。
用法
要编码段落,您需要提供标题(例如维基百科文章标题)和文本段落。这两部分必须用 [SEP]
标记分隔。对于编码段落,我们使用 ctx_encoder。
查询则使用 question_encoder 进行编码。
from sentence_transformers import SentenceTransformer, util
passage_encoder = SentenceTransformer("facebook-dpr-ctx_encoder-single-nq-base")
passages = [
"London [SEP] London is the capital and largest city of England and the United Kingdom.",
"Paris [SEP] Paris is the capital and most populous city of France.",
"Berlin [SEP] Berlin is the capital and largest city of Germany by both area and population.",
]
passage_embeddings = passage_encoder.encode(passages)
query_encoder = SentenceTransformer("facebook-dpr-question_encoder-single-nq-base")
query = "What is the capital of England?"
query_embedding = query_encoder.encode(query)
# Important: You must use dot-product, not cosine_similarity
scores = util.dot_score(query_embedding, passage_embeddings)
print("Scores:", scores)
重要提示: 当您使用这些模型时,必须使用点积(例如,在 util.dot_score
中实现的)进行计算,而不是余弦相似度。