预训练模型
我们已通过我们的 Cross Encoder Hugging Face 组织发布了各种预训练的 Cross Encoder 模型。此外,Hugging Face Hub 上也公开了许多社区 Cross Encoder 模型。
每个模型都可以这样轻松下载和使用
from sentence_transformers import CrossEncoder
import torch
# Load https://hugging-face.cn/cross-encoder/ms-marco-MiniLM-L6-v2
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2", activation_fn=torch.nn.Sigmoid())
scores = model.predict([
("How many people live in Berlin?", "Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers."),
("How many people live in Berlin?", "Berlin is well known for its museums."),
])
# => array([0.9998173 , 0.01312432], dtype=float32)
Cross-Encoders 需要文本对作为输入,并输出 0…1 的分数(如果使用 Sigmoid 激活函数)。它们不适用于单个句子,也不计算单个文本的嵌入。
MS MARCO
MS MARCO Passage Retrieval 是一个大型数据集,包含来自 Bing 搜索引擎的真实用户查询以及带注释的相关文本段落。在此数据集上训练的模型作为搜索系统的重排器非常有效。
注意
您可以使用 activation_fn=torch.nn.Sigmoid() 初始化这些模型,以强制模型返回 0 到 1 之间的分数。否则,原始值合理范围可在 -10 到 10 之间。
| 模型名称 | NDCG@10 (TREC DL 19) | MRR@10 (MS Marco Dev) | 文档/秒 |
|---|---|---|---|
| cross-encoder/ms-marco-TinyBERT-L2-v2 | 69.84 | 32.56 | 9000 |
| cross-encoder/ms-marco-MiniLM-L2-v2 | 71.01 | 34.85 | 4100 |
| cross-encoder/ms-marco-MiniLM-L4-v2 | 73.04 | 37.70 | 2500 |
| cross-encoder/ms-marco-MiniLM-L6-v2 | 74.30 | 39.01 | 1800 |
| cross-encoder/ms-marco-MiniLM-L12-v2 | 74.31 | 39.02 | 960 |
| cross-encoder/ms-marco-electra-base | 71.99 | 36.41 | 340 |
有关用法详情,请参阅检索与重排。
SQuAD (QNLI)
QNLI 基于 SQuAD 数据集 (HF),由 GLUE 基准 (HF) 引入。给定维基百科中的一段话,注释者创建了可由该段话回答的问题。如果一段话回答了某个问题,则这些模型会输出更高的分数。
| 模型名称 | QNLI 开发集准确率 |
|---|---|
| cross-encoder/qnli-distilroberta-base | 90.96 |
| cross-encoder/qnli-electra-base | 93.21 |
STSbenchmark
以下模型可以这样使用
from sentence_transformers import CrossEncoder
model = CrossEncoder("cross-encoder/stsb-roberta-base")
scores = model.predict([("It's a wonderful day outside.", "It's so sunny today!"), ("It's a wonderful day outside.", "He drove to work earlier.")])
# => array([0.60443085, 0.00240758], dtype=float32)
它们返回一个 0…1 的分数,表示给定句子对的语义相似度。
模型名称 |
STSbenchmark 测试性能 |
|---|---|
85.50 |
|
87.92 |
|
90.17 |
|
91.47 |
Quora 重复问题
这些模型已在 Quora 重复问题数据集上进行了训练。它们可以像 STSb 模型一样使用,并给出 0…1 的分数,表示两个问题是重复问题的概率。
| 模型名称 | 平均精度开发集 |
|---|---|
| cross-encoder/quora-distilroberta-base | 87.48 |
| cross-encoder/quora-roberta-base | 87.80 |
| cross-encoder/quora-roberta-large | 87.91 |
注意
该模型不适用于问题相似度。“如何学习 Java?”和“如何学习 Python?”这两个问题会得到低分,因为它们不是重复问题。对于问题相似度,在 Quora 数据集上训练的 SentenceTransformer 将产生更有意义的结果。
NLI
给定两个句子,它们是相互矛盾、相互包含还是中立?以下模型在 SNLI 和 MultiNLI 数据集上进行了训练。
模型名称 |
MNLI 不匹配集准确率 |
|---|---|
90.04 |
|
88.08 |
|
87.77 |
|
87.55 |
|
87.47 |
|
86.89 |
|
83.98 |
from sentence_transformers import CrossEncoder
model = CrossEncoder("cross-encoder/nli-deberta-v3-base")
scores = model.predict([
("A man is eating pizza", "A man eats something"),
("A black race car starts up in front of a crowd of people.", "A man is driving down a lonely road."),
])
# Convert scores to labels
label_mapping = ["contradiction", "entailment", "neutral"]
labels = [label_mapping[score_max] for score_max in scores.argmax(axis=1)]
# => ['entailment', 'contradiction']
社区模型
社区中的一些著名模型包括