增强检索的RAG与上下文检索

注意:有关上下文检索的更多背景信息,包括在各种数据集上的其他性能评估,我们建议阅读我们配套的博客文章

检索增强生成(RAG)使Claude能够在提供响应时利用您的内部知识库、代码库或任何其他文档语料库。企业越来越多地构建RAG应用程序来改进客户支持、问答内部公司文档、财务与法律分析、代码生成等工作流程。

单独的指南中,我们介绍了设置基本检索系统,演示了如何评估其性能,然后概述了几种提高性能的技术。在本指南中,我们将介绍一种提高检索性能的技术:上下文嵌入。

在传统的RAG中,文档通常被分割成更小的块以实现高效检索。虽然这种方法对许多应用程序都有效,但当单个块缺乏足够上下文时,可能会导致问题。上下文嵌入通过在嵌入每个块之前添加相关上下文来解决此问题。这种方法提高了每个嵌入块的质量,从而实现了更准确的检索,进而提高了整体性能。在我们测试的所有数据源的平均值来看,上下文嵌入将前20个块的检索失败率降低了35%。

同样的块特定上下文也可以与BM25搜索一起使用,以进一步提高检索性能。我们在“上下文BM25”部分介绍了这项技术。

在本指南中,我们将演示如何使用9个代码库的数据集作为我们的知识库来构建和优化上下文检索系统。我们将介绍:

1) 设置基本检索管道以建立性能基线。

2) 上下文嵌入:它是什么,为什么有效,以及提示缓存如何使其在生产用例中可行。

3) 实现上下文嵌入并演示性能改进。

4) 上下文BM25:通过上下文BM25混合搜索提高性能。

5) 通过重新排序提高性能,

评估指标与数据集:

我们使用预分块的9个代码库数据集——所有这些代码库都已根据基本的字符分割机制进行了分块。我们的评估数据集包含248个查询——每个查询都包含一个“黄金块”。我们将使用一种称为Pass@k的指标来评估性能。Pass@k检查在为每个查询检索到的前k个文档中是否存在“黄金文档”。在这种情况下,上下文嵌入帮助我们将Pass@10性能从约87%提高到约95%。

您可以在data/codebase_chunks.json中找到代码文件及其块,在data/evaluation_set.jsonl中找到评估数据集。

附加说明:

提示缓存有助于管理此检索方法的成本。此功能目前在Anthropic的1P API上可用,并即将登陆AWS Bedrock和GCP Vertex中的3P合作伙伴环境。我们知道许多客户在构建RAG解决方案时会利用AWS知识库和GCP Vertex AI API,并且此方法可以在任一平台上使用,只需稍作定制。考虑联系Anthropic或您的AWS/GCP客户团队以获取指导!

为了方便在Bedrock上使用此方法,AWS团队为我们提供了代码,您可以使用它来实现一个Lambda函数,该函数为每个文档添加上下文。如果您部署此Lambda函数,可以在配置Bedrock知识库时选择它作为自定义分块选项。您可以在contextual-rag-lambda-function中找到此代码。主要的Lambda函数代码在lambda_function.py中。

目录

1) 设置

2) 基本RAG

3) 上下文嵌入

4) 上下文BM25

5) 重新排序

设置

我们需要一些库,包括:

1) anthropic - 用于与Claude交互

2) voyageai - 用于生成高质量嵌入

3) cohere - 用于重新排序

4) elasticsearch 用于高性能BM25搜索

3) pandasnumpymatplotlibscikit-learn用于数据操作和可视化

您还需要来自AnthropicVoyage AICohere的API密钥

!pip install anthropic
!pip install voyageai
!pip install cohere
!pip install elasticsearch
!pip install pandas
!pip install numpy
import os

os.environ['VOYAGE_API_KEY'] = "YOUR KEY HERE"
os.environ['ANTHROPIC_API_KEY'] = "YOUR KEY HERE"
os.environ['COHERE_API_KEY'] = "YOUR KEY HERE"
import anthropic

client = anthropic.Anthropic(
    # This is the default and can be omitted
    api_key=os.getenv("ANTHROPIC_API_KEY"),
)

初始化向量数据库类

在此示例中,我们使用的是内存中的向量数据库,但对于生产应用程序,您可能希望使用托管解决方案。

import os
import pickle
import json
import numpy as np
import voyageai
from typing import List, Dict, Any
from tqdm import tqdm

class VectorDB:
    def __init__(self, name: str, api_key = None):
        if api_key is None:
            api_key = os.getenv("VOYAGE_API_KEY")
        self.client = voyageai.Client(api_key=api_key)
        self.name = name
        self.embeddings = []
        self.metadata = []
        self.query_cache = {}
        self.db_path = f"./data/{name}/vector_db.pkl"

    def load_data(self, dataset: List[Dict[str, Any]]):
        if self.embeddings and self.metadata:
            print("Vector database is already loaded. Skipping data loading.")
            return
        if os.path.exists(self.db_path):
            print("Loading vector database from disk.")
            self.load_db()
            return

        texts_to_embed = []
        metadata = []
        total_chunks = sum(len(doc['chunks']) for doc in dataset)

        with tqdm(total=total_chunks, desc="Processing chunks") as pbar:
            for doc in dataset:
                for chunk in doc['chunks']:
                    texts_to_embed.append(chunk['content'])
                    metadata.append({
                        'doc_id': doc['doc_id'],
                        'original_uuid': doc['original_uuid'],
                        'chunk_id': chunk['chunk_id'],
                        'original_index': chunk['original_index'],
                        'content': chunk['content']
                    })
                    pbar.update(1)

        self._embed_and_store(texts_to_embed, metadata)
        self.save_db()

        print(f"Vector database loaded and saved. Total chunks processed: {len(texts_to_embed)}")

    def _embed_and_store(self, texts: List[str], data: List[Dict[str, Any]]):
        batch_size = 128
        with tqdm(total=len(texts), desc="Embedding chunks") as pbar:
            result = []
            for i in range(0, len(texts), batch_size):
                batch = texts[i : i + batch_size]
                batch_result = self.client.embed(batch, model="voyage-2").embeddings
                result.extend(batch_result)
                pbar.update(len(batch))

        self.embeddings = result
        self.metadata = data

    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
        if query in self.query_cache:
            query_embedding = self.query_cache[query]
        else:
            query_embedding = self.client.embed([query], model="voyage-2").embeddings[0]
            self.query_cache[query] = query_embedding

        if not self.embeddings:
            raise ValueError("No data loaded in the vector database.")

        similarities = np.dot(self.embeddings, query_embedding)
        top_indices = np.argsort(similarities)[::-1][:k]

        top_results = []
        for idx in top_indices:
            result = {
                "metadata": self.metadata[idx],
                "similarity": float(similarities[idx]),
            }
            top_results.append(result)

        return top_results

    def save_db(self):
        data = {
            "embeddings": self.embeddings,
            "metadata": self.metadata,
            "query_cache": json.dumps(self.query_cache),
        }
        os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
        with open(self.db_path, "wb") as file:
            pickle.dump(data, file)

    def load_db(self):
        if not os.path.exists(self.db_path):
            raise ValueError("Vector database file not found. Use load_data to create a new database.")
        with open(self.db_path, "rb") as file:
            data = pickle.load(file)
        self.embeddings = data["embeddings"]
        self.metadata = data["metadata"]
        self.query_cache = json.loads(data["query_cache"])

    def validate_embedded_chunks(self):
        unique_contents = set()
        for meta in self.metadata:
            unique_contents.add(meta['content'])

        print(f"Validation results:")
        print(f"Total embedded chunks: {len(self.metadata)}")
        print(f"Unique embedded contents: {len(unique_contents)}")

        if len(self.metadata) != len(unique_contents):
            print("Warning: There may be duplicate chunks in the embedded data.")
        else:
            print("All embedded chunks are unique.")
# Load your transformed dataset
with open('data/codebase_chunks.json', 'r') as f:
    transformed_dataset = json.load(f)

# Initialize the VectorDB
base_db = VectorDB("base_db")

# Load and process the data
base_db.load_data(transformed_dataset)

基本RAG

为了开始,我们将使用一种最基本的方法来设置一个基本的RAG管道。这有时被称为行业中的“朴素RAG”。一个基本的RAG管道包括以下3个步骤:

1) 按标题分块文档 - 只包含每个子标题的内容

2) 嵌入每个文档

3) 使用余弦相似度检索文档以回答查询

import json
from typing import List, Dict, Any, Callable, Union
from tqdm import tqdm

def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
    """加载JSONL文件并返回字典列表。"""
    with open(file_path, 'r') as file:
        return [json.loads(line) for line in file]

def evaluate_retrieval(queries: List[Dict[str, Any]], retrieval_function: Callable, db, k: int = 20) -> Dict[str, float]:
    total_score = 0
    total_queries = len(queries)

    for query_item in tqdm(queries, desc="Evaluating retrieval"):
        query = query_item['query']
        golden_chunk_uuids = query_item['golden_chunk_uuids']

        # 查找所有黄金块内容
        golden_contents = []
        for doc_uuid, chunk_index in golden_chunk_uuids:
            golden_doc = next((doc for doc in query_item['golden_documents'] if doc['uuid'] == doc_uuid), None)
            if not golden_doc:
                print(f"警告:找不到UUID {doc_uuid} 的黄金文档")
                continue

            golden_chunk = next((chunk for chunk in golden_doc['chunks'] if chunk['index'] == chunk_index), None)
            if not golden_chunk:
                print(f"警告:在文档 {doc_uuid} 中找不到索引 {chunk_index} 的黄金块")
                continue

            golden_contents.append(golden_chunk['content'].strip())

        if not golden_contents:
            print(f"警告:未找到查询的黄金内容:{query}")
            continue

        retrieved_docs = retrieval_function(query, db, k=k)

        # 计算检索到的前k个文档中有多少个黄金块
        chunks_found = 0
        for golden_content in golden_contents:
            for doc in retrieved_docs[:k]:
                retrieved_content = doc['metadata'].get('original_content', doc['metadata'].get('content', '')).strip()
                if retrieved_content == golden_content:
                    chunks_found += 1
                    break

        query_score = chunks_found / len(golden_contents)
        total_score += query_score

    average_score = total_score / total_queries
    pass_at_n = average_score * 100
    return {
        "pass_at_n": pass_at_n,
        "average_score": average_score,
        "total_queries": total_queries
    }

def retrieve_base(query: str, db, k: int = 20) -> List[Dict[str, Any]]:
    """
    使用VectorDB或ContextualVectorDB检索相关文档。

    :param query: 查询字符串
    :param db: VectorDB或ContextualVectorDB实例
    :param k: 要检索的前k个结果的数量
    :return: 检索到的文档列表
    """
    return db.search(query, k=k)

def evaluate_db(db, original_jsonl_path: str, k):
    # 加载原始JSONL数据用于查询和地面真实
    original_data = load_jsonl(original_jsonl_path)

    # 评估检索
    results = evaluate_retrieval(original_data, retrieve_base, db, k)
    print(f"Pass@{k}: {results['pass_at_n']:.2f}%")
    print(f"Total Score: {results['average_score']}")
    print(f"Total queries: {results['total_queries']}")
results5 = evaluate_db(base_db, 'data/evaluation_set.jsonl', 5)
results10 = evaluate_db(base_db, 'data/evaluation_set.jsonl', 10)
results20 = evaluate_db(base_db, 'data/evaluation_set.jsonl', 20)
Evaluating retrieval: 100%|██████████| 248/248 [00:06<00:00, 40.70it/s]


Pass@5: 80.92%
Total Score: 0.8091877880184332
Total queries: 248


Evaluating retrieval: 100%|██████████| 248/248 [00:06<00:00, 39.50it/s]


Pass@10: 87.15%
Total Score: 0.8714957757296468
Total queries: 248


Evaluating retrieval: 100%|██████████| 248/248 [00:06<00:00, 39.43it/s]

Pass@20: 90.06%
Total Score: 0.9006336405529954
Total queries: 248

上下文嵌入

通过基本RAG,每个嵌入的块都包含可能有用信息,但这些块缺乏上下文。通过上下文嵌入,我们通过在嵌入每个文本块之前添加更多上下文来创建嵌入本身的变体。具体来说,我们使用Claude为每个块创建一个简洁的上下文,解释该块,并利用整个文档的上下文。在我们的代码库数据集的情况下,我们可以将块和每个块所在的整个文件提供给LLM,然后生成上下文。然后,我们将此“上下文”和原始文本块组合成一个单一的文本块,然后再创建每个嵌入。

其他考虑因素:成本和延迟

我们为“定位”每个文档所做的额外工作仅在摄取时发生:这是您存储每个文档时(以及将来如果您的知识库更新时)需要支付一次的成本。有许多方法,如HyDE(假设文档嵌入),涉及在执行搜索之前执行步骤来改进查询的表示。这些技术已被证明具有中等效果,但它们会显著增加运行时的延迟。

提示缓存也使其更具成本效益。创建上下文嵌入需要我们将相同的文档传递给模型,以获取我们想要生成额外上下文的每个块。通过提示缓存,我们可以将整个文档写入缓存一次,然后因为我们按顺序执行所有摄取工作,所以可以在生成该文档内每个块的上下文时从缓存中读取文档(您写入缓存的信息有5分钟的生存时间)。这意味着我们第一次将文档传递给模型时,我们会支付稍多一点的费用将其写入缓存,但对于包含该文档的每个后续API调用,我们将获得90%的缓存读取输入令牌折扣。假设800个令牌块,8k个令牌文档,50个令牌上下文指令,以及每个块100个令牌的上下文,则生成上下文化块的成本为每百万个文档令牌1.02美元。

当您在下面将数据加载到ContextualVectorDB时,您将在日志中看到这种影响有多大。

警告:一些较小的嵌入模型具有固定的输入令牌限制。上下文化块会使其变长,因此如果您注意到上下文化嵌入的性能明显变差,则上下文化块很可能被截断了

DOCUMENT_CONTEXT_PROMPT = """
<document>
{doc_content}
</document>
"""

CHUNK_CONTEXT_PROMPT = """
这是我们想要在整个文档中定位的块
<chunk>
{chunk_content}
</chunk>

请提供一个简短的上下文来将此块定位在整个文档中,以改进块的搜索检索。
只回答简洁的上下文,不要回答其他任何内容。
"""

def situate_context(doc: str, chunk: str) -> str:
    response = client.beta.prompt_caching.messages.create(
        model="claude-3-haiku-20240307",
        max_tokens=1024,
        temperature=0.0,
        messages=[
            {
                "role": "user", 
                "content": [
                    {
                        "type": "text",
                        "text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
                        "cache_control": {"type": "ephemeral"} #我们将为整个文档利用提示缓存
                    },
                    {
                        "type": "text",
                        "text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
                    }
                ]
            }
        ],
        extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
    )
    return response

# 示例用法
doc_content = jsonl_data[0]['golden_documents'][0]['content']
chunk_content = jsonl_data[0]['golden_chunks'][0]['content']

response = situate_context(doc_content, chunk_content)
print(f"Situated context: {response.content[0].text}")

# 打印缓存性能指标
print(f"Input tokens: {response.usage.input_tokens}")
print(f"Output tokens: {response.usage.output_tokens}")
print(f"Cache creation input tokens: {response.usage.cache_creation_input_tokens}")
print(f"Cache read input tokens: {response.usage.cache_read_input_tokens}")
Situated context: This chunk describes the `DiffExecutor` struct, which is an executor for differential fuzzing. It wraps two executors that are run sequentially with the same input, and also runs the secondary executor in the `run_target` method.
Input tokens: 366
Output tokens: 55
Cache creation input tokens: 3046
Cache read input tokens: 0
import os
import pickle
import json
import numpy as np
import voyageai
from typing import List, Dict, Any
from tqdm import tqdm
import anthropic
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

class ContextualVectorDB:
    def __init__(self, name: str, voyage_api_key=None, anthropic_api_key=None):
        if voyage_api_key is None:
            voyage_api_key = os.getenv("VOYAGE_API_KEY")
        if anthropic_api_key is None:
            anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")

        self.voyage_client = voyageai.Client(api_key=voyage_api_key)
        self.anthropic_client = anthropic.Anthropic(api_key=anthropic_api_key)
        self.name = name
        self.embeddings = []
        self.metadata = []
        self.query_cache = {}
        self.db_path = f"./data/{name}/contextual_vector_db.pkl"

        self.token_counts = {
            'input': 0,
            'output': 0,
            'cache_read': 0,
            'cache_creation': 0
        }
        self.token_lock = threading.Lock()

    def situate_context(self, doc: str, chunk: str) -> tuple[str, Any]:
        DOCUMENT_CONTEXT_PROMPT = """
        <document>
        {doc_content}
        </document>
        """

        CHUNK_CONTEXT_PROMPT = """
        这是我们想要在整个文档中定位的块
        <chunk>
        {chunk_content}
        </chunk>

        请提供一个简短的上下文来将此块定位在整个文档中,以改进块的搜索检索。
        只回答简洁的上下文,不要回答其他任何内容。
        """

        response = self.anthropic_client.beta.prompt_caching.messages.create(
            model="claude-3-haiku-20240307",
            max_tokens=1000,
            temperature=0.0,
            messages=[
                {
                    "role": "user", 
                    "content": [
                        {
                            "type": "text",
                            "text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
                            "cache_control": {"type": "ephemeral"} #我们将在整个文档中使用提示缓存
                        },
                        {
                            "type": "text",
                            "text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
                        },
                    ]
                },
            ],
            extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
        )
        return response.content[0].text, response.usage

    def load_data(self, dataset: List[Dict[str, Any]], parallel_threads: int = 1):
        if self.embeddings and self.metadata:
            print("Vector database is already loaded. Skipping data loading.")
            return
        if os.path.exists(self.db_path):
            print("Loading vector database from disk.")
            self.load_db()
            return

        texts_to_embed = []
        metadata = []
        total_chunks = sum(len(doc['chunks']) for doc in dataset)

        def process_chunk(doc, chunk):
            #对于每个块,生成上下文
            contextualized_text, usage = self.situate_context(doc['content'], chunk['content'])
            with self.token_lock:
                self.token_counts['input'] += usage.input_tokens
                self.token_counts['output'] += usage.output_tokens
                self.token_counts['cache_read'] += usage.cache_read_input_tokens
                self.token_counts['cache_creation'] += usage.cache_creation_input_tokens

            return {
                #将上下文附加到原始文本块
                'text_to_embed': f"{chunk['content']}\n\n{contextualized_text}",
                'metadata': {
                    'doc_id': doc['doc_id'],
                    'original_uuid': doc['original_uuid'],
                    'chunk_id': chunk['chunk_id'],
                    'original_index': chunk['original_index'],
                    'original_content': chunk['content'],
                    'contextualized_content': contextualized_text
                }
            }

        print(f"Processing {total_chunks} chunks with {parallel_threads} threads")
        with ThreadPoolExecutor(max_workers=parallel_threads) as executor:
            futures = []
            for doc in dataset:
                for chunk in doc['chunks']:
                    futures.append(executor.submit(process_chunk, doc, chunk))

            for future in tqdm(as_completed(futures), total=total_chunks, desc="Processing chunks"):
                result = future.result()
                texts_to_embed.append(result['text_to_embed'])
                metadata.append(result['metadata'])

        self._embed_and_store(texts_to_embed, metadata)
        self.save_db()

        #记录令牌使用情况
        print(f"Contextual Vector database loaded and saved. Total chunks processed: {len(texts_to_embed)}")
        print(f"Total input tokens without caching: {self.token_counts['input']}")
        print(f"Total output tokens: {self.token_counts['output']}")
        print(f"Total input tokens written to cache: {self.token_counts['cache_creation']}")
        print(f"Total input tokens read from cache: {self.token_counts['cache_read']}")

        total_tokens = self.token_counts['input'] + self.token_counts['cache_read'] + self.token_counts['cache_creation']
        savings_percentage = (self.token_counts['cache_read'] / total_tokens) * 100 if total_tokens > 0 else 0
        print(f"Total input token savings from prompt caching: {savings_percentage:.2f}% of all input tokens used were read from cache.")
        print("Tokens read from cache come at a 90 percent discount!")

    #我们在这里使用voyage AI进行嵌入。在此处阅读更多信息:https://docs.voyageai.com/docs/embeddings
    def _embed_and_store(self, texts: List[str], data: List[Dict[str, Any]]):
        batch_size = 128
        result = [
            self.voyage_client.embed(
                texts[i : i + batch_size],
                model="voyage-2"
            ).embeddings
            for i in range(0, len(texts), batch_size)
        ]
        self.embeddings = [embedding for batch in result for embedding in batch]
        self.metadata = data

    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
        if query in self.query_cache:
            query_embedding = self.query_cache[query]
        else:
            query_embedding = self.voyage_client.embed([query], model="voyage-2").embeddings[0]
            self.query_cache[query] = query_embedding

        if not self.embeddings:
            raise ValueError("No data loaded in the vector database.")

        similarities = np.dot(self.embeddings, query_embedding)
        top_indices = np.argsort(similarities)[::-1][:k]

        top_results = []
        for idx in top_indices:
            result = {
                "metadata": self.metadata[idx],
                "similarity": float(similarities[idx]),
            }
            top_results.append(result)
        return top_results

    def save_db(self):
        data = {
            "embeddings": self.embeddings,
            "metadata": self.metadata,
            "query_cache": json.dumps(self.query_cache),
        }
        os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
        with open(self.db_path, "wb") as file:
            pickle.dump(data, file)

    def load_db(self):
        if not os.path.exists(self.db_path):
            raise ValueError("Vector database file not found. Use load_data to create a new database.")
        with open(self.db_path, "rb") as file:
            data = pickle.load(file)
        self.embeddings = data["embeddings"]
        self.metadata = data["metadata"]
        self.query_cache = json.loads(data["query_cache"])
# Load the transformed dataset
with open('data/codebase_chunks.json', 'r') as f:
    transformed_dataset = json.load(f)

# Initialize the ContextualVectorDB
contextual_db = ContextualVectorDB("my_contextual_db")

# Load and process the data
#注意:考虑增加并行线程数以加快运行速度,或减少并行线程数以避免达到API速率限制
contextual_db.load_data(transformed_dataset, parallel_threads=5)
Processing 737 chunks with 5 threads


Processing chunks: 100%|██████████| 737/737 [02:37<00:00,  4.69it/s]


Contextual Vector database loaded and saved. Total chunks processed: 737
Total input tokens without caching: 500383
Total output tokens: 40318
Total input tokens written to cache: 341422
Total input tokens read from cache: 2825073
Total input token savings from prompt caching: 77.04% of all input tokens used were read from cache.
Tokens read from cache come at a 90 percent discount!
r5 = evaluate_db(contextual_db, 'data/evaluation_set.jsonl', 5)
r10 = evaluate_db(contextual_db, 'data/evaluation_set.jsonl', 10)
r20 = evaluate_db(contextual_db, 'data/evaluation_set.jsonl', 20)
Evaluating retrieval: 100%|██████████| 248/248 [00:06<00:00, 39.53it/s]


Pass@5: 86.37%
Total Score: 0.8637192780337941
Total queries: 248


Evaluating retrieval: 100%|██████████| 248/248 [00:06<00:00, 40.05it/s]


Pass@10: 92.81%
Total Score: 0.9280913978494625
Total queries: 248


Evaluating retrieval: 100%|██████████| 248/248 [00:06<00:00, 39.64it/s]

Pass@20: 93.78%
Total Score: 0.9378360215053763
Total queries: 248

添加重新排序步骤

如果您想进一步提高性能,我们建议添加重新排序步骤。在使用重新排序器时,您可以从向量存储中检索更多文档,然后使用重新排序器从这些文档中选择一个子集。一种常见的技术是将重新排序作为实现高精度混合搜索的方法。您可以在初始检索步骤中使用语义搜索和基于关键字的搜索的组合(如本指南前面所述),然后使用重新排序步骤从语义搜索和关键字搜索系统返回的文档组合列表中仅选择k个最相关的文档。

下面,我们将仅演示重新排序步骤(暂时跳过混合搜索技术)。您会看到,我们检索的文档数量是我们希望检索的最终k个文档数量的10倍,然后使用Cohere的重新排序模型从该列表中选择10个最相关的结果。添加重新排序步骤可带来适度的额外性能提升。在我们的例子中,Pass@10从92.81%提高到94.79%。

import cohere
from typing import List, Dict, Any, Callable
import json
from tqdm import tqdm

def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
    with open(file_path, 'r') as file:
        return [json.loads(line) for line in file]

def chunk_to_content(chunk: Dict[str, Any]) -> str:
    original_content = chunk['metadata']['original_content']
    contextualized_content = chunk['metadata']['contextualized_content']
    return f"{original_content}\n\nContext: {contextualized_content}" 

def retrieve_rerank(query: str, db, k: int) -> List[Dict[str, Any]]:
    co = cohere.Client( os.getenv("COHERE_API_KEY"))

    #检索比我们通常检索的更多的结果
    semantic_results = db.search(query, k=k*10)

    #提取用于重新排序的文档,使用上下文化的内容
    documents = [chunk_to_content(res) for res in semantic_results]

    response = co.rerank(
        model="rerank-english-v3.0",
        query=query,
        documents=documents,
        top_n=k
    )
    time.sleep(0.1)

    final_results = []
    for r in response.results:
        original_result = semantic_results[r.index]
        final_results.append({
            "chunk": original_result['metadata'],
            "score": r.relevance_score
        })

    return final_results

def evaluate_retrieval_rerank(queries: List[Dict[str, Any]], retrieval_function: Callable, db, k: int = 20) -> Dict[str, float]:
    total_score = 0
    total_queries = len(queries)

    for query_item in tqdm(queries, desc="Evaluating retrieval"):
        query = query_item['query']
        golden_chunk_uuids = query_item['golden_chunk_uuids']

        golden_contents = []
        for doc_uuid, chunk_index in golden_chunk_uuids:
            golden_doc = next((doc for doc in query_item['golden_documents'] if doc['uuid'] == doc_uuid), None)
            if golden_doc:
                golden_chunk = next((chunk for chunk in golden_doc['chunks'] if chunk['index'] == chunk_index), None)
                if golden_chunk:
                    golden_contents.append(golden_chunk['content'].strip())

        if not golden_contents:
            print(f"警告:未找到查询的黄金内容:{query}")
            continue

        retrieved_docs = retrieval_function(query, db, k)

        chunks_found = 0
        for golden_content in golden_contents:
            for doc in retrieved_docs[:k]:
                retrieved_content = doc['chunk']['original_content'].strip()
                if retrieved_content == golden_content:
                    chunks_found += 1
                    break

        query_score = chunks_found / len(golden_contents)
        total_score += query_score

    average_score = total_score / total_queries
    pass_at_n = average_score * 100
    return {
        "pass_at_n": pass_at_n,
        "average_score": average_score,
        "total_queries": total_queries
    }

def evaluate_db_advanced(db, original_jsonl_path, k):
    original_data = load_jsonl(original_jsonl_path)

    def retrieval_function(query, db, k):
        return retrieve_rerank(query, db, k)

    results = evaluate_retrieval_rerank(original_data, retrieval_function, db, k)
    print(f"Pass@{k}: {results['pass_at_n']:.2f}%")
    print(f"Average Score: {results['average_score']}")
    print(f"Total queries: {results['total_queries']}")
    return results
results5 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 5)
results10 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 10)
results20 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 20)
Evaluating retrieval: 100%|██████████| 248/248 [01:22<00:00,  2.99it/s]


Pass@5: 91.24%
Average Score: 0.912442396313364
Total queries: 248


Evaluating retrieval: 100%|██████████| 248/248 [01:34<00:00,  2.63it/s]


Pass@10: 94.79%
Average Score: 0.9479166666666667
Total queries: 248


Evaluating retrieval: 100%|██████████| 248/248 [02:08<00:00,  1.93it/s]

Pass@20: 96.30%
Average Score: 0.9630376344086022
Total queries: 248

后续步骤和关键要点

1) 我们演示了如何使用上下文嵌入来提高检索性能,然后通过上下文BM25和重新排序进行了进一步改进。

2) 此示例使用了代码库,但这些方法也适用于其他数据类型,如内部公司知识库、财务与法律内容、教育内容等。

3) 如果您是AWS用户,可以从contextual-rag-lambda-function中的Lambda函数开始,如果您是GCP用户,可以启动自己的Cloud Run实例并遵循类似的模式!