使用嵌入进行分类
文本分类的方法有很多。本笔记本分享了一个使用嵌入进行文本分类的示例。对于许多文本分类任务,我们发现微调模型比嵌入效果更好。请参阅Fine-tuned_classification.ipynb中用于分类的微调模型示例。我们还建议嵌入的维度要少于示例的数量,而我们在这里并未完全实现。
在此文本分类任务中,我们根据评论文本的嵌入来预测食品评论的分数(1 到 5)。我们将数据集分为训练集和测试集,以用于所有后续任务,这样我们就可以在未见过的数据上实际评估性能。数据集是在Get_embeddings_from_dataset Notebook中创建的。
import pandas as pd
import numpy as np
from ast import literal_eval
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
datafile_path = "data/fine_food_reviews_with_embeddings_1k.csv"
df = pd.read_csv(datafile_path)
df["embedding"] = df.embedding.apply(literal_eval).apply(np.array) # convert string to array
# split data into train and test
X_train, X_test, y_train, y_test = train_test_split(
list(df.embedding.values), df.Score, test_size=0.2, random_state=42
)
# train random forest classifier
clf = RandomForestClassifier(n_estimators=100)
clf.fit(X_train, y_train)
preds = clf.predict(X_test)
probas = clf.predict_proba(X_test)
report = classification_report(y_test, preds)
print(report)
precision recall f1-score support
1 0.90 0.45 0.60 20
2 1.00 0.38 0.55 8
3 1.00 0.18 0.31 11
4 0.88 0.26 0.40 27
5 0.76 1.00 0.86 134
accuracy 0.78 200 macro avg 0.91 0.45 0.54 200 weighted avg 0.81 0.78 0.73 200
我们可以看到,模型已经学会了区分这些类别。5星评论的整体表现最好,这并不奇怪,因为它们在数据集中最常见。
from utils.embeddings_utils import plot_multiclass_precision_recall
plot_multiclass_precision_recall(probas, y_test, [1, 2, 3, 4, 5], clf)
RandomForestClassifier() - 所有类别的平均精度得分:0.90
毫不奇怪,5星和1星评论似乎更容易预测。也许有了更多的数据,2-4星之间的细微差别可以得到更好的预测,但人们在使用中间分数时可能也有更多的个人主观性。