多类别交易分类

在本笔记本中,我们将对公共交易数据集进行分类,将其归入我们预先定义的五个类别中。这些方法应可复制到任何多类别分类用例,特别是当我们试图将交易数据拟合到预定义类别时。通过运行此笔记本,您将掌握处理标记和未标记数据集的几种方法。

在本笔记本中,我们将采用以下几种方法:

  • 零样本分类:首先,我们将进行零样本分类,仅使用提示作为指导,将交易放入五个预定义的类别中。
  • 基于嵌入的分类:在此之后,我们将在标记数据集中创建嵌入,然后使用传统的分类模型来测试其识别我们类别的有效性。
  • 微调分类:最后,我们将训练一个在标记数据集上进行训练的微调模型,以了解其与零样本和少样本分类方法相比的表现。

设置

%load_ext autoreload
%autoreload
%pip install openai 'openai[datalib]' 'openai[embeddings]' transformers scikit-learn matplotlib plotly pandas scipy
import openai
import pandas as pd
import numpy as np
import json
import os

COMPLETIONS_MODEL = "gpt-4"
os.environ["OPENAI_API_KEY"] = "<your-api-key>"
client = openai.OpenAI()

加载数据集

我们使用的是苏格兰国家图书馆(Library of Scotland)的公开数据集,其中包含超过 25,000 英镑的交易记录。该数据集有三个我们将使用的特征:

  • 供应商(Supplier):供应商的名称
  • 描述(Description):交易的文本描述
  • 价值(Value):交易的价值(英镑)

来源

https://data.nls.uk/data/organisational-data/transactions-over-25k/

transactions = pd.read_csv('./data/25000_spend_dataset_current.csv', encoding= 'unicode_escape')
print(f"交易数量: {len(transactions)}")
print(transactions.head())
交易数量: 359
          Date                      Supplier                 Description  \
0  21/04/2016          M & J Ballantyne Ltd       George IV Bridge Work   
1  26/04/2016                  Private Sale   Literary & Archival Items   
2  30/04/2016     City Of Edinburgh Council         Non Domestic Rates    
3  09/05/2016              Computacenter Uk                 Kelvin Hall   
4  09/05/2016  John Graham Construction Ltd  Causewayside Refurbishment

   Transaction value (£)  
0                35098.0  
1                30000.0  
2                40800.0  
3                72835.0  
4                64361.0

零样本分类

我们将首先评估基础模型在使用简单提示对这些交易进行分类的性能。我们将为模型提供 5 个类别和一个“无法分类”的“捕获所有”类别,用于无法归类的交易。

zero_shot_prompt = '''你是苏格兰国家图书馆的数据专家。
你正在分析所有价值超过 25,000 英镑的交易,并将它们归入五个类别之一。
这五个类别是:建筑改进、文学与档案、水电费账单、专业服务和软件/IT。
如果你无法判断它是什么,就说“无法分类”。

交易:

供应商:{}
描述:{}
价值:{}

分类结果是:'''

def format_prompt(transaction):
    return zero_shot_prompt.format(transaction['Supplier'], transaction['Description'], transaction['Transaction value (£)'])

def classify_transaction(transaction):
    prompt = format_prompt(transaction)
    messages = [
        {"role": "system", "content": prompt},
    ]
    completion_response = openai.chat.completions.create(
                            messages=messages,
                            temperature=0,
                            max_tokens=5,
                            top_p=1,
                            frequency_penalty=0,
                            presence_penalty=0,
                            model=COMPLETIONS_MODEL)
    label = completion_response.choices[0].message.content.replace('\n','')
    return label
# 获取一个测试交易
transaction = transactions.iloc[0]
# 使用我们的完成函数返回预测结果
print(f"交易: {transaction['Supplier']} {transaction['Description']} {transaction['Transaction value (£)']}")
print(f"分类: {classify_transaction(transaction)}")
交易: M & J Ballantyne Ltd George IV Bridge Work 35098.0
分类: 建筑改进

我们的第一次尝试是正确的,M & J Ballantyne Ltd 是一家房屋建筑商,他们进行的工作确实是建筑改进。

让我们将样本量扩大到 25,看看它的表现如何,同样只使用一个简单的提示来指导它。

test_transactions = transactions.iloc[:25]
test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x),axis=1)
/var/folders/3n/79rgh27s6l7_l91b9shw0_nr0000gp/T/ipykernel_81921/2775604370.py:2: SettingWithCopyWarning: 
A value is being set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x),axis=1)
test_transactions['Classification'].value_counts()
Classification
Building Improvement    17
Literature & Archive     3
Software/IT              2
Could not classify       2
Utility Bills            1
Name: count, dtype: int64
test_transactions.head(25)
Date Supplier Description Transaction value (£) Classification
0 21/04/2016 M & J Ballantyne Ltd George IV Bridge Work 35098.0 Building Improvement
1 26/04/2016 Private Sale Literary & Archival Items 30000.0 Literature & Archive
2 30/04/2016 City Of Edinburgh Council Non Domestic Rates 40800.0 Utility Bills
3 09/05/2016 Computacenter Uk Kelvin Hall 72835.0 Software/IT
4 09/05/2016 John Graham Construction Ltd Causewayside Refurbishment 64361.0 Building Improvement
5 09/05/2016 A McGillivray Causewayside Refurbishment 53690.0 Building Improvement
6 16/05/2016 John Graham Construction Ltd Causewayside Refurbishment 365344.0 Building Improvement
7 23/05/2016 Computacenter Uk Kelvin Hall 26506.0 Software/IT
8 23/05/2016 ECG Facilities Service Facilities Management Charge 32777.0 Building Improvement
9 23/05/2016 ECG Facilities Service Facilities Management Charge 32777.0 Building Improvement
10 30/05/2016 ALDL ALDL Charges 32317.0 Could not classify
11 10/06/2016 Wavetek Ltd Kelvin Hall 87589.0 Building Improvement
12 10/06/2016 John Graham Construction Ltd Causewayside Refurbishment 381803.0 Building Improvement
13 28/06/2016 ECG Facilities Service Facilities Management Charge 32832.0 Building Improvement
14 30/06/2016 Glasgow City Council Kelvin Hall 1700000.0 Building Improvement
15 11/07/2016 Wavetek Ltd Kelvin Hall 65692.0 Building Improvement
16 11/07/2016 John Graham Construction Ltd Causewayside Refurbishment 139845.0 Building Improvement
17 15/07/2016 Sotheby'S Literary & Archival Items 28500.0 Literature & Archive
18 18/07/2016 Christies Literary & Archival Items 33800.0 Literature & Archive
19 25/07/2016 A McGillivray Causewayside Refurbishment 30113.0 Building Improvement
20 31/07/2016 ALDL ALDL Charges 32317.0 Could not classify
21 08/08/2016 ECG Facilities Service Facilities Management Charge 32795.0 Building Improvement
22 15/08/2016 Creative Video Productions Ltd Kelvin Hall 26866.0 Building Improvement
23 15/08/2016 John Graham Construction Ltd Causewayside Refurbishment 196807.0 Building Improvement
24 24/08/2016 ECG Facilities Service Facilities Management Charge 32795.0 Building Improvement

初步结果相当不错,即使没有标记示例!那些无法分类的交易是更难处理的案例,几乎没有线索可以判断其主题,但也许如果我们清理标记的数据集以提供更多示例,我们可以获得更好的性能。

基于嵌入的分类

让我们从到目前为止已分类的小样本中创建嵌入——我们通过对数据集中 101 笔交易运行零样本分类器并手动更正我们得到的 15 个“无法分类”结果来创建一组标记示例。

创建嵌入

这一初始部分重用了“Get_embeddings_from_dataset Notebook”中的方法,以从连接所有特征的组合字段创建嵌入。

df = pd.read_csv('./data/labelled_transactions.csv')
df.head()
Date Supplier Description Transaction value (£) Classification
0 15/08/2016 Creative Video Productions Ltd Kelvin Hall 26866 Other
1 29/05/2017 John Graham Construction Ltd Causewayside Refurbishment 74806 Building Improvement
2 29/05/2017 Morris & Spottiswood Ltd George IV Bridge Work 56448 Building Improvement
3 31/05/2017 John Graham Construction Ltd Causewayside Refurbishment 164691 Building Improvement
4 24/07/2017 John Graham Construction Ltd Causewayside Refurbishment 27926 Building Improvement
df['combined'] = "Supplier: " + df['Supplier'].str.strip() + "; Description: " + df['Description'].str.strip() + "; Value: " + str(df['Transaction value (£)']).strip()
df.head(2)
Date Supplier Description Transaction value (£) Classification combined
0 15/08/2016 Creative Video Productions Ltd Kelvin Hall 26866 Other Supplier: Creative Video Productions Ltd; Desc...
1 29/05/2017 John Graham Construction Ltd Causewayside Refurbishment 74806 Building Improvement Supplier: John Graham Construction Ltd; Descri...
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

df['n_tokens'] = df.combined.apply(lambda x: len(tokenizer.encode(x)))
len(df)
101
embedding_path = './data/transactions_with_embeddings_100.csv'
from utils.embeddings_utils import get_embedding
df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x))
df['babbage_search'] = df.combined.apply(lambda x: get_embedding(x))
df.to_csv(embedding_path)

使用嵌入进行分类

现在我们有了嵌入,让我们看看将它们分类到我们命名的类别中是否能带来更多成功。

为此,我们将使用“Classification_using_embeddings”笔记本中的模板。

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from ast import literal_eval

fs_df = pd.read_csv(embedding_path)
fs_df["babbage_similarity"] = fs_df.babbage_similarity.apply(literal_eval).apply(np.array)
fs_df.head()
Unnamed: 0 Date Supplier Description Transaction value (£) Classification combined n_tokens babbage_similarity babbage_search
0 0 15/08/2016 Creative Video Productions Ltd Kelvin Hall 26866 Other Supplier: Creative Video Productions Ltd; Desc... 136 [-0.02898375503718853, -0.02881557121872902, 0... [-0.02879939414560795, -0.02867320366203785, 0...
1 1 29/05/2017 John Graham Construction Ltd Causewayside Refurbishment 74806 Building Improvement Supplier: John Graham Construction Ltd; Descri... 140 [-0.024112487211823463, -0.02881261520087719, ... [-0.024112487211823463, -0.02881261520087719, ...
2 2 29/05/2017 Morris & Spottiswood Ltd George IV Bridge Work 56448 Building Improvement Supplier: Morris & Spottiswood Ltd; Descriptio... 141 [0.013581369072198868, -0.003978211898356676, ... [0.013593776151537895, -0.0037341134157031775,...
3 3 31/05/2017 John Graham Construction Ltd Causewayside Refurbishment 164691 Building Improvement Supplier: John Graham Construction Ltd; Descri... 140 [-0.024112487211823463, -0.02881261520087719, ... [-0.024112487211823463, -0.02881261520087719, ...
4 4 24/07/2017 John Graham Construction Ltd Causewayside Refurbishment 27926 Building Improvement Supplier: John Graham Construction Ltd; Descri... 140 [-0.02408558875322342, -0.02881370671093464, 0... [-0.024109570309519768, -0.02880912832915783, ...
X_train, X_test, y_train, y_test = train_test_split(
    list(fs_df.babbage_similarity.values), fs_df.Classification, test_size=0.2, random_state=42
)

clf = RandomForestClassifier(n_estimators=100)
clf.fit(X_train, y_train)
preds = clf.predict(X_test)
probas = clf.predict_proba(X_test)

report = classification_report(y_test, preds)
print(report)
                      precision    recall  f1-score   support

Building Improvement       0.92      1.00      0.96        11
Literature & Archive       1.00      1.00      1.00         3
               Other       0.00      0.00      0.00         1
         Software/IT       1.00      1.00      1.00         1
       Utility Bills       1.00      1.00      1.00         5

            accuracy                           0.95        21
           macro avg       0.78      0.80      0.79        21
        weighted avg       0.91      0.95      0.93        21


/Users/vishnu/code/openai-cookbook/cookbook_env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vishnu/code/openai-cookbook/cookbook_env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vishnu/code/openai-cookbook/cookbook_env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

该模型在此方面的性能相当强劲,因此创建嵌入并使用更简单的分类器也是一种有效的方法,而零样本分类器则有助于我们对未标记数据集进行初始分类。

让我们更进一步,看看在同一标记数据集上训练的微调模型是否能给我们带来可比的结果。

微调交易分类

对于这个用例,我们将尝试通过在相同的 101 个标记交易数据集上训练一个微调模型,并将此微调模型应用于一组未见过的交易来改进上面的少样本分类。

构建微调分类器

我们首先需要进行一些数据准备工作,以使我们的数据准备就绪。这将包括以下步骤:

  • 为了准备我们的训练和验证集,我们将创建一组消息序列。每个序列的第一条消息是格式化了交易详细信息的用户提示,最后一条消息是模型预期的分类响应。
  • 我们的测试集将包含每笔交易的初始用户提示以及相应的预期类别标签。然后,我们将使用微调模型为每笔交易生成实际分类。
ft_prep_df = fs_df.copy()
len(ft_prep_df)
101
ft_prep_df.head()
Unnamed: 0 Date Supplier Description Transaction value (£) Classification combined n_tokens babbage_similarity babbage_search
0 0 15/08/2016 Creative Video Productions Ltd Kelvin Hall 26866 Other Supplier: Creative Video Productions Ltd; Desc... 136 [-0.028885245323181152, -0.028660893440246582,... [-0.02879939414560795, -0.02867320366203785, 0...
1 1 29/05/2017 John Graham Construction Ltd Causewayside Refurbishment 74806 Building Improvement Supplier: John Graham Construction Ltd; Descri... 140 [-0.024112487211823463, -0.02881261520087719, ... [-0.02414606139063835, -0.02883070334792137, 0...
2 2 29/05/2017 Morris & Spottiswood Ltd George IV Bridge Work 56448 Building Improvement Supplier: Morris & Spottiswood Ltd; Descriptio... 141 [0.013593776151537895, -0.0037341134157031775,... [0.013561442494392395, -0.004199974238872528, ...
3 3 31/05/2017 John Graham Construction Ltd Causewayside Refurbishment 164691 Building Improvement Supplier: John Graham Construction Ltd; Descri... 140 [-0.024112487211823463, -0.02881261520087719, ... [-0.024112487211823463, -0.02881261520087719, ...
4 4 24/07/2017 John Graham Construction Ltd Causewayside Refurbishment 27926 Building Improvement Supplier: John Graham Construction Ltd; Descri... 140 [-0.024112487211823463, -0.02881261520087719, ... [-0.024112487211823463, -0.02881261520087719, ...
classes = list(set(ft_prep_df['Classification']))
class_df = pd.DataFrame(classes).reset_index()
class_df.columns = ['class_id','class']
class_df  , len(class_df)

( class_id class 0 0 Other 1 1 Literature & Archive 2 2 Software/IT 3 3 Utility Bills 4 4 Building Improvement, 5)

ft_df_with_class = ft_prep_df.merge(class_df,left_on='Classification',right_on='class',how='inner')

# Creating a list of messages for the fine-tuning job. The user message is the prompt, and the assistant message is the response from the model
ft_df_with_class['messages'] = ft_df_with_class.apply(lambda x: [{"role": "user", "content": format_prompt(x)}, {"role": "assistant", "content": x['class']}],axis=1)
ft_df_with_class[['messages', 'class']].head()
messages class
0 [{'role': 'user', 'content': 'You are a data e... Other
1 [{'role': 'user', 'content': 'You are a data e... Building Improvement
2 [{'role': 'user', 'content': 'You are a data e... Building Improvement
3 [{'role': 'user', 'content': 'You are a data e... Building Improvement
4 [{'role': 'user', 'content': 'You are a data e... Building Improvement
# Create train/validation split
samples = ft_df_with_class["messages"].tolist()
train_df, valid_df = train_test_split(samples, test_size=0.2, random_state=42)

def write_to_jsonl(list_of_messages, filename):
    with open(filename, "w+") as f:
        for messages in list_of_messages:
            object = {  
                "messages": messages
            }
            f.write(json.dumps(object) + "\n")
# Write the train/validation split to jsonl files
train_file_name, valid_file_name = "transactions_grouped_train.jsonl", "transactions_grouped_valid.jsonl"
write_to_jsonl(train_df, train_file_name)
write_to_jsonl(valid_df, valid_file_name)
# Upload the files to OpenAI
train_file = client.files.create(file=open(train_file_name, "rb"), purpose="fine-tune")
valid_file = client.files.create(file=open(valid_file_name, "rb"), purpose="fine-tune")
# Create the fine-tuning job
fine_tuning_job = client.fine_tuning.jobs.create(training_file=train_file.id, validation_file=valid_file.id, model="gpt-4o-2024-08-06")
# Get the fine-tuning job status and model name
status = client.fine_tuning.jobs.retrieve(fine_tuning_job.id)
# Once the fine-tuning job is complete, you can retrieve the model name from the job status
fine_tuned_model = client.fine_tuning.jobs.retrieve(fine_tuning_job.id).fine_tuned_model
print(f"Fine tuned model id: {fine_tuned_model}")
Fine tuned model id: ft:gpt-4o-2024-08-06:openai::BKr3Xy8U

应用微调分类器

现在我们将应用我们的分类器来查看其性能。我们的训练集中只有 31 个唯一观测值,验证集中有 8 个,所以让我们看看性能如何。

# 创建一个包含预期类别标签的测试集
test_set = pd.read_json(valid_file_name, lines=True)
test_set['expected_class'] = test_set.apply(lambda x: x['messages'][-1]['content'], axis=1)
test_set.head()
messages expected_class
0 [{'role': 'user', 'content': 'You are a data e... Utility Bills
1 [{'role': 'user', 'content': 'You are a data e... Literature & Archive
2 [{'role': 'user', 'content': 'You are a data e... Literature & Archive
3 [{'role': 'user', 'content': 'You are a data e... Literature & Archive
4 [{'role': 'user', 'content': 'You are a data e... Building Improvement
# 将微调模型应用于测试集
test_set['response'] = test_set.apply(lambda x: openai.chat.completions.create(model=fine_tuned_model, messages=x['messages'][:-1], temperature=0),axis=1)
test_set['predicted_class'] = test_set.apply(lambda x: x['response'].choices[0].message.content, axis=1)

test_set.head()
messages expected_class response predicted_class
0 [{'role': 'user', 'content': 'You are a data e... Utility Bills ChatCompletion(id='chatcmpl-BKrC0S1wQSfM9ZQfcC... Utility Bills
1 [{'role': 'user', 'content': 'You are a data e... Literature & Archive ChatCompletion(id='chatcmpl-BKrC1BTr0DagbDkC2s... Literature & Archive
2 [{'role': 'user', 'content': 'You are a data e... Literature & Archive ChatCompletion(id='chatcmpl-BKrC1H3ZeIW5cz2Owr... Literature & Archive
3 [{'role': 'user', 'content': 'You are a data e... Literature & Archive ChatCompletion(id='chatcmpl-BKrC1wdhaMP0Q7YmYx... Literature & Archive
4 [{'role': 'user', 'content': 'You are a data e... Building Improvement ChatCompletion(id='chatcmpl-BKrC20c5pkpngy1xDu... Building Improvement
# Calculate the accuracy of the predictions
from sklearn.metrics import f1_score
test_set['result'] = test_set.apply(lambda x: str(x['predicted_class']).strip() == str(x['expected_class']).strip(), axis = 1)
test_set['result'].value_counts()

print(test_set['result'].value_counts())

print("F1 Score: ", f1_score(test_set['expected_class'], test_set['predicted_class'], average="weighted"))
print("Raw Accuracy: ", test_set['result'].value_counts()[True] / len(test_set))
result
True     20
False     1
Name: count, dtype: int64
F1 Score:  0.9296066252587991
Raw Accuracy:  0.9523809523809523