使用 Prompt 进行训练
什么是 Prompt?
许多现代嵌入模型都使用“指令”或“prompt”进行训练,遵循 INSTRUCTOR 论文。这些 prompt 是字符串,前缀于每个要嵌入的文本,允许模型区分不同类型的文本。
例如,mixedbread-ai/mxbai-embed-large-v1 模型使用 Represent this sentence for searching relevant passages:
作为所有查询的 prompt 进行训练。此 prompt 存储在 模型配置 中,prompt 名称为 "query"
,因此用户可以在 model.encode
中指定 prompt_name
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
query_embedding = model.encode("What are Pandas?", prompt_name="query")
# or
# query_embedding = model.encode("What are Pandas?", prompt="Represent this sentence for searching relevant passages: ")
document_embeddings = model.encode([
"Pandas is a software library written for the Python programming language for data manipulation and analysis.",
"Pandas are a species of bear native to South Central China. They are also known as the giant panda or simply panda.",
"Koala bears are not actually bears, they are marsupials native to Australia.",
])
similarity = model.similarity(query_embedding, document_embeddings)
print(similarity)
# => tensor([[0.7594, 0.7560, 0.4674]])
有关使用 prompt 进行推理的更多信息,请参阅 Prompt 模板。
我们为什么要使用 Prompt 进行训练?
INSTRUCTOR 论文 表明,在每个文本前添加 prompt 或指令可以将模型性能平均提高约 6%,尤其是在分类、聚类和语义文本相似性方面 gains 显著。请注意,检索的性能提升明显较低,小型模型分别为 0.4% 和 2.7%,大型模型分别为 0.4% 和 2.7%。

最近,BGE 论文 显示了类似的发现,表明如果查询前缀为 Represent this sentence for searching relevant passages:
,则检索性能提高约 1.4%。作者得出结论,使用指令可能大大有助于任务特定微调的质量。

本质上,只要在训练和推理期间都使用指令或 prompt,就可以提高性能。
我们如何使用 Prompt 进行训练?
自 v3.3.0 Sentence Transformers 版本发布以来,可以使用 SentenceTransformerTrainingArguments
类中的 prompts
参数,使用 prompt 微调嵌入模型。此参数接受 4 种不同的格式
str
:用于所有数据集中的所有列的单个 prompt。例如args = SentenceTransformerTrainingArguments( ..., prompts="text: ", ..., )
Dict[str, str]
:将列名映射到 prompt 的字典,应用于所有数据集。例如args = SentenceTransformerTrainingArguments( ..., prompts={ "query": "query: ", "answer": "document: ", }, ..., )
Dict[str, str]
:将数据集名称映射到 prompt 的字典。仅当您的训练/评估/测试数据集是DatasetDict
或Dataset
字典时,才应使用此项。例如args = SentenceTransformerTrainingArguments( ..., prompts={ "stsb": "Represent this text for semantic similarity search: ", "nq": "Represent this text for retrieval: ", }, ..., )
Dict[str, Dict[str, str]]
:将数据集名称映射到将列名映射到 prompt 的字典的字典。仅当您的训练/评估/测试数据集是DatasetDict
或Dataset
字典时,才应使用此项。例如args = SentenceTransformerTrainingArguments( ..., prompts={ "stsb": { "sentence1": "sts: ", "sentence2": "sts: ", }, "nq": { "query": "query: ", "document": "document: ", }, }, ..., )
此外,一些研究论文 (INSTRUCTOR, NV-Embed) 从平均池化步骤中排除了 prompt,这样 prompt 仅在 Transformer 模块中使用。在 Sentence Transformers 中,可以使用 Pooling
模块中的 include_prompt
参数/属性或通过 SentenceTransformer.set_pooling_include_prompt()
方法配置此项。根据我的个人经验,在池化中包含 prompt 的模型往往表现更好。
训练脚本
请参阅以下脚本,了解如何在实践中使用 prompt 进行训练的示例
training_nq_prompts.py:此脚本使用
CachedMultipleNegativesRankingLoss
损失,在来自 natural-questions 数据集的 10 万个查询-答案对上微调 mpnet-base。模型在训练期间使用NanoBEIREvaluator
进行评估。
此脚本有两个变量影响 1) 是否使用 prompt 以及 2) prompt 是否包含在池化中。我已经在使用各种不同设置的情况下微调了 mpnet-base
和 bert-base-uncased
,在没有额外成本的情况下,NDCG@10
相对提高了 0.66% 和 0.90%。
在各种设置下运行脚本产生了这些检查点
注意
当使用 prompt 进行训练并且在池化中排除这些 prompt 时,mpnet-base
似乎有点不稳定:损失在某些时候会飙升,例如 bert-base-uncased
没有观察到这种效果。
对于这两个模型,使用 prompt 训练的模型在整个训练过程中始终优于基线模型

此外,使用 prompt 训练的模型在自动生成的模型卡中包含训练数据集详细信息中的这些 prompt:tomaarsen/mpnet-base-nq-prompts#natural-questions。
重要提示
如果您使用 prompt 进行训练,那么强烈建议将 prompt 存储在模型配置中,作为 prompt 名称到 prompt 字符串的映射。您可以通过在保存之前使用 prompts
字典初始化 SentenceTransformer
、在保存之前更新已加载模型的 model.prompts
和/或更新已保存模型的 config_sentence_transformers.json 文件来完成此操作。
在模型配置中添加 prompt 后,prompt 训练模型的最终用法变为
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("tomaarsen/mpnet-base-nq-prompts")
query_embedding = model.encode("What are Pandas?", prompt_name="query")
document_embeddings = model.encode([
"Pandas is a software library written for the Python programming language for data manipulation and analysis.",
"Pandas are a species of bear native to South Central China. They are also known as the giant panda or simply panda.",
"Koala bears are not actually bears, they are marsupials native to Australia.",
],
prompt_name="document",
)
similarity = model.similarity(query_embedding, document_embeddings)
print(similarity)
# => tensor([[0.4725, 0.7339, 0.4369]])
在各种设置下运行脚本产生了这些检查点
对于这三个模型,除了第一次评估之外,使用 prompt 训练的模型在整个训练过程中始终优于基线模型。在平均池化中排除 prompt 的模型始终明显比其他两个模型表现更差。

此外,使用 prompt 训练的模型在自动生成的模型卡中包含训练数据集详细信息中的这些 prompt:tomaarsen/bert-base-nq-prompts#natural-questions。
重要提示
如果您使用 prompt 进行训练,那么强烈建议将 prompt 存储在模型配置中,作为 prompt 名称到 prompt 字符串的映射。您可以通过在保存之前使用 prompts
字典初始化 SentenceTransformer
、在保存之前更新已加载模型的 model.prompts
和/或更新已保存模型的 config_sentence_transformers.json 文件来完成此操作。
在模型配置中添加 prompt 后,prompt 训练模型的最终用法变为
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("tomaarsen/bert-base-nq-prompts")
query_embedding = model.encode("What are Pandas?", prompt_name="query")
document_embeddings = model.encode([
"Pandas is a software library written for the Python programming language for data manipulation and analysis.",
"Pandas are a species of bear native to South Central China. They are also known as the giant panda or simply panda.",
"Koala bears are not actually bears, they are marsupials native to Australia.",
],
prompt_name="document",
)
similarity = model.similarity(query_embedding, document_embeddings)
print(similarity)
# => tensor([[0.3955, 0.8226, 0.5706]])