结构化输出用于多代理系统
在本指南中,我们将探讨如何使用结构化输出来构建多代理系统。
结构化输出是一项新功能,它建立在 JSON 模式和函数调用之上,用于强制模型输出遵循严格的模式。
通过使用新的 strict: true
参数,我们可以保证响应符合提供的模式。
为了演示此功能,我们将使用它来构建一个多代理系统。
为什么要构建多代理系统?
在使用函数调用时,如果函数(或工具)的数量增加,性能可能会受到影响。
为了缓解这种情况,我们可以将工具进行逻辑分组,并让专门的“代理”来解决特定的任务或子任务,这将提高整体系统性能。
环境设置
from openai import OpenAI
from IPython.display import Image
import json
import pandas as pd
import matplotlib.pyplot as plt
from io import StringIO
import numpy as np
client = OpenAI()
MODEL = "gpt-4o-2024-08-06"
代理设置
我们将要处理的用例是数据分析任务。
让我们先设置我们的 4 个代理系统:
- 分类代理: 决定调用哪个代理。
- 数据预处理代理: 准备数据以供分析,例如通过清理数据。
- 数据分析代理: 对数据执行分析。
- 数据可视化代理: 可视化分析结果以提取见解。
我们将首先为每个代理定义系统提示。
triaging_system_prompt = """你是一个分类代理。你的角色是评估用户的查询并将其路由到相关代理。可用的代理是:
- 数据处理代理:清理、转换和聚合数据。
- 分析代理:执行统计、相关性和回归分析。
- 可视化代理:创建条形图、折线图和饼图。
使用 send_query_to_agents 工具将用户的查询转发给相关代理。此外,如果需要,请使用 speak_to_user 工具从用户那里获取更多信息。"""
processing_system_prompt = """你是数据处理代理。你的角色是使用以下工具清理、转换和聚合数据:
- clean_data
- transform_data
- aggregate_data"""
analysis_system_prompt = """你是分析代理。你的角色是使用以下工具执行统计、相关性和回归分析:
- stat_analysis
- correlation_analysis
- regression_analysis"""
visualization_system_prompt = """你是可视化代理。你的角色是使用以下工具创建条形图、折线图和饼图:
- create_bar_chart
- create_line_chart
- create_pie_chart"""
然后,我们将为每个代理定义工具。
除了分类代理,每个代理都将配备其角色特定的工具:
数据预处理代理
- 清理数据
- 转换数据
- 聚合数据
数据分析代理
- 统计分析
- 相关性分析
- 回归分析
数据可视化代理
- 创建条形图
- 创建折线图
- 创建饼图
triage_tools = [
{
"type": "function",
"function": {
"name": "send_query_to_agents",
"description": "根据代理的功能将用户查询发送给相关代理。",
"parameters": {
"type": "object",
"properties": {
"agents": {
"type": "array",
"items": {"type": "string"},
"description": "要发送查询的代理名称数组。"
},
"query": {
"type": "string",
"description": "要发送的用户查询。"
}
},
"required": ["agents", "query"]
}
},
"strict": True
}
]
preprocess_tools = [
{
"type": "function",
"function": {
"name": "clean_data",
"description": "通过删除重复项和处理缺失值来清理提供的数据。",
"parameters": {
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "要清理的数据集。应采用合适的格式,如 JSON 或 CSV。"
}
},
"required": ["data"],
"additionalProperties": False
}
},
"strict": True
},
{
"type": "function",
"function": {
"name": "transform_data",
"description": "根据指定的规则转换数据。",
"parameters": {
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "要转换的数据。应采用合适的格式,如 JSON 或 CSV。"
},
"rules": {
"type": "string",
"description": "要应用的转换规则,以结构化格式指定。"
}
},
"required": ["data", "rules"],
"additionalProperties": False
}
},
"strict": True
},
{
"type": "function",
"function": {
"name": "aggregate_data",
"description": "按指定的列和操作聚合数据。",
"parameters": {
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "要聚合的数据。应采用合适的格式,如 JSON 或 CSV。"
},
"group_by": {
"type": "array",
"items": {"type": "string"},
"description": "要分组的列。"
},
"operations": {
"type": "string",
"description": "要执行的聚合操作,以结构化格式指定。"
}
},
"required": ["data", "group_by", "operations"],
"additionalProperties": False
}
},
"strict": True
}
]
analysis_tools = [
{
"type": "function",
"function": {
"name": "stat_analysis",
"description": "对给定数据集执行统计分析。",
"parameters": {
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "要分析的数据集。应采用合适的格式,如 JSON 或 CSV。"
}
},
"required": ["data"],
"additionalProperties": False
}
},
"strict": True
},
{
"type": "function",
"function": {
"name": "correlation_analysis",
"description": "计算数据集中变量之间的相关系数。",
"parameters": {
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "要分析的数据集。应采用合适的格式,如 JSON 或 CSV。"
},
"variables": {
"type": "array",
"items": {"type": "string"},
"description": "要计算相关性的变量列表。"
}
},
"required": ["data", "variables"],
"additionalProperties": False
}
},
"strict": True
},
{
"type": "function",
"function": {
"name": "regression_analysis",
"description": "对数据集执行回归分析。",
"parameters": {
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "要分析的数据集。应采用合适的格式,如 JSON 或 CSV。"
},
"dependent_var": {
"type": "string",
"description": "回归的因变量。"
},
"independent_vars": {
"type": "array",
"items": {"type": "string"},
"description": "自变量列表。"
}
},
"required": ["data", "dependent_var", "independent_vars"],
"additionalProperties": False
}
},
"strict": True
}
]
visualization_tools = [
{
"type": "function",
"function": {
"name": "create_bar_chart",
"description": "根据提供的数据创建条形图。",
"parameters": {
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "条形图的数据。应采用合适的格式,如 JSON 或 CSV。"
},
"x": {
"type": "string",
"description": "x 轴的列。"
},
"y": {
"type": "string",
"description": "y 轴的列。"
}
},
"required": ["data", "x", "y"],
"additionalProperties": False
}
},
"strict": True
},
{
"type": "function",
"function": {
"name": "create_line_chart",
"description": "根据提供的数据创建折线图。",
"parameters": {
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "折线图的数据。应采用合适的格式,如 JSON 或 CSV。"
},
"x": {
"type": "string",
"description": "x 轴的列。"
},
"y": {
"type": "string",
"description": "y 轴的列。"
}
},
"required": ["data", "x", "y"],
"additionalProperties": False
}
},
"strict": True
},
{
"type": "function",
"function": {
"name": "create_pie_chart",
"description": "根据提供的数据创建饼图。",
"parameters": {
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "饼图的数据。应采用合适的格式,如 JSON 或 CSV。"
},
"labels": {
"type": "string",
"description": "标签的列。"
},
"values": {
"type": "string",
"description": "值的列。"
}
},
"required": ["data", "labels", "values"],
"additionalProperties": False
}
},
"strict": True
}
]
工具执行
我们需要编写代码逻辑来:
- 处理将用户查询传递给多代理系统
- 处理多代理系统的内部工作
- 执行工具调用
为了简洁起见,我们将仅定义与用户查询相关的工具的逻辑。
# 示例查询
user_query = """
以下是一些数据。我希望你先删除重复项,然后分析数据统计信息并绘制折线图。
house_size (m3), house_price ($)
90, 100
80, 90
100, 120
90, 100
"""
从用户查询中,我们可以推断出我们需要调用的工具是 clean_data
、start_analysis
和 use_line_chart
。
我们将首先定义执行工具的函数。
这会将工具调用映射到相应的函数。然后,它会将函数输出附加到对话历史记录中。
def clean_data(data):
data_io = StringIO(data)
df = pd.read_csv(data_io, sep=",")
df_deduplicated = df.drop_duplicates()
return df_deduplicated
def stat_analysis(data):
data_io = StringIO(data)
df = pd.read_csv(data_io, sep=",")
return df.describe()
def plot_line_chart(data):
data_io = StringIO(data)
df = pd.read_csv(data_io, sep=",")
x = df.iloc[:, 0]
y = df.iloc[:, 1]
coefficients = np.polyfit(x, y, 1)
polynomial = np.poly1d(coefficients)
y_fit = polynomial(x)
plt.figure(figsize=(10, 6))
plt.plot(x, y, 'o', label='Data Points')
plt.plot(x, y_fit, '-', label='Best Fit Line')
plt.title('Line Chart with Best Fit Line')
plt.xlabel(df.columns[0])
plt.ylabel(df.columns[1])
plt.legend()
plt.grid(True)
plt.show()
# 定义执行工具的函数
def execute_tool(tool_calls, messages):
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_arguments = json.loads(tool_call.function.arguments)
if tool_name == 'clean_data':
# 模拟数据清理
cleaned_df = clean_data(tool_arguments['data'])
cleaned_data = {"cleaned_data": cleaned_df.to_dict()}
messages.append({"role": "tool", "name": tool_name, "content": json.dumps(cleaned_data)})
print('Cleaned data: ', cleaned_df)
elif tool_name == 'transform_data':
# 模拟数据转换
transformed_data = {"transformed_data": "sample_transformed_data"}
messages.append({"role": "tool", "name": tool_name, "content": json.dumps(transformed_data)})
elif tool_name == 'aggregate_data':
# 模拟数据聚合
aggregated_data = {"aggregated_data": "sample_aggregated_data"}
messages.append({"role": "tool", "name": tool_name, "content": json.dumps(aggregated_data)})
elif tool_name == 'stat_analysis':
# 模拟统计分析
stats_df = stat_analysis(tool_arguments['data'])
stats = {"stats": stats_df.to_dict()}
messages.append({"role": "tool", "name": tool_name, "content": json.dumps(stats)})
print('Statistical Analysis: ', stats_df)
elif tool_name == 'correlation_analysis':
# 模拟相关性分析
correlations = {"correlations": "sample_correlations"}
messages.append({"role": "tool", "name": tool_name, "content": json.dumps(correlations)})
elif tool_name == 'regression_analysis':
# 模拟回归分析
regression_results = {"regression_results": "sample_regression_results"}
messages.append({"role": "tool", "name": tool_name, "content": json.dumps(regression_results)})
elif tool_name == 'create_bar_chart':
# 模拟条形图创建
bar_chart = {"bar_chart": "sample_bar_chart"}
messages.append({"role": "tool", "name": tool_name, "content": json.dumps(bar_chart)})
elif tool_name == 'create_line_chart':
# 模拟折线图创建
line_chart = {"line_chart": "sample_line_chart"}
messages.append({"role": "tool", "name": tool_name, "content": json.dumps(line_chart)})
plot_line_chart(tool_arguments['data'])
elif tool_name == 'create_pie_chart':
# 模拟饼图创建
pie_chart = {"pie_chart": "sample_pie_chart"}
messages.append({"role": "tool", "name": tool_name, "content": json.dumps(pie_chart)})
return messages
接下来,我们将为每个子代理创建工具处理程序。
这些传递给模型的提示和工具集是唯一的。
然后将输出传递给执行工具调用的执行函数。
我们还将消息附加到对话历史记录中。
# 定义处理每个代理的处理函数
def handle_data_processing_agent(query, conversation_messages):
messages = [{"role": "system", "content": processing_system_prompt}]
messages.append({"role": "user", "content": query})
response = client.chat.completions.create(
model=MODEL,
messages=messages,
temperature=0,
tools=preprocess_tools,
)
conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])
execute_tool(response.choices[0].message.tool_calls, conversation_messages)
def handle_analysis_agent(query, conversation_messages):
messages = [{"role": "system", "content": analysis_system_prompt}]
messages.append({"role": "user", "content": query})
response = client.chat.completions.create(
model=MODEL,
messages=messages,
temperature=0,
tools=analysis_tools,
)
conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])
execute_tool(response.choices[0].message.tool_calls, conversation_messages)
def handle_visualization_agent(query, conversation_messages):
messages = [{"role": "system", "content": visualization_system_prompt}]
messages.append({"role": "user", "content": query})
response = client.chat.completions.create(
model=MODEL,
messages=messages,
temperature=0,
tools=visualization_tools,
)
conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])
execute_tool(response.choices[0].message.tool_calls, conversation_messages)
最后,我们创建处理用户查询的总体工具。
此函数获取用户查询,从模型获取响应,并处理将其传递给其他代理以执行。此外,我们将维护正在进行的对话状态。
# 处理用户输入和分类的函数
def handle_user_message(user_query, conversation_messages=[]):
user_message = {"role": "user", "content": user_query}
conversation_messages.append(user_message)
messages = [{"role": "system", "content": triaging_system_prompt}]
messages.extend(conversation_messages)
response = client.chat.completions.create(
model=MODEL,
messages=messages,
temperature=0,
tools=triage_tools,
)
conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])
for tool_call in response.choices[0].message.tool_calls:
if tool_call.function.name == 'send_query_to_agents':
agents = json.loads(tool_call.function.arguments)['agents']
query = json.loads(tool_call.function.arguments)['query']
for agent in agents:
if agent == "Data Processing Agent":
handle_data_processing_agent(query, conversation_messages)
elif agent == "Analysis Agent":
handle_analysis_agent(query, conversation_messages)
elif agent == "Visualization Agent":
handle_visualization_agent(query, conversation_messages)
return conversation_messages
多代理系统执行
最后,我们在用户查询上运行总体的 handle_user_message
函数,并查看输出。
handle_user_message(user_query)
Cleaned data: house_size (m3) house_price ($)
0 90 100
1 80 90
2 100 120
Statistical Analysis: house_size house_price
count 4.000000 4.000000
mean 90.000000 102.500000
std 8.164966 12.583057
min 80.000000 90.000000
25% 87.500000 97.500000
50% 90.000000 100.000000
75% 92.500000 105.000000
max 100.000000 120.000000
[{'role': 'user',
'content': '\nBelow is some data. I want you to first remove the duplicates then analyze the statistics of the data as well as plot a line chart.\n\nhouse_size (m3), house_price ($)\n90, 100\n80, 90\n100, 120\n90, 100\n'},
[Function(arguments='{"agents": ["Data Processing Agent"], "query": "Remove duplicates from the data: house_size (m3), house_price ($)\\n90, 100\\n80, 90\\n100, 120\\n90, 100"}', name='send_query_to_agents'),
Function(arguments='{"agents": ["Analysis Agent"], "query": "Analyze the statistics of the data: house_size (m3), house_price ($)\\n90, 100\\n80, 90\\n100, 120\\n90, 100"}', name='send_query_to_agents'),
Function(arguments='{"agents": ["Visualization Agent"], "query": "Plot a line chart for the data: house_size (m3), house_price ($)\\n90, 100\\n80, 90\\n100, 120\\n90, 100"}', name='send_query_to_agents')],
[Function(arguments='{"data":"house_size (m3), house_price ($)\\n90, 100\\n80, 90\\n100, 120\\n90, 100"}', name='clean_data')],
{'role': 'tool',
'name': 'clean_data',
'content': '{"cleaned_data": {"house_size (m3)": {"0": 90, "1": 80, "2": 100}, " house_price ($)": {"0": 100, "1": 90, "2": 120}}}'},
[Function(arguments='{"data":"house_size,house_price\\n90,100\\n80,90\\n100,120\\n90,100"}', name='stat_analysis')],
{'role': 'tool',
'name': 'stat_analysis',
'content': '{"stats": {"house_size": {"count": 4.0, "mean": 90.0, "std": 8.16496580927726, "min": 80.0, "25%": 87.5, "50%": 90.0, "75%": 92.5, "max": 100.0}, "house_price": {"count": 4.0, "mean": 102.5, "std": 12.583057392117917, "min": 90.0, "25%": 97.5, "50%": 100.0, "75%": 105.0, "max": 120.0}}}'},
[Function(arguments='{"data":"house_size,house_price\\n90,100\\n80,90\\n100,120\\n90,100","x":"house_size","y":"house_price"}', name='create_line_chart')],
{'role': 'tool',
'name': 'create_line_chart',
'content': '{"line_chart": "sample_line_chart"}'}]
结论
在本指南中,我们探讨了如何利用结构化输出来构建更健壮的多代理系统。
使用此新功能可以确保工具调用遵循指定的模式,并避免在您这边处理边缘情况或验证参数。
这可以应用于更多用例,我们希望您能从中获得启发,构建自己的用例!