预训练模型

我们通过 Cross Encoder Hugging Face 组织发布了各种预训练的交叉编码器 (Cross Encoder) 模型。此外,许多社区交叉编码器模型也已在 Hugging Face Hub 上公开发布。

这些模型中的每一个都可以像这样轻松下载和使用

from sentence_transformers import CrossEncoder
import torch

# Load https://hugging-face.cn/cross-encoder/ms-marco-MiniLM-L6-v2
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2", activation_fn=torch.nn.Sigmoid())
scores = model.predict([
    ("How many people live in Berlin?", "Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers."),
    ("How many people live in Berlin?", "Berlin is well known for its museums."),
])
# => array([0.9998173 , 0.01312432], dtype=float32)

交叉编码器需要文本对作为输入,并输出一个 0 到 1 之间的分数(如果使用 Sigmoid 激活函数)。它们不适用于单个句子,也不计算单个文本的嵌入。

MS MARCO

MS MARCO Passage Retrieval 是一个大型数据集,包含来自必应搜索引擎的真实用户查询以及带标注的相关文本段落。在该数据集上训练的模型作为搜索系统的重排序器非常有效。

注意

您可以使用 activation_fn=torch.nn.Sigmoid() 来初始化这些模型,以强制模型返回 0 到 1 之间的分数。否则,原始值可以合理地在 -10 到 10 之间。

模型名称 NDCG@10 (TREC DL 19) MRR@10 (MS Marco Dev) 文档 / 秒
cross-encoder/ms-marco-TinyBERT-L2-v2 69.84 32.56 9000
cross-encoder/ms-marco-MiniLM-L2-v2 71.01 34.85 4100
cross-encoder/ms-marco-MiniLM-L4-v2 73.04 37.70 2500
cross-encoder/ms-marco-MiniLM-L6-v2 74.30 39.01 1800
cross-encoder/ms-marco-MiniLM-L12-v2 74.31 39.02 960
cross-encoder/ms-marco-electra-base 71.99 36.41 340

有关用法详情,请参阅检索与重排序 (Retrieve & Re-Rank)

SQuAD (QNLI)

QNLI 基于 SQuAD 数据集 (HF),并由 GLUE 基准测试 (HF) 引入。给定一段来自维基百科的段落,标注者创建了可以用该段落回答的问题。如果一个段落能回答一个问题,这些模型会输出更高的分数。

模型名称 QNLI 开发集上的准确率
cross-encoder/qnli-distilroberta-base 90.96
cross-encoder/qnli-electra-base 93.21

STSbenchmark

以下模型可以像这样使用

from sentence_transformers import CrossEncoder

model = CrossEncoder("cross-encoder/stsb-roberta-base")
scores = model.predict([("It's a wonderful day outside.", "It's so sunny today!"), ("It's a wonderful day outside.", "He drove to work earlier.")])
# => array([0.60443085, 0.00240758], dtype=float32)

它们返回一个 0 到 1 之间的分数,表示给定句子对的语义相似度。

模型名称

STSbenchmark 测试性能

cross-encoder/stsb-TinyBERT-L4

85.50

cross-encoder/stsb-distilroberta-base

87.92

cross-encoder/stsb-roberta-base

90.17

cross-encoder/stsb-roberta-large

91.47

Quora 重复问题

这些模型在 Quora 重复问题数据集上进行了训练。它们可以像 STSb 模型一样使用,并给出一个 0 到 1 之间的分数,表示两个问题是重复问题的概率。

模型名称 开发集上的平均精度
cross-encoder/quora-distilroberta-base 87.48
cross-encoder/quora-roberta-base 87.80
cross-encoder/quora-roberta-large 87.91

注意

该模型不适用于问题相似度。问题“如何学习 Java?”和“如何学习 Python?”会得到低分,因为这些问题不是重复的。对于问题相似度,在 Quora 数据集上训练的 SentenceTransformer 会产生更有意义的结果。

NLI

给定两个句子,它们是相互矛盾、一个蕴含另一个,还是中性关系?以下模型在 SNLIMultiNLI 数据集上进行了训练。

模型名称

MNLI 不匹配集上的准确率

cross-encoder/nli-deberta-v3-base

90.04

cross-encoder/nli-deberta-base

88.08

cross-encoder/nli-deberta-v3-xsmall

87.77

cross-encoder/nli-deberta-v3-small

87.55

cross-encoder/nli-roberta-base

87.47

cross-encoder/nli-MiniLM2-L6-H768

86.89

cross-encoder/nli-distilroberta-base

83.98

from sentence_transformers import CrossEncoder

model = CrossEncoder("cross-encoder/nli-deberta-v3-base")
scores = model.predict([
    ("A man is eating pizza", "A man eats something"),
    ("A black race car starts up in front of a crowd of people.", "A man is driving down a lonely road."),
])

# Convert scores to labels
label_mapping = ["contradiction", "entailment", "neutral"]
labels = [label_mapping[score_max] for score_max in scores.argmax(axis=1)]
# => ['entailment', 'contradiction']

社区模型

社区中一些值得注意的模型包括