使用嵌入进行分类

文本分类的方法有很多。本笔记本分享了一个使用嵌入进行文本分类的示例。对于许多文本分类任务,我们发现微调模型比嵌入效果更好。请参阅Fine-tuned_classification.ipynb中用于分类的微调模型示例。我们还建议嵌入的维度要少于示例的数量,而我们在这里并未完全实现。

在此文本分类任务中,我们根据评论文本的嵌入来预测食品评论的分数(1 到 5)。我们将数据集分为训练集和测试集,以用于所有后续任务,这样我们就可以在未见过的数据上实际评估性能。数据集是在Get_embeddings_from_dataset Notebook中创建的。

import pandas as pd
import numpy as np
from ast import literal_eval

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score

datafile_path = "data/fine_food_reviews_with_embeddings_1k.csv"

df = pd.read_csv(datafile_path)
df["embedding"] = df.embedding.apply(literal_eval).apply(np.array)  # convert string to array

# split data into train and test
X_train, X_test, y_train, y_test = train_test_split(
    list(df.embedding.values), df.Score, test_size=0.2, random_state=42
)

# train random forest classifier
clf = RandomForestClassifier(n_estimators=100)
clf.fit(X_train, y_train)
preds = clf.predict(X_test)
probas = clf.predict_proba(X_test)

report = classification_report(y_test, preds)
print(report)

precision recall f1-score support

       1       0.90      0.45      0.60        20
       2       1.00      0.38      0.55         8
       3       1.00      0.18      0.31        11
       4       0.88      0.26      0.40        27
       5       0.76      1.00      0.86       134

accuracy 0.78 200 macro avg 0.91 0.45 0.54 200 weighted avg 0.81 0.78 0.73 200

我们可以看到,模型已经学会了区分这些类别。5星评论的整体表现最好,这并不奇怪,因为它们在数据集中最常见。

from utils.embeddings_utils import plot_multiclass_precision_recall

plot_multiclass_precision_recall(probas, y_test, [1, 2, 3, 4, 5], clf)

RandomForestClassifier() - 所有类别的平均精度得分:0.90

png

毫不奇怪,5星和1星评论似乎更容易预测。也许有了更多的数据,2-4星之间的细微差别可以得到更好的预测,但人们在使用中间分数时可能也有更多的个人主观性。