训练概览

为什么要进行微调?

交叉编码器模型通常在 检索和重排 搜索栈中用作第二阶段的重排器。在这种情况下,交叉编码器对来自检索器(可以是 Sentence Transformer 模型)的前 X 个候选项进行重排。为了避免重排器模型降低您用例的性能,对其进行微调可能至关重要。重排器始终只有一个输出标签。

除此之外,交叉编码器模型还可以用作成对分类器。例如,一个在自然语言推理数据上训练的模型可以用来将文本对分类为“矛盾”、“蕴含”和“中性”。成对分类器通常有多个输出标签。

请参阅 训练示例,其中包含针对常见实际应用的众多训练脚本,您可以采用这些脚本。

训练组件

训练交叉编码器模型涉及 4 到 6 个组件,就像训练 Sentence Transformer 模型一样。

模型

交叉编码器模型通过加载一个带有序列分类头的预训练 transformers 模型来初始化。如果模型本身没有这样的头,那么它将被自动添加。因此,初始化一个交叉编码器模型相当简单:

from sentence_transformers import CrossEncoder

# This model already has a sequence classification head
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
# And this model does not, so it will be added automatically
model = CrossEncoder("google-bert/bert-base-uncased")

提示

您可以在交叉编码器 > 预训练模型文档中找到预训练的重排器模型。

对于其他模型,最强大的预训练模型通常是“编码器模型”,即那些经过训练以生成有意义的输入词元嵌入的模型。您可以在这里找到强大的候选模型:

考虑寻找专为您感兴趣的语言和/或领域设计的基础模型。例如,对于韩语,klue/bert-base 会比 google-bert/bert-base-uncased 效果好得多。

数据集

CrossEncoderTrainer 使用 datasets.Dataset(单个数据集)或 datasets.DatasetDict 实例(多个数据集,另请参见多数据集训练)进行训练和评估。

如果你想从 Hugging Face Datasets 加载数据,那么你应该使用 datasets.load_dataset()

from datasets import load_dataset

train_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="train")
eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="dev")

print(train_dataset)
"""
Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 942069
})
"""

一些数据集(包括 sentence-transformers/all-nli)要求您在数据集名称旁边提供一个“子集”。sentence-transformers/all-nli 有 4 个子集,每个子集的数据格式都不同:pairpair-classpair-scoretriplet

注意

许多与 Sentence Transformers 开箱即用的 Hugging Face 数据集都已标记为 sentence-transformers,您可以通过浏览 https://hugging-face.cn/datasets?other=sentence-transformers 轻松找到它们。我们强烈建议您浏览这些数据集,以寻找可能对您的任务有用的训练数据集。

如果您有常见文件格式的本地数据,那么您可以轻松地使用 datasets.load_dataset() 加载这些数据。

from datasets import load_dataset

dataset = load_dataset("csv", data_files="my_file.csv")

from datasets import load_dataset

dataset = load_dataset("json", data_files="my_file.json")

如果您有需要一些额外预处理的本地数据,我的建议是使用 datasets.Dataset.from_dict() 和一个列表字典来初始化您的数据集,如下所示:

from datasets import Dataset

anchors = []
positives = []
# Open a file, do preprocessing, filtering, cleaning, etc.
# and append to the lists

dataset = Dataset.from_dict({
    "anchor": anchors,
    "positive": positives,
})

字典中的每个键都将成为结果数据集中的一列。

数据集格式

重要的是,您的数据集格式要与您的损失函数相匹配(或者您选择一个与您的数据集格式和模型相匹配的损失函数)。验证数据集格式和模型是否与损失函数兼容涉及三个步骤:

  1. 根据损失概览表,所有不命名为“label”、“labels”、“score”或“scores”的列都被视为*输入*。剩余列的数量必须与您选择的损失函数的有效输入数量相匹配。这些列的名称是**无关紧要的**,只有**顺序重要**。

  2. 如果您的损失函数根据损失概览表需要一个*标签*,那么您的数据集必须有一个**名为“label”、“labels”、“score”或“scores”的列**。该列将自动作为标签。

  3. 模型输出标签的数量与损失概览表中所要求的数量相匹配。

例如,给定一个具有列 ["text1", "text2", "label"] 的数据集,其中“label”列是范围从 0 到 1 的浮点相似度得分,并且模型输出 1 个标签,我们可以将其与 BinaryCrossEntropyLoss 一起使用,因为:

  1. 该数据集有一个“label”列,符合此损失函数的要求。

  2. 该数据集有 2 个非标签列,正好是此损失函数所需的数量。

  3. 该模型有 1 个输出标签,正好符合此损失函数的要求。

如果您的列顺序不正确,请确保使用 Dataset.select_columns 重新排列您的数据集列。例如,如果您的数据集有 ["good_answer", "bad_answer", "question"] 作为列,那么该数据集理论上可以与需要(锚点, 正例, 负例)三元组的损失函数一起使用,但是 good_answer 列将被视为锚点,bad_answer 视为正例,而 question 视为负例。

此外,如果您的数据集有无关的列(例如 sample_id、metadata、source、type),您应该使用 Dataset.remove_columns 删除它们,否则它们将被用作输入。您也可以使用 Dataset.select_columns 来仅保留所需的列。

难负例挖掘

训练 CrossEncoder 模型的成功通常取决于*负例*的质量,即查询-负例得分应该很低的段落。负例可以分为两种类型:

  • 软负例 (Soft negatives):完全不相关的段落。

  • 难负例 (Hard negatives):那些看起来可能与查询相关,但实际上不相关的段落。

一个简洁的例子是:

  • 查询:苹果公司在哪里成立的?

  • 软负例:凯奇河大桥是一座帕克小马桁架桥,横跨阿肯色州核桃岭和帕拉古尔德之间的凯奇河。

  • 难负例:富士苹果是一种在 1930 年代后期开发并于 1962 年上市的苹果品种。

最强大的 CrossEncoder 模型通常被训练来识别难负例,因此能够“挖掘”难负例非常有价值。Sentence Transformers 支持一个强大的 mine_hard_negatives() 函数,可以在给定查询-答案对数据集的情况下提供帮助:

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import mine_hard_negatives

# Load the GooAQ dataset: https://hugging-face.cn/datasets/sentence-transformers/gooaq
train_dataset = load_dataset("sentence-transformers/gooaq", split=f"train").select(range(100_000))
print(train_dataset)

# Mine hard negatives using a very efficient embedding model
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_train_dataset = mine_hard_negatives(
    train_dataset,
    embedding_model,
    num_negatives=5,  # How many negatives per question-answer pair
    range_min=10,  # Skip the x most similar samples
    range_max=100,  # Consider only the x most similar samples
    max_score=0.8,  # Only consider samples with a similarity score of at most x
    absolute_margin=0.1,  # Anchor-negative similarity is at least x lower than anchor-positive similarity
    relative_margin=0.1,  # Anchor-negative similarity is at most 1-x times the anchor-positive similarity, e.g. 90%
    sampling_strategy="top",  # Sample the top negatives from the range
    batch_size=4096,  # Use a batch size of 4096 for the embedding model
    output_format="labeled-pair",  # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
    use_faiss=True,  # Using FAISS is recommended to keep memory usage low (pip install faiss-gpu or pip install faiss-cpu)
)
print(hard_train_dataset)
print(hard_train_dataset[1])
点击查看此脚本的输出。
Dataset({
    features: ['question', 'answer'],
    num_rows: 100000
})

Batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:01<00:00, 12.74it/s]
Batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 37.50it/s]
Querying FAISS index: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:18<00:00,  2.66s/it]
Metric       Positive       Negative     Difference
Count         100,000        436,925
Mean           0.5882         0.4040         0.2157
Median         0.5989         0.4024         0.1836
Std            0.1425         0.0905         0.1013
Min           -0.0514         0.1405         0.1014
25%            0.4993         0.3377         0.1352
50%            0.5989         0.4024         0.1836
75%            0.6888         0.4681         0.2699
Max            0.9748         0.7486         0.7545
Skipped 2,420,871 potential negatives (23.97%) due to the absolute_margin of 0.1.
Skipped 43 potential negatives (0.00%) due to the max_score of 0.8.
Could not find enough negatives for 63075 samples (12.62%). Consider adjusting the range_max, range_min, absolute_margin, relative_margin and max_score parameters if you'd like to find more valid negatives.
Dataset({
    features: ['question', 'answer', 'label'],
    num_rows: 536925
})

{
    'question': 'how to transfer bookmarks from one laptop to another?',
    'answer': 'Using an External Drive Just about any external drive, including a USB thumb drive, or an SD card can be used to transfer your files from one laptop to another. Connect the drive to your old laptop; drag your files to the drive, then disconnect it and transfer the drive contents onto your new laptop.',
    'label': 0
}

损失函数

损失函数量化了模型在给定一批数据上的表现,从而允许优化器更新模型权重以产生更优(即更低)的损失值。这是训练过程的核心。

遗憾的是,没有一个单一的损失函数能适用于所有用例。相反,使用哪种损失函数在很大程度上取决于您可用的数据和您的目标任务。请参阅数据集格式来了解哪些数据集对哪些损失函数有效。此外,损失概览将是您了解各种选项的最佳助手。

大多数损失函数只需要您正在训练的 CrossEncoder 以及一些可选参数即可初始化,例如:

from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.losses import MultipleNegativesRankingLoss

# Load a model to train/finetune
model = CrossEncoder("xlm-roberta-base", num_labels=1) # num_labels=1 is for rerankers

# Initialize the MultipleNegativesRankingLoss
# This loss requires pairs of related texts or triplets
loss = MultipleNegativesRankingLoss(model)

# Load an example training dataset that works with our loss function:
train_dataset = load_dataset("sentence-transformers/gooaq", split="train")

训练参数

CrossEncoderTrainingArguments 类可用于指定影响训练性能的参数,以及定义跟踪/调试参数。虽然它是可选的,但强烈建议尝试各种有用的参数。



下面是一个如何初始化 CrossEncoderTrainingArguments 的示例:

from sentence_transformers.cross_encoder import CrossEncoderTrainingArguments

args = CrossEncoderTrainingArguments(
    # Required parameter:
    output_dir="models/reranker-MiniLM-msmarco-v1",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # losses that use "in-batch negatives" benefit from no duplicates
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="reranker-MiniLM-msmarco-v1",  # Will be used in W&B if `wandb` is installed
)

评估器

您可以向 CrossEncoderTrainer 提供一个 eval_dataset 以在训练期间获取评估损失,但在此期间获得更具体的指标也可能很有用。为此,您可以使用评估器在训练之前、期间或之后评估模型的性能,并提供有用的指标。您可以同时使用 eval_dataset 和评估器,也可以只使用其中一个,或者都不使用。它们根据训练参数中的 eval_strategyeval_steps 进行评估。

以下是 Sentence Transformers 为交叉编码器模型提供的已实现评估器:

评估器

所需数据

CrossEncoderClassificationEvaluator (交叉编码器分类评估器)

带有类别标签的文本对(二分类或多分类)。

CrossEncoderCorrelationEvaluator (交叉编码器相关性评估器)

带相似度分数的句子对。

CrossEncoderNanoBEIREvaluator

无需数据。

CrossEncoderRerankingEvaluator (交叉编码器重排评估器)

{'query': '...', 'positive': [...], 'negative': [...]} 字典的列表。负例可以通过 mine_hard_negatives() 挖掘。

此外,应使用 SequentialEvaluator 将多个评估器合并为一个评估器,该评估器可以传递给 CrossEncoderTrainer

有时您没有所需的评估数据来自己准备这些评估器之一,但您仍然希望跟踪模型在某些常见基准测试上的表现。在这种情况下,您可以使用这些带有 Hugging Face 数据的评估器。

from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator

# Load a model
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")

# Initialize the evaluator. Unlike most other evaluators, this one loads the relevant datasets
# directly from Hugging Face, so there's no mandatory arguments
dev_evaluator = CrossEncoderNanoBEIREvaluator()
# You can run evaluation like so:
# results = dev_evaluator(model)

CrossEncoderRerankingEvaluator 准备数据可能很困难,因为除了查询-正例数据外,您还需要负例。

mine_hard_negatives() 函数有一个方便的 include_positives 参数,可以设置为 True 来同时挖掘正例文本。当作为 documents(这些文档必须 1. 已排序 2. 包含正例)提供给 CrossEncoderRerankingEvaluator 时,评估器不仅会评估 CrossEncoder 的重排性能,还会评估用于挖掘的嵌入模型的原始排名。

例如:

CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries:  1000     Positives: Min 1.0, Mean 1.0, Max 1.0   Negatives: Min 49.0, Mean 49.1, Max 50.0
          Base  -> Reranked
MAP:      53.28 -> 67.28
MRR@10:   52.40 -> 66.65
NDCG@10:  59.12 -> 71.35

请注意,默认情况下,如果您使用带有 documentsCrossEncoderRerankingEvaluator,评估器将会对*所有*正例进行重排,即使它们不在文档中。这对于从评估器中获得更强的信号很有用,但确实给出了一个略微不切实际的性能。毕竟,现在的最高性能是 100,而通常它受限于第一阶段检索器是否实际检索到了正例。

您可以通过在初始化 CrossEncoderRerankingEvaluator 时设置 always_rerank_positives=False 来启用真实行为。使用这种真实的两阶段性能重复相同的脚本会得到:

CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries:  1000     Positives: Min 1.0, Mean 1.0, Max 1.0   Negatives: Min 49.0, Mean 49.1, Max 50.0
          Base  -> Reranked
MAP:      53.28 -> 66.12
MRR@10:   52.40 -> 65.61
NDCG@10:  59.12 -> 70.10
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderRerankingEvaluator
from sentence_transformers.util import mine_hard_negatives

# Load a model
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")

# Load the GooAQ dataset: https://hugging-face.cn/datasets/sentence-transformers/gooaq
full_dataset = load_dataset("sentence-transformers/gooaq", split=f"train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
print(eval_dataset)
"""
Dataset({
    features: ['question', 'answer'],
    num_rows: 1000
})
"""

# Mine hard negatives using a very efficient embedding model
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_eval_dataset = mine_hard_negatives(
    eval_dataset,
    embedding_model,
    corpus=full_dataset["answer"],  # Use the full dataset as the corpus
    num_negatives=50,  # How many negatives per question-answer pair
    batch_size=4096,  # Use a batch size of 4096 for the embedding model
    output_format="n-tuple",  # The output format is (query, positive, negative1, negative2, ...) for the evaluator
    include_positives=True,  # Key: Include the positive answer in the list of negatives
    use_faiss=True,  # Using FAISS is recommended to keep memory usage low (pip install faiss-gpu or pip install faiss-cpu)
)
print(hard_eval_dataset)
"""
Dataset({
    features: ['question', 'answer', 'negative_1', 'negative_2', 'negative_3', 'negative_4', 'negative_5', 'negative_6', 'negative_7', 'negative_8', 'negative_9', 'negative_10', 'negative_11', 'negative_12', 'negative_13', 'negative_14', 'negative_15', 'negative_16', 'negative_17', 'negative_18', 'negative_19', 'negative_20', 'negative_21', 'negative_22', 'negative_23', 'negative_24', 'negative_25', 'negative_26', 'negative_27', 'negative_28', 'negative_29', 'negative_30', 'negative_31', 'negative_32', 'negative_33', 'negative_34', 'negative_35', 'negative_36', 'negative_37', 'negative_38', 'negative_39', 'negative_40', 'negative_41', 'negative_42', 'negative_43', 'negative_44', 'negative_45', 'negative_46', 'negative_47', 'negative_48', 'negative_49', 'negative_50'],
    num_rows: 1000
})
"""

reranking_evaluator = CrossEncoderRerankingEvaluator(
    samples=[
        {
            "query": sample["question"],
            "positive": [sample["answer"]],
            "documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
        }
        for sample in hard_eval_dataset
    ],
    batch_size=32,
    name="gooaq-dev",
)
# You can run evaluation like so
results = reranking_evaluator(model)
"""
CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries:  1000     Positives: Min 1.0, Mean 1.0, Max 1.0   Negatives: Min 49.0, Mean 49.1, Max 50.0
          Base  -> Reranked
MAP:      53.28 -> 67.28
MRR@10:   52.40 -> 66.65
NDCG@10:  59.12 -> 71.35
"""
# {'gooaq-dev_map': 0.6728370126462222, 'gooaq-dev_mrr@10': 0.6665190476190477, 'gooaq-dev_ndcg@10': 0.7135068904582963, 'gooaq-dev_base_map': 0.5327714512001362, 'gooaq-dev_base_mrr@10': 0.5239674603174603, 'gooaq-dev_base_ndcg@10': 0.5912299141913905}
from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderCorrelationEvaluator

# Load a model
model = CrossEncoder("cross-encoder/stsb-TinyBERT-L4")

# Load the STSB dataset (https://hugging-face.cn/datasets/sentence-transformers/stsb)
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
pairs = list(zip(eval_dataset["sentence1"], eval_dataset["sentence2"]))

# Initialize the evaluator
dev_evaluator = CrossEncoderCorrelationEvaluator(
    sentence_pairs=pairs,
    scores=eval_dataset["score"],
    name="sts_dev",
)
# You can run evaluation like so:
# results = dev_evaluator(model)
from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sentence_transformers.evaluation import TripletEvaluator, SimilarityFunction

# Load a model
model = CrossEncoder("cross-encoder/nli-deberta-v3-base")

# Load triplets from the AllNLI dataset (https://hugging-face.cn/datasets/sentence-transformers/all-nli)
max_samples = 1000
eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split=f"dev[:{max_samples}]")

# Create a list of pairs, and map the labels to the labels that the model knows
pairs = list(zip(eval_dataset["premise"], eval_dataset["hypothesis"]))
label_mapping = {0: 1, 1: 2, 2: 0}
labels = [label_mapping[label] for label in eval_dataset["label"]]

# Initialize the evaluator
cls_evaluator = CrossEncoderClassificationEvaluator(
    sentence_pairs=pairs,
    labels=labels,
    name="all-nli-dev",
)
# You can run evaluation like so:
# results = cls_evaluator(model)

警告

当使用分布式训练时,评估器只在第一个设备上运行,而训练和评估数据集则在所有设备间共享。

训练器

CrossEncoderTrainer 是所有先前组件汇集的地方。我们只需为训练器指定模型、训练参数(可选)、训练数据集、评估数据集(可选)、损失函数、评估器(可选),就可以开始训练了。让我们看一个所有这些组件汇集在一起的脚本:

import logging
import traceback

from datasets import load_dataset

from sentence_transformers.cross_encoder import (
    CrossEncoder,
    CrossEncoderModelCardData,
    CrossEncoderTrainer,
    CrossEncoderTrainingArguments,
)
from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
from sentence_transformers.cross_encoder.losses import CachedMultipleNegativesRankingLoss

# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)

model_name = "microsoft/MiniLM-L12-H384-uncased"
train_batch_size = 64
num_epochs = 1
num_rand_negatives = 5  # How many random negatives should be used for each question-answer pair

# 1a. Load a model to finetune with 1b. (Optional) model card data
model = CrossEncoder(
    model_name,
    model_card_data=CrossEncoderModelCardData(
        language="en",
        license="apache-2.0",
        model_name="MiniLM-L12-H384 trained on GooAQ",
    ),
)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)

# 2. Load the GooAQ dataset: https://hugging-face.cn/datasets/sentence-transformers/gooaq
logging.info("Read the gooaq training dataset")
full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
logging.info(train_dataset)
logging.info(eval_dataset)

# 3. Define our training loss.
loss = CachedMultipleNegativesRankingLoss(
    model=model,
    num_negatives=num_rand_negatives,
    mini_batch_size=32,  # Informs the memory usage
)

# 4. Use CrossEncoderNanoBEIREvaluator, a light-weight evaluator for English reranking
evaluator = CrossEncoderNanoBEIREvaluator(
    dataset_names=["msmarco", "nfcorpus", "nq"],
    batch_size=train_batch_size,
)
evaluator(model)

# 5. Define the training arguments
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
run_name = f"reranker-{short_model_name}-gooaq-cmnrl"
args = CrossEncoderTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=num_epochs,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=train_batch_size,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=50,
    logging_first_step=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
    seed=12,
)

# 6. Create the trainer & start training
trainer = CrossEncoderTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=evaluator,
)
trainer.train()

# 7. Evaluate the final model, useful to include these in the model card
evaluator(model)

# 8. Save the final model
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)

# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
    model.push_to_hub(run_name)
except Exception:
    logging.error(
        f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
        f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
        f"and saving it using `model.push_to_hub('{run_name}')`."
    )
import logging
import traceback

import torch
from datasets import load_dataset

from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import (
    CrossEncoder,
    CrossEncoderModelCardData,
    CrossEncoderTrainer,
    CrossEncoderTrainingArguments,
)
from sentence_transformers.cross_encoder.evaluation import (
    CrossEncoderNanoBEIREvaluator,
    CrossEncoderRerankingEvaluator,
)
from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss
from sentence_transformers.evaluation import SequentialEvaluator
from sentence_transformers.util import mine_hard_negatives

# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)


def main():
    model_name = "answerdotai/ModernBERT-base"

    train_batch_size = 64
    num_epochs = 1
    num_hard_negatives = 5  # How many hard negatives should be mined for each question-answer pair

    # 1a. Load a model to finetune with 1b. (Optional) model card data
    model = CrossEncoder(
        model_name,
        model_card_data=CrossEncoderModelCardData(
            language="en",
            license="apache-2.0",
            model_name="ModernBERT-base trained on GooAQ",
        ),
    )
    print("Model max length:", model.max_length)
    print("Model num labels:", model.num_labels)

    # 2a. Load the GooAQ dataset: https://hugging-face.cn/datasets/sentence-transformers/gooaq
    logging.info("Read the gooaq training dataset")
    full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
    dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
    train_dataset = dataset_dict["train"]
    eval_dataset = dataset_dict["test"]
    logging.info(train_dataset)
    logging.info(eval_dataset)

    # 2b. Modify our training dataset to include hard negatives using a very efficient embedding model
    embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
    hard_train_dataset = mine_hard_negatives(
        train_dataset,
        embedding_model,
        num_negatives=num_hard_negatives,  # How many negatives per question-answer pair
        margin=0,  # Similarity between query and negative samples should be x lower than query-positive similarity
        range_min=0,  # Skip the x most similar samples
        range_max=100,  # Consider only the x most similar samples
        sampling_strategy="top",  # Sample the top negatives from the range
        batch_size=4096,  # Use a batch size of 4096 for the embedding model
        output_format="labeled-pair",  # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
        use_faiss=True,
    )
    logging.info(hard_train_dataset)

    # 2c. (Optionally) Save the hard training dataset to disk
    # hard_train_dataset.save_to_disk("gooaq-hard-train")
    # Load again with:
    # hard_train_dataset = load_from_disk("gooaq-hard-train")

    # 3. Define our training loss.
    # pos_weight is recommended to be set as the ratio between positives to negatives, a.k.a. `num_hard_negatives`
    loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))

    # 4a. Define evaluators. We use the CrossEncoderNanoBEIREvaluator, which is a light-weight evaluator for English reranking
    nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
        dataset_names=["msmarco", "nfcorpus", "nq"],
        batch_size=train_batch_size,
    )

    # 4b. Define a reranking evaluator by mining hard negatives given query-answer pairs
    # We include the positive answer in the list of negatives, so the evaluator can use the performance of the
    # embedding model as a baseline.
    hard_eval_dataset = mine_hard_negatives(
        eval_dataset,
        embedding_model,
        corpus=full_dataset["answer"],  # Use the full dataset as the corpus
        num_negatives=30,  # How many documents to rerank
        batch_size=4096,
        include_positives=True,
        output_format="n-tuple",
        use_faiss=True,
    )
    logging.info(hard_eval_dataset)
    reranking_evaluator = CrossEncoderRerankingEvaluator(
        samples=[
            {
                "query": sample["question"],
                "positive": [sample["answer"]],
                "documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
            }
            for sample in hard_eval_dataset
        ],
        batch_size=train_batch_size,
        name="gooaq-dev",
        # Realistic setting: only rerank the positives that the retriever found
        # Set to True to rerank *all* positives
        always_rerank_positives=False,
    )

    # 4c. Combine the evaluators & run the base model on them
    evaluator = SequentialEvaluator([reranking_evaluator, nano_beir_evaluator])
    evaluator(model)

    # 5. Define the training arguments
    short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
    run_name = f"reranker-{short_model_name}-gooaq-bce"
    args = CrossEncoderTrainingArguments(
        # Required parameter:
        output_dir=f"models/{run_name}",
        # Optional training parameters:
        num_train_epochs=num_epochs,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=train_batch_size,
        learning_rate=2e-5,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        dataloader_num_workers=4,
        load_best_model_at_end=True,
        metric_for_best_model="eval_gooaq-dev_ndcg@10",
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=1000,
        save_strategy="steps",
        save_steps=1000,
        save_total_limit=2,
        logging_steps=200,
        logging_first_step=True,
        run_name=run_name,  # Will be used in W&B if `wandb` is installed
        seed=12,
    )

    # 6. Create the trainer & start training
    trainer = CrossEncoderTrainer(
        model=model,
        args=args,
        train_dataset=hard_train_dataset,
        loss=loss,
        evaluator=evaluator,
    )
    trainer.train()

    # 7. Evaluate the final model, useful to include these in the model card
    evaluator(model)

    # 8. Save the final model
    final_output_dir = f"models/{run_name}/final"
    model.save_pretrained(final_output_dir)

    # 9. (Optional) save the model to the Hugging Face Hub!
    # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
    try:
        model.push_to_hub(run_name)
    except Exception:
        logging.error(
            f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
            f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
            f"and saving it using `model.push_to_hub('{run_name}')`."
        )


if __name__ == "__main__":
    main()

回调

这个 CrossEncoder 训练器集成了对各种 transformers.TrainerCallback 子类的支持,例如:

  • WandbCallback,如果安装了 wandb,则自动将训练指标记录到 W&B。

  • TensorBoardCallback,如果可以访问 tensorboard,则将训练指标记录到 TensorBoard。

  • CodeCarbonCallback,如果安装了 codecarbon,则在训练期间跟踪模型的碳排放。

    • 注意:这些碳排放量将被包含在您自动生成的模型卡片中。

有关集成回调以及如何编写您自己的回调的更多信息,请参阅 Transformers 回调文档。

多数据集训练

表现最好的模型是使用多个数据集同时训练的。通常,这相当棘手,因为每个数据集的格式都不同。然而,CrossEncoderTrainer 可以在不将每个数据集转换为相同格式的情况下进行多数据集训练。它甚至可以对每个数据集应用不同的损失函数。使用多个数据集进行训练的步骤如下:

  • 使用一个 Dataset 实例的字典(或一个 DatasetDict)作为 train_dataset(可选地也作为 eval_dataset)。

  • (可选)使用一个损失函数字典,将数据集名称映射到损失。仅当您希望为不同的数据集使用不同的损失函数时才需要。

每个训练/评估批次将只包含来自其中一个数据集的样本。从多个数据集中采样批次的顺序由 MultiDatasetBatchSamplers 枚举定义,该枚举可以通过 multi_dataset_batch_sampler 传递给 CrossEncoderTrainingArguments。有效选项包括:

  • MultiDatasetBatchSamplers.ROUND_ROBIN:从每个数据集中轮流采样,直到其中一个耗尽。使用此策略,很可能不会使用每个数据集的所有样本,但每个数据集的采样是均等的。

  • MultiDatasetBatchSamplers.PROPORTIONAL (默认): 按每个数据集的大小比例进行采样。使用这种策略,每个数据集的所有样本都会被使用,并且较大的数据集会被更频繁地采样。

训练技巧

交叉编码器模型有其独特的怪癖,所以这里有一些技巧可以帮助你:

  1. CrossEncoder 模型过拟合得相当快,因此建议使用像 CrossEncoderNanoBEIREvaluatorCrossEncoderRerankingEvaluator 这样的评估器,并结合 load_best_model_at_endmetric_for_best_model 训练参数,在训练后加载评估性能最佳的模型。

  2. CrossEncoder 对强难负例 (mine_hard_negatives()) 特别敏感。它们教会模型变得非常严格,例如在区分回答问题的段落和与问题相关的段落时非常有用。

    1. 请注意,如果您只使用难负例,您的模型在处理较简单的任务时可能会出人意料地表现更差。这可能意味着,对来自第一阶段检索系统(例如使用 SentenceTransformer 模型)的前 200 个结果进行重排,实际上可能比对前 100 个结果进行重排得到的前 10 个结果更差。同时使用随机负例和难负例进行训练可以缓解此问题。

  3. 不要低估 BinaryCrossEntropyLoss,尽管它比学习到排序 (LambdaLoss, ListNetLoss) 或批内负例 (CachedMultipleNegativesRankingLoss, MultipleNegativesRankingLoss) 损失更简单,但它仍然是一个非常强大的选项,而且它的数据很容易准备,尤其是使用 mine_hard_negatives()

已弃用的训练方式

在 Sentence Transformers v4.0 版本发布之前,模型将使用 CrossEncoder.fit() 方法和一个 DataLoaderInputExample 进行训练,代码大致如下:

from sentence_transformers import CrossEncoder, InputExample
from torch.utils.data import DataLoader

# Define the model. Either from scratch of by loading a pre-trained model
model = CrossEncoder("distilbert/distilbert-base-uncased")

# Define your train examples. You need more than just two examples...
train_examples = [
    InputExample(texts=["What are pandas?", "The giant panda ..."], label=1),
    InputExample(texts=["What's a panda?", "Mount Vesuvius is a ..."], label=0),
]

# Define your train dataset, the dataloader and the train loss
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

# Tune the model
model.fit(train_dataloader=train_dataloader, epochs=1, warmup_steps=100)

自 v4.0 版本发布以来,使用 CrossEncoder.fit() 仍然是可能的,但它会在幕后初始化一个 CrossEncoderTrainer。建议直接使用 Trainer,因为您可以通过 CrossEncoderTrainingArguments 获得更多控制,但依赖 CrossEncoder.fit() 的现有训练脚本应该仍然可以工作。

如果更新后的 CrossEncoder.fit() 出现问题,您也可以通过调用 CrossEncoder.old_fit() 来获得完全相同的旧行为,但该方法计划在未来被完全弃用。

与 SentenceTransformer 训练的比较

训练 CrossEncoder 模型与训练 SentenceTransformer 模型非常相似,但有一些关键区别:

  • 除了 scorelabel 之外,名为 scoreslabels 的列也将在 CrossEncoder 训练中被视为“标签列”。正如您在损失概览文档中所见,某些损失函数需要在具有这些名称之一的列中包含特定的标签/分数。

  • SentenceTransformer 训练中,您不能在训练/评估数据集的列中使用输入列表(例如,文本列表)。对于 CrossEncoder 训练,您**可以**在列中使用(可变大小的)文本列表。例如,这对于 ListNetLoss 类是必需的。

有关训练 SentenceTransformer 模型的更多详细信息,请参阅Sentence Transformer > 训练概览文档。