第6章:树的递归思维
树是算法题中最优雅的数据结构之一。它的递归定义——每个节点都是其子树的根——天然适合递归处理。对于熟悉函数式编程的工程师来说,树的递归处理就像模式匹配一样自然。本章将深入探讨如何用递归思维优雅地解决树的各类问题,从基础的遍历到复杂的路径计算,帮助你建立处理树结构的系统性思维框架。
6.1 递归三要素:终止条件、递归体、返回值
递归处理树的核心在于理解三个要素:何时停止(终止条件)、如何处理当前节点(递归体)、以及返回什么信息(返回值)。这三要素的设计决定了递归的正确性和效率。
终止条件的设计原则
终止条件通常是空节点(null/None),但有时也可能是叶子节点。选择哪种取决于问题的性质:
def recursive_template(node):
# 终止条件1:空节点
if not node:
return base_value # 根据问题定义基值
# 终止条件2:叶子节点(某些问题需要)
if not node.left and not node.right:
return leaf_value
# 递归处理
left_result = recursive_template(node.left)
right_result = recursive_template(node.right)
# 组合结果
return combine(node.val, left_result, right_result)
递归体的逻辑组织
递归体的组织方式决定了信息的流向。主要有三种模式:
- 前序处理:先处理当前节点,再递归子树
- 中序处理:在左右子树递归之间处理当前节点
- 后序处理:先递归子树,再处理当前节点
选择哪种模式取决于信息依赖关系。如果当前节点的处理依赖子树信息,用后序;如果要向子树传递信息,用前序。
返回值的语义设计
返回值是递归函数的"契约"。清晰的返回值语义能让递归逻辑自然流畅:
# 返回值语义示例
def height(node):
"""返回以node为根的子树高度"""
if not node:
return 0 # 空树高度为0
return 1 + max(height(node.left), height(node.right))
def contains(node, target):
"""返回以node为根的子树是否包含target"""
if not node:
return False
if node.val == target:
return True
return contains(node.left, target) or contains(node.right, target)
实例分析:#236 二叉树的最近公共祖先
这道题完美展示了递归三要素的运用。问题是找两个节点p和q的最近公共祖先(LCA)。
def lowestCommonAncestor(root, p, q):
# 终止条件:空节点或找到目标
if not root or root == p or root == q:
return root
# 递归体:在左右子树中查找
left = lowestCommonAncestor(root.left, p, q)
right = lowestCommonAncestor(root.right, p, q)
# 返回值语义:
# - 如果左右都找到,当前节点就是LCA
# - 如果只有一边找到,返回那一边的结果
# - 如果都没找到,返回None
if left and right:
return root
return left if left else right
这个解法的精妙之处在于返回值的多重语义:既表示"是否找到目标节点",又携带"找到的节点或LCA"信息。这种设计避免了额外的状态传递。
实例分析:#114 二叉树展开为链表
将二叉树原地展开为链表,要求用前序遍历的顺序。这题的关键是理解递归的"契约":
def flatten(root):
def flattenTree(node):
"""
返回值:(展开后的头节点, 展开后的尾节点)
契约:将以node为根的子树展开,返回展开后的头尾
"""
if not node:
return None, None
# 如果是叶子节点,头尾都是自己
if not node.left and not node.right:
return node, node
# 递归展开左右子树
left_head, left_tail = None, None
right_head, right_tail = None, None
if node.left:
left_head, left_tail = flattenTree(node.left)
if node.right:
right_head, right_tail = flattenTree(node.right)
# 按前序连接:node -> 左子树 -> 右子树
node.left = None # 清空左指针
if left_head: # 有左子树
node.right = left_head
if right_head: # 也有右子树
left_tail.right = right_head
return node, right_tail
else: # 只有左子树
return node, left_tail
else: # 没有左子树,只有右子树
node.right = right_head
return node, right_tail if right_tail else node
flattenTree(root)
更优雅的解法是利用后序遍历的特性:
def flatten(root):
def helper(node):
"""后序遍历,返回展开后的最后一个节点"""
if not node:
return None
# 先递归处理左右子树
left_tail = helper(node.left)
right_tail = helper(node.right)
# 如果有左子树,需要插入到node和右子树之间
if left_tail:
left_tail.right = node.right
node.right = node.left
node.left = None
# 返回展开后的最后一个节点
if right_tail:
return right_tail
elif left_tail:
return left_tail
else:
return node
helper(root)
6.2 自顶向下vs自底向上
理解这两种递归策略的区别是掌握树算法的关键。它们代表了信息在树中流动的两种基本方向。
自顶向下:携带信息下传
自顶向下的递归从根节点开始,将信息(状态、约束、累积值等)向下传递给子节点。这种方式适合需要"上下文"的问题。
def top_down(node, accumulated_info):
# 终止条件
if not node:
return
# 处理当前节点(使用传下来的信息)
process(node, accumulated_info)
# 准备传给子节点的信息
left_info = compute_left_info(accumulated_info, node)
right_info = compute_right_info(accumulated_info, node)
# 递归传递
top_down(node.left, left_info)
top_down(node.right, right_info)
典型应用场景:
- 路径和问题(累积从根到当前节点的和)
- 深度计算(传递当前深度)
- 约束验证(传递允许的值范围)
自底向上:收集信息上传
自底向上的递归先处理子节点,收集子树的信息,然后在当前节点汇总。这种方式适合需要子树完整信息才能做决策的问题。
def bottom_up(node):
# 终止条件
if not node:
return base_info
# 先递归获取子树信息
left_info = bottom_up(node.left)
right_info = bottom_up(node.right)
# 基于子树信息处理当前节点
current_info = process(node, left_info, right_info)
# 返回当前子树的汇总信息
return current_info
典型应用场景:
- 子树属性计算(高度、节点数、是否平衡)
- 最优解问题(最大路径和、最长路径)
- 子树验证(是否为BST)
实例分析:#543 二叉树的直径
二叉树的直径是任意两节点间最长路径的长度。这是典型的自底向上问题,因为需要知道子树的高度才能计算经过当前节点的最长路径。
def diameterOfBinaryTree(root):
max_diameter = 0
def height(node):
nonlocal max_diameter
# 空节点高度为0
if not node:
return 0
# 递归计算左右子树高度
left_height = height(node.left)
right_height = height(node.right)
# 更新全局最大直径
# 经过当前节点的路径长度 = 左高度 + 右高度
max_diameter = max(max_diameter, left_height + right_height)
# 返回当前子树的高度
return 1 + max(left_height, right_height)
height(root)
return max_diameter
这里的关键洞察:
- 直径可能不经过根节点,所以需要检查每个节点
- 经过某节点的最长路径 = 左子树高度 + 右子树高度
- 使用自底向上递归,在计算高度的同时更新直径
实例分析:#124 二叉树中的最大路径和
找出二叉树中的最大路径和,路径可以从任意节点开始和结束。这题需要仔细设计返回值的语义。
def maxPathSum(root):
max_sum = float('-inf')
def max_gain(node):
"""
返回:从node向下延伸的最大路径和(单边)
副作用:更新全局最大路径和(可能是拐弯的)
"""
nonlocal max_sum
if not node:
return 0
# 递归计算左右子树的最大贡献
# 如果贡献为负,不如不选(取0)
left_gain = max(0, max_gain(node.left))
right_gain = max(0, max_gain(node.right))
# 计算经过当前节点的最大路径和(可以拐弯)
current_max = node.val + left_gain + right_gain
max_sum = max(max_sum, current_max)
# 返回从当前节点向下的最大路径和(不能拐弯)
return node.val + max(left_gain, right_gain)
max_gain(root)
return max_sum
设计要点:
- 返回值语义:只返回"单边"最大值(供上层使用)
- 全局更新:在每个节点更新可能的最大值(包括拐弯路径)
- 负值处理:负贡献直接舍弃(取0)
混合策略:同时进行上传下传
有些复杂问题需要同时进行信息的上传和下传:
def hybrid_recursion(node, down_info):
if not node:
return base_up_info
# 使用下传信息处理当前节点
process_with_context(node, down_info)
# 准备传给子节点的信息
left_down = compute_left_context(down_info, node)
right_down = compute_right_context(down_info, node)
# 递归并收集子树信息
left_up = hybrid_recursion(node.left, left_down)
right_up = hybrid_recursion(node.right, right_down)
# 汇总并返回
return combine(node, left_up, right_up, down_info)
6.3 路径问题:从根到叶、任意路径
路径问题是树算法中的一大类,关键在于理解不同的路径定义和相应的处理策略。
路径的不同定义
树中的路径有多种定义方式,每种定义对应不同的算法策略:
- 根到叶路径:必须从根开始,到叶子结束
- 根到任意节点路径:从根开始,可在任意节点结束
- 任意路径:可从任意节点开始和结束,但必须是父子方向
- 祖先-后代路径:必须是祖先到后代的路径
理解题目的路径定义是解题的第一步。不同定义需要不同的递归策略。
路径和问题的递归处理
处理路径问题的核心技巧是维护"路径状态"。根据路径定义的不同,状态维护方式也不同:
# 策略1:显式维护路径(适合需要完整路径的问题)
def find_paths_explicit(node, current_path, target):
if not node:
return []
# 添加当前节点到路径
current_path.append(node.val)
# 检查是否满足条件
if is_valid_path(current_path, target):
results.append(current_path[:]) # 注意要复制
# 递归
find_paths_explicit(node.left, current_path, target)
find_paths_explicit(node.right, current_path, target)
# 回溯
current_path.pop()
# 策略2:只维护路径和(适合只关心和的问题)
def find_paths_sum(node, current_sum, target):
if not node:
return 0
current_sum += node.val
count = 1 if current_sum == target else 0
count += find_paths_sum(node.left, current_sum, target)
count += find_paths_sum(node.right, current_sum, target)
return count
实例分析:#113 路径总和II
找出所有从根到叶子节点路径总和等于目标值的路径。这是典型的根到叶路径问题。
def pathSum(root, targetSum):
def dfs(node, current_path, remaining):
if not node:
return
# 添加当前节点
current_path.append(node.val)
remaining -= node.val
# 叶子节点且和为目标值
if not node.left and not node.right and remaining == 0:
result.append(current_path[:]) # 复制当前路径
# 递归搜索
dfs(node.left, current_path, remaining)
dfs(node.right, current_path, remaining)
# 回溯
current_path.pop()
result = []
dfs(root, [], targetSum)
return result
优化版本,使用更函数式的风格:
def pathSum(root, targetSum):
if not root:
return []
# 叶子节点
if not root.left and not root.right:
return [[root.val]] if root.val == targetSum else []
# 递归获取子树的所有路径
left_paths = pathSum(root.left, targetSum - root.val)
right_paths = pathSum(root.right, targetSum - root.val)
# 在每条路径前加上当前节点
return [[root.val] + path for path in left_paths + right_paths]
实例分析:#437 路径总和III
找出路径和等于目标值的路径数量,路径可以从任意节点开始和结束(但必须向下)。这是"任意路径"问题的典型代表。
朴素解法:在每个节点都尝试作为起点:
def pathSum(root, targetSum):
def pathsFromNode(node, remaining):
"""从node开始向下的路径中,和为remaining的路径数"""
if not node:
return 0
count = 1 if node.val == remaining else 0
count += pathsFromNode(node.left, remaining - node.val)
count += pathsFromNode(node.right, remaining - node.val)
return count
if not root:
return 0
# 三部分:从当前节点开始的 + 左子树中的 + 右子树中的
return (pathsFromNode(root, targetSum) +
pathSum(root.left, targetSum) +
pathSum(root.right, targetSum))
优化解法:使用前缀和 + 哈希表,类似于数组的子数组和问题:
def pathSum(root, targetSum):
def dfs(node, current_sum):
nonlocal count
if not node:
return
# 更新当前路径和
current_sum += node.val
# 检查是否存在前缀和使得 current_sum - prefix_sum = target
count += prefix_sums.get(current_sum - targetSum, 0)
# 记录当前前缀和
prefix_sums[current_sum] = prefix_sums.get(current_sum, 0) + 1
# 递归处理子树
dfs(node.left, current_sum)
dfs(node.right, current_sum)
# 回溯:移除当前前缀和
prefix_sums[current_sum] -= 1
count = 0
prefix_sums = {0: 1} # 初始化:空路径的和为0
dfs(root, 0)
return count
这个优化的关键洞察:
- 路径和问题可以转化为前缀和问题
- 如果
prefix_sum[j] - prefix_sum[i] = target,则从i+1到j的路径和为target - 使用哈希表记录前缀和出现的次数
- 回溯时要移除当前路径的前缀和,避免影响其他分支
6.4 树的序列化与构造
序列化是将树结构转换为线性表示,反序列化是从线性表示重建树。这类问题考察对树结构和递归的深刻理解。
序列化的不同策略
树的序列化有多种策略,每种策略有其适用场景:
- 前序遍历 + 空节点标记:最直观,易于实现递归
- 层序遍历(BFS):适合完全二叉树,空间效率高
- 前序 + 中序:无需空节点标记,但要求无重复值
- 括号表示法:适合可视化和调试
选择策略时考虑:
- 是否需要保留空节点信息
- 重建的复杂度
- 序列化结果的可读性
从序列重建树的技巧
重建树的核心是确定每个节点的位置和子树范围:
# 技巧1:使用索引或迭代器避免数组切片
def deserialize_with_index(data):
def build():
nonlocal idx
if idx >= len(data) or data[idx] == 'null':
idx += 1
return None
node = TreeNode(data[idx])
idx += 1
node.left = build()
node.right = build()
return node
idx = 0
return build()
# 技巧2:使用队列或栈维护遍历状态
def deserialize_with_queue(data):
queue = collections.deque(data)
def build():
val = queue.popleft()
if val == 'null':
return None
node = TreeNode(val)
node.left = build()
node.right = build()
return node
return build()
实例分析:#297 二叉树的序列化与反序列化
设计算法实现树的序列化和反序列化。这题的关键是选择合适的遍历方式。
前序遍历解法(最直观):
class Codec:
def serialize(self, root):
"""前序遍历序列化"""
def preorder(node):
if not node:
vals.append('null')
return
vals.append(str(node.val))
preorder(node.left)
preorder(node.right)
vals = []
preorder(root)
return ','.join(vals)
def deserialize(self, data):
"""前序遍历反序列化"""
def build():
val = next(vals)
if val == 'null':
return None
node = TreeNode(int(val))
node.left = build()
node.right = build()
return node
vals = iter(data.split(','))
return build()
层序遍历解法(BFS,更符合直觉):
class Codec:
def serialize(self, root):
"""层序遍历序列化"""
if not root:
return ''
result = []
queue = collections.deque([root])
while queue:
node = queue.popleft()
if node:
result.append(str(node.val))
queue.append(node.left)
queue.append(node.right)
else:
result.append('null')
# 去除末尾的null
while result and result[-1] == 'null':
result.pop()
return ','.join(result)
def deserialize(self, data):
"""层序遍历反序列化"""
if not data:
return None
values = data.split(',')
root = TreeNode(int(values[0]))
queue = collections.deque([root])
i = 1
while queue and i < len(values):
node = queue.popleft()
# 处理左子节点
if i < len(values) and values[i] != 'null':
node.left = TreeNode(int(values[i]))
queue.append(node.left)
i += 1
# 处理右子节点
if i < len(values) and values[i] != 'null':
node.right = TreeNode(int(values[i]))
queue.append(node.right)
i += 1
return root
实例分析:#105 从前序与中序遍历序列构造二叉树
给定前序和中序遍历结果,重建二叉树。这题的关键洞察:
- 前序的第一个元素是根
- 在中序中找到根,左边是左子树,右边是右子树
递归解法:
def buildTree(preorder, inorder):
if not preorder:
return None
# 前序的第一个是根
root_val = preorder[0]
root = TreeNode(root_val)
# 在中序中找到根的位置
mid = inorder.index(root_val)
# 递归构建左右子树
# 左子树:前序[1:mid+1],中序[0:mid]
# 右子树:前序[mid+1:],中序[mid+1:]
root.left = buildTree(preorder[1:mid+1], inorder[:mid])
root.right = buildTree(preorder[mid+1:], inorder[mid+1:])
return root
优化版本(避免数组切片,使用索引):
def buildTree(preorder, inorder):
def build(pre_start, pre_end, in_start, in_end):
if pre_start > pre_end:
return None
# 根节点
root_val = preorder[pre_start]
root = TreeNode(root_val)
# 找到根在中序中的位置
mid = inorder_map[root_val]
left_size = mid - in_start
# 递归构建
root.left = build(pre_start + 1, pre_start + left_size,
in_start, mid - 1)
root.right = build(pre_start + left_size + 1, pre_end,
mid + 1, in_end)
return root
# 建立中序值到索引的映射,避免重复查找
inorder_map = {val: i for i, val in enumerate(inorder)}
return build(0, len(preorder) - 1, 0, len(inorder) - 1)
更优雅的迭代器解法:
def buildTree(preorder, inorder):
def build(stop):
if inorder and inorder[-1] != stop:
root_val = preorder.pop()
root = TreeNode(root_val)
root.left = build(root_val)
inorder.pop()
root.right = build(stop)
return root
return None
preorder.reverse()
inorder.reverse()
return build(None)
这个解法的精妙之处:
- 利用了前序和中序的遍历顺序特性
- 使用stop参数标记子树的边界
- 通过反转数组,可以用pop()高效地获取元素
6.5 本章小结
树的递归处理是算法题中最优雅的部分之一。通过本章的学习,我们建立了处理树问题的系统性思维框架。
核心概念回顾
-
递归三要素 - 终止条件:定义递归的边界 - 递归体:处理当前节点的逻辑 - 返回值:向上传递的信息契约
-
信息流向策略 - 自顶向下:携带上下文信息下传 - 自底向上:收集子树信息上传 - 混合策略:同时进行双向信息传递
-
路径问题框架 - 明确路径定义(根到叶、任意路径等) - 选择状态维护方式(显式路径、累积值) - 利用前缀和等技巧优化
-
序列化与重建 - 选择合适的遍历方式 - 使用迭代器或索引避免数组切片 - 理解遍历序列的结构特性
递归模式总结
树的递归有几种常见模式,掌握这些模式能让你快速识别和解决问题:
# 模式1:纯遍历(不需要返回值)
def traverse(node):
if not node:
return
process(node)
traverse(node.left)
traverse(node.right)
# 模式2:分治(需要汇总子树信息)
def divide_conquer(node):
if not node:
return base_value
left_result = divide_conquer(node.left)
right_result = divide_conquer(node.right)
return combine(node.val, left_result, right_result)
# 模式3:路径追踪(需要回溯)
def path_tracking(node, path):
if not node:
return
path.append(node.val)
if is_target(node):
process_path(path)
path_tracking(node.left, path)
path_tracking(node.right, path)
path.pop() # 回溯
# 模式4:全局变量更新
def global_update(node):
nonlocal global_var
if not node:
return local_result
local = compute(node)
global_var = update(global_var, local)
return local
解题思维流程
面对树的问题时,按以下步骤思考:
- 识别问题类型:遍历、查找、路径、构造?
- 确定信息流向:需要上下文还是子树信息?
- 设计返回值语义:返回什么信息给上层?
- 处理边界情况:空节点、叶子节点如何处理?
- 考虑优化空间:能否减少重复计算或空间使用?
6.6 常见陷阱与错误 (Gotchas)
在处理树的递归问题时,即使是经验丰富的工程师也容易犯以下错误。
1. 递归终止条件的遗漏
常见错误:忘记处理空节点,导致空指针异常。
# 错误示例
def sum_tree(node):
return node.val + sum_tree(node.left) + sum_tree(node.right)
# 忘记处理 node 为 None 的情况
# 正确示例
def sum_tree(node):
if not node: # 必须先检查
return 0
return node.val + sum_tree(node.left) + sum_tree(node.right)
2. 返回值语义混淆
常见错误:返回值的含义不清晰,导致逻辑错误。
# 混淆示例
def find_path(node, target):
# 返回值是什么?是否找到?路径本身?节点数?
if not node:
return ???
# 清晰示例
def find_path(node, target):
"""返回是否存在和为target的根到叶路径"""
if not node:
return False # 明确的布尔值
3. 重复计算问题
常见错误:在递归中重复计算相同的子问题。
# 低效示例(计算平衡二叉树)
def isBalanced(root):
def height(node):
if not node:
return 0
return 1 + max(height(node.left), height(node.right))
if not root:
return True
# height被重复计算多次
return (abs(height(root.left) - height(root.right)) <= 1 and
isBalanced(root.left) and
isBalanced(root.right))
# 优化示例
def isBalanced(root):
def check(node):
"""返回 (是否平衡, 高度)"""
if not node:
return True, 0
left_balanced, left_height = check(node.left)
if not left_balanced:
return False, 0
right_balanced, right_height = check(node.right)
if not right_balanced:
return False, 0
return abs(left_height - right_height) <= 1, 1 + max(left_height, right_height)
return check(root)[0]
4. 全局变量的误用
常见错误:在递归函数中不当使用全局变量。
# 错误示例
max_sum = 0 # 全局变量在多次调用时会保留状态
def maxPathSum(root):
def helper(node):
global max_sum # 容易忘记声明global
# ...
# 正确示例
def maxPathSum(root):
max_sum = float('-inf') # 局部变量
def helper(node):
nonlocal max_sum # 使用nonlocal
# ...
helper(root)
return max_sum
5. 引用传递的陷阱
常见错误:Python中列表是引用传递,容易出现意外修改。
# 错误示例
def all_paths(root):
result = []
def dfs(node, path):
if not node:
return
path.append(node.val)
if not node.left and not node.right:
result.append(path) # 错误!path是引用
dfs(node.left, path)
dfs(node.right, path)
path.pop()
dfs(root, [])
return result
# 正确示例
def all_paths(root):
result = []
def dfs(node, path):
if not node:
return
path.append(node.val)
if not node.left and not node.right:
result.append(path[:]) # 正确!复制path
dfs(node.left, path)
dfs(node.right, path)
path.pop()
dfs(root, [])
return result
调试递归的技巧
- 打印递归路径:
def debug_recursion(node, depth=0):
print(' ' * depth + f'Enter: {node.val if node else "None"}')
if not node:
return
result = process(node)
debug_recursion(node.left, depth + 1)
debug_recursion(node.right, depth + 1)
print(' ' * depth + f'Exit: {node.val}, result={result}')
-
使用小数据集:先在纸上画出小树,手动追踪递归过程
-
检查边界条件:特别测试空树、单节点、只有左/右子树的情况
-
可视化递归栈:理解递归的调用顺序和返回顺序
-
使用断言:在关键位置加入断言验证假设
def recursive_func(node):
if not node:
return base_value
left_result = recursive_func(node.left)
right_result = recursive_func(node.right)
# 验证中间结果
assert left_result >= 0, f"Unexpected negative: {left_result}"
return combine(left_result, right_result)
记住:递归的优雅来自于其简洁性,但这种简洁性需要建立在清晰的思维模型之上。多练习、多思考、多总结,你就能掌握树递归的精髓。