使用嵌入和最近邻搜索进行推荐

推荐在网络上无处不在。

  • “购买了该商品?试试这些类似的商品。”
  • “喜欢那本书?试试这些类似的标题。”
  • “找不到您要的帮助页面?试试这些类似的页面。”

本笔记本演示了如何使用嵌入来查找要推荐的相似商品。特别是,我们使用 AG 新闻文章语料库 作为我们的数据集。

我们的模型将回答这个问题:给定一篇文章,还有哪些文章与它最相似?

import pandas as pd
import pickle

from utils.embeddings_utils import (
    get_embedding,
    distances_from_embeddings,
    tsne_components_from_embeddings,
    chart_from_components,
    indices_of_nearest_neighbors_from_distances,
)

EMBEDDING_MODEL = "text-embedding-3-small"

2. 加载数据

接下来,让我们加载 AG 新闻数据并看看它的样子。

# 加载数据(完整数据集可在 http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html 获取)
dataset_path = "data/AG_news_samples.csv"
df = pd.read_csv(dataset_path)

n_examples = 5
df.head(n_examples)
title description label_int label
0 World Briefings BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime M... 1 World
1 Nvidia Puts a Firewall on a Motherboard (PC Wo... PC World - Upcoming chip set will include buil... 4 Sci/Tech
2 Olympic joy in Greek, Chinese press Newspapers in Greece reflect a mixture of exhi... 2 Sports
3 U2 Can iPod with Pictures SAN JOSE, Calif. -- Apple Computer (Quote, Cha... 4 Sci/Tech
4 The Dream Factory Any product, any shape, any size -- manufactur... 4 Sci/Tech

让我们看看相同的示例,但不要被省略号截断。

# 打印每个示例的标题、描述和标签
for idx, row in df.head(n_examples).iterrows():
    print("")
    print(f"Title: {row['title']}")
    print(f"Description: {row['description']}")
    print(f"Label: {row['label']}")
Title: World Briefings
Description: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the  quot;alarming quot; growth of greenhouse gases.
Label: World

Title: Nvidia Puts a Firewall on a Motherboard (PC World)
Description: PC World - Upcoming chip set will include built-in security features for your PC.
Label: Sci/Tech

Title: Olympic joy in Greek, Chinese press
Description: Newspapers in Greece reflect a mixture of exhilaration that the Athens Olympics proved successful, and relief that they passed off without any major setback.
Label: Sports

Title: U2 Can iPod with Pictures
Description: SAN JOSE, Calif. -- Apple Computer (Quote, Chart) unveiled a batch of new iPods, iTunes software and promos designed to keep it atop the heap of digital music players.
Label: Sci/Tech

Title: The Dream Factory
Description: Any product, any shape, any size -- manufactured on your desktop! The future is the fabricator. By Bruce Sterling from Wired magazine.
Label: Sci/Tech

3. 构建缓存以保存嵌入

在获取这些文章的嵌入之前,让我们设置一个缓存来保存我们生成的嵌入。通常,最好保存嵌入以便以后重用。如果不保存它们,每次重新计算时都需要付费。

缓存是一个将 (text, model) 的元组映射到嵌入(浮点数列表)的字典。缓存保存为 Python 的 pickle 文件。

# 建立嵌入缓存以避免重复计算
# 缓存是一个字典,其中键是 (text, model) 元组,值是嵌入,保存为 pickle 文件

# 设置嵌入缓存的路径
embedding_cache_path = "data/recommendations_embeddings_cache.pkl"

# 如果缓存存在,则加载它,并保存一份副本到磁盘
try:
    embedding_cache = pd.read_pickle(embedding_cache_path)
except FileNotFoundError:
    embedding_cache = {}
with open(embedding_cache_path, "wb") as embedding_cache_file:
    pickle.dump(embedding_cache, embedding_cache_file)

# 定义一个函数,用于从缓存中检索嵌入(如果存在),否则通过 API 请求
def embedding_from_string(
    string: str,
    model: str = EMBEDDING_MODEL,
    embedding_cache=embedding_cache
) -> list:
    """返回给定字符串的嵌入,使用缓存避免重复计算。"""
    if (string, model) not in embedding_cache.keys():
        embedding_cache[(string, model)] = get_embedding(string, model)
        with open(embedding_cache_path, "wb") as embedding_cache_file:
            pickle.dump(embedding_cache, embedding_cache_file)
    return embedding_cache[(string, model)]

让我们通过获取嵌入来检查它是否有效。

# 例如,取数据集中第一个描述
example_string = df["description"].values[0]
print(f"\n示例字符串: {example_string}")

# 打印嵌入的前 10 个维度
example_embedding = embedding_from_string(example_string)
print(f"\n示例嵌入: {example_embedding[:10]}...")
示例字符串: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the  quot;alarming quot; growth of greenhouse gases.

示例嵌入: [0.0545826330780983, -0.00428084097802639, 0.04785159230232239, 0.01587914116680622, -0.03640881925821304, 0.0143799539655447, -0.014267769642174244, -0.015175441280007362, -0.002344391541555524, 0.011075624264776707]...

4. 基于嵌入推荐相似文章

为了找到相似的文章,让我们遵循一个三步计划:

  1. 获取所有文章描述的相似性嵌入
  2. 计算源标题与其他文章之间的距离
  3. 打印出最接近源标题的其他文章
def print_recommendations_from_strings(
    strings: list[str],
    index_of_source_string: int,
    k_nearest_neighbors: int = 1,
    model=EMBEDDING_MODEL,
) -> list[int]:
    """打印给定字符串的 k 个最近邻。"""
    # 获取所有字符串的嵌入
    embeddings = [embedding_from_string(string, model=model) for string in strings]

    # 获取源字符串的嵌入
    query_embedding = embeddings[index_of_source_string]

    # 计算源嵌入与其他嵌入之间的距离(来自 utils.embeddings_utils.py 的函数)
    distances = distances_from_embeddings(query_embedding, embeddings, distance_metric="cosine")

    # 获取最近邻的索引(来自 utils.utils.embeddings_utils.py 的函数)
    indices_of_nearest_neighbors = indices_of_nearest_neighbors_from_distances(distances)

    # 打印源字符串
    query_string = strings[index_of_source_string]
    print(f"源字符串: {query_string}")
    # 打印其 k 个最近邻
    k_counter = 0
    for i in indices_of_nearest_neighbors:
        # 跳过任何与起始字符串完全相同的字符串
        if query_string == strings[i]:
            continue
        # 打印完 k 篇文章后停止
        if k_counter >= k_nearest_neighbors:
            break
        k_counter += 1

        # 打印相似字符串及其距离
        print(
            f"""
        --- 推荐 #{k_counter} (第 {k_counter}/{k_nearest_neighbors} 个最近邻) ---
        字符串: {strings[i]}
        距离: {distances[i]:0.3f}"""
        )

    return indices_of_nearest_neighbors

5. 示例推荐

让我们查找与第一篇关于托尼·布莱尔的文章相似的文章。

article_descriptions = df["description"].tolist()

tony_blair_articles = print_recommendations_from_strings(
    strings=article_descriptions,  # 基于文章描述来判断相似性
    index_of_source_string=0,  # 与第一篇关于托尼·布莱尔的文章相似的文章
    k_nearest_neighbors=5,  # 5篇最相似的文章
)
源字符串: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the  quot;alarming quot; growth of greenhouse gases.

        --- 推荐 #1 (第 1/5 个最近邻) ---
        字符串: The anguish of hostage Kenneth Bigley in Iraq hangs over Prime Minister Tony Blair today as he faces the twin test of a local election and a debate by his Labour Party about the divisive war.
        距离: 0.514

        --- 推荐 #2 (第 2/5 个最近邻) ---
        字符串: THE re-election of British Prime Minister Tony Blair would be seen as an endorsement of the military action in Iraq, Prime Minister John Howard said today.
        距离: 0.516

        --- 推荐 #3 (第 3/5 个最近邻) ---
        字符串: Israel is prepared to back a Middle East conference convened by Tony Blair early next year despite having expressed fears that the British plans were over-ambitious and designed 
        距离: 0.546

        --- 推荐 #4 (第 4/5 个最近邻) ---
        字符串: Allowing dozens of casinos to be built in the UK would bring investment and thousands of jobs, Tony Blair says.
        距离: 0.568

        --- 推荐 #5 (第 5/5 个最近邻) ---
        字符串: AFP - A battle group of British troops rolled out of southern Iraq on a US-requested mission to deadlier areas near Baghdad, in a major political gamble for British Prime Minister Tony Blair.
        距离: 0.579

相当不错!5 篇推荐中有 4 篇明确提到了托尼·布莱尔,第五篇是来自伦敦关于气候变化的文章,这些话题可能经常与托尼·布莱尔相关。

让我们看看我们的推荐器如何处理第二篇关于 NVIDIA 新芯片组具有更高安全性的示例文章。

chipset_security_articles = print_recommendations_from_strings(
    strings=article_descriptions,  # 基于文章描述来判断相似性
    index_of_source_string=1,  # 看看与第二篇关于更安全芯片组的文章相似的文章
    k_nearest_neighbors=5,  # 查看 5 篇最相似的文章
)
源字符串: PC World - Upcoming chip set will include built-in security features for your PC.

        --- 推荐 #1 (第 1/5 个最近邻) ---
        字符串: PC World - Updated antivirus software for businesses adds intrusion prevention features.
        距离: 0.422

        --- 推荐 #2 (第 2/5 个最近邻) ---
        字符串: PC World - Symantec, McAfee hope raising virus-definition fees will move users to\  suites.
        距离: 0.518

        --- 推荐 #3 (第 3/5 个最近邻) ---
        字符串: originally offered on notebook PCs -- to its Opteron 32- and 64-bit x86 processors for server applications. The technology will help servers to run 
        距离: 0.522

        --- 推荐 #4 (第 4/5 个最近邻) ---
        字符串: PC World - Send your video throughout your house--wirelessly--with new gateways and media adapters.
        距离: 0.532

        --- 推荐 #5 (第 5/5 个最近邻) ---
        字符串: Chips that help a computer's main microprocessors perform specific types of math problems are becoming a big business once again.\
        距离: 0.532

从打印的距离可以看出,#1 推荐比其他所有推荐都更接近(0.11 对比 0.14+)。并且 #1 推荐看起来与起始文章非常相似——这是另一篇来自 PC World 的关于提高计算机安全性的文章。相当不错!

附录:在更复杂的推荐器中使用嵌入

构建推荐系统的一种更复杂的方法是训练一个机器学习模型,该模型接收数十或数百个信号,例如商品流行度或用户点击数据。即使在这个系统中,嵌入也可以作为推荐器的非常有用的信号,特别是对于那些“冷启动”但没有用户数据(例如,新添加到目录中但没有任何点击的新产品)的商品。

附录:使用嵌入可视化相似文章

为了了解我们的最近邻推荐器的工作原理,让我们可视化文章嵌入。虽然我们无法绘制每个嵌入向量的 2048 个维度,但我们可以使用 t-SNEPCA 等技术将嵌入压缩到 2 或 3 个维度,然后我们可以对其进行绘图。

在可视化最近邻之前,让我们使用 t-SNE 可视化所有文章描述。请注意,t-SNE 不是确定性的,这意味着结果可能因运行而异。

# 获取所有文章描述的嵌入
embeddings = [embedding_from_string(string) for string in article_descriptions]
# 使用 t-SNE 将 2048 维嵌入压缩到 2 维
tsne_components = tsne_components_from_embeddings(embeddings)
# 获取用于为图着色的文章标签
labels = df["label"].tolist()

chart_from_components(
    components=tsne_components,
    labels=labels,
    strings=article_descriptions,
    width=600,
    height=500,
    title="文章描述的 t-SNE 组件",
)

如上图所示,即使是高度压缩的嵌入也能很好地按类别对文章描述进行聚类。值得强调的是:这种聚类是在完全不知道标签本身的情况下完成的!

此外,如果您仔细查看最明显的异常值,它们通常是由于错误标记而不是嵌入效果不佳造成的。例如,绿色体育集群中的大多数蓝色“World”点似乎都是体育报道。

接下来,让我们按它们是源文章、其最近邻还是其他文章来重新着色这些点。

# 为推荐的文章创建标签
def nearest_neighbor_labels(
    list_of_indices: list[int],
    k_nearest_neighbors: int = 5
) -> list[str]:
    """返回用于为 k 个最近邻着色的标签列表。"""
    labels = ["Other" for _ in list_of_indices]
    source_index = list_of_indices[0]
    labels[source_index] = "Source"
    for i in range(k_nearest_neighbors):
        nearest_neighbor_index = list_of_indices[i + 1]
        labels[nearest_neighbor_index] = f"Nearest neighbor (top {k_nearest_neighbors})"
    return labels


tony_blair_labels = nearest_neighbor_labels(tony_blair_articles, k_nearest_neighbors=5)
chipset_security_labels = nearest_neighbor_labels(chipset_security_articles, k_nearest_neighbors=5
)
# 托尼·布莱尔文章的最近邻的 2D 图
chart_from_components(
    components=tsne_components,
    labels=tony_blair_labels,
    strings=article_descriptions,
    width=600,
    height=500,
    title="托尼·布莱尔文章的最近邻",
    category_orders={"label": ["Other", "Nearest neighbor (top 5)", "Source"]},
)

从上图的 2D 图中,我们可以看到关于托尼·布莱尔的文章在“World”新闻集群中彼此靠近。有趣的是,尽管(红色的)5 个最近邻在高维空间中是最接近的,但它们并不是在这个压缩的 2D 空间中最接近的点。将嵌入压缩到 2 维会丢失其大部分信息,并且 2D 空间中的最近邻似乎不如完整嵌入空间中的最近邻相关。

# 芯片组安全示例的 2D 图
chart_from_components(
    components=tsne_components,
    labels=chipset_security_labels,
    strings=article_descriptions,
    width=600,
    height=500,
    title="芯片组安全文章的最近邻",
    category_orders={"label": ["Other", "Nearest neighbor (top 5)", "Source"]},
)

对于芯片组安全示例,完整嵌入空间中的 4 个最接近的最近邻在此压缩的 2D 可视化中仍然是最近邻。第五个显示为更远,尽管它在完整嵌入空间中更近。

如果您愿意,您还可以使用 chart_from_components_3D 函数制作一个交互式 3D 嵌入图。(这样做需要使用 n_components=3 重新计算 t-SNE 组件。)