Embedding 长度超过模型最大上下文的文本

OpenAI 的 embedding 模型无法嵌入超过最大长度的文本。最大长度因模型而异,并且以 token 而不是字符串长度来衡量。如果您不熟悉 tokenization,请查看 如何使用 tiktoken 计算 token

本 notebook 展示了如何处理长度超过模型最大上下文的文本。我们将演示使用 text-embedding-3-small 的 embedding,但相同的思路可以应用于其他模型和任务。要了解有关 embedding 的更多信息,请查看 OpenAI Embeddings Guide

1. 模型上下文长度

首先,我们选择模型并定义一个函数以从 API 获取 embedding。

from openai import OpenAI
import os
import openai
from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_not_exception_type

client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))

EMBEDDING_MODEL = 'text-embedding-3-small'
EMBEDDING_CTX_LENGTH = 8191
EMBEDDING_ENCODING = 'cl100k_base'

# 让我们确保不重试无效请求,因为这是我们要演示的
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(openai.BadRequestError))
def get_embedding(text_or_tokens, model=EMBEDDING_MODEL):
    return client.embeddings.create(input=text_or_tokens, model=model).data[0].embedding

text-embedding-3-small 模型使用 cl100k_base 编码具有 8191 个 token 的上下文长度,我们可以看到超过该限制会导致错误。

long_text = 'AGI ' * 5000
try:
    get_embedding(long_text)
except openai.BadRequestError as e:
    print(e)
Error code: 400 - {'error': {'message': "This model's maximum context length is 8192 tokens, however you requested 10001 tokens (10001 in your prompt; 0 for the completion). Please reduce your prompt; or completion length.", 'type': 'invalid_request_error', 'param': None, 'code': None}}

显然,我们希望避免这些错误,尤其是在以编程方式处理大量 embedding 时。但是,我们仍然可能遇到比最大上下文长度更长的文本。下面我们描述并提供处理这些更长文本的主要方法的配方:(1)简单地将文本截断到允许的最大长度,以及(2)将文本分块并单独嵌入每个块。

1. 截断输入文本

最简单的解决方案是将输入文本截断到允许的最大长度。因为上下文长度是以 token 计算的,所以我们必须先对文本进行 tokenization,然后再进行截断。API 接受文本或 token 形式的输入,因此只要您小心使用适当的编码,就不需要将 token 转换回字符串形式。下面是一个这样的截断函数的示例。

import tiktoken

def truncate_text_tokens(text, encoding_name=EMBEDDING_ENCODING, max_tokens=EMBEDDING_CTX_LENGTH):
    """根据给定的编码将字符串截断为 `max_tokens`。"""
    encoding = tiktoken.get_encoding(encoding_name)
    return encoding.encode(text)[:max_tokens]

我们之前的示例现在可以正常工作,没有错误。

truncated = truncate_text_tokens(long_text)
len(get_embedding(truncated))
1536

2. 分块输入文本

虽然截断有效,但丢弃可能相关的文本是一个明显的缺点。另一种方法是将输入文本分成块,然后单独嵌入每个块。然后,我们可以单独使用块 embedding,或者以某种方式组合它们,例如平均(按每个块的大小加权)。

我们将从 Python 自己的 cookbook 中采用一个函数,该函数将序列分解成块。

from itertools import islice

def batched(iterable, n):
    """将数据分批成 n 个元组。最后一个批次可能较短。"""
    # batched('ABCDEFG', 3) --> ABC DEF G
    if n < 1:
        raise ValueError('n 必须至少为一')
    it = iter(iterable)
    while (batch := tuple(islice(it, n))):
        yield batch

现在我们定义一个函数,该函数将字符串编码为 token,然后将其分解成块。

def chunked_tokens(text, encoding_name, chunk_length):
    encoding = tiktoken.get_encoding(encoding_name)
    tokens = encoding.encode(text)
    chunks_iterator = batched(tokens, chunk_length)
    yield from chunks_iterator

最后,我们可以编写一个函数,该函数通过将输入 token 分块并单独嵌入每个块来安全地处理 embedding 请求,即使输入文本长度超过最大上下文长度。可以将 average 标志设置为 True 以返回块 embedding 的加权平均值,或设置为 False 以仅返回未修改的块 embedding 列表。

import numpy as np


def len_safe_get_embedding(text, model=EMBEDDING_MODEL, max_tokens=EMBEDDING_CTX_LENGTH, encoding_name=EMBEDDING_ENCODING, average=True):
    chunk_embeddings = []
    chunk_lens = []
    for chunk in chunked_tokens(text, encoding_name=encoding_name, chunk_length=max_tokens):
        chunk_embeddings.append(get_embedding(chunk, model=model))
        chunk_lens.append(len(chunk))

    if average:
        chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens)
        chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings)  # 将长度归一化为 1
        chunk_embeddings = chunk_embeddings.tolist()
    return chunk_embeddings

再次,我们现在可以处理长输入文本。

average_embedding_vector = len_safe_get_embedding(long_text, average=True)
chunks_embedding_vectors = len_safe_get_embedding(long_text, average=False)

print(f"设置 average=True 为我们的长文本提供了一个单一的 {len(average_embedding_vector)}-维 embedding 向量。")
print(f"设置 average=False 为我们提供了 {len(chunks_embedding_vectors)} 个 embedding 向量,每个向量对应一个块。")
设置 average=True 为我们的长文本提供了一个单一的 1536-维 embedding 向量。
设置 average=False 为我们提供了 2 个 embedding 向量,每个向量对应一个块。

在某些情况下,将块在段落边界或句子边界处拆分可能是有意义的,以帮助保留文本的含义。