第6章:树的递归思维

树是算法题中最优雅的数据结构之一。它的递归定义——每个节点都是其子树的根——天然适合递归处理。对于熟悉函数式编程的工程师来说,树的递归处理就像模式匹配一样自然。本章将深入探讨如何用递归思维优雅地解决树的各类问题,从基础的遍历到复杂的路径计算,帮助你建立处理树结构的系统性思维框架。

6.1 递归三要素:终止条件、递归体、返回值

递归处理树的核心在于理解三个要素:何时停止(终止条件)、如何处理当前节点(递归体)、以及返回什么信息(返回值)。这三要素的设计决定了递归的正确性和效率。

终止条件的设计原则

终止条件通常是空节点(null/None),但有时也可能是叶子节点。选择哪种取决于问题的性质:

def recursive_template(node):
    # 终止条件1:空节点
    if not node:
        return base_value  # 根据问题定义基值

    # 终止条件2:叶子节点(某些问题需要)
    if not node.left and not node.right:
        return leaf_value

    # 递归处理
    left_result = recursive_template(node.left)
    right_result = recursive_template(node.right)

    # 组合结果
    return combine(node.val, left_result, right_result)

递归体的逻辑组织

递归体的组织方式决定了信息的流向。主要有三种模式:

  1. 前序处理:先处理当前节点,再递归子树
  2. 中序处理:在左右子树递归之间处理当前节点
  3. 后序处理:先递归子树,再处理当前节点

选择哪种模式取决于信息依赖关系。如果当前节点的处理依赖子树信息,用后序;如果要向子树传递信息,用前序。

返回值的语义设计

返回值是递归函数的"契约"。清晰的返回值语义能让递归逻辑自然流畅:

# 返回值语义示例
def height(node):
    """返回以node为根的子树高度"""
    if not node:
        return 0  # 空树高度为0
    return 1 + max(height(node.left), height(node.right))

def contains(node, target):
    """返回以node为根的子树是否包含target"""
    if not node:
        return False
    if node.val == target:
        return True
    return contains(node.left, target) or contains(node.right, target)

实例分析:#236 二叉树的最近公共祖先

这道题完美展示了递归三要素的运用。问题是找两个节点p和q的最近公共祖先(LCA)。

def lowestCommonAncestor(root, p, q):
    # 终止条件:空节点或找到目标
    if not root or root == p or root == q:
        return root

    # 递归体:在左右子树中查找
    left = lowestCommonAncestor(root.left, p, q)
    right = lowestCommonAncestor(root.right, p, q)

    # 返回值语义:
    # - 如果左右都找到,当前节点就是LCA
    # - 如果只有一边找到,返回那一边的结果
    # - 如果都没找到,返回None
    if left and right:
        return root
    return left if left else right

这个解法的精妙之处在于返回值的多重语义:既表示"是否找到目标节点",又携带"找到的节点或LCA"信息。这种设计避免了额外的状态传递。

实例分析:#114 二叉树展开为链表

将二叉树原地展开为链表,要求用前序遍历的顺序。这题的关键是理解递归的"契约":

def flatten(root):
    def flattenTree(node):
        """
        返回值:(展开后的头节点, 展开后的尾节点)
        契约:将以node为根的子树展开,返回展开后的头尾
        """
        if not node:
            return None, None

        # 如果是叶子节点,头尾都是自己
        if not node.left and not node.right:
            return node, node

        # 递归展开左右子树
        left_head, left_tail = None, None
        right_head, right_tail = None, None

        if node.left:
            left_head, left_tail = flattenTree(node.left)
        if node.right:
            right_head, right_tail = flattenTree(node.right)

        # 按前序连接:node -> 左子树 -> 右子树
        node.left = None  # 清空左指针

        if left_head:  # 有左子树
            node.right = left_head
            if right_head:  # 也有右子树
                left_tail.right = right_head
                return node, right_tail
            else:  # 只有左子树
                return node, left_tail
        else:  # 没有左子树,只有右子树
            node.right = right_head
            return node, right_tail if right_tail else node

    flattenTree(root)

更优雅的解法是利用后序遍历的特性:

def flatten(root):
    def helper(node):
        """后序遍历,返回展开后的最后一个节点"""
        if not node:
            return None

        # 先递归处理左右子树
        left_tail = helper(node.left)
        right_tail = helper(node.right)

        # 如果有左子树,需要插入到node和右子树之间
        if left_tail:
            left_tail.right = node.right
            node.right = node.left
            node.left = None

        # 返回展开后的最后一个节点
        if right_tail:
            return right_tail
        elif left_tail:
            return left_tail
        else:
            return node

    helper(root)

6.2 自顶向下vs自底向上

理解这两种递归策略的区别是掌握树算法的关键。它们代表了信息在树中流动的两种基本方向。

自顶向下:携带信息下传

自顶向下的递归从根节点开始,将信息(状态、约束、累积值等)向下传递给子节点。这种方式适合需要"上下文"的问题。

def top_down(node, accumulated_info):
    # 终止条件
    if not node:
        return

    # 处理当前节点(使用传下来的信息)
    process(node, accumulated_info)

    # 准备传给子节点的信息
    left_info = compute_left_info(accumulated_info, node)
    right_info = compute_right_info(accumulated_info, node)

    # 递归传递
    top_down(node.left, left_info)
    top_down(node.right, right_info)

典型应用场景:

  • 路径和问题(累积从根到当前节点的和)
  • 深度计算(传递当前深度)
  • 约束验证(传递允许的值范围)

自底向上:收集信息上传

自底向上的递归先处理子节点,收集子树的信息,然后在当前节点汇总。这种方式适合需要子树完整信息才能做决策的问题。

def bottom_up(node):
    # 终止条件
    if not node:
        return base_info

    # 先递归获取子树信息
    left_info = bottom_up(node.left)
    right_info = bottom_up(node.right)

    # 基于子树信息处理当前节点
    current_info = process(node, left_info, right_info)

    # 返回当前子树的汇总信息
    return current_info

典型应用场景:

  • 子树属性计算(高度、节点数、是否平衡)
  • 最优解问题(最大路径和、最长路径)
  • 子树验证(是否为BST)

实例分析:#543 二叉树的直径

二叉树的直径是任意两节点间最长路径的长度。这是典型的自底向上问题,因为需要知道子树的高度才能计算经过当前节点的最长路径。

def diameterOfBinaryTree(root):
    max_diameter = 0

    def height(node):
        nonlocal max_diameter

        # 空节点高度为0
        if not node:
            return 0

        # 递归计算左右子树高度
        left_height = height(node.left)
        right_height = height(node.right)

        # 更新全局最大直径
        # 经过当前节点的路径长度 = 左高度 + 右高度
        max_diameter = max(max_diameter, left_height + right_height)

        # 返回当前子树的高度
        return 1 + max(left_height, right_height)

    height(root)
    return max_diameter

这里的关键洞察:

  1. 直径可能不经过根节点,所以需要检查每个节点
  2. 经过某节点的最长路径 = 左子树高度 + 右子树高度
  3. 使用自底向上递归,在计算高度的同时更新直径

实例分析:#124 二叉树中的最大路径和

找出二叉树中的最大路径和,路径可以从任意节点开始和结束。这题需要仔细设计返回值的语义。

def maxPathSum(root):
    max_sum = float('-inf')

    def max_gain(node):
        """
        返回:从node向下延伸的最大路径和(单边)
        副作用:更新全局最大路径和(可能是拐弯的)
        """
        nonlocal max_sum

        if not node:
            return 0

        # 递归计算左右子树的最大贡献
        # 如果贡献为负,不如不选(取0)
        left_gain = max(0, max_gain(node.left))
        right_gain = max(0, max_gain(node.right))

        # 计算经过当前节点的最大路径和(可以拐弯)
        current_max = node.val + left_gain + right_gain
        max_sum = max(max_sum, current_max)

        # 返回从当前节点向下的最大路径和(不能拐弯)
        return node.val + max(left_gain, right_gain)

    max_gain(root)
    return max_sum

设计要点:

  1. 返回值语义:只返回"单边"最大值(供上层使用)
  2. 全局更新:在每个节点更新可能的最大值(包括拐弯路径)
  3. 负值处理:负贡献直接舍弃(取0)

混合策略:同时进行上传下传

有些复杂问题需要同时进行信息的上传和下传:

def hybrid_recursion(node, down_info):
    if not node:
        return base_up_info

    # 使用下传信息处理当前节点
    process_with_context(node, down_info)

    # 准备传给子节点的信息
    left_down = compute_left_context(down_info, node)
    right_down = compute_right_context(down_info, node)

    # 递归并收集子树信息
    left_up = hybrid_recursion(node.left, left_down)
    right_up = hybrid_recursion(node.right, right_down)

    # 汇总并返回
    return combine(node, left_up, right_up, down_info)

6.3 路径问题:从根到叶、任意路径

路径问题是树算法中的一大类,关键在于理解不同的路径定义和相应的处理策略。

路径的不同定义

树中的路径有多种定义方式,每种定义对应不同的算法策略:

  1. 根到叶路径:必须从根开始,到叶子结束
  2. 根到任意节点路径:从根开始,可在任意节点结束
  3. 任意路径:可从任意节点开始和结束,但必须是父子方向
  4. 祖先-后代路径:必须是祖先到后代的路径

理解题目的路径定义是解题的第一步。不同定义需要不同的递归策略。

路径和问题的递归处理

处理路径问题的核心技巧是维护"路径状态"。根据路径定义的不同,状态维护方式也不同:

# 策略1:显式维护路径(适合需要完整路径的问题)
def find_paths_explicit(node, current_path, target):
    if not node:
        return []

    # 添加当前节点到路径
    current_path.append(node.val)

    # 检查是否满足条件
    if is_valid_path(current_path, target):
        results.append(current_path[:])  # 注意要复制

    # 递归
    find_paths_explicit(node.left, current_path, target)
    find_paths_explicit(node.right, current_path, target)

    # 回溯
    current_path.pop()

# 策略2:只维护路径和(适合只关心和的问题)
def find_paths_sum(node, current_sum, target):
    if not node:
        return 0

    current_sum += node.val
    count = 1 if current_sum == target else 0

    count += find_paths_sum(node.left, current_sum, target)
    count += find_paths_sum(node.right, current_sum, target)

    return count

实例分析:#113 路径总和II

找出所有从根到叶子节点路径总和等于目标值的路径。这是典型的根到叶路径问题。

def pathSum(root, targetSum):
    def dfs(node, current_path, remaining):
        if not node:
            return

        # 添加当前节点
        current_path.append(node.val)
        remaining -= node.val

        # 叶子节点且和为目标值
        if not node.left and not node.right and remaining == 0:
            result.append(current_path[:])  # 复制当前路径

        # 递归搜索
        dfs(node.left, current_path, remaining)
        dfs(node.right, current_path, remaining)

        # 回溯
        current_path.pop()

    result = []
    dfs(root, [], targetSum)
    return result

优化版本,使用更函数式的风格:

def pathSum(root, targetSum):
    if not root:
        return []

    # 叶子节点
    if not root.left and not root.right:
        return [[root.val]] if root.val == targetSum else []

    # 递归获取子树的所有路径
    left_paths = pathSum(root.left, targetSum - root.val)
    right_paths = pathSum(root.right, targetSum - root.val)

    # 在每条路径前加上当前节点
    return [[root.val] + path for path in left_paths + right_paths]

实例分析:#437 路径总和III

找出路径和等于目标值的路径数量,路径可以从任意节点开始和结束(但必须向下)。这是"任意路径"问题的典型代表。

朴素解法:在每个节点都尝试作为起点:

def pathSum(root, targetSum):
    def pathsFromNode(node, remaining):
        """从node开始向下的路径中,和为remaining的路径数"""
        if not node:
            return 0

        count = 1 if node.val == remaining else 0
        count += pathsFromNode(node.left, remaining - node.val)
        count += pathsFromNode(node.right, remaining - node.val)

        return count

    if not root:
        return 0

    # 三部分:从当前节点开始的 + 左子树中的 + 右子树中的
    return (pathsFromNode(root, targetSum) + 
            pathSum(root.left, targetSum) + 
            pathSum(root.right, targetSum))

优化解法:使用前缀和 + 哈希表,类似于数组的子数组和问题:

def pathSum(root, targetSum):
    def dfs(node, current_sum):
        nonlocal count

        if not node:
            return

        # 更新当前路径和
        current_sum += node.val

        # 检查是否存在前缀和使得 current_sum - prefix_sum = target
        count += prefix_sums.get(current_sum - targetSum, 0)

        # 记录当前前缀和
        prefix_sums[current_sum] = prefix_sums.get(current_sum, 0) + 1

        # 递归处理子树
        dfs(node.left, current_sum)
        dfs(node.right, current_sum)

        # 回溯:移除当前前缀和
        prefix_sums[current_sum] -= 1

    count = 0
    prefix_sums = {0: 1}  # 初始化:空路径的和为0
    dfs(root, 0)
    return count

这个优化的关键洞察:

  1. 路径和问题可以转化为前缀和问题
  2. 如果 prefix_sum[j] - prefix_sum[i] = target,则从i+1到j的路径和为target
  3. 使用哈希表记录前缀和出现的次数
  4. 回溯时要移除当前路径的前缀和,避免影响其他分支

6.4 树的序列化与构造

序列化是将树结构转换为线性表示,反序列化是从线性表示重建树。这类问题考察对树结构和递归的深刻理解。

序列化的不同策略

树的序列化有多种策略,每种策略有其适用场景:

  1. 前序遍历 + 空节点标记:最直观,易于实现递归
  2. 层序遍历(BFS):适合完全二叉树,空间效率高
  3. 前序 + 中序:无需空节点标记,但要求无重复值
  4. 括号表示法:适合可视化和调试

选择策略时考虑:

  • 是否需要保留空节点信息
  • 重建的复杂度
  • 序列化结果的可读性

从序列重建树的技巧

重建树的核心是确定每个节点的位置和子树范围:

# 技巧1:使用索引或迭代器避免数组切片
def deserialize_with_index(data):
    def build():
        nonlocal idx
        if idx >= len(data) or data[idx] == 'null':
            idx += 1
            return None

        node = TreeNode(data[idx])
        idx += 1
        node.left = build()
        node.right = build()
        return node

    idx = 0
    return build()

# 技巧2:使用队列或栈维护遍历状态
def deserialize_with_queue(data):
    queue = collections.deque(data)

    def build():
        val = queue.popleft()
        if val == 'null':
            return None

        node = TreeNode(val)
        node.left = build()
        node.right = build()
        return node

    return build()

实例分析:#297 二叉树的序列化与反序列化

设计算法实现树的序列化和反序列化。这题的关键是选择合适的遍历方式。

前序遍历解法(最直观):

class Codec:
    def serialize(self, root):
        """前序遍历序列化"""
        def preorder(node):
            if not node:
                vals.append('null')
                return
            vals.append(str(node.val))
            preorder(node.left)
            preorder(node.right)

        vals = []
        preorder(root)
        return ','.join(vals)

    def deserialize(self, data):
        """前序遍历反序列化"""
        def build():
            val = next(vals)
            if val == 'null':
                return None

            node = TreeNode(int(val))
            node.left = build()
            node.right = build()
            return node

        vals = iter(data.split(','))
        return build()

层序遍历解法(BFS,更符合直觉):

class Codec:
    def serialize(self, root):
        """层序遍历序列化"""
        if not root:
            return ''

        result = []
        queue = collections.deque([root])

        while queue:
            node = queue.popleft()
            if node:
                result.append(str(node.val))
                queue.append(node.left)
                queue.append(node.right)
            else:
                result.append('null')

        # 去除末尾的null
        while result and result[-1] == 'null':
            result.pop()

        return ','.join(result)

    def deserialize(self, data):
        """层序遍历反序列化"""
        if not data:
            return None

        values = data.split(',')
        root = TreeNode(int(values[0]))
        queue = collections.deque([root])
        i = 1

        while queue and i < len(values):
            node = queue.popleft()

            # 处理左子节点
            if i < len(values) and values[i] != 'null':
                node.left = TreeNode(int(values[i]))
                queue.append(node.left)
            i += 1

            # 处理右子节点
            if i < len(values) and values[i] != 'null':
                node.right = TreeNode(int(values[i]))
                queue.append(node.right)
            i += 1

        return root

实例分析:#105 从前序与中序遍历序列构造二叉树

给定前序和中序遍历结果,重建二叉树。这题的关键洞察:

  • 前序的第一个元素是根
  • 在中序中找到根,左边是左子树,右边是右子树

递归解法:

def buildTree(preorder, inorder):
    if not preorder:
        return None

    # 前序的第一个是根
    root_val = preorder[0]
    root = TreeNode(root_val)

    # 在中序中找到根的位置
    mid = inorder.index(root_val)

    # 递归构建左右子树
    # 左子树:前序[1:mid+1],中序[0:mid]
    # 右子树:前序[mid+1:],中序[mid+1:]
    root.left = buildTree(preorder[1:mid+1], inorder[:mid])
    root.right = buildTree(preorder[mid+1:], inorder[mid+1:])

    return root

优化版本(避免数组切片,使用索引):

def buildTree(preorder, inorder):
    def build(pre_start, pre_end, in_start, in_end):
        if pre_start > pre_end:
            return None

        # 根节点
        root_val = preorder[pre_start]
        root = TreeNode(root_val)

        # 找到根在中序中的位置
        mid = inorder_map[root_val]
        left_size = mid - in_start

        # 递归构建
        root.left = build(pre_start + 1, pre_start + left_size, 
                         in_start, mid - 1)
        root.right = build(pre_start + left_size + 1, pre_end,
                          mid + 1, in_end)

        return root

    # 建立中序值到索引的映射,避免重复查找
    inorder_map = {val: i for i, val in enumerate(inorder)}
    return build(0, len(preorder) - 1, 0, len(inorder) - 1)

更优雅的迭代器解法:

def buildTree(preorder, inorder):
    def build(stop):
        if inorder and inorder[-1] != stop:
            root_val = preorder.pop()
            root = TreeNode(root_val)
            root.left = build(root_val)
            inorder.pop()
            root.right = build(stop)
            return root
        return None

    preorder.reverse()
    inorder.reverse()
    return build(None)

这个解法的精妙之处:

  1. 利用了前序和中序的遍历顺序特性
  2. 使用stop参数标记子树的边界
  3. 通过反转数组,可以用pop()高效地获取元素

6.5 本章小结

树的递归处理是算法题中最优雅的部分之一。通过本章的学习,我们建立了处理树问题的系统性思维框架。

核心概念回顾

  1. 递归三要素 - 终止条件:定义递归的边界 - 递归体:处理当前节点的逻辑 - 返回值:向上传递的信息契约

  2. 信息流向策略 - 自顶向下:携带上下文信息下传 - 自底向上:收集子树信息上传 - 混合策略:同时进行双向信息传递

  3. 路径问题框架 - 明确路径定义(根到叶、任意路径等) - 选择状态维护方式(显式路径、累积值) - 利用前缀和等技巧优化

  4. 序列化与重建 - 选择合适的遍历方式 - 使用迭代器或索引避免数组切片 - 理解遍历序列的结构特性

递归模式总结

树的递归有几种常见模式,掌握这些模式能让你快速识别和解决问题:

# 模式1:纯遍历(不需要返回值)
def traverse(node):
    if not node:
        return
    process(node)
    traverse(node.left)
    traverse(node.right)

# 模式2:分治(需要汇总子树信息)
def divide_conquer(node):
    if not node:
        return base_value
    left_result = divide_conquer(node.left)
    right_result = divide_conquer(node.right)
    return combine(node.val, left_result, right_result)

# 模式3:路径追踪(需要回溯)
def path_tracking(node, path):
    if not node:
        return
    path.append(node.val)
    if is_target(node):
        process_path(path)
    path_tracking(node.left, path)
    path_tracking(node.right, path)
    path.pop()  # 回溯

# 模式4:全局变量更新
def global_update(node):
    nonlocal global_var
    if not node:
        return local_result
    local = compute(node)
    global_var = update(global_var, local)
    return local

解题思维流程

面对树的问题时,按以下步骤思考:

  1. 识别问题类型:遍历、查找、路径、构造?
  2. 确定信息流向:需要上下文还是子树信息?
  3. 设计返回值语义:返回什么信息给上层?
  4. 处理边界情况:空节点、叶子节点如何处理?
  5. 考虑优化空间:能否减少重复计算或空间使用?

6.6 常见陷阱与错误 (Gotchas)

在处理树的递归问题时,即使是经验丰富的工程师也容易犯以下错误。

1. 递归终止条件的遗漏

常见错误:忘记处理空节点,导致空指针异常。

# 错误示例
def sum_tree(node):
    return node.val + sum_tree(node.left) + sum_tree(node.right)
    # 忘记处理 node 为 None 的情况

# 正确示例
def sum_tree(node):
    if not node:  # 必须先检查
        return 0
    return node.val + sum_tree(node.left) + sum_tree(node.right)

2. 返回值语义混淆

常见错误:返回值的含义不清晰,导致逻辑错误。

# 混淆示例
def find_path(node, target):
    # 返回值是什么?是否找到?路径本身?节点数?
    if not node:
        return ???

# 清晰示例
def find_path(node, target):
    """返回是否存在和为target的根到叶路径"""
    if not node:
        return False  # 明确的布尔值

3. 重复计算问题

常见错误:在递归中重复计算相同的子问题。

# 低效示例(计算平衡二叉树)
def isBalanced(root):
    def height(node):
        if not node:
            return 0
        return 1 + max(height(node.left), height(node.right))

    if not root:
        return True

    # height被重复计算多次
    return (abs(height(root.left) - height(root.right)) <= 1 and 
            isBalanced(root.left) and 
            isBalanced(root.right))

# 优化示例
def isBalanced(root):
    def check(node):
        """返回 (是否平衡, 高度)"""
        if not node:
            return True, 0

        left_balanced, left_height = check(node.left)
        if not left_balanced:
            return False, 0

        right_balanced, right_height = check(node.right)
        if not right_balanced:
            return False, 0

        return abs(left_height - right_height) <= 1, 1 + max(left_height, right_height)

    return check(root)[0]

4. 全局变量的误用

常见错误:在递归函数中不当使用全局变量。

# 错误示例
max_sum = 0  # 全局变量在多次调用时会保留状态

def maxPathSum(root):
    def helper(node):
        global max_sum  # 容易忘记声明global
        # ...

# 正确示例
def maxPathSum(root):
    max_sum = float('-inf')  # 局部变量

    def helper(node):
        nonlocal max_sum  # 使用nonlocal
        # ...

    helper(root)
    return max_sum

5. 引用传递的陷阱

常见错误:Python中列表是引用传递,容易出现意外修改。

# 错误示例
def all_paths(root):
    result = []

    def dfs(node, path):
        if not node:
            return
        path.append(node.val)
        if not node.left and not node.right:
            result.append(path)  # 错误!path是引用
        dfs(node.left, path)
        dfs(node.right, path)
        path.pop()

    dfs(root, [])
    return result

# 正确示例
def all_paths(root):
    result = []

    def dfs(node, path):
        if not node:
            return
        path.append(node.val)
        if not node.left and not node.right:
            result.append(path[:])  # 正确!复制path
        dfs(node.left, path)
        dfs(node.right, path)
        path.pop()

    dfs(root, [])
    return result

调试递归的技巧

  1. 打印递归路径
def debug_recursion(node, depth=0):
    print('  ' * depth + f'Enter: {node.val if node else "None"}')
    if not node:
        return
    result = process(node)
    debug_recursion(node.left, depth + 1)
    debug_recursion(node.right, depth + 1)
    print('  ' * depth + f'Exit: {node.val}, result={result}')
  1. 使用小数据集:先在纸上画出小树,手动追踪递归过程

  2. 检查边界条件:特别测试空树、单节点、只有左/右子树的情况

  3. 可视化递归栈:理解递归的调用顺序和返回顺序

  4. 使用断言:在关键位置加入断言验证假设

def recursive_func(node):
    if not node:
        return base_value

    left_result = recursive_func(node.left)
    right_result = recursive_func(node.right)

    # 验证中间结果
    assert left_result >= 0, f"Unexpected negative: {left_result}"

    return combine(left_result, right_result)

记住:递归的优雅来自于其简洁性,但这种简洁性需要建立在清晰的思维模型之上。多练习、多思考、多总结,你就能掌握树递归的精髓。