训练概览
为什么要进行微调?
微调稀疏编码器模型通常会极大提升模型在您的用例上的性能,因为每个任务都需要不同的相似性概念。例如,给定以下新闻文章:
“Apple launches the new iPad” (苹果发布新款iPad)
“NVIDIA is gearing up for the next GPU generation” (英伟达正为下一代GPU做准备)
那么在以下用例中,我们可能有不同的相似性概念:
用于将新闻文章分类为经济、体育、科技、政治等的模型,应该为这些文本生成相似的嵌入。
用于语义文本相似性的模型,应该为这些文本生成不相似的嵌入,因为它们有不同的含义。
用于语义搜索的模型不需要两个文档之间的相似性概念,因为它只应比较查询和文档。
另请参阅 训练示例,其中包含大量针对常见现实世界应用的训练脚本,您可以直接采用。
训练组件
训练稀疏编码器模型涉及 4 到 6 个组件:
模型
稀疏编码器模型由一系列模块、稀疏编码器特定模块或自定义模块组成,具有很大的灵活性。如果您想进一步微调一个 SparseEncoder 模型(例如,它有一个 modules.json 文件),那么您不必担心使用了哪些模块。
from sentence_transformers import SparseEncoder
model = SparseEncoder("naver/splade-cocondenser-ensembledistil")
但如果您想从另一个检查点或从头开始训练,那么以下是您可以使用的最常见的架构:
Splade 模型使用 MLMTransformer
,后接一个 SpladePooling
模块。前者加载一个预训练的掩码语言建模 transformer 模型(例如 BERT, RoBERTa, DistilBERT, ModernBERT 等),后者对 MLMHead 的输出进行池化,以生成一个与词汇表大小相同的稀疏嵌入。
from sentence_transformers import models, SparseEncoder
from sentence_transformers.sparse_encoder.models import MLMTransformer, SpladePooling
# Initialize MLM Transformer (use a fill-mask model)
mlm_transformer = MLMTransformer("google-bert/bert-base-uncased")
# Initialize SpladePooling module
splade_pooling = SpladePooling(pooling_strategy="max")
# Create the Splade model
model = SparseEncoder(modules=[mlm_transformer, splade_pooling])
如果您向 SparseEncoder 提供一个 fill-mask 模型架构,此架构是默认设置,因此使用快捷方式更简单:
from sentence_transformers import SparseEncoder
model = SparseEncoder("google-bert/bert-base-uncased")
# SparseEncoder(
# (0): MLMTransformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'BertForMaskedLM'})
# (1): SpladePooling({'pooling_strategy': 'max', 'activation_function': 'relu', 'word_embedding_dimension': None})
# )
免推理 Splade 使用一个 Router
模块,为查询和文档使用不同的模块。通常对于这种类型的架构,文档部分是传统的 Splade 架构(一个 MLMTransformer
后接一个 SpladePooling
模块),而查询部分是一个 SparseStaticEmbedding
模块,它只为查询中的每个词元返回一个预计算的分数。
from sentence_transformers import SparseEncoder
from sentence_transformers.models import Router
from sentence_transformers.sparse_encoder.models import MLMTransformer, SparseStaticEmbedding, SpladePooling
# Initialize MLM Transformer for document encoding
doc_encoder = MLMTransformer("google-bert/bert-base-uncased")
# Create a router model with different paths for queries and documents
router = Router.for_query_document(
query_modules=[SparseStaticEmbedding(tokenizer=doc_encoder.tokenizer, frozen=False)],
# Document path: full MLM transformer + pooling
document_modules=[doc_encoder, SpladePooling("max")],
)
# Create the inference-free model
model = SparseEncoder(modules=[router], similarity_fn_name="dot")
# SparseEncoder(
# (0): Router(
# (query_0_SparseStaticEmbedding): SparseStaticEmbedding({'frozen': False}, dim:30522, tokenizer: BertTokenizerFast)
# (document_0_MLMTransformer): MLMTransformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'BertForMaskedLM'})
# (document_1_SpladePooling): SpladePooling({'pooling_strategy': 'max', 'activation_function': 'relu', 'word_embedding_dimension': None})
# )
# )
这种架构允许使用轻量级的 SparseStaticEmbedding 方法进行快速的查询时处理,该方法可以被训练并看作是线性权重,而文档则使用完整的 MLM transformer 和 SpladePooling 进行处理。
提示
免推理 Splade 对于查询延迟至关重要的搜索应用特别有用,因为它将计算复杂性转移到可以离线完成的文档索引阶段。
注意
在使用 Router
模块训练模型时,您必须在 SparseEncoderTrainingArguments
中使用 router_mapping
参数,将训练数据集的列映射到正确的路由(“query”或“document”)。例如,如果您的数据集有 ["question", "answer"]
列,那么您可以使用以下映射:
args = SparseEncoderTrainingArguments(
...,
router_mapping={
"question": "query",
"answer": "document",
}
)
此外,建议为 SparseStaticEmbedding 模块使用比模型其余部分高得多的学习率。为此,您应该在 SparseEncoderTrainingArguments
中使用 learning_rate_mapping
参数,将参数模式映射到它们的学习率。例如,如果您想为 SparseStaticEmbedding 模块使用 1e-3
的学习率,为模型其余部分使用 2e-5
的学习率,您可以这样做:
args = SparseEncoderTrainingArguments(
...,
learning_rate=2e-5,
learning_rate_mapping={
r"SparseStaticEmbedding\.*": 1e-3,
}
)
对比稀疏表示 (CSR) 模型在一个密集的 Sentence Transformer 模型之上应用一个 SparseAutoEncoder
模块,该模型通常由一个 Transformer
后接一个 Pooling
模块组成。您可以像这样从头开始初始化一个:
from sentence_transformers import models, SparseEncoder
from sentence_transformers.sparse_encoder.models import SparseAutoEncoder
# Initialize transformer (can be any dense encoder model)
transformer = models.Transformer("google-bert/bert-base-uncased")
# Initialize pooling
pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
# Initialize SparseAutoEncoder module
sae = SparseAutoEncoder(
input_dim=transformer.get_word_embedding_dimension(),
hidden_dim=4 * transformer.get_word_embedding_dimension(),
k=256, # Number of top values to keep
k_aux=512, # Number of top values for auxiliary loss
)
# Create the CSR model
model = SparseEncoder(modules=[transformer, pooling, sae])
或者,如果您的基础模型是 1) 一个密集的 Sentence Transformer 模型或 2) 一个非 MLM Transformer 模型(默认情况下这些模型会作为 Splade 模型加载),那么这个快捷方式将自动为您初始化 CSR 模型:
from sentence_transformers import SparseEncoder
model = SparseEncoder("mixedbread-ai/mxbai-embed-large-v1")
# SparseEncoder(
# (0): Transformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'BertModel'})
# (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
# (2): SparseAutoEncoder({'input_dim': 1024, 'hidden_dim': 4096, 'k': 256, 'k_aux': 512, 'normalize': False, 'dead_threshold': 30})
# )
警告
与(免推理)Splade 模型不同,CSR 模型生成的稀疏嵌入大小与基础模型的词汇表大小不同。这意味着您不能像使用 Splade 模型那样直接解释嵌入中激活了哪些词,在 Splade 模型中,每个维度对应词汇表中的一个特定词元。
除此之外,CSR 模型在那些使用高维表示(例如 1024-4096 维)的密集编码器模型上最为有效。
数据集
SparseEncoderTrainer
使用 datasets.Dataset
(单个数据集)或 datasets.DatasetDict
实例(多个数据集,另请参阅 多数据集训练)进行训练和评估。
如果您想从 Hugging Face Datasets 加载数据,那么您应该使用 datasets.load_dataset()
:
from datasets import load_dataset
train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
print(train_dataset)
"""
Dataset({
features: ['anchor', 'positive', 'negative'],
num_rows: 557850
})
"""
一些数据集(包括 sentence-transformers/all-nli)要求您在数据集名称旁边提供一个“子集”。sentence-transformers/all-nli
有 4 个子集,每个子集具有不同的数据格式:pair、pair-class、pair-score、triplet。
注意
许多可以直接与 Sentence Transformers 一起使用的 Hugging Face 数据集都已标记为 sentence-transformers
,您可以通过浏览 https://hugging-face.cn/datasets?other=sentence-transformers 轻松找到它们。我们强烈建议您浏览这些数据集,以找到可能对您的任务有用的训练数据集。
如果您有常见文件格式的本地数据,那么您可以使用 datasets.load_dataset()
轻松加载这些数据:
from datasets import load_dataset
dataset = load_dataset("csv", data_files="my_file.csv")
或
from datasets import load_dataset
dataset = load_dataset("json", data_files="my_file.json")
如果您有需要一些额外预处理的本地数据,我建议您使用 datasets.Dataset.from_dict()
和一个列表字典来初始化您的数据集,如下所示:
from datasets import Dataset
anchors = []
positives = []
# Open a file, do preprocessing, filtering, cleaning, etc.
# and append to the lists
dataset = Dataset.from_dict({
"anchor": anchors,
"positive": positives,
})
字典中的每个键都将成为结果数据集中的一列。
数据集格式
重要的是,您的数据集格式要与您的损失函数相匹配(或者您选择一个与您的数据集格式相匹配的损失函数)。验证数据集格式是否适用于损失函数涉及两个步骤:
如果您的损失函数需要一个 *标签* (Label),根据 损失概览 表,那么您的数据集必须有一个**名为“label”或“score”的列**。该列会自动被视为标签。
所有不名为“label”或“score”的列都被视为 *输入* (Inputs),根据 损失概览 表。剩余列的数量必须与您选择的损失的有效输入数量相匹配。这些列的名称**无关紧要**,只有**顺序重要**。
例如,给定一个具有 ["text1", "text2", "label"]
列的数据集,其中“label”列包含 0 到 1 之间的浮点相似度分数,我们可以将其与 SparseCoSENTLoss
、SparseAnglELoss
和 SparseCosineSimilarityLoss
一起使用,因为它:
有一个“label”列,这是这些损失函数所要求的。
有 2 个非标签列,正好是这些损失函数所需的数量。
如果您的数据集列顺序不正确,请务必使用 Dataset.select_columns
重新排序。例如,如果您的数据集有 ["good_answer", "bad_answer", "question"]
作为列,那么理论上该数据集可以用于需要(锚点、正例、负例)三元组的损失,但 good_answer
列将被视为锚点,bad_answer
将被视为正例,question
将被视为负例。
此外,如果您的数据集有无关的列(例如 sample_id、metadata、source、type),您应该使用 Dataset.remove_columns
删除它们,否则它们将被用作输入。您也可以使用 Dataset.select_columns
只保留所需的列。
损失函数
损失函数量化了模型在给定一批数据上的表现,从而允许优化器更新模型权重以产生更优(即更低)的损失值。这是训练过程的核心。
遗憾的是,没有一个单一的损失函数能对所有用例都表现最佳。相反,使用哪种损失函数在很大程度上取决于您可用的数据和您的目标任务。请参阅 数据集格式 了解哪些数据集对哪些损失函数有效。此外,损失概览 将是您了解各种选项的最佳助手。
警告
要训练一个 SparseEncoder
,您需要 SpladeLoss
或 CSRLoss
,具体取决于架构。这些是包装损失,在主损失函数之上添加稀疏性正则化,主损失函数必须作为参数提供。唯一可以独立使用的损失是 SparseMSELoss
,因为它执行嵌入级别的蒸馏,通过直接复制教师模型的稀疏嵌入来确保稀疏性。
大多数损失函数只需使用您正在训练的 SparseEncoder
以及一些可选参数即可初始化,例如:
from datasets import load_dataset
from sentence_transformers import SparseEncoder
from sentence_transformers.sparse_encoder.losses import SpladeLoss, SparseMultipleNegativesRankingLoss
# Load a model to train/finetune
model = SparseEncoder("distilbert/distilbert-base-uncased")
# Initialize the SpladeLoss with a SparseMultipleNegativesRankingLoss
# This loss requires pairs of related texts or triplets
loss = SpladeLoss(
model=model,
loss=SparseMultipleNegativesRankingLoss(model=model),
query_regularizer_weight=5e-5, # Weight for query loss
document_regularizer_weight=3e-5,
)
# Load an example training dataset that works with our loss function:
train_dataset = load_dataset("sentence-transformers/natural-questions", split="train")
print(train_dataset)
"""
Dataset({
features: ['query', 'answer'],
num_rows: 100231
})
"""
训练参数
SparseEncoderTrainingArguments
类可用于指定影响训练性能的参数,以及定义跟踪/调试参数。虽然它是可选的,但强烈建议尝试各种有用的参数。
learning_rate
(学习率) lr_scheduler_type
(学习率调度器类型) warmup_ratio
(预热比例) num_train_epochs
(训练轮数) max_steps
(最大步数) per_device_train_batch_size
(每设备训练批次大小) per_device_eval_batch_size
(每设备评估批次大小) auto_find_batch_size
(自动查找批次大小) fp16
bf16
load_best_model_at_end
(结束时加载最佳模型) metric_for_best_model
(最佳模型指标) gradient_accumulation_steps
(梯度累积步数) gradient_checkpointing
(梯度检查点) eval_accumulation_steps
(评估累积步数) optim
(优化器) batch_sampler
(批次采样器) multi_dataset_batch_sampler
(多数据集批次采样器) prompts
(提示) router_mapping
(路由映射) learning_rate_mapping
(学习率映射)以下是如何初始化 SparseEncoderTrainingArguments
的示例:
args = SparseEncoderTrainingArguments(
# Required parameter:
output_dir="models/splade-distilbert-base-uncased-nq",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # losses that use "in-batch negatives" benefit from no duplicates
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=2,
logging_steps=100,
run_name="splade-distilbert-base-uncased-nq", # Will be used in W&B if `wandb` is installed
)
评估器
您可以为 SparseEncoderTrainer
提供一个 eval_dataset
以在训练期间获取评估损失,但在此期间获得更具体的指标也可能很有用。为此,您可以使用评估器在训练前、训练中或训练后评估模型的性能,并获得有用的指标。您既可以同时使用 eval_dataset
和评估器,也可以只使用其中之一,或者都不使用。它们根据 eval_strategy
和 eval_steps
训练参数进行评估。
以下是 Sentence Transformers 为稀疏编码器模型实现的评估器:
评估器 |
所需数据 |
---|---|
带有类别标签的句子对。 |
|
带相似度分数的句子对。 |
|
查询(qid => 问题)、语料库(cid => 文档)和相关文档(qid => set[cid])。 |
|
无需数据。 |
|
源句子供教师模型嵌入,目标句子供学生模型嵌入。可以是相同的文本。 |
|
包含 |
|
两种不同语言的句子对。 |
|
(锚点,正例,负例)对。 |
此外,应使用 SequentialEvaluator
将多个评估器合并为一个评估器,该评估器可以传递给 SparseEncoderTrainer
。
有时您没有所需的评估数据来自己准备这些评估器之一,但您仍然希望跟踪模型在某些常见基准测试上的表现。在这种情况下,您可以使用这些带有 Hugging Face 数据的评估器。
from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator
# Initialize the evaluator. Unlike most other evaluators, this one loads the relevant datasets
# directly from Hugging Face, so there's no mandatory arguments
dev_evaluator = SparseNanoBEIREvaluator()
# You can run evaluation like so:
# results = dev_evaluator(model)
from datasets import load_dataset
from sentence_transformers.evaluation import SimilarityFunction
from sentence_transformers.sparse_encoder.evaluation import SparseEmbeddingSimilarityEvaluator
# Load the STSB dataset (https://hugging-face.cn/datasets/sentence-transformers/stsb)
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
# Initialize the evaluator
dev_evaluator = SparseEmbeddingSimilarityEvaluator(
sentences1=eval_dataset["sentence1"],
sentences2=eval_dataset["sentence2"],
scores=eval_dataset["score"],
main_similarity=SimilarityFunction.COSINE,
name="sts-dev",
)
# You can run evaluation like so:
# results = dev_evaluator(model)
from datasets import load_dataset
from sentence_transformers.evaluation import SimilarityFunction
from sentence_transformers.sparse_encoder.evaluation import SparseTripletEvaluator
# Load triplets from the AllNLI dataset (https://hugging-face.cn/datasets/sentence-transformers/all-nli)
max_samples = 1000
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split=f"dev[:{max_samples}]")
# Initialize the evaluator
dev_evaluator = SparseTripletEvaluator(
anchors=eval_dataset["anchor"],
positives=eval_dataset["positive"],
negatives=eval_dataset["negative"],
main_distance_function=SimilarityFunction.DOT,
name="all-nli-dev",
)
# You can run evaluation like so:
# results = dev_evaluator(model)
提示
当使用较小的 eval_steps
频繁进行训练时评估时,请考虑使用一个微小的 eval_dataset
来最小化评估开销。如果您担心评估集的大小,90-1-9 的训练-评估-测试划分可以提供一个平衡,为最终评估保留一个合理大小的测试集。训练后,您可以使用 trainer.evaluate(test_dataset)
评估模型的测试损失,或使用 test_evaluator(model)
初始化一个测试评估器以获取详细的测试指标。
如果您在训练后但在保存模型前进行评估,您自动生成的模型卡仍将包含测试结果。
警告
使用分布式训练时,评估器仅在第一台设备上运行,而训练和评估数据集则在所有设备间共享。
训练器
SparseEncoderTrainer
是所有先前组件汇集的地方。我们只需为训练器指定模型、训练参数(可选)、训练数据集、评估数据集(可选)、损失函数、评估器(可选),然后就可以开始训练了。让我们看一个将所有这些组件整合在一起的脚本:
import logging
from datasets import load_dataset
from sentence_transformers import (
SparseEncoder,
SparseEncoderModelCardData,
SparseEncoderTrainer,
SparseEncoderTrainingArguments,
)
from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator
from sentence_transformers.sparse_encoder.losses import SparseMultipleNegativesRankingLoss, SpladeLoss
from sentence_transformers.training_args import BatchSamplers
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
# 1. Load a model to finetune with 2. (Optional) model card data
model = SparseEncoder(
"distilbert/distilbert-base-uncased",
model_card_data=SparseEncoderModelCardData(
language="en",
license="apache-2.0",
model_name="DistilBERT base trained on Natural-Questions tuples",
)
)
# 3. Load a dataset to finetune on
full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
# 4. Define a loss function
loss = SpladeLoss(
model=model,
loss=SparseMultipleNegativesRankingLoss(model=model),
query_regularizer_weight=5e-5,
document_regularizer_weight=3e-5,
)
# 5. (Optional) Specify training arguments
run_name = "splade-distilbert-base-uncased-nq"
args = SparseEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
save_total_limit=2,
logging_steps=200,
run_name=run_name, # Will be used in W&B if `wandb` is installed
)
# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = SparseNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=16)
# 7. Create a trainer & train
trainer = SparseEncoderTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# 8. Evaluate the model performance again after training
dev_evaluator(model)
# 9. Save the trained model
model.save_pretrained(f"models/{run_name}/final")
# 10. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)
import logging
from datasets import load_dataset
from sentence_transformers import (
SparseEncoder,
SparseEncoderModelCardData,
SparseEncoderTrainer,
SparseEncoderTrainingArguments,
)
from sentence_transformers.models import Router
from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator
from sentence_transformers.sparse_encoder.losses import SparseMultipleNegativesRankingLoss, SpladeLoss
from sentence_transformers.sparse_encoder.models import MLMTransformer, SparseStaticEmbedding, SpladePooling
from sentence_transformers.training_args import BatchSamplers
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
# 1. Load a model to finetune with 2. (Optional) model card data
mlm_transformer = MLMTransformer("distilbert/distilbert-base-uncased", tokenizer_args={"model_max_length": 512})
splade_pooling = SpladePooling(
pooling_strategy="max", word_embedding_dimension=mlm_transformer.get_sentence_embedding_dimension()
)
router = Router.for_query_document(
query_modules=[SparseStaticEmbedding(tokenizer=mlm_transformer.tokenizer, frozen=False)],
document_modules=[mlm_transformer, splade_pooling],
)
model = SparseEncoder(
modules=[router],
model_card_data=SparseEncoderModelCardData(
language="en",
license="apache-2.0",
model_name="Inference-free SPLADE distilbert-base-uncased trained on Natural-Questions tuples",
),
)
# 3. Load a dataset to finetune on
full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
print(train_dataset)
print(train_dataset[0])
# 4. Define a loss function
loss = SpladeLoss(
model=model,
loss=SparseMultipleNegativesRankingLoss(model=model),
query_regularizer_weight=0,
document_regularizer_weight=3e-4,
)
# 5. (Optional) Specify training arguments
run_name = "inference-free-splade-distilbert-base-uncased-nq"
args = SparseEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=2e-5,
learning_rate_mapping={r"SparseStaticEmbedding\.weight": 1e-3}, # Set a higher learning rate for the SparseStaticEmbedding module
warmup_ratio=0.1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
router_mapping={"query": "query", "answer": "document"}, # Map the column names to the routes
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
save_total_limit=2,
logging_steps=200,
run_name=run_name, # Will be used in W&B if `wandb` is installed
)
# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = SparseNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=16)
# 7. Create a trainer & train
trainer = SparseEncoderTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# 8. Evaluate the model performance again after training
dev_evaluator(model)
# 9. Save the trained model
model.save_pretrained(f"models/{run_name}/final")
# 10. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)
回调
这个稀疏编码器训练器集成了对各种 transformers.TrainerCallback
子类的支持,例如:
SpladeRegularizerWeightSchedulerCallback
用于在训练期间调度SpladeLoss
损失的 lambda 参数。如果安装了
wandb
,WandbCallback
会自动将训练指标记录到 W&B。如果可以访问
tensorboard
,TensorBoardCallback
会将训练指标记录到 TensorBoard。如果安装了
codecarbon
,CodeCarbonCallback
会在训练期间跟踪模型的碳排放。注意:这些碳排放量将被包含在您自动生成的模型卡片中。
有关集成回调以及如何编写您自己的回调的更多信息,请参阅 Transformers 回调文档。
多数据集训练
性能顶尖的模型是使用多个数据集同时训练的。通常情况下,这相当棘手,因为每个数据集的格式都不同。然而,SparseEncoderTrainer
可以在不将每个数据集转换为相同格式的情况下,使用多个数据集进行训练。它甚至可以对每个数据集应用不同的损失函数。使用多个数据集进行训练的步骤如下:
使用一个
Dataset
实例的字典(或一个DatasetDict
)作为train_dataset
(也可选地用于eval_dataset
)。(可选)使用一个损失函数字典,将数据集名称映射到损失。仅当您希望为不同的数据集使用不同的损失函数时才需要。
每个训练/评估批次将只包含来自其中一个数据集的样本。从多个数据集中采样批次的顺序由 MultiDatasetBatchSamplers
枚举定义,该枚举可以通过 multi_dataset_batch_sampler
传递给 SparseEncoderTrainingArguments
。有效选项包括:
MultiDatasetBatchSamplers.ROUND_ROBIN
:从每个数据集中轮流采样,直到其中一个数据集耗尽。使用此策略,很可能不会使用每个数据集的所有样本,但每个数据集的采样机会是均等的。MultiDatasetBatchSamplers.PROPORTIONAL
(默认):按比例从每个数据集中采样。使用此策略,每个数据集的所有样本都会被使用,并且较大的数据集会被更频繁地采样。
训练技巧
稀疏编码器模型在训练时有一些需要注意的特点:
稀疏编码器模型不应仅使用评估分数进行评估,还应考虑嵌入的稀疏性。毕竟,低稀疏性意味着模型嵌入存储成本高,检索速度慢。这也意味着决定稀疏性的参数(例如,
SpladeLoss
中的query_regularizer_weight
、document_regularizer_weight
,以及CSRLoss
中的beta
和gamma
)应进行调整,以在性能和稀疏性之间取得良好平衡。每个 评估器 都会输出active_dims
和sparsity_ratio
指标,可用于评估嵌入的稀疏性。不建议在训练前对未经训练的模型使用评估器,因为稀疏度会非常低,因此内存使用可能会出乎意料地高。
性能更强的稀疏编码器模型几乎完全是通过从更强的教师模型(例如 CrossEncoder 模型)进行蒸馏来训练的,而不是直接从文本对或三元组进行训练。例如,请参阅 SPLADE-v3 论文,该论文使用
SparseDistillKLDivLoss
和SparseMarginMSELoss
进行蒸馏。尽管大多数密集嵌入模型被训练用于余弦相似度,但
SparseEncoder
模型通常被训练用于点积来计算相似度。一些损失函数要求您提供一个相似度函数,在这种情况下,使用点积可能会更好。请注意,您通常可以向损失函数提供model.similarity
或model.similarity_pairwise
。