采样器

批处理采样器 (BatchSamplers)

class sentence_transformers.training_args.BatchSamplers(value)[源代码]

存储批处理采样器可接受的字符串标识符。

批处理采样器负责在训练期间确定样本如何分组到批次中。有效选项有:

如果你想使用自定义批处理采样器,可以子类化 DefaultBatchSampler,并将该类(而非实例)传递给 SentenceTransformerTrainingArguments(或 CrossEncoderTrainingArguments 等)中的 batch_sampler 参数。或者,你可以传递一个函数,该函数接受 datasetbatch_sizedrop_lastvalid_label_columnsgeneratorseed,并返回一个 DefaultBatchSampler 实例。

用法
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.losses import MultipleNegativesRankingLoss
from datasets import Dataset

model = SentenceTransformer("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
    "anchor": ["It's nice weather outside today.", "He drove to work."],
    "positive": ["It's so sunny.", "He took the car to the office."],
})
loss = MultipleNegativesRankingLoss(model)
args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    batch_sampler=BatchSamplers.NO_DUPLICATES,
)
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()
class sentence_transformers.sampler.DefaultBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] | None = None, generator: Generator | None = None, seed: int = 0)[源代码]

该采样器是 SentenceTransformer 库中使用的默认批处理采样器。它等同于 PyTorch 的 BatchSampler。

参数:
  • sampler (SamplerIterable) – 用于从数据集中采样元素的采样器,例如 SubsetRandomSampler。

  • batch_size (int) – 每个批次的样本数。

  • drop_last (bool) – 如果为 True,当数据集大小不能被批次大小整除时,丢弃最后一个不完整的批次。

  • valid_label_columns (List[str], 可选) – 用于检查标签的列名列表。在数据集中找到的 valid_label_columns中的第一个列名将被用作标签列。

  • generator (torch.Generator, 可选) – 用于打乱索引的可选随机数生成器。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

class sentence_transformers.sampler.NoDuplicatesBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] | None = None, generator: Generator | None = None, seed: int = 0)[源代码]

该采样器创建的批次中,每个批次包含的样本值都是唯一的,即使跨列也是如此。这在损失函数将批次中的其他样本视为批内负例,并且您希望确保负例不是锚点/正例样本的副本时非常有用。

推荐用于
参数:
  • dataset (Dataset) – 要从中采样的数据集。

  • batch_size (int) – 每个批次的样本数。

  • drop_last (bool) – 如果为 True,当数据集大小不能被批次大小整除时,丢弃最后一个不完整的批次。

  • valid_label_columns (List[str], 可选) – 用于检查标签的列名列表。在数据集中找到的 valid_label_columns中的第一个列名将被用作标签列。

  • generator (torch.Generator, 可选) – 用于打乱索引的可选随机数生成器。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

class sentence_transformers.sampler.GroupByLabelBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] | None = None, generator: Generator | None = None, seed: int = 0)[源代码]

该采样器按标签对样本进行分组,旨在创建批次,使每个批次中样本的标签尽可能同质。该采样器旨在与 Batch...TripletLoss 类一起使用,这些类要求每个批次中每个标签类别至少包含 2 个样本。

推荐用于
参数:
  • dataset (Dataset) – 要从中采样的数据集。

  • batch_size (int) – 每个批次的样本数。必须能被 2 整除。

  • drop_last (bool) – 如果为 True,当数据集大小不能被批次大小整除时,丢弃最后一个不完整的批次。

  • valid_label_columns (List[str], 可选) – 用于检查标签的列名列表。在数据集中找到的 valid_label_columns中的第一个列名将被用作标签列。

  • generator (torch.Generator, 可选) – 用于打乱索引的可选随机数生成器。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

多数据集批处理采样器 (MultiDatasetBatchSamplers)

class sentence_transformers.training_args.MultiDatasetBatchSamplers(value)[源代码]

存储多数据集批处理采样器可接受的字符串标识符。

多数据集批处理采样器负责在训练期间确定从多个数据集中采样批次的顺序。有效选项有:

  • MultiDatasetBatchSamplers.ROUND_ROBIN: 使用 RoundRobinBatchSampler,它以轮询方式从每个数据集中采样,直到其中一个数据集耗尽。使用此策略,很可能不会使用每个数据集的所有样本,但每个数据集的采样是均等的。

  • MultiDatasetBatchSamplers.PROPORTIONAL: [默认] 使用 ProportionalBatchSampler,它根据每个数据集的大小按比例进行采样。使用此策略,将使用每个数据集的所有样本,并且会更频繁地从较大的数据集中采样。

如果你想使用自定义的多数据集批处理采样器,可以子类化 MultiDatasetDefaultBatchSampler,并将该类(而非实例)传递给 SentenceTransformerTrainingArguments(或 CrossEncoderTrainingArguments 等)中的 multi_dataset_batch_sampler 参数。或者,你可以传递一个函数,该函数接受 dataset(一个 ConcatDataset)、batch_samplers(即 ConcatDataset 中每个数据集的批处理采样器列表)、generatorseed,并返回一个 MultiDatasetDefaultBatchSampler 实例。

用法
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.training_args import MultiDatasetBatchSamplers
from sentence_transformers.losses import CoSENTLoss
from datasets import Dataset, DatasetDict

model = SentenceTransformer("microsoft/mpnet-base")
train_general = Dataset.from_dict({
    "sentence_A": ["It's nice weather outside today.", "He drove to work."],
    "sentence_B": ["It's so sunny.", "He took the car to the bank."],
    "score": [0.9, 0.4],
})
train_medical = Dataset.from_dict({
    "sentence_A": ["The patient has a fever.", "The doctor prescribed medication.", "The patient is sweating."],
    "sentence_B": ["The patient feels hot.", "The medication was given to the patient.", "The patient is perspiring."],
    "score": [0.8, 0.6, 0.7],
})
train_legal = Dataset.from_dict({
    "sentence_A": ["This contract is legally binding.", "The parties agree to the terms and conditions."],
    "sentence_B": ["Both parties acknowledge their obligations.", "By signing this agreement, the parties enter into a legal relationship."],
    "score": [0.7, 0.8],
})
train_dataset = DatasetDict({
    "general": train_general,
    "medical": train_medical,
    "legal": train_legal,
})

loss = CoSENTLoss(model)
args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
)
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()
class sentence_transformers.sampler.MultiDatasetDefaultBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator | None = None, seed: int = 0)[源代码]

从多个批处理采样器中生成批次的抽象基础批处理采样器。必须子类化此类以实现特定的采样策略,不能直接使用。

参数:
  • dataset (ConcatDataset) – 多个数据集的串联。

  • batch_samplers (List[BatchSampler]) – 批处理采样器列表,ConcatDataset 中每个数据集对应一个。

  • generator (torch.Generator, 可选) – 用于可复现采样的生成器。默认为 None。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

class sentence_transformers.sampler.RoundRobinBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator | None = None, seed: int = 0)[源代码]

以轮询方式从多个批处理采样器中生成批次的批处理采样器,直到其中一个耗尽。使用此采样器,很可能不会使用每个数据集的所有样本,但我们确实确保每个数据集的采样是均等的。

参数:
  • dataset (ConcatDataset) – 多个数据集的串联。

  • batch_samplers (List[BatchSampler]) – 批处理采样器列表,ConcatDataset 中每个数据集对应一个。

  • generator (torch.Generator, 可选) – 用于可复现采样的生成器。默认为 None。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

class sentence_transformers.sampler.ProportionalBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator | None = None, seed: int = 0)[源代码]

根据每个数据集的大小按比例进行采样的批处理采样器,直到所有数据集同时耗尽。使用此采样器,将使用每个数据集的所有样本,并且会更频繁地从较大的数据集中采样。

参数:
  • dataset (ConcatDataset) – 多个数据集的串联。

  • batch_samplers (List[BatchSampler]) – 批处理采样器列表,ConcatDataset 中每个数据集对应一个。

  • generator (torch.Generator, 可选) – 用于可复现采样的生成器。默认为 None。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。