自然语言推断
给定两个句子(前提和假设),自然语言推断 (NLI) 是决定前提是否蕴含假设、它们是否矛盾或它们是否中性的任务。常用的 NLI 数据集是 SNLI 和 MultiNLI。
要在 NLI 上训练 CrossEncoder,请参阅以下示例文件
-
此示例使用
CrossEntropyLoss
来训练 CrossEncoder 模型,以预测“矛盾”、“蕴含”和“中性”中正确类别的最高 logits。
您还可以训练和使用 SentenceTransformer
模型来完成此任务。有关更多详细信息,请参阅 Sentence Transformer > 训练示例 > 自然语言推断。
数据
我们将 SNLI 和 MultiNLI 组合成一个我们称之为 AllNLI 的数据集。这两个数据集包含句子对和以下三个标签之一:蕴含、中性、矛盾
句子 A(前提) | 句子 B(假设) | 标签 |
---|---|---|
一场有多名男性参加的足球比赛。 | 一些男人正在进行一项运动。 | 蕴含 |
一个年长的男人和一个年轻的男人在微笑。 | 两个男人在微笑和大笑,看着在地板上玩耍的猫。 | 中性 |
一名男子正在检查某个东亚国家人物的制服。 | 这个人正在睡觉。 | 矛盾 |
我们以几个不同的子集格式化 AllNLI,这些子集与不同的损失函数兼容。例如,请参阅 AllNLI 的 pair-class 子集。
CrossEntropyLoss
CrossEntropyLoss
是一种相当基础的损失,它将常见的 torch.nn.CrossEntropyLoss
应用于 logits(又名输出、原始预测),这些 logits 是在 1) 将标记化的文本对传递到模型中,以及 2) 对 logits 应用可选激活函数后产生的。如果 CrossEncoder 模型必须预测多个类别,则它非常常用。
推理
您可以使用任何 预训练的 NLI CrossEncoder 模型 执行推理,如下所示
from sentence_transformers import CrossEncoder
model = CrossEncoder("cross-encoder/nli-deberta-v3-base")
scores = model.predict([
("A man is eating pizza", "A man eats something"),
("A black race car starts up in front of a crowd of people.", "A man is driving down a lonely road."),
])
# Convert scores to labels
label_mapping = ["contradiction", "entailment", "neutral"]
labels = [label_mapping[score_max] for score_max in scores.argmax(axis=1)]
# => ['entailment', 'contradiction']