投机性提示缓存

本食谱演示了“投机性提示缓存”——一种通过在用户仍在制定查询时预热缓存来减少首次令牌时间(TTFT)的模式。

无投机缓存:

  1. 用户键入问题(3 秒)
  2. 用户提交问题
  3. API 将上下文加载到缓存中并生成响应

有投机缓存:

  1. 用户开始键入(缓存预热立即开始)
  2. 用户继续键入(缓存预热在后台完成)
  3. 用户提交问题
  4. API 使用预热的缓存生成响应

设置

首先,让我们安装所需的包:

%pip install anthropic httpx --quiet

注意:您可能需要重新启动内核才能使用更新的包。

import copy
import datetime
import time
import asyncio
import httpx
from anthropic import AsyncAnthropic

# 配置常量
MODEL = "claude-3-5-sonnet-20241022"
SQLITE_SOURCES = {
    "btree.h": "https://sqlite.org/src/raw/18e5e7b2124c23426a283523e5f31a4bff029131b795bb82391f9d2f3136fc50?at=btree.h",
    "btree.c": "https://sqlite.org/src/raw/63ca6b647342e8cef643863cd0962a542f133e1069460725ba4461dcda92b03c?at=btree.c",
}
DEFAULT_CLIENT_ARGS = {
    "system": "您是一位帮助分析数据库内部的专家系统程序员。",
    "max_tokens": 4096,
    "temperature": 0,
    "extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"},
}

辅助函数

让我们设置下载大型上下文和准备消息的函数:

async def get_sqlite_sources() -> dict[str, str]:
    print("正在下载 SQLite 源文件...")

    source_files = {}
    start_time = time.time()

    async with httpx.AsyncClient(timeout=30.0) as client:
        tasks = []

        async def download_file(filename: str, url: str) -> tuple[str, str]:
            response = await client.get(url, follow_redirects=True)
            response.raise_for_status()
            print(f"成功下载 {filename}")
            return filename, response.text

        for filename, url in SQLITE_SOURCES.items():
            tasks.append(download_file(filename, url))

        results = await asyncio.gather(*tasks)
        source_files = dict(results)

    duration = time.time() - start_time
    print(f"在 {duration:.2f} 秒内下载了 {len(source_files)} 个文件")
    return source_files


async def create_initial_message():
    sources = await get_sqlite_sources()
    # 使用源代码作为上下文准备初始消息。
    # 包含时间戳以防止不同运行之间的缓存共享。
    initial_message = {
        "role": "user",
        "content": [
            {
                "type": "text",
                "text": f"""
当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}

要分析的来源:

btree.h:
```c
{sources["btree.h"]}

btree.c:

{sources["btree.c"]}
```""",
                "cache_control": {"type": "ephemeral"},
            }
        ],
    }
    return initial_message


async def sample_one_token(client: AsyncAnthropic, messages: list):
    """发送单个令牌请求以预热缓存"""
    args = copy.deepcopy(DEFAULT_CLIENT_ARGS)
    args["max_tokens"] = 1
    await client.messages.create(
        messages=messages,
        model=MODEL,
        **args,
    )


def print_query_statistics(response, query_type: str) -> None:
    print(f"\n{query_type} 查询统计:")
    print(f"\t输入令牌数:{response.usage.input_tokens}")
    print(f"\t输出令牌数:{response.usage.output_tokens}")
    print(
        f"\t缓存读取输入令牌数:{getattr(response.usage, 'cache_read_input_tokens', '---')}"
    )
    print(
        f"\t缓存创建输入令牌数:{getattr(response.usage, 'cache_creation_input_tokens', '---')}"
    )

示例 1:标准提示缓存(无投机缓存)

首先,让我们看看标准的提示缓存是如何工作的。用户键入他们的问题,然后我们将整个上下文+问题发送给 API:

async def standard_prompt_caching_demo():
    client = AsyncAnthropic()

    # 准备大型上下文
    initial_message = await create_initial_message()

    # 模拟用户键入时间(在实际应用中,这将是实际的用户输入)
    print("用户正在键入他们的问题...")
    await asyncio.sleep(3)  # 模拟 3 秒键入时间
    user_question = "BtShared 结构的作用是什么?"
    print(f"用户提交:{user_question}")

    # 现在发送完整请求(上下文 + 问题)
    full_message = copy.deepcopy(initial_message)
    full_message["content"].append(
        {"type": "text", "text": f"回答用户的问题:{user_question}"}
    )

    print("\n正在向 API 发送请求...")
    start_time = time.time()

    # 测量首次令牌时间
    first_token_time = None
    async with client.messages.stream(
        messages=[full_message],
        model=MODEL,
        **DEFAULT_CLIENT_ARGS,
    ) as stream:
        async for text in stream.text_stream:
            if first_token_time is None and text.strip():
                first_token_time = time.time() - start_time
                print(f"\n🕐 首次令牌时间:{first_token_time:.2f} 秒")
                break

        # 获取完整响应
        response = await stream.get_final_message()

    total_time = time.time() - start_time
    print(f"总响应时间:{total_time:.2f} 秒")
    print_query_statistics(response, "标准缓存")

    return first_token_time, total_time
# 运行标准演示
standard_ttft, standard_total = await standard_prompt_caching_demo()

正在下载 SQLite 源文件... 成功下载 btree.h 成功下载 btree.c 在 0.30 秒内下载了 2 个文件 用户正在键入他们的问题... 用户提交:BtShared 结构的作用是什么?

正在向 API 发送请求...

🕐 首次令牌时间:20.87 秒 总响应时间:28.32 秒

标准缓存 查询统计: 输入令牌数:22 输出令牌数:362 缓存读取输入令牌数:0 缓存创建输入令牌数:151629

示例 2:投机性提示缓存

现在让我们看看投机性提示缓存如何通过在用户键入时预热缓存来改善 TTFT:

async def speculative_prompt_caching_demo():
    client = AsyncAnthropic()

    # 用户希望与之交互的大量上下文,
    # 在这种情况下是 sqlite b-tree 实现(约 150k 令牌)。
    initial_message = await create_initial_message()

    # 在用户键入时开始投机性缓存
    print("用户正在键入他们的问题...")
    print("🔥 在后台开始缓存预热...")

    # 在用户键入他们的问题时,我们从用户将要交互的上下文中采样单个令牌
    # 并显式开启提示缓存以预热缓存。
    cache_task = asyncio.create_task(sample_one_token(client, [initial_message]))

    # 模拟用户键入时间
    await asyncio.sleep(3)  # 模拟 3 秒键入时间
    user_question = "What is the purpose of the BtShared structure?"
    print(f"用户提交:{user_question}")

    # 确保缓存预热已完成
    await cache_task
    print("✅ 缓存预热已完成!")

    # 准备缓存查询的消息。我们确保
    # 重用与缓存相同的初始消息,以确保我们有缓存命中。
    cached_message = copy.deepcopy(initial_message)
    cached_message["content"].append(
        {"type": "text", "text": f"Answer the user's question: {user_question}"}
    )

    print("\n正在向 API 发送请求(使用预热的缓存)...")
    start_time = time.time()

    # 测量首次令牌时间
    first_token_time = None
    async with client.messages.stream(
        messages=[cached_message],
        model=MODEL,
        **DEFAULT_CLIENT_ARGS,
    ) as stream:
        async for text in stream.text_stream:
            if first_token_time is None and text.strip():
                first_token_time = time.time() - start_time
                print(f"\n🚀 Time to first token: {first_token_time:.2f} seconds")
                break

        # 获取完整响应
        response = await stream.get_final_message()

    total_time = time.time() - start_time
    print(f"总响应时间:{total_time:.2f} 秒")
    print_query_statistics(response, "投机性缓存")

    return first_token_time, total_time
# 运行投机性缓存演示  
speculative_ttft, speculative_total = await speculative_prompt_caching_demo()

正在下载 SQLite 源文件... 成功下载 btree.h 成功下载 btree.c 在 0.36 秒内下载了 2 个文件 用户正在键入他们的问题... 🔥 在后台开始缓存预热... 用户提交:What is the purpose of the BtShared structure? ✅ 缓存预热已完成!

正在向 API 发送请求(使用预热的缓存)...

🚀 Time to first token: 1.94 seconds 总响应时间:8.40 秒

投机性缓存 查询统计: 输入令牌数:22 输出令牌数:330 缓存读取输入令牌数:151629 缓存创建输入令牌数:0

性能比较

让我们比较结果,看看投机性缓存的好处:

print("=" * 60)
print("性能比较")
print("=" * 60)

print(f"\n标准提示缓存:")
print(f"  首次令牌时间:{standard_ttft:.2f} 秒")
print(f"  总响应时间:{standard_total:.2f} 秒")

print(f"\n投机性提示缓存:")
print(f"  首次令牌时间:{speculative_ttft:.2f} 秒")
print(f"  总响应时间:{speculative_total:.2f} 秒")

ttft_improvement = (standard_ttft - speculative_ttft) / standard_ttft * 100
total_improvement = (standard_total - speculative_total) / standard_total * 100

print(f"\n🎯 改进:")
print(f"  TTFT 改进:{ttft_improvement:.1f}% ({standard_ttft - speculative_ttft:.2f} 秒更快)")
print(f"  总时间改进:{total_improvement:.1f}% ({standard_total - speculative_total:.2f} 秒更快)")

============================================================ 性能比较 ============================================================

标准提示缓存: 首次令牌时间:20.87 秒 总响应时间:28.32 秒

投机性提示缓存: 首次令牌时间:1.94 秒 总响应时间:8.40 秒

🎯 改进: TTFT 改进:90.7% (18.93 秒更快) 总时间改进:70.4% (19.92 秒更快)

主要收获

  1. 投机性缓存通过在用户键入时预热缓存,极大地减少了 TTFT
  2. 该模式最有效于大型上下文(>1000 令牌)的场景,这些上下文跨查询重用
  3. 实现简单——只需在用户键入时发送一个令牌请求
  4. 缓存预热与用户输入并行发生,有效地“隐藏”了缓存创建时间

最佳实践

  • 尽早开始缓存预热(例如,当用户聚焦输入字段时)
  • 使用与实际请求完全相同的上下文进行预热,以确保缓存命中
  • 监控 cache_read_input_tokens 以验证缓存命中
  • 添加时间戳以防止跨会话的不必要缓存共享