自然语言推断
自然语言推理(NLI)任务是给定两个句子(前提和假设),判断前提是否蕴含假设、两者是否矛盾或是否中立。常用的 NLI 数据集有 SNLI 和 MultiNLI。
要训练一个用于 NLI 任务的 CrossEncoder,请参考以下示例文件:
-
此示例使用
CrossEntropyLoss
来训练 CrossEncoder 模型,使其为“矛盾”、“蕴含”和“中立”这几个类别中的正确类别预测出最高的 logit 值。
您也可以训练并使用 SentenceTransformer
模型来完成此任务。更多详情请参阅Sentence Transformer > 训练示例 > 自然语言推理。
数据
我们将 SNLI 和 MultiNLI 合并成一个我们称之为 AllNLI 的数据集。这两个数据集包含句子对和三个标签之一:蕴含、中立、矛盾。
句子 A (前提) | 句子 B (假设) | 标签 |
---|---|---|
一场有多个男性参与的足球比赛。 | 一些男人在进行一项运动。 | 蕴含 |
一个年长和一个年轻的男人在微笑。 | 两个男人在笑着看地板上玩的猫。 | 中立 |
一个男人在某个东亚国家检查一个人物的制服。 | 那个男人在睡觉。 | 矛盾 |
我们将 AllNLI 格式化为几个不同的子集,以兼容不同的损失函数。例如,请参见 AllNLI 的 pair-class 子集。
CrossEntropyLoss (交叉熵损失)
CrossEntropyLoss
是一个相当基础的损失函数,它将常见的 torch.nn.CrossEntropyLoss
应用于以下过程产生的 logits(也称为输出、原始预测)上:1) 将分词后的文本对传入模型,以及 2) 对 logits 应用可选的激活函数。如果 CrossEncoder 模型需要预测的类别不止一个,那么这个损失函数就非常常用。
推理
您可以使用任何预训练的 NLI CrossEncoder 模型进行推理,如下所示:
from sentence_transformers import CrossEncoder
model = CrossEncoder("cross-encoder/nli-deberta-v3-base")
scores = model.predict([
("A man is eating pizza", "A man eats something"),
("A black race car starts up in front of a crowd of people.", "A man is driving down a lonely road."),
])
# Convert scores to labels
label_mapping = ["contradiction", "entailment", "neutral"]
labels = [label_mapping[score_max] for score_max in scores.argmax(axis=1)]
# => ['entailment', 'contradiction']