结构化输出用于多代理系统

在本指南中,我们将探讨如何使用结构化输出来构建多代理系统。

结构化输出是一项新功能,它建立在 JSON 模式和函数调用之上,用于强制模型输出遵循严格的模式。

通过使用新的 strict: true 参数,我们可以保证响应符合提供的模式。

为了演示此功能,我们将使用它来构建一个多代理系统。

为什么要构建多代理系统?

在使用函数调用时,如果函数(或工具)的数量增加,性能可能会受到影响。

为了缓解这种情况,我们可以将工具进行逻辑分组,并让专门的“代理”来解决特定的任务或子任务,这将提高整体系统性能。

环境设置

from openai import OpenAI
from IPython.display import Image
import json
import pandas as pd
import matplotlib.pyplot as plt
from io import StringIO
import numpy as np
client = OpenAI()
MODEL = "gpt-4o-2024-08-06"

代理设置

我们将要处理的用例是数据分析任务。

让我们先设置我们的 4 个代理系统:

  1. 分类代理: 决定调用哪个代理。
  2. 数据预处理代理: 准备数据以供分析,例如通过清理数据。
  3. 数据分析代理: 对数据执行分析。
  4. 数据可视化代理: 可视化分析结果以提取见解。

我们将首先为每个代理定义系统提示。

triaging_system_prompt = """你是一个分类代理。你的角色是评估用户的查询并将其路由到相关代理。可用的代理是:

- 数据处理代理:清理、转换和聚合数据。
- 分析代理:执行统计、相关性和回归分析。
- 可视化代理:创建条形图、折线图和饼图。

使用 send_query_to_agents 工具将用户的查询转发给相关代理。此外,如果需要,请使用 speak_to_user 工具从用户那里获取更多信息。"""

processing_system_prompt = """你是数据处理代理。你的角色是使用以下工具清理、转换和聚合数据:

- clean_data
- transform_data
- aggregate_data"""

analysis_system_prompt = """你是分析代理。你的角色是使用以下工具执行统计、相关性和回归分析:

- stat_analysis
- correlation_analysis
- regression_analysis"""

visualization_system_prompt = """你是可视化代理。你的角色是使用以下工具创建条形图、折线图和饼图:

- create_bar_chart
- create_line_chart
- create_pie_chart"""

然后,我们将为每个代理定义工具。

除了分类代理,每个代理都将配备其角色特定的工具:

数据预处理代理

  1. 清理数据
  2. 转换数据
  3. 聚合数据

数据分析代理

  1. 统计分析
  2. 相关性分析
  3. 回归分析

数据可视化代理

  1. 创建条形图
  2. 创建折线图
  3. 创建饼图
triage_tools = [
    {
        "type": "function",
        "function": {
            "name": "send_query_to_agents",
            "description": "根据代理的功能将用户查询发送给相关代理。",
            "parameters": {
                "type": "object",
                "properties": {
                    "agents": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "要发送查询的代理名称数组。"
                    },
                    "query": {
                        "type": "string",
                        "description": "要发送的用户查询。"
                    }
                },
                "required": ["agents", "query"]
            }
        },
        "strict": True
    }
]

preprocess_tools = [
    {
        "type": "function",
        "function": {
            "name": "clean_data",
            "description": "通过删除重复项和处理缺失值来清理提供的数据。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "要清理的数据集。应采用合适的格式,如 JSON 或 CSV。"
                    }
                },
                "required": ["data"],
                "additionalProperties": False
            }
        },
        "strict": True
    },
    {
        "type": "function",
        "function": {
            "name": "transform_data",
            "description": "根据指定的规则转换数据。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "要转换的数据。应采用合适的格式,如 JSON 或 CSV。"
                    },
                    "rules": {
                        "type": "string",
                        "description": "要应用的转换规则,以结构化格式指定。"
                    }
                },
                "required": ["data", "rules"],
                "additionalProperties": False
            }
        },
        "strict": True

    },
    {
        "type": "function",
        "function": {
            "name": "aggregate_data",
            "description": "按指定的列和操作聚合数据。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "要聚合的数据。应采用合适的格式,如 JSON 或 CSV。"
                    },
                    "group_by": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "要分组的列。"
                    },
                    "operations": {
                        "type": "string",
                        "description": "要执行的聚合操作,以结构化格式指定。"
                    }
                },
                "required": ["data", "group_by", "operations"],
                "additionalProperties": False
            }
        },
        "strict": True
    }
]


analysis_tools = [
    {
        "type": "function",
        "function": {
            "name": "stat_analysis",
            "description": "对给定数据集执行统计分析。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "要分析的数据集。应采用合适的格式,如 JSON 或 CSV。"
                    }
                },
                "required": ["data"],
                "additionalProperties": False
            }
        },
        "strict": True
    },
    {
        "type": "function",
        "function": {
            "name": "correlation_analysis",
            "description": "计算数据集中变量之间的相关系数。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "要分析的数据集。应采用合适的格式,如 JSON 或 CSV。"
                    },
                    "variables": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "要计算相关性的变量列表。"
                    }
                },
                "required": ["data", "variables"],
                "additionalProperties": False
            }
        },
        "strict": True
    },
    {
        "type": "function",
        "function": {
            "name": "regression_analysis",
            "description": "对数据集执行回归分析。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "要分析的数据集。应采用合适的格式,如 JSON 或 CSV。"
                    },
                    "dependent_var": {
                        "type": "string",
                        "description": "回归的因变量。"
                    },
                    "independent_vars": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "自变量列表。"
                    }
                },
                "required": ["data", "dependent_var", "independent_vars"],
                "additionalProperties": False
            }
        },
        "strict": True
    }
]

visualization_tools = [
    {
        "type": "function",
        "function": {
            "name": "create_bar_chart",
            "description": "根据提供的数据创建条形图。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "条形图的数据。应采用合适的格式,如 JSON 或 CSV。"
                    },
                    "x": {
                        "type": "string",
                        "description": "x 轴的列。"
                    },
                    "y": {
                        "type": "string",
                        "description": "y 轴的列。"
                    }
                },
                "required": ["data", "x", "y"],
                "additionalProperties": False
            }
        },
        "strict": True
    },
    {
        "type": "function",
        "function": {
            "name": "create_line_chart",
            "description": "根据提供的数据创建折线图。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "折线图的数据。应采用合适的格式,如 JSON 或 CSV。"
                    },
                    "x": {
                        "type": "string",
                        "description": "x 轴的列。"
                    },
                    "y": {
                        "type": "string",
                        "description": "y 轴的列。"
                    }
                },
                "required": ["data", "x", "y"],
                "additionalProperties": False
            }
        },
        "strict": True
    },
    {
        "type": "function",
        "function": {
            "name": "create_pie_chart",
            "description": "根据提供的数据创建饼图。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "饼图的数据。应采用合适的格式,如 JSON 或 CSV。"
                    },
                    "labels": {
                        "type": "string",
                        "description": "标签的列。"
                    },
                    "values": {
                        "type": "string",
                        "description": "值的列。"
                    }
                },
                "required": ["data", "labels", "values"],
                "additionalProperties": False
            }
        },
        "strict": True
    }
]

工具执行

我们需要编写代码逻辑来:

  • 处理将用户查询传递给多代理系统
  • 处理多代理系统的内部工作
  • 执行工具调用

为了简洁起见,我们将仅定义与用户查询相关的工具的逻辑。

# 示例查询

user_query = """
以下是一些数据。我希望你先删除重复项,然后分析数据统计信息并绘制折线图。

house_size (m3), house_price ($)
90, 100
80, 90
100, 120
90, 100
"""

从用户查询中,我们可以推断出我们需要调用的工具是 clean_datastart_analysisuse_line_chart

我们将首先定义执行工具的函数。

这会将工具调用映射到相应的函数。然后,它会将函数输出附加到对话历史记录中。

def clean_data(data):
    data_io = StringIO(data)
    df = pd.read_csv(data_io, sep=",")
    df_deduplicated = df.drop_duplicates()
    return df_deduplicated

def stat_analysis(data):
    data_io = StringIO(data)
    df = pd.read_csv(data_io, sep=",")
    return df.describe()

def plot_line_chart(data):
    data_io = StringIO(data)
    df = pd.read_csv(data_io, sep=",")

    x = df.iloc[:, 0]
    y = df.iloc[:, 1]

    coefficients = np.polyfit(x, y, 1)
    polynomial = np.poly1d(coefficients)
    y_fit = polynomial(x)

    plt.figure(figsize=(10, 6))
    plt.plot(x, y, 'o', label='Data Points')
    plt.plot(x, y_fit, '-', label='Best Fit Line')
    plt.title('Line Chart with Best Fit Line')
    plt.xlabel(df.columns[0])
    plt.ylabel(df.columns[1])
    plt.legend()
    plt.grid(True)
    plt.show()

# 定义执行工具的函数
def execute_tool(tool_calls, messages):
    for tool_call in tool_calls:
        tool_name = tool_call.function.name
        tool_arguments = json.loads(tool_call.function.arguments)

        if tool_name == 'clean_data':
            # 模拟数据清理
            cleaned_df = clean_data(tool_arguments['data'])
            cleaned_data = {"cleaned_data": cleaned_df.to_dict()}
            messages.append({"role": "tool", "name": tool_name, "content": json.dumps(cleaned_data)})
            print('Cleaned data: ', cleaned_df)
        elif tool_name == 'transform_data':
            # 模拟数据转换
            transformed_data = {"transformed_data": "sample_transformed_data"}
            messages.append({"role": "tool", "name": tool_name, "content": json.dumps(transformed_data)})
        elif tool_name == 'aggregate_data':
            # 模拟数据聚合
            aggregated_data = {"aggregated_data": "sample_aggregated_data"}
            messages.append({"role": "tool", "name": tool_name, "content": json.dumps(aggregated_data)})
        elif tool_name == 'stat_analysis':
            # 模拟统计分析
            stats_df = stat_analysis(tool_arguments['data'])
            stats = {"stats": stats_df.to_dict()}
            messages.append({"role": "tool", "name": tool_name, "content": json.dumps(stats)})
            print('Statistical Analysis: ', stats_df)
        elif tool_name == 'correlation_analysis':
            # 模拟相关性分析
            correlations = {"correlations": "sample_correlations"}
            messages.append({"role": "tool", "name": tool_name, "content": json.dumps(correlations)})
        elif tool_name == 'regression_analysis':
            # 模拟回归分析
            regression_results = {"regression_results": "sample_regression_results"}
            messages.append({"role": "tool", "name": tool_name, "content": json.dumps(regression_results)})
        elif tool_name == 'create_bar_chart':
            # 模拟条形图创建
            bar_chart = {"bar_chart": "sample_bar_chart"}
            messages.append({"role": "tool", "name": tool_name, "content": json.dumps(bar_chart)})
        elif tool_name == 'create_line_chart':
            # 模拟折线图创建
            line_chart = {"line_chart": "sample_line_chart"}
            messages.append({"role": "tool", "name": tool_name, "content": json.dumps(line_chart)})
            plot_line_chart(tool_arguments['data'])
        elif tool_name == 'create_pie_chart':
            # 模拟饼图创建
            pie_chart = {"pie_chart": "sample_pie_chart"}
            messages.append({"role": "tool", "name": tool_name, "content": json.dumps(pie_chart)})
    return messages

接下来,我们将为每个子代理创建工具处理程序。

这些传递给模型的提示和工具集是唯一的。

然后将输出传递给执行工具调用的执行函数。

我们还将消息附加到对话历史记录中。

# 定义处理每个代理的处理函数
def handle_data_processing_agent(query, conversation_messages):
    messages = [{"role": "system", "content": processing_system_prompt}]
    messages.append({"role": "user", "content": query})

    response = client.chat.completions.create(
        model=MODEL,
        messages=messages,
        temperature=0,
        tools=preprocess_tools,
    )

    conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])
    execute_tool(response.choices[0].message.tool_calls, conversation_messages)

def handle_analysis_agent(query, conversation_messages):
    messages = [{"role": "system", "content": analysis_system_prompt}]
    messages.append({"role": "user", "content": query})

    response = client.chat.completions.create(
        model=MODEL,
        messages=messages,
        temperature=0,
        tools=analysis_tools,
    )

    conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])
    execute_tool(response.choices[0].message.tool_calls, conversation_messages)

def handle_visualization_agent(query, conversation_messages):
    messages = [{"role": "system", "content": visualization_system_prompt}]
    messages.append({"role": "user", "content": query})

    response = client.chat.completions.create(
        model=MODEL,
        messages=messages,
        temperature=0,
        tools=visualization_tools,
    )

    conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])
    execute_tool(response.choices[0].message.tool_calls, conversation_messages)

最后,我们创建处理用户查询的总体工具。

此函数获取用户查询,从模型获取响应,并处理将其传递给其他代理以执行。此外,我们将维护正在进行的对话状态。

# 处理用户输入和分类的函数
def handle_user_message(user_query, conversation_messages=[]):
    user_message = {"role": "user", "content": user_query}
    conversation_messages.append(user_message)


    messages = [{"role": "system", "content": triaging_system_prompt}]
    messages.extend(conversation_messages)

    response = client.chat.completions.create(
        model=MODEL,
        messages=messages,
        temperature=0,
        tools=triage_tools,
    )

    conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])

    for tool_call in response.choices[0].message.tool_calls:
        if tool_call.function.name == 'send_query_to_agents':
            agents = json.loads(tool_call.function.arguments)['agents']
            query = json.loads(tool_call.function.arguments)['query']
            for agent in agents:
                if agent == "Data Processing Agent":
                    handle_data_processing_agent(query, conversation_messages)
                elif agent == "Analysis Agent":
                    handle_analysis_agent(query, conversation_messages)
                elif agent == "Visualization Agent":
                    handle_visualization_agent(query, conversation_messages)

    return conversation_messages

多代理系统执行

最后,我们在用户查询上运行总体的 handle_user_message 函数,并查看输出。

handle_user_message(user_query)
Cleaned data:     house_size (m3)   house_price ($)
0               90               100
1               80                90
2              100               120
Statistical Analysis:         house_size  house_price
count    4.000000     4.000000
mean    90.000000   102.500000
std      8.164966    12.583057
min     80.000000    90.000000
25%     87.500000    97.500000
50%     90.000000   100.000000
75%     92.500000   105.000000
max    100.000000   120.000000

png

[{'role': 'user',
  'content': '\nBelow is some data. I want you to first remove the duplicates then analyze the statistics of the data as well as plot a line chart.\n\nhouse_size (m3), house_price ($)\n90, 100\n80, 90\n100, 120\n90, 100\n'},
 [Function(arguments='{"agents": ["Data Processing Agent"], "query": "Remove duplicates from the data: house_size (m3), house_price ($)\\n90, 100\\n80, 90\\n100, 120\\n90, 100"}', name='send_query_to_agents'),
  Function(arguments='{"agents": ["Analysis Agent"], "query": "Analyze the statistics of the data: house_size (m3), house_price ($)\\n90, 100\\n80, 90\\n100, 120\\n90, 100"}', name='send_query_to_agents'),
  Function(arguments='{"agents": ["Visualization Agent"], "query": "Plot a line chart for the data: house_size (m3), house_price ($)\\n90, 100\\n80, 90\\n100, 120\\n90, 100"}', name='send_query_to_agents')],
 [Function(arguments='{"data":"house_size (m3), house_price ($)\\n90, 100\\n80, 90\\n100, 120\\n90, 100"}', name='clean_data')],
 {'role': 'tool',
  'name': 'clean_data',
  'content': '{"cleaned_data": {"house_size (m3)": {"0": 90, "1": 80, "2": 100}, " house_price ($)": {"0": 100, "1": 90, "2": 120}}}'},
 [Function(arguments='{"data":"house_size,house_price\\n90,100\\n80,90\\n100,120\\n90,100"}', name='stat_analysis')],
 {'role': 'tool',
  'name': 'stat_analysis',
  'content': '{"stats": {"house_size": {"count": 4.0, "mean": 90.0, "std": 8.16496580927726, "min": 80.0, "25%": 87.5, "50%": 90.0, "75%": 92.5, "max": 100.0}, "house_price": {"count": 4.0, "mean": 102.5, "std": 12.583057392117917, "min": 90.0, "25%": 97.5, "50%": 100.0, "75%": 105.0, "max": 120.0}}}'},
 [Function(arguments='{"data":"house_size,house_price\\n90,100\\n80,90\\n100,120\\n90,100","x":"house_size","y":"house_price"}', name='create_line_chart')],
 {'role': 'tool',
  'name': 'create_line_chart',
  'content': '{"line_chart": "sample_line_chart"}'}]

结论

在本指南中,我们探讨了如何利用结构化输出来构建更健壮的多代理系统。

使用此新功能可以确保工具调用遵循指定的模式,并避免在您这边处理边缘情况或验证参数。

这可以应用于更多用例,我们希望您能从中获得启发,构建自己的用例!