创建自定义模型
Sentence Transformer 模型的结构
Sentence Transformer 模型由一系列按顺序执行的模块(文档)组成。最常见的架构是 Transformer
模块、Pooling
模块以及可选的 Dense
模块和/或 Normalize
模块的组合。
Transformer
:此模块负责处理输入文本并生成上下文嵌入。Pooling
:此模块通过聚合嵌入来降低 Transformer 模块输出的维度。常见的池化策略包括均值池化和 CLS 池化。Dense
:此模块包含一个线性层,用于后处理来自 Pooling 模块的嵌入输出。Normalize
:此模块标准化来自上一层的嵌入。
例如,流行的 all-MiniLM-L6-v2 模型也可以通过初始化构成该模型的 3 个特定模块来加载
from sentence_transformers import models, SentenceTransformer
transformer = models.Transformer("sentence-transformers/all-MiniLM-L6-v2", max_seq_length=256)
pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
normalize = models.Normalize()
model = SentenceTransformer(modules=[transformer, pooling, normalize])
保存 Sentence Transformer 模型
每当保存 Sentence Transformer 模型时,都会生成三种类型的文件
modules.json
:此文件包含用于重建模型的模块名称、路径和类型的列表。config_sentence_transformers.json
:此文件包含 Sentence Transformer 模型的一些配置选项,包括已保存的 prompts、模型的相似度函数以及模型作者使用的 Sentence Transformer 包版本。模块特定文件:每个模块都保存在以模块索引和模型名称命名的单独子文件夹中(例如,
1_Pooling
、2_Normalize
),除非第一个模块的save_in_root
属性设置为True
时,可以保存在根目录中。在 Sentence Transformers 中,Transformer
和CLIPModel
模块就是这种情况。大多数模块文件夹都包含一个config.json
文件(或Transformer
模块的sentence_bert_config.json
文件),该文件存储传递给该模块的关键字参数的默认值。因此,sentence_bert_config.json
文件{ "max_seq_length": 4096, "do_lower_case": false }
意味着
Transformer
模块将使用max_seq_length=4096
和do_lower_case=False
进行初始化。
因此,如果我在前一个代码片段中的 model
上调用 SentenceTransformer.save_pretrained("local-all-MiniLM-L6-v2")
,则会生成以下文件
local-all-MiniLM-L6-v2/
├── 1_Pooling
│ └── config.json
├── 2_Normalize
├── README.md
├── config.json
├── config_sentence_transformers.json
├── model.safetensors
├── modules.json
├── sentence_bert_config.json
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer_config.json
└── vocab.txt
这包含一个包含以下内容的 modules.json
[
{
"idx": 0,
"name": "0",
"path": "",
"type": "sentence_transformers.models.Transformer"
},
{
"idx": 1,
"name": "1",
"path": "1_Pooling",
"type": "sentence_transformers.models.Pooling"
},
{
"idx": 2,
"name": "2",
"path": "2_Normalize",
"type": "sentence_transformers.models.Normalize"
}
]
以及一个包含以下内容的 config_sentence_transformers.json
{
"__version__": {
"sentence_transformers": "3.0.1",
"transformers": "4.43.4",
"pytorch": "2.5.0"
},
"prompts": {},
"default_prompt_name": null,
"similarity_fn_name": null
}
此外,1_Pooling
目录包含 Pooling
模块的配置文件,而 2_Normalize
目录为空,因为 Normalize
模块不需要任何配置。sentence_bert_config.json
文件包含 Transformer
模块的配置,并且该模块还在根目录中保存了许多与 tokenizer 和模型本身相关的文件。
加载 Sentence Transformer 模型
要从已保存的模型目录加载 Sentence Transformer 模型,将读取 modules.json
以确定构成模型的模块。每个模块都使用存储在相应模块目录中的配置进行初始化,之后使用加载的模块实例化 SentenceTransformer 类。
来自 Transformers 模型的 Sentence Transformer 模型
当您使用纯 Transformers 模型(例如,BERT、RoBERTa、DistilBERT、T5)初始化 Sentence Transformer 模型时,Sentence Transformers 默认创建 Transformer 模块和 Mean Pooling 模块。这提供了一种利用预训练语言模型进行句子嵌入的简单方法。
具体来说,以下两个代码片段是相同的
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("bert-base-uncased")
from sentence_transformers import models, SentenceTransformer
transformer = models.Transformer("bert-base-uncased")
pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[transformer, pooling])
高级:自定义模块
要创建自定义 Sentence Transformer 模型,您可以子类化 PyTorch 的 torch.nn.Module
类并实现以下方法
torch.nn.Module.forward()
方法,该方法接受一个features
字典,其中包含诸如input_ids
、attention_mask
、token_type_ids
、token_embeddings
和sentence_embedding
等键,具体取决于模块在模型 pipeline 中的位置。一个
save
方法,该方法接受save_dir
参数并将模块的配置保存到该目录。一个
load
静态方法,该方法接受load_dir
参数并根据该目录中的模块配置初始化模块。(如果是第一个模块)
get_max_seq_length
方法,该方法返回模块可以处理的最大序列长度。仅当模块处理输入文本时才需要。(如果是第一个模块)
tokenize
方法,该方法接受输入列表并返回一个字典,其中包含诸如input_ids
、attention_mask
、token_type_ids
、pixel_values
等键。此字典将传递给模块的forward
方法。(可选)
get_sentence_embedding_dimension
方法,该方法返回模块生成的句子嵌入的维度。仅当模块生成嵌入或更新嵌入的维度时才需要。(可选)
get_config_dict
方法,该方法返回包含模块配置的字典。此方法可用于将模块的配置保存到磁盘,并将模块配置保存在模型卡中。
例如,我们可以通过实现自定义模块来创建自定义池化方法。
# decay_pooling.py
import json
import os
import torch
import torch.nn as nn
class DecayMeanPooling(nn.Module):
def __init__(self, dimension: int, decay: float = 0.95) -> None:
super(DecayMeanPooling, self).__init__()
self.dimension = dimension
self.decay = decay
def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict [str, torch.Tensor]:
token_embeddings = features["token_embeddings"]
attention_mask = features["attention_mask"].unsqueeze(-1)
# Apply the attention mask to filter away padding tokens
token_embeddings = token_embeddings * attention_mask
# Calculate mean of token embeddings
sentence_embeddings = token_embeddings.sum(1) / attention_mask.sum(1)
# Apply exponential decay
importance_per_dim = self.decay ** torch.arange(sentence_embeddings. size(1), device=sentence_embeddings.device)
features["sentence_embedding"] = sentence_embeddings * importance_per_dim
return features
def get_config_dict(self) -> dict[str, float]:
return {"dimension": self.dimension, "decay": self.decay}
def get_sentence_embedding_dimension(self) -> int:
return self.dimension
def save(self, save_dir: str, **kwargs) -> None:
with open(os.path.join(save_dir, "config.json"), "w") as fOut:
json.dump(self.get_config_dict(), fOut, indent=4)
def load(load_dir: str, **kwargs) -> "DecayMeanPooling":
with open(os.path.join(load_dir, "config.json")) as fIn:
config = json.load(fIn)
return DecayMeanPooling(**config)
注意
建议将 **kwargs
添加到 __init__
、forward
、save
、load
和 tokenize
方法,以确保这些方法与 Sentence Transformers 库的未来更新兼容。
注意
如果您的模块是第一个模块,那么如果您希望在保存时将模块保存在根目录中,则可以在模块的类定义中设置 save_in_root = True
。请注意,与子目录不同,根目录在加载模块之前不会从 Hugging Face Hub 下载。因此,模块应首先检查所需文件是否在本地存在,否则使用 huggingface_hub.hf_hub_download()
下载它们。
现在可以将其用作 Sentence Transformer 模型中的模块
from sentence_transformers import models, SentenceTransformer
from decay_pooling import DecayMeanPooling
transformer = models.Transformer("bert-base-uncased", max_seq_length=256)
decay_mean_pooling = DecayMeanPooling(transformer.get_word_embedding_dimension(), decay=0.99)
normalize = models.Normalize()
model = SentenceTransformer(modules=[transformer, decay_mean_pooling, normalize])
print(model)
"""
SentenceTransformer(
(0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel
(1): DecayMeanPooling()
(2): Normalize()
)
"""
texts = [
"Hello, World!",
"The quick brown fox jumps over the lazy dog.",
"I am a sentence that is used for testing purposes.",
"This is a test sentence.",
"This is another test sentence.",
]
embeddings = model.encode(texts)
print(embeddings.shape)
# [5, 384]
您可以使用 SentenceTransformer.save_pretrained
保存此模型,从而生成以下 modules.json
[
{
"idx": 0,
"name": "0",
"path": "",
"type": "sentence_transformers.models.Transformer"
},
{
"idx": 1,
"name": "1",
"path": "1_DecayMeanPooling",
"type": "decay_pooling.DecayMeanPooling"
},
{
"idx": 2,
"name": "2",
"path": "2_Normalize",
"type": "sentence_transformers.models.Normalize"
}
]
为了确保可以导入 decay_pooling.DecayMeanPooling
,您应该将 decay_pooling.py
文件复制到保存模型的目录。如果您将模型推送到 Hugging Face Hub,那么您还应该将 decay_pooling.py
文件上传到模型的存储库。然后,每个人都可以通过调用 SentenceTransformer("your-username/your-model-id", trust_remote_code=True)
来使用您的自定义模块。
注意
使用 Hugging Face Hub 上存储的远程代码的自定义模块需要您的用户在加载模型时将 trust_remote_code
指定为 True
。这是一种安全措施,旨在防止远程代码执行攻击。
如果您的模型和自定义建模代码在 Hugging Face Hub 上,那么将您的自定义模块分离到单独的存储库可能更有意义。这样,您只需维护自定义模块的一个实现,并且可以在多个模型中重用它。您可以通过更新 modules.json
文件中的 type
以包含存储自定义模块的存储库的路径(如 {repository_id}--{dot_path_to_module}
)来做到这一点。例如,如果 decay_pooling.py
文件存储在名为 my-user/my-model-implementation
的存储库中,并且模块名为 DecayMeanPooling
,则 modules.json
文件可能如下所示
[
{
"idx": 0,
"name": "0",
"path": "",
"type": "sentence_transformers.models.Transformer"
},
{
"idx": 1,
"name": "1",
"path": "1_DecayMeanPooling",
"type": "my-user/my-model-implementation--decay_pooling.DecayMeanPooling"
},
{
"idx": 2,
"name": "2",
"path": "2_Normalize",
"type": "sentence_transformers.models.Normalize"
}
]
高级:自定义模块中的关键字参数传递
如果您希望用户能够通过 SentenceTransformer.encode
方法指定自定义关键字参数,那么您可以将它们的名称添加到 modules.json
文件中。例如,如果我的模块在用户指定 task_type
关键字参数时应表现不同,那么您的 modules.json
可能如下所示
[
{
"idx": 0,
"name": "0",
"path": "",
"type": "custom_transformer.CustomTransformer",
"kwargs": ["task_type"]
},
{
"idx": 1,
"name": "1",
"path": "1_Pooling",
"type": "sentence_transformers.models.Pooling"
},
{
"idx": 2,
"name": "2",
"path": "2_Normalize",
"type": "sentence_transformers.models.Normalize"
}
]
然后,您可以在自定义模块的 forward
方法中访问 task_type
关键字参数
from sentence_transformers.models import Transformer
class CustomTransformer(Transformer):
def forward(self, features: dict[str, torch.Tensor], task_type: Optional[str] = None) -> dict[str, torch.Tensor]:
if task_type == "default":
# Do something
else:
# Do something else
return features
这样,用户在调用 SentenceTransformer.encode
时可以指定 task_type
关键字参数
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("your-username/your-model-id", trust_remote_code=True)
texts = [...]
model.encode(texts, task_type="default")