Note:为了根据文本文件回答问题,我们建议采用使用嵌入进行问答中的过程。下面的一些代码可能依赖于已弃用的 API 端点

1. 收集关于 2020 年奥运会的维基百科数据

本项目旨在创建一个基于提供的几段文本的问答模型。当答案包含在段落中时,基础 GPT-3 模型在回答问题方面表现良好,但如果答案不包含在内,基础模型往往会尽力回答,这通常会导致捏造的答案。

为了创建一个仅在有足够上下文时才回答问题的模型,我们首先根据文本段落创建一个问题和答案的数据集。为了训练模型仅在答案存在时才回答,我们还添加了对抗性示例,其中问题与上下文不匹配。在这些情况下,我们要求模型输出“没有足够的上下文来回答问题”。

我们将分三个笔记本完成此任务:

  1. 第一个(本)笔记本侧重于收集 GPT-3 在预训练期间未见过​​的近期数据。我们选择了 2020 年奥运会(实际上发生在 2021 年夏季)的主题,并下载了 713 个唯一页面。我们将数据集按各个部分进行了组织,这些部分将作为提问和回答问题的上下文。
  2. 第二个笔记本将利用 Davinci-instruct 根据维基百科的某个部分提出一些问题,并根据该部分回答这些问题。
  3. 第三个笔记本将利用上下文、问题和答案对的数据集来另外创建对抗性问题和上下文对,其中问题不是基于该上下文生成的。在这些情况下,模型将被提示回答“没有足够的上下文来回答问题”。我们还将训练一个判别器模型,该模型预测问题是否可以根据上下文进行回答。

1.1 使用维基百科 API 进行数据提取

提取数据大约需要半小时,处理也可能需要同样长的时间。

import pandas as pd
import wikipedia


def filter_olympic_2020_titles(titles):
    """
    给定一个标题列表,获取与 2020 年奥运会相关的标题
    """
    titles = [title for title in titles if '2020' in title and 'olympi' in title.lower()]

    return titles

def get_wiki_page(title):
    """
    给定一个标题获取维基百科页面
    """
    try:
        return wikipedia.page(title)
    except wikipedia.exceptions.DisambiguationError as e:
        return wikipedia.page(e.options[0])
    except wikipedia.exceptions.PageError as e:
        return None

def recursively_find_all_pages(titles, titles_so_far=set()):
    """
    递归地查找链接到列表中维基百科标题的所有页面
    """
    all_pages = []

    titles = list(set(titles) - titles_so_far)
    titles = filter_olympic_2020_titles(titles)
    titles_so_far.update(titles)
    for title in titles:
        page = get_wiki_page(title)
        if page is None:
            continue
        all_pages.append(page)

        new_pages = recursively_find_all_pages(page.links, titles_so_far)
        for pg in new_pages:
            if pg.title not in [p.title for p in all_pages]:
                all_pages.append(pg)
        titles_so_far.update(page.links)
    return all_pages


pages = recursively_find_all_pages(["2020 Summer Olympics"])
len(pages)
909

1.2 按标题过滤维基百科页面并将它们拆分成节

我们删除不太可能包含文本信息的节,并确保每个节的长度不超过令牌限制。

import re
from typing import Set
from transformers import GPT2TokenizerFast

import numpy as np
from nltk.tokenize import sent_tokenize

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

def count_tokens(text: str) -> int:
    """计算字符串中的令牌数"""
    return len(tokenizer.encode(text))

def reduce_long(
    long_text: str, long_text_tokens: bool = False, max_len: int = 590
) -> str:
    """
    通过可能在句子末尾截断,将长文本减少到最多 `max_len` 个令牌
    """
    if not long_text_tokens:
        long_text_tokens = count_tokens(long_text)
    if long_text_tokens > max_len:
        sentences = sent_tokenize(long_text.replace("\n", " "))
        ntokens = 0
        for i, sentence in enumerate(sentences):
            ntokens += 1 + count_tokens(sentence)
            if ntokens > max_len:
                return ". ".join(sentences[:i]) + "."

    return long_text

discard_categories = ['参见', '参考文献', '外部链接', '延伸阅读', "注释",
    "参考书目", "来源", "引文", "文献", "注释", "注释和参考文献",
    "图片库", "引文", "照片", "图库", "注释", "参考文献和来源",
    "参考文献和注释",]


def extract_sections(
    wiki_text: str,
    title: str,
    max_len: int = 1500,
    discard_categories: Set[str] = discard_categories,
) -> str:
    """
    提取维基百科页面的节,丢弃参考文献和其他信息量少的节
    """
    if len(wiki_text) == 0:
        return []

    # 查找所有标题和相应的内​​容
    headings = re.findall("==+ .* ==+", wiki_text)
    for heading in headings:
        wiki_text = wiki_text.replace(heading, "==+ !! ==+")
    contents = wiki_text.split("==+ !! ==+")
    contents = [c.strip() for c in contents]
    assert len(headings) == len(contents) - 1

    cont = contents.pop(0).strip()
    outputs = [(title, "摘要", cont, count_tokens(cont)+4)]

    # 丢弃丢弃类别,考虑树结构
    max_level = 100
    keep_group_level = max_level
    remove_group_level = max_level
    nheadings, ncontents = [], []
    for heading, content in zip(headings, contents):
        plain_heading = " ".join(heading.split(" ")[1:-1])
        num_equals = len(heading.split(" ")[0])
        if num_equals <= keep_group_level:
            keep_group_level = max_level

        if num_equals > remove_group_level:
            if (
                num_equals <= keep_group_level
            ):
                continue
        keep_group_level = max_level
        if plain_heading in discard_categories:
            remove_group_level = num_equals
            keep_group_level = max_level
            continue
        nheadings.append(heading.replace("=", "").strip())
        ncontents.append(content)
        remove_group_level = max_level

    # 计算每个节的令牌数
    ncontent_ntokens = [
        count_tokens(c)

        + 3
        + count_tokens(" ".join(h.split(" ")[1:-1]))
        - (1 if len(c) == 0 else 0)
        for h, c in zip(nheadings, ncontents)
    ]

    # 创建一个元组 (title, section_name, content, number of tokens)
    outputs += [(title, h, c, t) if t<max_len 
                else (title, h, reduce_long(c, max_len), count_tokens(reduce_long(c,max_len))) 
                    for h, c, t in zip(nheadings, ncontents, ncontent_ntokens)]

    return outputs

# 正在处理的示例页面
bermuda_page = get_wiki_page('Bermuda at the 2020 Summer Olympics')
ber = extract_sections(bermuda_page.content, bermuda_page.title)

# 示例节
ber[-1]
('Bermuda at the 2020 Summer Olympics', 'Equestrian', "Bermuda entered one dressage rider into the Olympic competition by finishing in the top four, outside the group selection, of the individual FEI Olympic Rankings for Groups D and E (North, Central, and South America), marking the country's recurrence to the sport after an eight-year absence. The quota was later withdrawn, following an injury of Annabelle Collins' main horse Joyero and a failure to obtain minimum eligibility requirements (MER) aboard a new horse Chuppy Checker.", 104)

1.2.1 我们创建一个数据集,并过滤掉任何少于 40 个令牌的节,因为这些节不太可能包含足够的上下文来提出一个好的问题。

res = []
for page in pages:
    res += extract_sections(page.content, page.title)
df = pd.DataFrame(res, columns=["title", "heading", "content", "tokens"])
df = df[df.tokens>40]
df = df.drop_duplicates(['title','heading'])
df = df.reset_index().drop('index',axis=1) # 重置索引
df.head()
Token indices sequence length is longer than the specified maximum sequence length for this model (1060 > 1024). Running this sequence through the model will result in indexing errors
title heading content tokens
0 2020 Summer Olympics Summary The 2020 Summer Olympics (Japanese: 2020年夏季オリン... 713
1 2020 Summer Olympics Host city selection The International Olympic Committee (IOC) vote... 126
2 2020 Summer Olympics Impact of the COVID-19 pandemic In January 2020, concerns were raised about th... 369
3 2020 Summer Olympics Qualifying event cancellation and postponement Concerns about the pandemic began to affect qu... 298
4 2020 Summer Olympics Effect on doping tests Mandatory doping tests were being severely res... 163

保存节数据集

我们将保存节数据集,供 下一个笔记本 使用。

df.to_csv('olympics-data/olympics_sections.csv', index=False)

1.3(可选)数据探索

df.title.value_counts().head()
Concerns and controversies at the 2020 Summer Olympics    51
United States at the 2020 Summer Olympics                 46
Great Britain at the 2020 Summer Olympics                 42
Canada at the 2020 Summer Olympics                        39
Olympic Games                                             39
Name: title, dtype: int64

似乎有冬季和夏季奥运会 2020。我们选择在数据集中保留一些模糊性和噪声,即使我们只对 2020 年夏季奥运会感兴趣。

df.title.str.contains('Summer').value_counts()
True     3567
False     305
Name: title, dtype: int64
df.title.str.contains('Winter').value_counts()
False    3774
True       98
Name: title, dtype: int64
import pandas as pd
from matplotlib import pyplot as plt

df = pd.read_csv('olympics-data/olympics_sections.csv')
df[['tokens']].hist()
# 添加轴描述和标题
plt.xlabel('令牌数')
plt.ylabel('维基百科节数')
plt.title('维基百科节中令牌数分布')
plt.show()

png

我们可以看到,大多数节都相当短(少于 500 个令牌)。