采样器

批次采样器

class sentence_transformers.training_args.BatchSamplers(value)[source]

存储批次采样器可接受的字符串标识符。

批次采样器负责在训练期间确定样本如何分组到批次中。有效选项是

如果您想使用自定义批次采样器,可以子类化 DefaultBatchSampler 并将该类(而不是实例)传递给 SentenceTransformerTrainingArguments(或 CrossEncoderTrainingArguments 等)中的 batch_sampler 参数。或者,您可以传递一个函数,该函数接受 datasetbatch_sizedrop_lastvalid_label_columnsgeneratorseed,并返回 DefaultBatchSampler 实例。

用法
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.losses import MultipleNegativesRankingLoss
from datasets import Dataset

model = SentenceTransformer("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
    "anchor": ["It's nice weather outside today.", "He drove to work."],
    "positive": ["It's so sunny.", "He took the car to the office."],
})
loss = MultipleNegativesRankingLoss(model)
args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    batch_sampler=BatchSamplers.NO_DUPLICATES,
)
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()
class sentence_transformers.sampler.DefaultBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] | None = None, generator: Generator | None = None, seed: int = 0)[source]

此采样器是 SentenceTransformer 库中使用的默认批次采样器。它等同于 PyTorch BatchSampler。

参数:
  • 采样器 (SamplerIterable) – 用于从数据集中采样元素的采样器,例如 SubsetRandomSampler。

  • batch_size (int) – 每个批次的样本数。

  • drop_last (bool) – 如果为 True,当数据集大小不能被批次大小整除时,丢弃最后一个不完整的批次。

  • valid_label_columns (List[str], 可选) – 要检查标签的列名列表。在数据集中找到的 valid_label_columns 中的第一个列名将用作标签列。

  • generator (torch.Generator, 可选) – 用于洗牌索引的可选随机数生成器。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

class sentence_transformers.sampler.NoDuplicatesBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] | None = None, generator: Generator | None = None, seed: int = 0)[source]

此采样器创建批次,使得每个批次包含的样本值在不同列之间也是唯一的。当损失函数将批次中的其他样本视为批次内负样本时,这很有用,并且您希望确保负样本不是锚点/正样本的重复。

推荐用于
参数:
  • dataset (Dataset) – 要从中采样的数据集。

  • batch_size (int) – 每个批次的样本数。

  • drop_last (bool) – 如果为 True,当数据集大小不能被批次大小整除时,丢弃最后一个不完整的批次。

  • valid_label_columns (List[str], 可选) – 要检查标签的列名列表。在数据集中找到的 valid_label_columns 中的第一个列名将用作标签列。

  • generator (torch.Generator, 可选) – 用于洗牌索引的可选随机数生成器。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

class sentence_transformers.sampler.GroupByLabelBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] | None = None, generator: Generator | None = None, seed: int = 0)[source]

此采样器按标签对样本进行分组,旨在创建批次,使得每个批次中的样本标签尽可能同质。此采样器旨在与 Batch...TripletLoss 类一起使用,这些类要求每个批次中每个标签类别至少包含 2 个示例。

推荐用于
参数:
  • dataset (Dataset) – 要从中采样的数据集。

  • batch_size (int) – 每个批次的样本数。必须能被 2 整除。

  • drop_last (bool) – 如果为 True,当数据集大小不能被批次大小整除时,丢弃最后一个不完整的批次。

  • valid_label_columns (List[str], 可选) – 要检查标签的列名列表。在数据集中找到的 valid_label_columns 中的第一个列名将用作标签列。

  • generator (torch.Generator, 可选) – 用于洗牌索引的可选随机数生成器。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

多数据集批次采样器

class sentence_transformers.training_args.MultiDatasetBatchSamplers(value)[source]

存储多数据集批次采样器可接受的字符串标识符。

多数据集批次采样器负责确定在训练期间从多个数据集中采样批次的顺序。有效选项是

  • MultiDatasetBatchSamplers.ROUND_ROBIN: 使用 RoundRobinBatchSampler,它从每个数据集轮流采样,直到其中一个耗尽。使用此策略,可能不会使用每个数据集中的所有样本,但会平等地从每个数据集中采样。

  • MultiDatasetBatchSamplers.PROPORTIONAL: [默认] 使用 ProportionalBatchSampler,它按数据集大小比例从每个数据集中采样。使用此策略,会使用每个数据集中的所有样本,并且更频繁地从较大的数据集中采样。

如果您想使用自定义多数据集批次采样器,可以子类化 MultiDatasetDefaultBatchSampler 并将该类(而不是实例)传递给 SentenceTransformerTrainingArguments 中的 multi_dataset_batch_sampler 参数。(或 CrossEncoderTrainingArguments 等)。或者,您可以传递一个函数,该函数接受 dataset(一个 ConcatDataset)、batch_samplers(即 ConcatDataset 中每个数据集的批次采样器列表)、generatorseed,并返回 MultiDatasetDefaultBatchSampler 实例。

用法
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.training_args import MultiDatasetBatchSamplers
from sentence_transformers.losses import CoSENTLoss
from datasets import Dataset, DatasetDict

model = SentenceTransformer("microsoft/mpnet-base")
train_general = Dataset.from_dict({
    "sentence_A": ["It's nice weather outside today.", "He drove to work."],
    "sentence_B": ["It's so sunny.", "He took the car to the bank."],
    "score": [0.9, 0.4],
})
train_medical = Dataset.from_dict({
    "sentence_A": ["The patient has a fever.", "The doctor prescribed medication.", "The patient is sweating."],
    "sentence_B": ["The patient feels hot.", "The medication was given to the patient.", "The patient is perspiring."],
    "score": [0.8, 0.6, 0.7],
})
train_legal = Dataset.from_dict({
    "sentence_A": ["This contract is legally binding.", "The parties agree to the terms and conditions."],
    "sentence_B": ["Both parties acknowledge their obligations.", "By signing this agreement, the parties enter into a legal relationship."],
    "score": [0.7, 0.8],
})
train_dataset = DatasetDict({
    "general": train_general,
    "medical": train_medical,
    "legal": train_legal,
})

loss = CoSENTLoss(model)
args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
)
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()
class sentence_transformers.sampler.MultiDatasetDefaultBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator | None = None, seed: int = 0)[source]

从多个批次采样器生成批次的抽象基批次采样器。此类必须被子类化以实现特定的采样策略,不能直接使用。

参数:
  • dataset (ConcatDataset) – 多个数据集的串联。

  • batch_samplers (List[BatchSampler]) – 批次采样器列表,ConcatDataset 中每个数据集一个。

  • generator (torch.Generator, 可选) – 用于可重现采样的生成器。默认为 None。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

class sentence_transformers.sampler.RoundRobinBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator | None = None, seed: int = 0)[source]

批次采样器,从多个批次采样器中以轮询方式生成批次,直到其中一个耗尽。使用此采样器,每个数据集中的所有样本不太可能都被使用,但我们确实确保每个数据集都平等地被采样。

参数:
  • dataset (ConcatDataset) – 多个数据集的串联。

  • batch_samplers (List[BatchSampler]) – 批次采样器列表,ConcatDataset 中每个数据集一个。

  • generator (torch.Generator, 可选) – 用于可重现采样的生成器。默认为 None。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

class sentence_transformers.sampler.ProportionalBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator | None = None, seed: int = 0)[source]

批次采样器,按数据集大小比例从每个数据集采样,直到所有数据集同时耗尽。使用此采样器,每个数据集中的所有样本都将被使用,并且更频繁地从较大的数据集中采样。

参数:
  • dataset (ConcatDataset) – 多个数据集的串联。

  • batch_samplers (List[BatchSampler]) – 批次采样器列表,ConcatDataset 中每个数据集一个。

  • generator (torch.Generator, 可选) – 用于可重现采样的生成器。默认为 None。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。