如何使用 Claude 进行 SQL 查询

在本笔记本中,我们将探讨如何使用 Claude 根据自然语言问题生成 SQL 查询。我们将设置一个测试数据库,向 Claude 提供模式,并演示它如何理解和转换人类语言为 SQL 查询。

设置

首先,让我们安装必要的库,并使用我们的 API 密钥设置 Anthropic 客户端。

# 安装必要的库
%pip install anthropic
# 导入所需的库
from anthropic import Anthropic
import sqlite3

# 设置 Anthropic API 客户端
client = Anthropic()
MODEL_NAME = "claude-3-opus-20240229"

创建测试数据库

我们将使用 SQLite 创建一个测试数据库并填充示例数据:

# 连接到测试数据库(如果不存在则创建)
conn = sqlite3.connect("test_db.db")
cursor = conn.cursor()

# 创建一个示例表
cursor.execute("""
    CREATE TABLE IF NOT EXISTS employees (
        id INTEGER PRIMARY KEY,
        name TEXT,
        department TEXT,
        salary INTEGER
    )
""")

# 插入示例数据
sample_data = [
    (1, "John Doe", "Sales", 50000),
    (2, "Jane Smith", "Engineering", 75000),
    (3, "Mike Johnson", "Sales", 60000),
    (4, "Emily Brown", "Engineering", 80000),
    (5, "David Lee", "Marketing", 55000)
]
cursor.executemany("INSERT INTO employees VALUES (?, ?, ?, ?)", sample_data)
conn.commit()

使用 Claude 生成 SQL 查询

现在,让我们定义一个函数,将自然语言问题发送给 Claude 并获取生成的 SQL 查询:

# 定义一个将查询发送给 Claude 并获取响应的函数
def ask_claude(query, schema):
    prompt = f"""这是数据库的模式:

{schema}

给定此模式,你能否输出一个 SQL 查询来回答以下问题?只输出 SQL 查询,不要输出其他任何内容。

问题: {query}
"""

    response = client.messages.create(
        model=MODEL_NAME,
        max_tokens=2048,
        messages=[{
            "role": 'user', "content":  prompt
        }]
    )
    return response.content[0].text

我们将检索数据库模式并将其格式化为字符串:

# 获取数据库模式
schema = cursor.execute("PRAGMA table_info(employees)").fetchall()
schema_str = "CREATE TABLE EMPLOYEES (\n" + "\n".join([f"{col[1]} {col[2]}" for col in schema]) + "\n)"
print(schema_str)
CREATE TABLE EMPLOYEES (
id INTEGER
name TEXT
department TEXT
salary INTEGER
)

现在,让我们提供一个示例自然语言问题并将其发送给 Claude:

# 示例自然语言问题
question = "工程部门的员工姓名和薪水是多少?"
# 将问题发送给 Claude 并获取 SQL 查询
sql_query = ask_claude(question, schema_str)
print(sql_query)
SELECT name, salary
FROM EMPLOYEES
WHERE department = 'Engineering';

执行生成的 SQL 查询

最后,我们将在测试数据库上执行生成的 SQL 查询并打印结果:

# 执行 SQL 查询并打印结果
results = cursor.execute(sql_query).fetchall()

for row in results:
    print(row)
('Jane Smith', 75000)
('Emily Brown', 80000)

别忘了在完成后关闭数据库连接:

# 关闭数据库连接
conn.close()