三维可视化嵌入
该示例使用 PCA 将嵌入的维度从 1536 降至 3。然后,我们可以将数据点可视化在三维图中。小型数据集 dbpedia_samples.jsonl
是通过从 DBpedia 验证数据集 中随机采样 200 个样本而创建的。
1. 加载数据集和查询嵌入
import pandas as pd
samples = pd.read_json("data/dbpedia_samples.jsonl", lines=True)
categories = sorted(samples["category"].unique())
print("DBpedia 样本类别:", samples["category"].value_counts())
samples.head()
DBpedia 样本类别: Artist 21
Film 19
Plant 19
OfficeHolder 18
Company 17
NaturalPlace 16
Athlete 16
Village 12
WrittenWork 11
Building 11
Album 11
Animal 11
EducationalInstitution 10
MeanOfTransportation 8
Name: category, dtype: int64
text | category | |
---|---|---|
0 | Morada Limited is a textile company based in ... | Company |
1 | The Armenian Mirror-Spectator is a newspaper ... | WrittenWork |
2 | Mt. Kinka (金華山 Kinka-zan) also known as Kinka... | NaturalPlace |
3 | Planning the Play of a Bridge Hand is a book ... | WrittenWork |
4 | Wang Yuanping (born 8 December 1976) is a ret... | Athlete |
from utils.embeddings_utils import get_embeddings
# 注意:以下代码将向 /embeddings 发送一个批量大小为 200 的查询
matrix = get_embeddings(samples["text"].to_list(), model="text-embedding-3-small")
2. 降低嵌入维度
from sklearn.decomposition import PCA
pca = PCA(n_components=3)
vis_dims = pca.fit_transform(matrix)
samples["embed_vis"] = vis_dims.tolist()
3. 绘制低维嵌入图
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(projection='3d')
cmap = plt.get_cmap("tab20")
# 单独绘制每个样本类别,以便设置标签名称。
for i, cat in enumerate(categories):
sub_matrix = np.array(samples[samples["category"] == cat]["embed_vis"].to_list())
x=sub_matrix[:, 0]
y=sub_matrix[:, 1]
z=sub_matrix[:, 2]
colors = [cmap(i/len(categories))] * len(sub_matrix)
ax.scatter(x, y, zs=z, zdir='z', c=colors, label=cat)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
ax.legend(bbox_to_anchor=(1.1, 1))