训练概述

为什么要微调?

微调稀疏编码器模型通常能极大地提高模型在您的用例中的性能,因为每个任务对相似性都有不同的概念。例如,给定新闻文章:

  • “苹果发布新款 iPad”

  • “NVIDIA 正在为下一代 GPU 做准备”

那么在以下用例中,我们可能对相似性有不同的概念:

  • 用于将新闻文章分类为经济、体育、技术、政治等的模型,应为这些文本生成相似的嵌入

  • 用于语义文本相似度的模型应为这些文本生成不相似的嵌入,因为它们的含义不同。

  • 用于语义搜索的模型不需要两个文档之间的相似性概念,因为它只应比较查询和文档。

另请参阅训练示例,其中包含您可以采用的针对常见实际应用的众多训练脚本。

训练组件

训练稀疏编码器模型涉及 4 到 6 个组件

模型

稀疏编码器模型由一系列模块稀疏编码器专用模块自定义模块组成,提供了很大的灵活性。如果您想进一步微调稀疏编码器模型(例如,它有一个modules.json 文件),那么您不必担心使用哪些模块。

from sentence_transformers import SparseEncoder

model = SparseEncoder("naver/splade-cocondenser-ensembledistil")

但是,如果您想从另一个检查点训练,或从头开始训练,那么这些是最常见的架构:

Splade 模型使用MLMTransformer,后跟SpladePooling模块。前者加载一个预训练的掩码语言建模 transformer 模型(例如BERTRoBERTaDistilBERTModernBERT等),后者池化 MLMHead 的输出,以生成一个词汇量大小的单一稀疏嵌入。

from sentence_transformers import models, SparseEncoder
from sentence_transformers.sparse_encoder.models import MLMTransformer, SpladePooling

# Initialize MLM Transformer (use a fill-mask model)
mlm_transformer = MLMTransformer("google-bert/bert-base-uncased")

# Initialize SpladePooling module
splade_pooling = SpladePooling(pooling_strategy="max")

# Create the Splade model
model = SparseEncoder(modules=[mlm_transformer, splade_pooling])

如果您为 SparseEncoder 提供填充掩码模型架构,此架构是默认设置,因此使用快捷方式更简单。

from sentence_transformers import SparseEncoder

model = SparseEncoder("google-bert/bert-base-uncased")
# SparseEncoder(
#   (0): MLMTransformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'BertForMaskedLM'})
#   (1): SpladePooling({'pooling_strategy': 'max', 'activation_function': 'relu', 'word_embedding_dimension': None})
# )

免推理 Splade 使用Router模块,其中包含用于查询和文档的不同模块。通常对于这种架构,文档部分是传统的 Splade 架构(MLMTransformer后跟SpladePooling模块),查询部分是SparseStaticEmbedding模块,它只是返回查询中每个标记的预计算分数。

from sentence_transformers import SparseEncoder
from sentence_transformers.models import Router
from sentence_transformers.sparse_encoder.models import MLMTransformer, SparseStaticEmbedding, SpladePooling

# Initialize MLM Transformer for document encoding
doc_encoder = MLMTransformer("google-bert/bert-base-uncased")

# Create a router model with different paths for queries and documents
router = Router.for_query_document(
    query_modules=[SparseStaticEmbedding(tokenizer=doc_encoder.tokenizer, frozen=False)],
    # Document path: full MLM transformer + pooling
    document_modules=[doc_encoder, SpladePooling("max")],
)

# Create the inference-free model
model = SparseEncoder(modules=[router], similarity_fn_name="dot")
# SparseEncoder(
#   (0): Router(
#     (query_0_SparseStaticEmbedding): SparseStaticEmbedding({'frozen': False}, dim:30522, tokenizer: BertTokenizerFast)
#     (document_0_MLMTransformer): MLMTransformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'BertForMaskedLM'})
#     (document_1_SpladePooling): SpladePooling({'pooling_strategy': 'max', 'activation_function': 'relu', 'word_embedding_dimension': None})
#   )
# )

这种架构允许使用轻量级的 SparseStaticEmbedding 方法进行快速查询时处理,该方法可以被训练并视为线性权重,而文档则使用完整的 MLM Transformer 和 SpladePooling 进行处理。

提示

免推理 Splade 对于查询延迟至关重要的搜索应用特别有用,因为它将计算复杂性转移到文档索引阶段,该阶段可以离线完成。

注意

使用Router模块训练模型时,您必须在SparseEncoderTrainingArguments中使用router_mapping参数将训练数据集列映射到正确的路由(“query”或“document”)。例如,如果您的数据集包含["question", "answer"]列,则可以使用以下映射:

args = SparseEncoderTrainingArguments(
    ...,
    router_mapping={
        "question": "query",
        "answer": "document",
    }
)

此外,建议为 SparseStaticEmbedding 模块使用比模型其余部分高得多的学习率。为此,您应该在SparseEncoderTrainingArguments中使用learning_rate_mapping参数将参数模式映射到其学习率。例如,如果您想为 SparseStaticEmbedding 模块使用1e-3的学习率,而为模型其余部分使用2e-5的学习率,您可以这样做:

args = SparseEncoderTrainingArguments(
    ...,
    learning_rate=2e-5,
    learning_rate_mapping={
        r"SparseStaticEmbedding\.*": 1e-3,
    }
)

对比稀疏表示(CSR)模型在密集 Sentence Transformer 模型之上应用SparseAutoEncoder模块,该模型通常由Transformer后跟Pooling模块组成。您可以从头开始初始化一个,如下所示:

from sentence_transformers import models, SparseEncoder
from sentence_transformers.sparse_encoder.models import SparseAutoEncoder

# Initialize transformer (can be any dense encoder model)
transformer = models.Transformer("google-bert/bert-base-uncased")

# Initialize pooling
pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")

# Initialize SparseAutoEncoder module
sae = SparseAutoEncoder(
    input_dim=transformer.get_word_embedding_dimension(),
    hidden_dim=4 * transformer.get_word_embedding_dimension(),
    k=256,  # Number of top values to keep
    k_aux=512,  # Number of top values for auxiliary loss
)
# Create the CSR model
model = SparseEncoder(modules=[transformer, pooling, sae])

或者,如果您的基础模型是 1) 一个密集 Sentence Transformer 模型,或 2) 一个非 MLM Transformer 模型(这些模型默认作为 Splade 模型加载),那么此快捷方式将自动为您初始化 CSR 模型。

from sentence_transformers import SparseEncoder

model = SparseEncoder("mixedbread-ai/mxbai-embed-large-v1")
# SparseEncoder(
#   (0): Transformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'BertModel'})
#   (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
#   (2): SparseAutoEncoder({'input_dim': 1024, 'hidden_dim': 4096, 'k': 256, 'k_aux': 512, 'normalize': False, 'dead_threshold': 30})
# )

警告

与(免推理)Splade 模型不同,CSR 模型生成的稀疏嵌入与基础模型的词汇量大小不同。这意味着您无法像 Splade 模型那样直接解释嵌入中激活了哪些词,因为在 Splade 模型中,每个维度都对应于词汇表中的特定标记。

除此之外,CSR 模型在使用高维表示(例如 1024-4096 维度)的密集编码器模型上最有效。

数据集

SparseEncoderTrainer使用datasets.Dataset(一个数据集)或datasets.DatasetDict实例(多个数据集,另请参阅多数据集训练)进行训练和评估。

如果您想从Hugging Face Datasets加载数据,那么您应该使用datasets.load_dataset()

from datasets import load_dataset

train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")

print(train_dataset)
"""
Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 557850
})
"""

某些数据集(包括sentence-transformers/all-nli)需要您提供一个“子集”以及数据集名称。sentence-transformers/all-nli有 4 个子集,每个子集都有不同的数据格式:pairpair-classpair-scoretriplet

注意

许多开箱即用与 Sentence Transformers 配合使用的 Hugging Face 数据集已标记为sentence-transformers,这使您可以通过浏览https://hugging-face.cn/datasets?other=sentence-transformers轻松找到它们。我们强烈建议您浏览这些数据集,以找到可能对您的任务有用的训练数据集。

如果您有常见文件格式的本地数据,那么您可以使用datasets.load_dataset()轻松加载这些数据

from datasets import load_dataset

dataset = load_dataset("csv", data_files="my_file.csv")

from datasets import load_dataset

dataset = load_dataset("json", data_files="my_file.json")

如果您有需要额外预处理的本地数据,我建议您使用datasets.Dataset.from_dict()和列表字典来初始化您的数据集,如下所示:

from datasets import Dataset

anchors = []
positives = []
# Open a file, do preprocessing, filtering, cleaning, etc.
# and append to the lists

dataset = Dataset.from_dict({
    "anchor": anchors,
    "positive": positives,
})

字典中的每个键都将成为结果数据集中的一列。

数据集格式

您的数据集格式与您的损失函数匹配(或者您选择与您的数据集格式匹配的损失函数)非常重要。验证数据集格式是否适用于损失函数涉及两个步骤:

  1. 如果您的损失函数根据损失概述表需要一个标签,那么您的数据集必须有一个名为“label”或“score”的列。此列将自动作为标签。

  2. 所有不名为“label”或“score”的列都被视为损失概述表中的输入。剩余列的数量必须与您所选损失的有效输入数量匹配。这些列的名称无关紧要,只有顺序很重要

例如,给定一个包含["text1", "text2", "label"]列的数据集,其中“label”列包含 0 到 1 之间的浮点相似度分数,我们可以将其与SparseCoSENTLossSparseAnglELossSparseCosineSimilarityLoss一起使用,因为它:

  1. 有一个“label”列,这些损失函数需要它。

  2. 有 2 个非标签列,正好是这些损失函数所需的数量。

如果您的列顺序不正确,请务必使用Dataset.select_columns重新排序您的数据集列。例如,如果您的数据集有["good_answer", "bad_answer", "question"]作为列,那么这个数据集技术上可以与需要(锚点、正例、负例)三元组的损失一起使用,但是good_answer列将被视为锚点,bad_answer作为正例,question作为负例。

此外,如果您的数据集有多余的列(例如 sample_id、metadata、source、type),您应该使用Dataset.remove_columns删除它们,否则它们将被用作输入。您还可以使用Dataset.select_columns仅保留所需的列。

损失函数

损失函数量化模型在给定批次数据上的表现,从而使优化器能够更新模型权重以产生更有利(即更低)的损失值。这是训练过程的核心。

遗憾的是,没有一个单一的损失函数适用于所有用例。相反,使用哪个损失函数在很大程度上取决于您的可用数据和目标任务。请参阅数据集格式以了解哪些数据集对哪些损失函数有效。此外,损失概述将是您了解选项的最佳帮手。

警告

要训练SparseEncoder,您需要SpladeLossCSRLoss,具体取决于架构。这些是包装损失,它们在主要损失函数之上添加了稀疏性正则化,主损失函数必须作为参数提供。唯一可以独立使用的损失是SparseMSELoss,因为它执行嵌入级蒸馏,通过直接复制教师的稀疏嵌入来确保稀疏性。

大多数损失函数只需使用您正在训练的SparseEncoder以及一些可选参数进行初始化,例如:

from datasets import load_dataset
from sentence_transformers import SparseEncoder
from sentence_transformers.sparse_encoder.losses import SpladeLoss, SparseMultipleNegativesRankingLoss

# Load a model to train/finetune
model = SparseEncoder("distilbert/distilbert-base-uncased")

# Initialize the SpladeLoss with a SparseMultipleNegativesRankingLoss
# This loss requires pairs of related texts or triplets
loss = SpladeLoss(
    model=model,
    loss=SparseMultipleNegativesRankingLoss(model=model),
    query_regularizer_weight=5e-5,  # Weight for query loss
    document_regularizer_weight=3e-5,
)

# Load an example training dataset that works with our loss function:
train_dataset = load_dataset("sentence-transformers/natural-questions", split="train")
print(train_dataset)
"""
Dataset({
    features: ['query', 'answer'],
    num_rows: 100231
})
"""

训练参数

SparseEncoderTrainingArguments类可用于指定影响训练性能的参数以及定义跟踪/调试参数。虽然它是可选的,但强烈建议尝试各种有用的参数。



以下是SparseEncoderTrainingArguments如何初始化的示例:

args = SparseEncoderTrainingArguments(
    # Required parameter:
    output_dir="models/splade-distilbert-base-uncased-nq",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # losses that use "in-batch negatives" benefit from no duplicates
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="splade-distilbert-base-uncased-nq",  # Will be used in W&B if `wandb` is installed
)

评估器

您可以为SparseEncoderTrainer提供一个eval_dataset以在训练期间获得评估损失,但在训练期间获得更具体的指标也可能很有用。为此,您可以使用评估器在训练之前、期间或之后使用有用的指标评估模型的性能。您可以同时使用eval_dataset和评估器,或其中之一,或两者都不用。它们根据eval_strategyeval_steps训练参数进行评估。

以下是 Sentence Transformers 为稀疏编码器模型实现的所有评估器:

评估器

所需数据

SparseBinaryClassificationEvaluator

带有类标签的对。

SparseEmbeddingSimilarityEvaluator

带有相似度分数的对。

SparseInformationRetrievalEvaluator

查询(qid => 问题)、语料库(cid => 文档)和相关文档(qid => set[cid])。

SparseNanoBEIREvaluator

无需数据。

SparseMSEEvaluator

用教师模型嵌入源句子,用学生模型嵌入目标句子。可以是相同的文本。

SparseRerankingEvaluator

字典列表,例如 {'query': '...', 'positive': [...], 'negative': [...]}

SparseTranslationEvaluator

两种不同语言的句子对。

SparseTripletEvaluator

(锚点,正例,负例)对。

此外,SequentialEvaluator应用于将多个评估器组合成一个评估器,该评估器可以传递给SparseEncoderTrainer

有时您没有所需的评估数据来独自准备这些评估器之一,但您仍然想跟踪模型在一些常见基准上的表现。在这种情况下,您可以将这些评估器与 Hugging Face 的数据一起使用。

from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator

# Initialize the evaluator. Unlike most other evaluators, this one loads the relevant datasets
# directly from Hugging Face, so there's no mandatory arguments
dev_evaluator = SparseNanoBEIREvaluator()
# You can run evaluation like so:
# results = dev_evaluator(model)
from datasets import load_dataset
from sentence_transformers.evaluation import SimilarityFunction
from sentence_transformers.sparse_encoder.evaluation import SparseEmbeddingSimilarityEvaluator

# Load the STSB dataset (https://hugging-face.cn/datasets/sentence-transformers/stsb)
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")

# Initialize the evaluator
dev_evaluator = SparseEmbeddingSimilarityEvaluator(
    sentences1=eval_dataset["sentence1"],
    sentences2=eval_dataset["sentence2"],
    scores=eval_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-dev",
)
# You can run evaluation like so:
# results = dev_evaluator(model)
from datasets import load_dataset
from sentence_transformers.evaluation import SimilarityFunction
from sentence_transformers.sparse_encoder.evaluation import SparseTripletEvaluator

# Load triplets from the AllNLI dataset (https://hugging-face.cn/datasets/sentence-transformers/all-nli)
max_samples = 1000
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split=f"dev[:{max_samples}]")

# Initialize the evaluator
dev_evaluator = SparseTripletEvaluator(
    anchors=eval_dataset["anchor"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
    main_distance_function=SimilarityFunction.DOT,
    name="all-nli-dev",
)
# You can run evaluation like so:
# results = dev_evaluator(model)

提示

当在训练期间频繁使用较小的eval_steps进行评估时,请考虑使用微小的eval_dataset以最大程度地减少评估开销。如果您担心评估集大小,90-1-9 的训练-评估-测试划分可以提供平衡,为最终评估保留一个合理大小的测试集。训练后,您可以使用trainer.evaluate(test_dataset)评估模型的测试损失,或使用test_evaluator(model)初始化测试评估器以获取详细的测试指标。

如果您在训练后但在保存模型之前进行评估,您的自动生成的模型卡仍将包含测试结果。

警告

使用分布式训练时,评估器仅在第一个设备上运行,而训练和评估数据集则在所有设备之间共享。

训练器

SparseEncoderTrainer是所有先前组件的集合。我们只需要指定带有模型、训练参数(可选)、训练数据集、评估数据集(可选)、损失函数、评估器(可选)的训练器,然后就可以开始训练了。让我们看看一个包含所有这些组件的脚本:

import logging

from datasets import load_dataset

from sentence_transformers import (
    SparseEncoder,
    SparseEncoderModelCardData,
    SparseEncoderTrainer,
    SparseEncoderTrainingArguments,
)
from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator
from sentence_transformers.sparse_encoder.losses import SparseMultipleNegativesRankingLoss, SpladeLoss
from sentence_transformers.training_args import BatchSamplers

logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)

# 1. Load a model to finetune with 2. (Optional) model card data
model = SparseEncoder(
    "distilbert/distilbert-base-uncased",
    model_card_data=SparseEncoderModelCardData(
        language="en",
        license="apache-2.0",
        model_name="DistilBERT base trained on Natural-Questions tuples",
    )
)

# 3. Load a dataset to finetune on
full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]

# 4. Define a loss function
loss = SpladeLoss(
    model=model,
    loss=SparseMultipleNegativesRankingLoss(model=model),
    query_regularizer_weight=5e-5,
    document_regularizer_weight=3e-5,
)

# 5. (Optional) Specify training arguments
run_name = "splade-distilbert-base-uncased-nq"
args = SparseEncoderTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    logging_steps=200,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = SparseNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=16)

# 7. Create a trainer & train
trainer = SparseEncoderTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# 8. Evaluate the model performance again after training
dev_evaluator(model)

# 9. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

# 10. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)
import logging

from datasets import load_dataset

from sentence_transformers import (
    SparseEncoder,
    SparseEncoderModelCardData,
    SparseEncoderTrainer,
    SparseEncoderTrainingArguments,
)
from sentence_transformers.models import Router
from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator
from sentence_transformers.sparse_encoder.losses import SparseMultipleNegativesRankingLoss, SpladeLoss
from sentence_transformers.sparse_encoder.models import MLMTransformer, SparseStaticEmbedding, SpladePooling
from sentence_transformers.training_args import BatchSamplers

logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)

# 1. Load a model to finetune with 2. (Optional) model card data
mlm_transformer = MLMTransformer("distilbert/distilbert-base-uncased", tokenizer_args={"model_max_length": 512})
splade_pooling = SpladePooling(
    pooling_strategy="max", word_embedding_dimension=mlm_transformer.get_sentence_embedding_dimension()
)
router = Router.for_query_document(
    query_modules=[SparseStaticEmbedding(tokenizer=mlm_transformer.tokenizer, frozen=False)],
    document_modules=[mlm_transformer, splade_pooling],
)

model = SparseEncoder(
    modules=[router],
    model_card_data=SparseEncoderModelCardData(
        language="en",
        license="apache-2.0",
        model_name="Inference-free SPLADE distilbert-base-uncased trained on Natural-Questions tuples",
    ),
)

# 3. Load a dataset to finetune on
full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
print(train_dataset)
print(train_dataset[0])

# 4. Define a loss function
loss = SpladeLoss(
    model=model,
    loss=SparseMultipleNegativesRankingLoss(model=model),
    query_regularizer_weight=0,
    document_regularizer_weight=3e-4,
)

# 5. (Optional) Specify training arguments
run_name = "inference-free-splade-distilbert-base-uncased-nq"
args = SparseEncoderTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    learning_rate_mapping={r"SparseStaticEmbedding\.weight": 1e-3},  # Set a higher learning rate for the SparseStaticEmbedding module
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    router_mapping={"query": "query", "answer": "document"},  # Map the column names to the routes
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    logging_steps=200,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = SparseNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=16)

# 7. Create a trainer & train
trainer = SparseEncoderTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# 8. Evaluate the model performance again after training
dev_evaluator(model)

# 9. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

# 10. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)

回调

此稀疏编码器训练器集成了对各种transformers.TrainerCallback子类的支持,例如:

有关集成回调以及如何编写自己的回调的更多信息,请参阅 Transformers 回调文档。

多数据集训练

性能最佳的模型是同时使用许多数据集进行训练的。通常,这相当棘手,因为每个数据集都有不同的格式。然而,SparseEncoderTrainer可以训练多个数据集而无需将每个数据集转换为相同的格式。它甚至可以为每个数据集应用不同的损失函数。使用多个数据集进行训练的步骤是:

  • 使用Dataset实例的字典(或DatasetDict)作为train_dataset(可选地也作为eval_dataset)。

  • (可选)使用损失函数字典,将数据集名称映射到损失。仅当您希望对不同数据集使用不同的损失函数时才需要。

每个训练/评估批次将只包含来自一个数据集的样本。从多个数据集中采样批次的顺序由MultiDatasetBatchSamplers枚举定义,该枚举可以通过multi_dataset_batch_sampler传递给SparseEncoderTrainingArguments。有效选项有:

  • MultiDatasetBatchSamplers.ROUND_ROBIN:从每个数据集轮流采样,直到其中一个耗尽。使用此策略,可能不会使用每个数据集中的所有样本,但每个数据集都以相同的频率采样。

  • MultiDatasetBatchSamplers.PROPORTIONAL(默认):按其大小比例从每个数据集采样。使用此策略,每个数据集中的所有样本都将被使用,并且较大的数据集会更频繁地被采样。

训练技巧

稀疏编码器模型有一些您在训练它们时应注意的特点:

  1. 稀疏编码器模型不应仅通过评估分数进行评估,还应通过嵌入的稀疏性进行评估。毕竟,低稀疏性意味着模型嵌入存储成本高且检索速度慢。这也意味着决定稀疏性的参数(例如SpladeLoss中的query_regularizer_weightdocument_regularizer_weight以及CSRLoss中的betagamma)应进行调整,以在性能和稀疏性之间取得良好平衡。每个评估器都会输出active_dimssparsity_ratio指标,可用于评估嵌入的稀疏性。

  2. 不建议在训练前对未经训练的模型使用评估器,因为稀疏性会非常低,因此内存使用量可能会出乎意料地高。

  3. 更强的稀疏编码器模型几乎完全通过从更强的教师模型(例如交叉编码器模型)进行蒸馏来训练,而不是直接从文本对或三元组进行训练。例如,参见SPLADE-v3 论文,其中使用SparseDistillKLDivLossSparseMarginMSELoss进行蒸馏。

  4. 虽然大多数密集嵌入模型都经过训练以与余弦相似度一起使用,但SparseEncoder模型通常经过训练以与点积一起使用来计算相似度。某些损失要求您提供相似度函数,而您在那里使用点积可能会更好。请注意,您通常可以为损失提供model.similaritymodel.similarity_pairwise