多类别交易分类
在本笔记本中,我们将对公共交易数据集进行分类,将其归入我们预先定义的五个类别中。这些方法应可复制到任何多类别分类用例,特别是当我们试图将交易数据拟合到预定义类别时。通过运行此笔记本,您将掌握处理标记和未标记数据集的几种方法。
在本笔记本中,我们将采用以下几种方法:
- 零样本分类:首先,我们将进行零样本分类,仅使用提示作为指导,将交易放入五个预定义的类别中。
- 基于嵌入的分类:在此之后,我们将在标记数据集中创建嵌入,然后使用传统的分类模型来测试其识别我们类别的有效性。
- 微调分类:最后,我们将训练一个在标记数据集上进行训练的微调模型,以了解其与零样本和少样本分类方法相比的表现。
设置
%load_ext autoreload
%autoreload
%pip install openai 'openai[datalib]' 'openai[embeddings]' transformers scikit-learn matplotlib plotly pandas scipy
import openai
import pandas as pd
import numpy as np
import json
import os
COMPLETIONS_MODEL = "gpt-4"
os.environ["OPENAI_API_KEY"] = "<your-api-key>"
client = openai.OpenAI()
加载数据集
我们使用的是苏格兰国家图书馆(Library of Scotland)的公开数据集,其中包含超过 25,000 英镑的交易记录。该数据集有三个我们将使用的特征:
- 供应商(Supplier):供应商的名称
- 描述(Description):交易的文本描述
- 价值(Value):交易的价值(英镑)
来源:
https://data.nls.uk/data/organisational-data/transactions-over-25k/
transactions = pd.read_csv('./data/25000_spend_dataset_current.csv', encoding= 'unicode_escape')
print(f"交易数量: {len(transactions)}")
print(transactions.head())
交易数量: 359
Date Supplier Description \
0 21/04/2016 M & J Ballantyne Ltd George IV Bridge Work
1 26/04/2016 Private Sale Literary & Archival Items
2 30/04/2016 City Of Edinburgh Council Non Domestic Rates
3 09/05/2016 Computacenter Uk Kelvin Hall
4 09/05/2016 John Graham Construction Ltd Causewayside Refurbishment
Transaction value (£)
0 35098.0
1 30000.0
2 40800.0
3 72835.0
4 64361.0
零样本分类
我们将首先评估基础模型在使用简单提示对这些交易进行分类的性能。我们将为模型提供 5 个类别和一个“无法分类”的“捕获所有”类别,用于无法归类的交易。
zero_shot_prompt = '''你是苏格兰国家图书馆的数据专家。
你正在分析所有价值超过 25,000 英镑的交易,并将它们归入五个类别之一。
这五个类别是:建筑改进、文学与档案、水电费账单、专业服务和软件/IT。
如果你无法判断它是什么,就说“无法分类”。
交易:
供应商:{}
描述:{}
价值:{}
分类结果是:'''
def format_prompt(transaction):
return zero_shot_prompt.format(transaction['Supplier'], transaction['Description'], transaction['Transaction value (£)'])
def classify_transaction(transaction):
prompt = format_prompt(transaction)
messages = [
{"role": "system", "content": prompt},
]
completion_response = openai.chat.completions.create(
messages=messages,
temperature=0,
max_tokens=5,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
model=COMPLETIONS_MODEL)
label = completion_response.choices[0].message.content.replace('\n','')
return label
# 获取一个测试交易
transaction = transactions.iloc[0]
# 使用我们的完成函数返回预测结果
print(f"交易: {transaction['Supplier']} {transaction['Description']} {transaction['Transaction value (£)']}")
print(f"分类: {classify_transaction(transaction)}")
交易: M & J Ballantyne Ltd George IV Bridge Work 35098.0
分类: 建筑改进
我们的第一次尝试是正确的,M & J Ballantyne Ltd 是一家房屋建筑商,他们进行的工作确实是建筑改进。
让我们将样本量扩大到 25,看看它的表现如何,同样只使用一个简单的提示来指导它。
test_transactions = transactions.iloc[:25]
test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x),axis=1)
/var/folders/3n/79rgh27s6l7_l91b9shw0_nr0000gp/T/ipykernel_81921/2775604370.py:2: SettingWithCopyWarning:
A value is being set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x),axis=1)
test_transactions['Classification'].value_counts()
Classification
Building Improvement 17
Literature & Archive 3
Software/IT 2
Could not classify 2
Utility Bills 1
Name: count, dtype: int64
test_transactions.head(25)
Date | Supplier | Description | Transaction value (£) | Classification | |
---|---|---|---|---|---|
0 | 21/04/2016 | M & J Ballantyne Ltd | George IV Bridge Work | 35098.0 | Building Improvement |
1 | 26/04/2016 | Private Sale | Literary & Archival Items | 30000.0 | Literature & Archive |
2 | 30/04/2016 | City Of Edinburgh Council | Non Domestic Rates | 40800.0 | Utility Bills |
3 | 09/05/2016 | Computacenter Uk | Kelvin Hall | 72835.0 | Software/IT |
4 | 09/05/2016 | John Graham Construction Ltd | Causewayside Refurbishment | 64361.0 | Building Improvement |
5 | 09/05/2016 | A McGillivray | Causewayside Refurbishment | 53690.0 | Building Improvement |
6 | 16/05/2016 | John Graham Construction Ltd | Causewayside Refurbishment | 365344.0 | Building Improvement |
7 | 23/05/2016 | Computacenter Uk | Kelvin Hall | 26506.0 | Software/IT |
8 | 23/05/2016 | ECG Facilities Service | Facilities Management Charge | 32777.0 | Building Improvement |
9 | 23/05/2016 | ECG Facilities Service | Facilities Management Charge | 32777.0 | Building Improvement |
10 | 30/05/2016 | ALDL | ALDL Charges | 32317.0 | Could not classify |
11 | 10/06/2016 | Wavetek Ltd | Kelvin Hall | 87589.0 | Building Improvement |
12 | 10/06/2016 | John Graham Construction Ltd | Causewayside Refurbishment | 381803.0 | Building Improvement |
13 | 28/06/2016 | ECG Facilities Service | Facilities Management Charge | 32832.0 | Building Improvement |
14 | 30/06/2016 | Glasgow City Council | Kelvin Hall | 1700000.0 | Building Improvement |
15 | 11/07/2016 | Wavetek Ltd | Kelvin Hall | 65692.0 | Building Improvement |
16 | 11/07/2016 | John Graham Construction Ltd | Causewayside Refurbishment | 139845.0 | Building Improvement |
17 | 15/07/2016 | Sotheby'S | Literary & Archival Items | 28500.0 | Literature & Archive |
18 | 18/07/2016 | Christies | Literary & Archival Items | 33800.0 | Literature & Archive |
19 | 25/07/2016 | A McGillivray | Causewayside Refurbishment | 30113.0 | Building Improvement |
20 | 31/07/2016 | ALDL | ALDL Charges | 32317.0 | Could not classify |
21 | 08/08/2016 | ECG Facilities Service | Facilities Management Charge | 32795.0 | Building Improvement |
22 | 15/08/2016 | Creative Video Productions Ltd | Kelvin Hall | 26866.0 | Building Improvement |
23 | 15/08/2016 | John Graham Construction Ltd | Causewayside Refurbishment | 196807.0 | Building Improvement |
24 | 24/08/2016 | ECG Facilities Service | Facilities Management Charge | 32795.0 | Building Improvement |
初步结果相当不错,即使没有标记示例!那些无法分类的交易是更难处理的案例,几乎没有线索可以判断其主题,但也许如果我们清理标记的数据集以提供更多示例,我们可以获得更好的性能。
基于嵌入的分类
让我们从到目前为止已分类的小样本中创建嵌入——我们通过对数据集中 101 笔交易运行零样本分类器并手动更正我们得到的 15 个“无法分类”结果来创建一组标记示例。
创建嵌入
这一初始部分重用了“Get_embeddings_from_dataset Notebook”中的方法,以从连接所有特征的组合字段创建嵌入。
df = pd.read_csv('./data/labelled_transactions.csv')
df.head()
Date | Supplier | Description | Transaction value (£) | Classification | |
---|---|---|---|---|---|
0 | 15/08/2016 | Creative Video Productions Ltd | Kelvin Hall | 26866 | Other |
1 | 29/05/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 74806 | Building Improvement |
2 | 29/05/2017 | Morris & Spottiswood Ltd | George IV Bridge Work | 56448 | Building Improvement |
3 | 31/05/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 164691 | Building Improvement |
4 | 24/07/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 27926 | Building Improvement |
df['combined'] = "Supplier: " + df['Supplier'].str.strip() + "; Description: " + df['Description'].str.strip() + "; Value: " + str(df['Transaction value (£)']).strip()
df.head(2)
Date | Supplier | Description | Transaction value (£) | Classification | combined | |
---|---|---|---|---|---|---|
0 | 15/08/2016 | Creative Video Productions Ltd | Kelvin Hall | 26866 | Other | Supplier: Creative Video Productions Ltd; Desc... |
1 | 29/05/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 74806 | Building Improvement | Supplier: John Graham Construction Ltd; Descri... |
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
df['n_tokens'] = df.combined.apply(lambda x: len(tokenizer.encode(x)))
len(df)
101
embedding_path = './data/transactions_with_embeddings_100.csv'
from utils.embeddings_utils import get_embedding
df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x))
df['babbage_search'] = df.combined.apply(lambda x: get_embedding(x))
df.to_csv(embedding_path)
使用嵌入进行分类
现在我们有了嵌入,让我们看看将它们分类到我们命名的类别中是否能带来更多成功。
为此,我们将使用“Classification_using_embeddings”笔记本中的模板。
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from ast import literal_eval
fs_df = pd.read_csv(embedding_path)
fs_df["babbage_similarity"] = fs_df.babbage_similarity.apply(literal_eval).apply(np.array)
fs_df.head()
Unnamed: 0 | Date | Supplier | Description | Transaction value (£) | Classification | combined | n_tokens | babbage_similarity | babbage_search | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 15/08/2016 | Creative Video Productions Ltd | Kelvin Hall | 26866 | Other | Supplier: Creative Video Productions Ltd; Desc... | 136 | [-0.02898375503718853, -0.02881557121872902, 0... | [-0.02879939414560795, -0.02867320366203785, 0... |
1 | 1 | 29/05/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 74806 | Building Improvement | Supplier: John Graham Construction Ltd; Descri... | 140 | [-0.024112487211823463, -0.02881261520087719, ... | [-0.024112487211823463, -0.02881261520087719, ... |
2 | 2 | 29/05/2017 | Morris & Spottiswood Ltd | George IV Bridge Work | 56448 | Building Improvement | Supplier: Morris & Spottiswood Ltd; Descriptio... | 141 | [0.013581369072198868, -0.003978211898356676, ... | [0.013593776151537895, -0.0037341134157031775,... |
3 | 3 | 31/05/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 164691 | Building Improvement | Supplier: John Graham Construction Ltd; Descri... | 140 | [-0.024112487211823463, -0.02881261520087719, ... | [-0.024112487211823463, -0.02881261520087719, ... |
4 | 4 | 24/07/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 27926 | Building Improvement | Supplier: John Graham Construction Ltd; Descri... | 140 | [-0.02408558875322342, -0.02881370671093464, 0... | [-0.024109570309519768, -0.02880912832915783, ... |
X_train, X_test, y_train, y_test = train_test_split(
list(fs_df.babbage_similarity.values), fs_df.Classification, test_size=0.2, random_state=42
)
clf = RandomForestClassifier(n_estimators=100)
clf.fit(X_train, y_train)
preds = clf.predict(X_test)
probas = clf.predict_proba(X_test)
report = classification_report(y_test, preds)
print(report)
precision recall f1-score support
Building Improvement 0.92 1.00 0.96 11
Literature & Archive 1.00 1.00 1.00 3
Other 0.00 0.00 0.00 1
Software/IT 1.00 1.00 1.00 1
Utility Bills 1.00 1.00 1.00 5
accuracy 0.95 21
macro avg 0.78 0.80 0.79 21
weighted avg 0.91 0.95 0.93 21
/Users/vishnu/code/openai-cookbook/cookbook_env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vishnu/code/openai-cookbook/cookbook_env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vishnu/code/openai-cookbook/cookbook_env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
该模型在此方面的性能相当强劲,因此创建嵌入并使用更简单的分类器也是一种有效的方法,而零样本分类器则有助于我们对未标记数据集进行初始分类。
让我们更进一步,看看在同一标记数据集上训练的微调模型是否能给我们带来可比的结果。
微调交易分类
对于这个用例,我们将尝试通过在相同的 101 个标记交易数据集上训练一个微调模型,并将此微调模型应用于一组未见过的交易来改进上面的少样本分类。
构建微调分类器
我们首先需要进行一些数据准备工作,以使我们的数据准备就绪。这将包括以下步骤:
- 为了准备我们的训练和验证集,我们将创建一组消息序列。每个序列的第一条消息是格式化了交易详细信息的用户提示,最后一条消息是模型预期的分类响应。
- 我们的测试集将包含每笔交易的初始用户提示以及相应的预期类别标签。然后,我们将使用微调模型为每笔交易生成实际分类。
ft_prep_df = fs_df.copy()
len(ft_prep_df)
101
ft_prep_df.head()
Unnamed: 0 | Date | Supplier | Description | Transaction value (£) | Classification | combined | n_tokens | babbage_similarity | babbage_search | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 15/08/2016 | Creative Video Productions Ltd | Kelvin Hall | 26866 | Other | Supplier: Creative Video Productions Ltd; Desc... | 136 | [-0.028885245323181152, -0.028660893440246582,... | [-0.02879939414560795, -0.02867320366203785, 0... |
1 | 1 | 29/05/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 74806 | Building Improvement | Supplier: John Graham Construction Ltd; Descri... | 140 | [-0.024112487211823463, -0.02881261520087719, ... | [-0.02414606139063835, -0.02883070334792137, 0... |
2 | 2 | 29/05/2017 | Morris & Spottiswood Ltd | George IV Bridge Work | 56448 | Building Improvement | Supplier: Morris & Spottiswood Ltd; Descriptio... | 141 | [0.013593776151537895, -0.0037341134157031775,... | [0.013561442494392395, -0.004199974238872528, ... |
3 | 3 | 31/05/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 164691 | Building Improvement | Supplier: John Graham Construction Ltd; Descri... | 140 | [-0.024112487211823463, -0.02881261520087719, ... | [-0.024112487211823463, -0.02881261520087719, ... |
4 | 4 | 24/07/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 27926 | Building Improvement | Supplier: John Graham Construction Ltd; Descri... | 140 | [-0.024112487211823463, -0.02881261520087719, ... | [-0.024112487211823463, -0.02881261520087719, ... |
classes = list(set(ft_prep_df['Classification']))
class_df = pd.DataFrame(classes).reset_index()
class_df.columns = ['class_id','class']
class_df , len(class_df)
( class_id class 0 0 Other 1 1 Literature & Archive 2 2 Software/IT 3 3 Utility Bills 4 4 Building Improvement, 5)
ft_df_with_class = ft_prep_df.merge(class_df,left_on='Classification',right_on='class',how='inner')
# Creating a list of messages for the fine-tuning job. The user message is the prompt, and the assistant message is the response from the model
ft_df_with_class['messages'] = ft_df_with_class.apply(lambda x: [{"role": "user", "content": format_prompt(x)}, {"role": "assistant", "content": x['class']}],axis=1)
ft_df_with_class[['messages', 'class']].head()
messages | class | |
---|---|---|
0 | [{'role': 'user', 'content': 'You are a data e... | Other |
1 | [{'role': 'user', 'content': 'You are a data e... | Building Improvement |
2 | [{'role': 'user', 'content': 'You are a data e... | Building Improvement |
3 | [{'role': 'user', 'content': 'You are a data e... | Building Improvement |
4 | [{'role': 'user', 'content': 'You are a data e... | Building Improvement |
# Create train/validation split
samples = ft_df_with_class["messages"].tolist()
train_df, valid_df = train_test_split(samples, test_size=0.2, random_state=42)
def write_to_jsonl(list_of_messages, filename):
with open(filename, "w+") as f:
for messages in list_of_messages:
object = {
"messages": messages
}
f.write(json.dumps(object) + "\n")
# Write the train/validation split to jsonl files
train_file_name, valid_file_name = "transactions_grouped_train.jsonl", "transactions_grouped_valid.jsonl"
write_to_jsonl(train_df, train_file_name)
write_to_jsonl(valid_df, valid_file_name)
# Upload the files to OpenAI
train_file = client.files.create(file=open(train_file_name, "rb"), purpose="fine-tune")
valid_file = client.files.create(file=open(valid_file_name, "rb"), purpose="fine-tune")
# Create the fine-tuning job
fine_tuning_job = client.fine_tuning.jobs.create(training_file=train_file.id, validation_file=valid_file.id, model="gpt-4o-2024-08-06")
# Get the fine-tuning job status and model name
status = client.fine_tuning.jobs.retrieve(fine_tuning_job.id)
# Once the fine-tuning job is complete, you can retrieve the model name from the job status
fine_tuned_model = client.fine_tuning.jobs.retrieve(fine_tuning_job.id).fine_tuned_model
print(f"Fine tuned model id: {fine_tuned_model}")
Fine tuned model id: ft:gpt-4o-2024-08-06:openai::BKr3Xy8U
应用微调分类器
现在我们将应用我们的分类器来查看其性能。我们的训练集中只有 31 个唯一观测值,验证集中有 8 个,所以让我们看看性能如何。
# 创建一个包含预期类别标签的测试集
test_set = pd.read_json(valid_file_name, lines=True)
test_set['expected_class'] = test_set.apply(lambda x: x['messages'][-1]['content'], axis=1)
test_set.head()
messages | expected_class | |
---|---|---|
0 | [{'role': 'user', 'content': 'You are a data e... | Utility Bills |
1 | [{'role': 'user', 'content': 'You are a data e... | Literature & Archive |
2 | [{'role': 'user', 'content': 'You are a data e... | Literature & Archive |
3 | [{'role': 'user', 'content': 'You are a data e... | Literature & Archive |
4 | [{'role': 'user', 'content': 'You are a data e... | Building Improvement |
# 将微调模型应用于测试集
test_set['response'] = test_set.apply(lambda x: openai.chat.completions.create(model=fine_tuned_model, messages=x['messages'][:-1], temperature=0),axis=1)
test_set['predicted_class'] = test_set.apply(lambda x: x['response'].choices[0].message.content, axis=1)
test_set.head()
messages | expected_class | response | predicted_class | |
---|---|---|---|---|
0 | [{'role': 'user', 'content': 'You are a data e... | Utility Bills | ChatCompletion(id='chatcmpl-BKrC0S1wQSfM9ZQfcC... | Utility Bills |
1 | [{'role': 'user', 'content': 'You are a data e... | Literature & Archive | ChatCompletion(id='chatcmpl-BKrC1BTr0DagbDkC2s... | Literature & Archive |
2 | [{'role': 'user', 'content': 'You are a data e... | Literature & Archive | ChatCompletion(id='chatcmpl-BKrC1H3ZeIW5cz2Owr... | Literature & Archive |
3 | [{'role': 'user', 'content': 'You are a data e... | Literature & Archive | ChatCompletion(id='chatcmpl-BKrC1wdhaMP0Q7YmYx... | Literature & Archive |
4 | [{'role': 'user', 'content': 'You are a data e... | Building Improvement | ChatCompletion(id='chatcmpl-BKrC20c5pkpngy1xDu... | Building Improvement |
# Calculate the accuracy of the predictions
from sklearn.metrics import f1_score
test_set['result'] = test_set.apply(lambda x: str(x['predicted_class']).strip() == str(x['expected_class']).strip(), axis = 1)
test_set['result'].value_counts()
print(test_set['result'].value_counts())
print("F1 Score: ", f1_score(test_set['expected_class'], test_set['predicted_class'], average="weighted"))
print("Raw Accuracy: ", test_set['result'].value_counts()[True] / len(test_set))
result
True 20
False 1
Name: count, dtype: int64
F1 Score: 0.9296066252587991
Raw Accuracy: 0.9523809523809523