以下是代码搜索示例的逐字翻译:

使用嵌入进行代码搜索

本笔记本展示了如何使用 Ada 嵌入来实现语义代码搜索。在本演示中,我们使用了我们自己的 openai-python 代码仓库。我们实现了一个简单的文件解析和从 Python 文件中提取函数的方法,这些函数可以被嵌入、索引和查询。

辅助函数

我们首先设置一些简单的解析函数,允许我们从代码库中提取重要信息。

import pandas as pd
from pathlib import Path

DEF_PREFIXES = ['def ', 'async def ']
NEWLINE = '\n'

def get_function_name(code):
    """
    从以 'def' 或 'async def' 开头的行中提取函数名。
    """
    for prefix in DEF_PREFIXES:
        if code.startswith(prefix):
            return code[len(prefix): code.index('(')]


def get_until_no_space(all_lines, i):
    """
    获取直到找到函数定义之外的行。
    """
    ret = [all_lines[i]]
    for j in range(i + 1, len(all_lines)):
        if len(all_lines[j]) == 0 or all_lines[j][0] in [' ', '\t', ')']:
            ret.append(all_lines[j])
        else:
            break
    return NEWLINE.join(ret)


def get_functions(filepath):
    """
    获取 Python 文件中的所有函数。
    """
    with open(filepath, 'r') as file:
        all_lines = file.read().replace('\r', NEWLINE).split(NEWLINE)
        for i, l in enumerate(all_lines):
            for prefix in DEF_PREFIXES:
                if l.startswith(prefix):
                    code = get_until_no_space(all_lines, i)
                    function_name = get_function_name(code)
                    yield {
                        'code': code,
                        'function_name': function_name,
                        'filepath': filepath,
                    }
                    break


def extract_functions_from_repo(code_root):
    """
    从仓库中提取所有 .py 函数。
    """
    code_files = list(code_root.glob('**/*.py'))

    num_files = len(code_files)
    print(f'总共有多少个 .py 文件: {num_files}')

    if num_files == 0:
        print('请验证 openai-python 仓库是否存在并且 code_root 设置正确。')
        return None

    all_funcs = [
        func
        for code_file in code_files
        for func in get_functions(str(code_file))
    ]

    num_funcs = len(all_funcs)
    print(f'提取的总函数数: {num_funcs}')

    return all_funcs

数据加载

我们将首先加载 openai-python 文件夹,并使用我们上面定义的函数提取所需信息。

# 将用户根目录设置为 'openai-python' 仓库
root_dir = Path.home()

# 假设 'openai-python' 仓库存在于用户的根目录中
code_root = root_dir / 'openai-python'

# 从仓库中提取所有函数
all_funcs = extract_functions_from_repo(code_root)
总共有多少个 .py 文件: 51
提取的总函数数: 97

现在我们有了内容,可以将其传递给 text-embedding-3-small 模型并获取向量嵌入。

from utils.embeddings_utils import get_embedding

df = pd.DataFrame(all_funcs)
df['code_embedding'] = df['code'].apply(lambda x: get_embedding(x, model='text-embedding-3-small'))
df['filepath'] = df['filepath'].map(lambda x: Path(x).relative_to(code_root))
df.to_csv("data/code_search_openai-python.csv", index=False)
df.head()
code function_name filepath code_embedding
0 def _console_log_level(): if openai.log i... _console_log_level openai/util.py [0.005937571171671152, 0.05450401455163956, 0....
1 def log_debug(message, **params): msg = l... log_debug openai/util.py [0.017557814717292786, 0.05647840350866318, -0...
2 def log_info(message, **params): msg = lo... log_info openai/util.py [0.022524144500494003, 0.06219055876135826, -0...
3 def log_warn(message, **params): msg = lo... log_warn openai/util.py [0.030524108558893204, 0.0667714849114418, -0....
4 def logfmt(props): def fmt(key, val): ... logfmt openai/util.py [0.05337328091263771, 0.03697286546230316, -0....

测试

让我们用一些简单的查询来测试我们的端点。如果您熟悉 openai-python 仓库,您会发现我们能够轻松地仅通过简单的英文描述找到我们想要的函数。

我们定义了一个 search_functions 方法,它接受包含嵌入的数据、一个查询字符串以及一些其他配置选项。搜索我们数据库的过程如下:

  1. 我们首先使用 text-embedding-3-small 嵌入我们的查询字符串(code_query)。这样做的原因是,像“一个反转字符串的函数”这样的查询字符串和一个像“def reverse(string): return string[::-1]”这样的函数在嵌入后会非常相似。
  2. 然后,我们计算我们的查询字符串嵌入与数据库中所有数据点之间的余弦相似度。这给出了每个数据点与我们查询的距离。
  3. 最后,我们按数据点与查询字符串的距离对所有数据点进行排序,并返回函数参数中请求的结果数量。
from utils.embeddings_utils import cosine_similarity

def search_functions(df, code_query, n=3, pprint=True, n_lines=7):
    embedding = get_embedding(code_query, model='text-embedding-3-small')
    df['similarities'] = df.code_embedding.apply(lambda x: cosine_similarity(x, embedding))

    res = df.sort_values('similarities', ascending=False).head(n)

    if pprint:
        for r in res.iterrows():
            print(f"{r[1].filepath}:{r[1].function_name}  score={round(r[1].similarities, 3)}")
            print("\n".join(r[1].code.split("\n")[:n_lines]))
            print('-' * 70)

    return res
res = search_functions(df, 'fine-tuning input data validation logic', n=3)
openai/validators.py:format_inferrer_validator  score=0.453
def format_inferrer_validator(df):
    """
    此验证器将推断数据的可能微调格式,并在分类时将其显示给用户。
    它还将建议使用 ada 并解释训练/验证拆分的优点。
    """
    ft_type = infer_task_type(df)
    immediate_msg = None
----------------------------------------------------------------------
openai/validators.py:infer_task_type  score=0.37
def infer_task_type(df):
    """
    从数据中推断可能的微调任务类型
    """
    CLASSIFICATION_THRESHOLD = 3  # 每个类的最小平均实例数
    if sum(df.prompt.str.len()) == 0:
        return "open-ended generation"
----------------------------------------------------------------------
openai/validators.py:apply_validators  score=0.369
def apply_validators(
    df,
    fname,
    remediation,
    validators,
    auto_accept,
    write_out_file_func,
----------------------------------------------------------------------
res = search_functions(df, 'find common suffix', n=2, n_lines=10)
openai/validators.py:get_common_xfix  score=0.487
def get_common_xfix(series, xfix="suffix"):
    """
    查找系列中所有值的最长公共后缀或前缀
    """
    common_xfix = ""
    while True:
        common_xfixes = (
            series.str[-(len(common_xfix) + 1) :]
            if xfix == "suffix"
            else series.str[: len(common_xfix) + 1]
----------------------------------------------------------------------
openai/validators.py:common_completion_suffix_validator  score=0.449
def common_completion_suffix_validator(df):
    """
    此验证器将在分类或条件生成的情况下,建议在完成时添加一个公共后缀(如果尚不存在)。
    """
    error_msg = None
    immediate_msg = None
    optional_msg = None
    optional_fn = None

    ft_type = infer_task_type(df)
----------------------------------------------------------------------
res = search_functions(df, 'Command line interface for fine-tuning', n=1, n_lines=20)
openai/cli.py:tools_register  score=0.391
def tools_register(parser):
    subparsers = parser.add_subparsers(
        title="Tools", help="Convenience client side tools"
    )

    def help(args):
        parser.print_help()

    parser.set_defaults(func=help)

    sub = subparsers.add_parser("fine_tunes.prepare_data")
    sub.add_argument(
        "-f",
        "--file",
        required=True,
        help="JSONL, JSON, CSV, TSV, TXT 或 XLSX 文件,包含要分析的提示-完成示例。"
        "这应该是本地文件路径。",
    )
    sub.add_argument(
        "-q",
----------------------------------------------------------------------