嵌入维基百科文章以供搜索
本笔记本展示了我们如何为搜索准备维基百科文章数据集,该数据集用于 Question_answering_using_embeddings.ipynb。
流程:
- 先决条件:导入库,设置 API 密钥(如果需要)
- 收集:我们下载几百篇关于 2022 年奥运会的维基百科文章
- 分块:文档被分割成简短的、半独立的片段以供嵌入
- 嵌入:每个片段都使用 OpenAI API 进行嵌入
- 存储:嵌入保存在 CSV 文件中(对于大型数据集,请使用向量数据库)
0. 先决条件
导入库
# imports
import mwclient # 用于下载示例维基百科文章
import mwparserfromhell # 用于将维基百科文章分割成节
from openai import OpenAI # 用于生成嵌入
import os # 用于环境变量
import pandas as pd # 用于存储文章节和嵌入的 DataFrame
import re # 用于从维基百科文章中删除 <ref> 链接
import tiktoken # 用于计算 token 数量
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))
使用终端中的 pip install
安装任何缺少的库。例如:
pip install openai
(您也可以在笔记本单元格中使用 !pip install openai
来完成此操作。)
如果您安装了任何库,请务必重启笔记本内核。
设置 API 密钥(如果需要)
请注意,OpenAI 库将尝试从 OPENAI_API_KEY
环境变量读取您的 API 密钥。如果您还没有设置,请按照这些说明设置此环境变量。
1. 收集文档
在此示例中,我们将下载几百篇与 2022 年冬季奥运会相关的维基百科文章。
# 获取关于 2022 年冬季奥运会的维基百科页面
CATEGORY_TITLE = "Category:2022 Winter Olympics"
WIKI_SITE = "en.wikipedia.org"
def titles_from_category(
category: mwclient.listing.Category, max_depth: int
) -> set[str]:
"""返回给定维基类别及其子类别中的页面标题集。"""
titles = set()
for cm in category.members():
if type(cm) == mwclient.page.Page:
# ^type() 用于代替 isinstance() 来捕获匹配但没有继承的
titles.add(cm.name)
elif isinstance(cm, mwclient.listing.Category) and max_depth > 0:
deeper_titles = titles_from_category(cm, max_depth=max_depth - 1)
titles.update(deeper_titles)
return titles
site = mwclient.Site(WIKI_SITE)
category_page = site.pages[CATEGORY_TITLE]
titles = titles_from_category(category_page, max_depth=1)
# ^注意:max_depth=1 表示我们在类别树中只深入一层
print(f"在 {CATEGORY_TITLE} 中找到 {len(titles)} 个文章标题。")
在 Category:2022 Winter Olympics 中找到 179 个文章标题。
2. 分块文档
现在我们有了参考文档,需要为搜索做准备。
由于 GPT 一次只能读取有限的文本量,我们将每个文档分割成足够短的块。
对于这个关于维基百科文章的具体示例,我们将:
- 丢弃不太相关的部分,如“外部链接”和“脚注”
- 通过删除引用标签(例如
<ref>
)、空格和过短的节来清理文本 - 将每篇文章分割成节
- 在每个节的文本前加上标题和副标题,以帮助 GPT 理解上下文
- 如果一个节很长(例如,超过 1600 个 token),我们将递归地将其分割成更小的节,尝试沿着语义边界(如段落)进行分割
# 定义从维基百科页面分割成节的函数
SECTIONS_TO_IGNORE = [
"See also",
"References",
"External links",
"Further reading",
"Footnotes",
"Bibliography",
"Sources",
"Citations",
"Literature",
"Footnotes",
"Notes and references",
"Photo gallery",
"Works cited",
"Photos",
"Gallery",
"Notes",
"References and sources",
"References and notes",
]
def all_subsections_from_section(
section: mwparserfromhell.wikicode.Wikicode,
parent_titles: list[str],
sections_to_ignore: set[str],
) -> list[tuple[list[str], str]]:
"""
从一个维基百科节中,返回所有嵌套子节的扁平列表。
每个子节是一个元组,其中:
- 第一个元素是父标题列表,以页面标题开头
- 第二个元素是子节的文本(但不包括任何子节的文本)
"""
headings = [str(h) for h in section.filter_headings()]
title = headings[0]
if title.strip("=" + " ") in sections_to_ignore:
# ^wiki 标题被包装成 "== Heading =="
return []
titles = parent_titles + [title]
full_text = str(section)
section_text = full_text.split(title)[1]
if len(headings) == 1:
return [(titles, section_text)]
else:
first_subtitle = headings[1]
section_text = section_text.split(first_subtitle)[0]
results = [(titles, section_text)]
for subsection in section.get_sections(levels=[len(titles) + 1]):
results.extend(all_subsections_from_section(subsection, titles, sections_to_ignore))
return results
def all_subsections_from_title(
title: str,
sections_to_ignore: set[str] = SECTIONS_TO_IGNORE,
site_name: str = WIKI_SITE,
) -> list[tuple[list[str], str]]:
"""从维基百科页面标题返回所有嵌套子节的扁平列表。
每个子节是一个元组,其中:
- 第一个元素是父标题列表,以页面标题开头
- 第二个元素是子节的文本(但不包括任何子节的文本)
"""
site = mwclient.Site(site_name)
page = site.pages[title]
text = page.text()
parsed_text = mwparserfromhell.parse(text)
headings = [str(h) for h in parsed_text.filter_headings()]
if headings:
summary_text = str(parsed_text).split(headings[0])[0]
else:
summary_text = str(parsed_text)
results = [([title], summary_text)]
for subsection in parsed_text.get_sections(levels=[2]):
results.extend(all_subsections_from_section(subsection, [title], sections_to_ignore))
return results
# 分割页面成节
# 可能需要 ~1 分钟处理 100 篇文章
wikipedia_sections = []
for title in titles:
wikipedia_sections.extend(all_subsections_from_title(title))
print(f"在 {len(titles)} 篇文章中找到 {len(wikipedia_sections)} 个节。")
在 179 篇文章中找到 1838 个节。
# 清理文本
def clean_section(section: tuple[list[str], str]) -> tuple[list[str], str]:
"""
返回一个清理过的节,其中:
- 移除了 <ref>xyz</ref> 模式
- 移除了首尾空格
"""
titles, text = section
text = re.sub(r"<ref.*?</ref>", "", text)
text = text.strip()
return (titles, text)
wikipedia_sections = [clean_section(ws) for ws in wikipedia_sections]
# 过滤掉短的/空的节
def keep_section(section: tuple[list[str], str]) -> bool:
"""如果节应该被保留,则返回 True,否则返回 False。"""
titles, text = section
if len(text) < 16:
return False
else:
return True
original_num_sections = len(wikipedia_sections)
wikipedia_sections = [ws for ws in wikipedia_sections if keep_section(ws)]
print(f"过滤掉了 {original_num_sections-len(wikipedia_sections)} 个节,剩下 {len(wikipedia_sections)} 个节。")
过滤掉了 89 个节,剩下 1749 个节。
# 打印示例数据
for ws in wikipedia_sections[:5]:
print(ws[0])
display(ws[1][:77] + "...")
print()
['Concerns and controversies at the 2022 Winter Olympics']
'{{Short description|Overview of concerns and controversies surrounding the Ga...'
['Concerns and controversies at the 2022 Winter Olympics', '==Criticism of host selection==']
'American sportscaster [[Bob Costas]] criticized the [[International Olympic C...'
['Concerns and controversies at the 2022 Winter Olympics', '==Organizing concerns and controversies==', '===Cost and climate===']
'Several cities withdrew their applications during [[Bids for the 2022 Winter ...'
['Concerns and controversies at the 2022 Winter Olympics', '==Organizing concerns and controversies==', '===Promotional song===']
'Some commentators alleged that one of the early promotional songs for the [[2...'
['Concerns and controversies at the 2022 Winter Olympics', '== Diplomatic boycotts or non-attendance ==']
'<section begin=boycotts />\n[[File:2022 Winter Olympics (Beijing) diplomatic b...'
接下来,我们将递归地分割长节为更小的节。
分割文本成节没有完美的公式。
一些权衡包括:
- 更长的节可能更适合需要更多上下文的问题
- 更长的节可能不利于检索,因为它们可能混合了更多主题
- 更短的节更适合降低成本(成本与 token 数量成正比)
- 更短的节允许检索更多节,这可能有助于召回
- 重叠的节可能有助于防止答案被节边界截断
在这里,我们将使用一种简单的方法,将节限制在每个 1600 个 token,并将过长的节递归地分成两半。为了避免在句子中间分割,我们将尽可能沿着语义边界进行分割。
GPT_MODEL = "gpt-4o-mini" # 仅在选择使用哪个分词器时相关
def num_tokens(text: str, model: str = GPT_MODEL) -> int:
"""返回字符串中的 token 数量。"""
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(text))
def halved_by_delimiter(string: str, delimiter: str = "\n") -> list[str, str]:
"""在一个分隔符上将字符串分成两半,尝试平衡两侧的 token 数量。"""
chunks = string.split(delimiter)
if len(chunks) == 1:
return [string, ""] # 未找到分隔符
elif len(chunks) == 2:
return chunks # 无需搜索中间点
else:
total_tokens = num_tokens(string)
halfway = total_tokens // 2
best_diff = halfway
for i, chunk in enumerate(chunks):
left = delimiter.join(chunks[: i + 1])
left_tokens = num_tokens(left)
diff = abs(halfway - left_tokens)
if diff >= best_diff:
break
else:
best_diff = diff
left = delimiter.join(chunks[:i])
right = delimiter.join(chunks[i:])
return [left, right]
def truncated_string(
string: str,
model: str,
max_tokens: int,
print_warning: bool = True,
) -> str:
"""将字符串截断到最大 token 数量。"""
encoding = tiktoken.encoding_for_model(model)
encoded_string = encoding.encode(string)
truncated_string = encoding.decode(encoded_string[:max_tokens])
if print_warning and len(encoded_string) > max_tokens:
print(f"警告:字符串已从 {len(encoded_string)} 个 token 截断到 {max_tokens} 个 token。")
return truncated_string
def split_strings_from_subsection(
subsection: tuple[list[str], str],
max_tokens: int = 1000,
model: str = GPT_MODEL,
max_recursion: int = 5,
) -> list[str]:
"""
将一个子节分割成一个子节列表,每个子节最多包含 max_tokens。
每个子节是一个元组,包含父标题 [H1, H2, ...] 和文本 (str)。
"""
titles, text = subsection
string = "\n\n".join(titles + [text])
num_tokens_in_string = num_tokens(string)
# 如果长度合适,则返回字符串
if num_tokens_in_string <= max_tokens:
return [string]
# 如果递归 X 次后仍未找到分割点,则直接截断
elif max_recursion == 0:
return [truncated_string(string, model=model, max_tokens=max_tokens)]
# 否则,分成两半并递归处理
else:
titles, text = subsection
for delimiter in ["\n\n", "\n", ". "]:
left, right = halved_by_delimiter(text, delimiter=delimiter)
if left == "" or right == "":
# 如果任一半为空,则尝试使用更精细的分隔符
continue
else:
# 对每一半进行递归处理
results = []
for half in [left, right]:
half_subsection = (titles, half)
half_strings = split_strings_from_subsection(
half_subsection,
max_tokens=max_tokens,
model=model,
max_recursion=max_recursion - 1,
)
results.extend(half_strings)
return results
# 否则未找到分割点,直接截断(这种情况非常罕见)
return [truncated_string(string, model=model, max_tokens=max_tokens)]
# 将节分割成块
MAX_TOKENS = 1600
wikipedia_strings = []
for section in wikipedia_sections:
wikipedia_strings.extend(split_strings_from_subsection(section, max_tokens=MAX_TOKENS))
print(f"{len(wikipedia_sections)} 个维基百科节被分割成 {len(wikipedia_strings)} 个字符串。")
1749 个维基百科节被分割成 2052 个字符串。
# 打印示例数据
print(wikipedia_strings[1])
Concerns and controversies at the 2022 Winter Olympics
==Criticism of host selection==
American sportscaster [[Bob Costas]] criticized the [[International Olympic Committee]]'s (IOC) decision to award the games to China saying "The IOC deserves all of the disdain and disgust that comes their way for going back to China yet again" referencing China's human rights record.
After winning two gold medals and returning to his home country of Sweden skater [[Nils van der Poel]] criticized the IOC's selection of China as the host saying "I think it is extremely irresponsible to give it to a country that violates human rights as blatantly as the Chinese regime is doing." He had declined to criticize China before leaving for the games saying "I don't think it would be particularly wise for me to criticize the system I'm about to transition to, if I want to live a long and productive life."
3. 嵌入文档块
现在我们已经将库分割成更短的独立字符串,我们可以为每个字符串计算嵌入。
(对于大型嵌入任务,请使用类似 api_request_parallel_processor.py 的脚本来并行化请求,同时进行限流以保持在速率限制内。)
EMBEDDING_MODEL = "text-embedding-3-small"
BATCH_SIZE = 1000 # 您每次请求最多可以提交 2048 个嵌入输入
embeddings = []
for batch_start in range(0, len(wikipedia_strings), BATCH_SIZE):
batch_end = batch_start + BATCH_SIZE
batch = wikipedia_strings[batch_start:batch_end]
print(f"Batch {batch_start} to {batch_end-1}")
response = client.embeddings.create(model=EMBEDDING_MODEL, input=batch)
for i, be in enumerate(response.data):
assert i == be.index # 再次检查嵌入是否与输入顺序相同
batch_embeddings = [e.embedding for e in response.data]
embeddings.extend(batch_embeddings)
df = pd.DataFrame({"text": wikipedia_strings, "embedding": embeddings})
Batch 0 to 999
Batch 1000 to 1999
Batch 2000 to 2999
4. 存储文档块和嵌入
因为这个示例只使用了几千个字符串,所以我们将它们存储在 CSV 文件中。
(对于更大的数据集,请使用向量数据库,它将更具性能。)
# 保存文档块和嵌入
SAVE_PATH = "data/winter_olympics_2022.csv"
df.to_csv(SAVE_PATH, index=False)