创建自定义模型
Sentence Transformer 模型的结构
Sentence Transformer 模型由按顺序执行的模块集合(文档)组成。最常见的架构是 Transformer 模块、Pooling 模块的组合,以及可选的 Dense 模块和/或 Normalize 模块。
Transformer:此模块负责处理输入文本并生成上下文化的嵌入。Pooling:此模块通过聚合嵌入来降低 Transformer 模块输出的维度。常见的池化策略包括平均池化和 CLS 池化。Dense:此模块包含一个线性层,用于后处理 Pooling 模块的嵌入输出。Normalize:此模块规范化上一层的嵌入。
例如,流行的 all-MiniLM-L6-v2 模型也可以通过初始化构成该模型的 3 个特定模块来加载
from sentence_transformers import models, SentenceTransformer
transformer = models.Transformer("sentence-transformers/all-MiniLM-L6-v2", max_seq_length=256)
pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
normalize = models.Normalize()
model = SentenceTransformer(modules=[transformer, pooling, normalize])
保存 Sentence Transformer 模型
每当保存 Sentence Transformer 模型时,会生成三种类型的文件
modules.json:此文件包含用于重建模型的模块名称、路径和类型列表。config_sentence_transformers.json:此文件包含 Sentence Transformer 模型的一些配置选项,包括保存的提示、模型及其相似性函数以及模型作者使用的 Sentence Transformer 包版本。模块特定文件:每个模块都保存在单独的子文件夹中,子文件夹以模块索引和模型名称命名(例如,
1_Pooling,2_Normalize),但第一个模块如果其save_in_root属性设置为True,则可以保存在根目录中。在 Sentence Transformers 中,Transformer和CLIPModel模块就是这种情况。大多数模块文件夹包含一个config.json(对于Transformer模块为sentence_bert_config.json)文件,该文件存储传递给该模块的关键字参数的默认值。因此,sentence_bert_config.json的{ "max_seq_length": 4096, "do_lower_case": false }
意味着
Transformer模块将使用max_seq_length=4096和do_lower_case=False进行初始化。
因此,如果我对上一个代码片段中的 model 调用 SentenceTransformer.save_pretrained("local-all-MiniLM-L6-v2"),将生成以下文件
local-all-MiniLM-L6-v2/
├── 1_Pooling
│ └── config.json
├── 2_Normalize
├── README.md
├── config.json
├── config_sentence_transformers.json
├── model.safetensors
├── modules.json
├── sentence_bert_config.json
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer_config.json
└── vocab.txt
这包含一个具有以下内容的 modules.json
[
{
"idx": 0,
"name": "0",
"path": "",
"type": "sentence_transformers.models.Transformer"
},
{
"idx": 1,
"name": "1",
"path": "1_Pooling",
"type": "sentence_transformers.models.Pooling"
},
{
"idx": 2,
"name": "2",
"path": "2_Normalize",
"type": "sentence_transformers.models.Normalize"
}
]
以及一个具有以下内容的 config_sentence_transformers.json
{
"__version__": {
"sentence_transformers": "3.0.1",
"transformers": "4.43.4",
"pytorch": "2.5.0"
},
"prompts": {},
"default_prompt_name": null,
"similarity_fn_name": null
}
此外,1_Pooling 目录包含 Pooling 模块的配置文件,而 2_Normalize 目录是空的,因为 Normalize 模块不需要任何配置。sentence_bert_config.json 文件包含 Transformer 模块的配置,并且此模块还在根目录中保存了许多与分词器和模型本身相关的文件。
加载 Sentence Transformer 模型
要从已保存的模型目录加载 Sentence Transformer 模型,会读取 modules.json 以确定构成模型的模块。每个模块都使用相应模块目录中存储的配置进行初始化,然后使用加载的模块实例化 SentenceTransformer 类。
来自 Transformers 模型的 Sentence Transformer 模型
当您使用纯 Transformers 模型(例如 BERT、RoBERTa、DistilBERT、T5)初始化 Sentence Transformer 模型时,Sentence Transformers 默认创建 Transformer 模块和平均池化模块。这提供了一种利用预训练语言模型进行句子嵌入的简单方法。
具体来说,这两个片段是相同的
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("bert-base-uncased")
from sentence_transformers import models, SentenceTransformer
transformer = models.Transformer("bert-base-uncased")
pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[transformer, pooling])
高级:自定义模块
输入模块
管道中的第一个模块称为输入模块。它负责对输入文本进行分词并为后续模块生成输入特征。输入模块可以是实现 InputModule 类的任何模块,该类是 Module 类的子类。
它有三个您需要实现的抽象方法
一个
forward()方法,接受一个features字典,其中包含input_ids、attention_mask、token_type_ids、token_embeddings和sentence_embedding等键,具体取决于模块在模型管道中的位置。一个
save()方法,将模块的配置和可选权重保存到提供的目录中。一个
tokenize()方法,接受输入列表并返回一个字典,其中包含input_ids、attention_mask、token_type_ids、pixel_values等键。此字典将传递给模块的forward方法。
可选地,您还可以实现以下方法
一个
load()静态方法,接受model_name_or_path参数、用于从 Hugging Face 加载的关键字参数(subfolder、token、cache_folder等)和模块关键字参数(model_kwargs、trust_remote_code、backend等),并根据该目录或模型名称中的模块配置初始化模块。一个
get_sentence_embedding_dimension()方法,返回模块生成的句子嵌入的维度。如果模块生成嵌入或更新嵌入的维度,则需要此方法。一个
get_max_seq_length()方法,返回模块可以处理的最大序列长度。仅当模块处理输入文本时才需要。
后续模块
管道中的后续模块称为非输入模块。它们负责处理输入模块生成的输入特征并生成最终的句子嵌入。非输入模块可以是实现 Module 类的任何模块。
它有两个您需要实现的抽象方法
一个
forward()方法,接受一个features字典,其中包含input_ids、attention_mask、token_type_ids、token_embeddings和sentence_embedding等键,具体取决于模块在模型管道中的位置。一个
save()方法,将模块的配置和可选权重保存到提供的目录中。
可选地,您还可以实现以下方法
一个
load()静态方法,接受model_name_or_path参数、用于从 Hugging Face 加载的关键字参数(subfolder、token、cache_folder等)和模块关键字参数(model_kwargs、trust_remote_code、backend等),并根据该目录或模型名称中的模块配置初始化模块。一个
get_sentence_embedding_dimension()方法,返回模块生成的句子嵌入的维度。如果模块生成嵌入或更新嵌入的维度,则需要此方法。
示例模块
例如,我们可以通过实现自定义模块来创建自定义池化方法。
# decay_pooling.py
import torch
from sentence_transformers.models import Module
class DecayMeanPooling(Module):
config_keys: list[str] = ["dimension", "decay"]
def __init__(self, dimension: int, decay: float = 0.95, **kwargs) -> None:
super(DecayMeanPooling, self).__init__()
self.dimension = dimension
self.decay = decay
def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
# This module is expected to be used after some modules that provide "token_embeddings"
# and "attention_mask" in the features dictionary.
token_embeddings = features["token_embeddings"]
attention_mask = features["attention_mask"].unsqueeze(-1)
# Apply the attention mask to filter away padding tokens
token_embeddings = token_embeddings * attention_mask
# Calculate mean of token embeddings
sentence_embeddings = token_embeddings.sum(1) / attention_mask.sum(1)
# Apply exponential decay
importance_per_dim = self.decay ** torch.arange(
sentence_embeddings.size(1), device=sentence_embeddings.device
)
features["sentence_embedding"] = sentence_embeddings * importance_per_dim
return features
def get_sentence_embedding_dimension(self) -> int:
return self.dimension
def save(self, output_path, *args, safe_serialization=True, **kwargs) -> None:
self.save_config(output_path)
# The `load` method by default loads the config.json file from the model directory
# and initializes the class with the loaded parameters, i.e. the `config_keys`.
# This works for us, so no need to override it.
注意
建议在 __init__、forward、save、load 和 tokenize 方法中添加 **kwargs,以确保这些方法与 Sentence Transformers 库的未来更新保持兼容。
现在这可以作为 Sentence Transformer 模型中的模块使用
from sentence_transformers import models, SentenceTransformer
from decay_pooling import DecayMeanPooling
transformer = models.Transformer("bert-base-uncased", max_seq_length=256)
decay_mean_pooling = DecayMeanPooling(transformer.get_word_embedding_dimension(), decay=0.99)
normalize = models.Normalize()
model = SentenceTransformer(modules=[transformer, decay_mean_pooling, normalize])
print(model)
"""
SentenceTransformer(
(0): Transformer({'max_seq_length': 256, 'do_lower_case': False, 'architecture': 'BertModel'})
(1): DecayMeanPooling()
(2): Normalize()
)
"""
texts = [
"Hello, World!",
"The quick brown fox jumps over the lazy dog.",
"I am a sentence that is used for testing purposes.",
"This is a test sentence.",
"This is another test sentence.",
]
embeddings = model.encode(texts)
print(embeddings.shape)
# [5, 768]
您可以使用 SentenceTransformer.save_pretrained 保存此模型,这将生成一个 modules.json,内容如下
[
{
"idx": 0,
"name": "0",
"path": "",
"type": "sentence_transformers.models.Transformer"
},
{
"idx": 1,
"name": "1",
"path": "1_DecayMeanPooling",
"type": "decay_pooling.DecayMeanPooling"
},
{
"idx": 2,
"name": "2",
"path": "2_Normalize",
"type": "sentence_transformers.models.Normalize"
}
]
为确保可以导入 decay_pooling.DecayMeanPooling,您应该将 decay_pooling.py 文件复制到保存模型的目录中。如果您将模型推送到 Hugging Face Hub,那么您还应该将 decay_pooling.py 文件上传到模型的仓库。然后,每个人都可以通过调用 SentenceTransformer("your-username/your-model-id", trust_remote_code=True) 来使用您的自定义模块。
注意
使用存储在 Hugging Face Hub 上的远程代码的自定义模块需要您的用户在加载模型时将 trust_remote_code 指定为 True。这是一种安全措施,可防止远程代码执行攻击。
如果您的模型和自定义建模代码在 Hugging Face Hub 上,那么将您的自定义模块分离到单独的仓库中可能是有意义的。这样,您只需维护一个自定义模块的实现,并且可以在多个模型中重复使用它。您可以通过更新 modules.json 文件中的 type 来实现,以包含存储自定义模块的仓库路径,例如 {repository_id}--{dot_path_to_module}。例如,如果 decay_pooling.py 文件存储在名为 my-user/my-model-implementation 的仓库中,并且模块名为 DecayMeanPooling,那么 modules.json 文件可能如下所示
[
{
"idx": 0,
"name": "0",
"path": "",
"type": "sentence_transformers.models.Transformer"
},
{
"idx": 1,
"name": "1",
"path": "1_DecayMeanPooling",
"type": "my-user/my-model-implementation--decay_pooling.DecayMeanPooling"
},
{
"idx": 2,
"name": "2",
"path": "2_Normalize",
"type": "sentence_transformers.models.Normalize"
}
]
高级:自定义模块中的关键字参数传递
如果您希望用户能够通过 SentenceTransformer.encode 方法指定自定义关键字参数,则可以将其名称添加到 modules.json 文件中。例如,如果我的模块在用户指定 task 关键字参数时应表现不同,那么您的 modules.json 可能如下所示
[
{
"idx": 0,
"name": "0",
"path": "",
"type": "custom_transformer.CustomTransformer",
"kwargs": ["task"]
},
{
"idx": 1,
"name": "1",
"path": "1_Pooling",
"type": "sentence_transformers.models.Pooling"
},
{
"idx": 2,
"name": "2",
"path": "2_Normalize",
"type": "sentence_transformers.models.Normalize"
}
]
然后,您可以在自定义模块的 forward 方法中访问 task 关键字参数
from sentence_transformers.models import Transformer
class CustomTransformer(Transformer):
def forward(self, features: dict[str, torch.Tensor], task: Optional[str] = None, **kwargs) -> dict[str, torch.Tensor]:
if task == "default":
# Do something
else:
# Do something else
return features
这样,用户在调用 SentenceTransformer.encode 时可以指定 task 关键字参数
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("your-username/your-model-id", trust_remote_code=True)
texts = [...]
model.encode(texts, task="default")