第11章:区间DP与状态压缩

章节大纲

11.1 区间DP:戳气球、矩阵链乘法

  • 区间DP的核心思想:枚举分割点
  • 经典问题:戳气球 (#312)
  • 矩阵链乘法与多边形三角剖分 (#1039)
  • 区间DP的实现技巧与边界处理

11.2 数位DP:统计特定范围内的数

  • 数位DP的基本框架:记忆化搜索
  • 数字1的个数 (#233)
  • 最大为N的数字组合 (#902)
  • 数位DP的通用模板与优化

11.3 状态压缩:用二进制表示集合

  • 位运算基础与集合表示
  • 我能赢吗 (#464)
  • 访问所有节点的最短路径 (#847)
  • 状态压缩DP的空间与时间权衡

11.4 概率DP与期望DP

  • 概率DP的状态定义与转移
  • 骑士在棋盘上的概率 (#688)
  • 新21点 (#837)
  • 期望的线性性质与计算技巧

本章小结

  • 四种高级DP技巧的适用场景
  • 关键公式与模板总结

常见陷阱与错误

  • 区间DP的边界处理错误
  • 数位DP的前导零和limit标记
  • 状态压缩的位运算错误
  • 概率计算的精度问题

开篇段落

本章我们将探讨动态规划的四种高级技巧。这些技巧虽然在面试中出现频率不如基础DP高,但它们展示了DP思想的深度和灵活性。区间DP通过枚举分割点处理区间问题;数位DP用记忆化搜索统计特定范围内的数;状态压缩将集合编码为整数以处理指数级状态空间;概率DP则将不确定性纳入状态转移。掌握这些技巧不仅能解决特定类型的难题,更能深化对动态规划本质的理解。

11.1 区间DP:戳气球、矩阵链乘法

区间DP是处理区间相关问题的经典方法。其核心思想是:对于区间 [i, j],枚举所有可能的分割点 k,将问题分解为子区间 [i, k] 和 [k, j] 的组合。这种方法特别适用于"最后一步"思维——思考最后执行的操作是什么,然后递归处理剩余部分。

戳气球问题 (#312)

戳气球是区间DP的经典题目。给定 n 个气球,编号为 0 到 n-1,每个气球上标有一个数字 nums[i]。戳破第 i 个气球可以获得 nums[i-1] * nums[i] * nums[i+1] 枚硬币。求戳破所有气球能获得的最大硬币数。

关键洞察:如果我们按照戳气球的顺序思考,问题会变得复杂,因为戳破一个气球会改变相邻气球的邻居关系。但如果我们反向思考——哪个气球是最后被戳破的,问题就变得清晰了。

 dp[i][j] = 戳破区间 (i, j) 内所有气球能获得的最大硬币数
注意是开区间不包括 i  j

对于区间 (i, j)假设最后戳破的是气球 k (i < k < j)

- 此时只剩下气球 i, k, j
- 戳破 k 获得 nums[i] * nums[k] * nums[j] 枚硬币
- 加上之前戳破 (i, k)  (k, j) 内气球的收益

状态转移方程
dp[i][j] = max(dp[i][k] + dp[k][j] + nums[i] * nums[k] * nums[j])
           for all k in (i, j)

为了处理边界,我们在数组两端添加虚拟气球,值为1:

def maxCoins(nums):
    # 添加虚拟气球
    nums = [1] + nums + [1]
    n = len(nums)
    dp = [[0] * n for _ in range(n)]

    # 枚举区间长度
    for length in range(3, n + 1):  # 至少需要3个点(包括两端)
        for i in range(n - length + 1):
            j = i + length - 1
            # 枚举分割点
            for k in range(i + 1, j):
                dp[i][j] = max(dp[i][j], 
                             dp[i][k] + dp[k][j] + nums[i] * nums[k] * nums[j])

    return dp[0][n - 1]

矩阵链乘法与多边形三角剖分

多边形三角剖分问题 (#1039) 本质上是矩阵链乘法的变体。给定一个凸多边形的顶点值,将其三角剖分,使得所有三角形的权值之和最小。

这个问题的结构与戳气球类似:

  • 对于多边形的边 (i, j),枚举与其构成三角形的第三个顶点 k
  • 三角形 (i, k, j) 将多边形分割为三部分:三角形本身、多边形 (i, k)、多边形 (k, j)
def minScoreTriangulation(values):
    n = len(values)
    dp = [[0] * n for _ in range(n)]

    # 枚举区间长度
    for length in range(3, n + 1):
        for i in range(n - length + 1):
            j = i + length - 1
            dp[i][j] = float('inf')
            # 枚举分割点
            for k in range(i + 1, j):
                score = values[i] * values[k] * values[j]
                dp[i][j] = min(dp[i][j], dp[i][k] + dp[k][j] + score)

    return dp[0][n - 1]

区间DP的实现技巧

  1. 枚举顺序:区间DP通常按区间长度从小到大枚举,确保计算当前区间时,所有子区间已经计算完毕。

  2. 边界处理: - 长度为1或2的区间通常作为base case - 有时需要添加虚拟元素简化边界处理

  3. 空间优化:区间DP通常需要O(n²)空间,某些问题可以通过滚动数组优化,但会增加实现复杂度。

  4. 记忆化搜索 vs 迭代:区间DP可以用记忆化搜索实现,代码更直观:

def solve_interval_dp_memo(arr):
    n = len(arr)
    memo = {}

    def dp(i, j):
        if j - i <= 1:  # base case
            return 0
        if (i, j) in memo:
            return memo[(i, j)]

        result = float('inf')  # or -float('inf') for max
        for k in range(i + 1, j):
            result = min(result, dp(i, k) + dp(k, j) + cost(i, k, j))

        memo[(i, j)] = result
        return result

    return dp(0, n - 1)

11.2 数位DP:统计特定范围内的数

数位DP是一种处理"统计满足特定条件的数字个数"问题的技巧。典型问题包括:统计 [L, R] 范围内包含特定数字的个数、数位和为特定值的个数等。其核心思想是按位构造数字,使用记忆化搜索避免重复计算。

数位DP的基本框架

数位DP的通用模板基于深度优先搜索,逐位构造数字:

def digit_dp(n):
    # 将数字转换为数位数组
    digits = []
    while n:
        digits.append(n % 10)
        n //= 10
    digits.reverse()

    memo = {}

    def dfs(pos, state, tight, lead):
        """
        pos: 当前处理的数位位置
        state: 问题相关的状态(如数位和、是否包含某数字等)
        tight: 是否贴着上界(true表示前面的数位都等于上界对应数位)
        lead: 是否有前导零
        """
        if pos == len(digits):
            # 到达末尾,根据state判断是否满足条件
            return 1 if check_valid(state) else 0

        if not tight and not lead and (pos, state) in memo:
            return memo[(pos, state)]

        # 确定当前位可以填的数字范围
        limit = digits[pos] if tight else 9
        result = 0

        for digit in range(0, limit + 1):
            # 跳过前导零的情况(可选)
            if lead and digit == 0:
                result += dfs(pos + 1, state, False, True)
            else:
                new_state = update_state(state, digit)
                new_tight = tight and (digit == limit)
                result += dfs(pos + 1, new_state, new_tight, False)

        if not tight and not lead:
            memo[(pos, state)] = result
        return result

    return dfs(0, initial_state, True, True)

数字1的个数 (#233)

计算 1 到 n 中数字 1 出现的总次数。这是数位DP的经典应用:

def countDigitOne(n):
    if n <= 0:
        return 0

    digits = []
    while n:
        digits.append(n % 10)
        n //= 10
    digits.reverse()

    memo = {}

    def dfs(pos, count, tight):
        """
        pos: 当前数位位置
        count: 已经出现的1的个数
        tight: 是否贴着上界
        """
        if pos == len(digits):
            return count

        if not tight and (pos, count) in memo:
            return memo[(pos, count)]

        limit = digits[pos] if tight else 9
        result = 0

        for digit in range(0, limit + 1):
            # 如果当前位是1,增加计数
            new_count = count + (1 if digit == 1 else 0)
            new_tight = tight and (digit == limit)
            result += dfs(pos + 1, new_count, new_tight)

        if not tight:
            memo[(pos, count)] = result
        return result

    return dfs(0, 0, True)

但这个问题有更巧妙的数学解法,通过分析每一位上1出现的规律:

def countDigitOne_math(n):
    count = 0
    i = 1  # 从个位开始

    while i <= n:
        # 将n分为高位、当前位、低位三部分
        high = n // (i * 10)
        cur = (n // i) % 10
        low = n % i

        if cur == 0:
            # 当前位为0,1的个数由高位决定
            count += high * i
        elif cur == 1:
            # 当前位为1,需要加上低位部分
            count += high * i + low + 1
        else:
            # 当前位大于1
            count += (high + 1) * i

        i *= 10

    return count

最大为N的数字组合 (#902)

给定数字数组 digits 和上界 n,返回可以用 digits 中的数字组成的小于等于 n 的正整数个数。

def atMostNGivenDigitSet(digits, n):
    s = str(n)
    k = len(s)
    dp = [0] * k + [1]  # dp[i] = 从第i位开始能组成的数字个数

    # 从后往前处理每一位
    for i in range(k - 1, -1, -1):
        # 对于第i位,尝试填入每个可用数字
        for d in digits:
            if d < s[i]:
                # 如果小于当前位,后面可以任意填
                dp[i] += len(digits) ** (k - i - 1)
            elif d == s[i]:
                # 如果等于当前位,需要考虑后续限制
                dp[i] += dp[i + 1]
            # 如果大于当前位,不能使用(会超过n)

    # 加上位数更少的所有可能
    for i in range(1, k):
        dp[0] += len(digits) ** i

    return dp[0]

数位DP的优化技巧

  1. 状态设计:只记录必要的信息,避免状态爆炸。例如,如果只关心数位和模3的余数,状态只需要记录0、1、2三种情况。

  2. 前导零处理:某些问题需要特殊处理前导零(如统计不含0的数),需要额外的lead标记。

  3. tight优化:当tight=False时,后续所有状态都不再受限,可以直接使用预计算的结果。

  4. 对称性利用:对于统计 [L, R] 范围的问题,通常转化为 solve(R) - solve(L-1)。

11.3 状态压缩:用二进制表示集合

状态压缩DP利用二进制数表示集合,将指数级的集合状态压缩到整数中。这种技巧适用于元素数量较少(通常不超过20)但需要枚举所有子集的问题。通过位运算,我们可以高效地进行集合操作。

位运算基础与集合表示

在状态压缩中,我们用二进制数的每一位表示集合中某个元素是否存在:

# 集合操作的位运算实现
def bit_operations():
    # 假设有n个元素,用n位二进制表示集合
    n = 5

    # 空集
    empty_set = 0

    # 全集
    full_set = (1 << n) - 1  # 11111

    # 单元素集合 {i}
    singleton = 1 << i

    # 添加元素i到集合s
    s_with_i = s | (1 << i)

    # 从集合s删除元素i
    s_without_i = s & ~(1 << i)

    # 检查元素i是否在集合s中
    contains_i = (s >> i) & 1

    # 集合的并、交、差
    union = s1 | s2
    intersection = s1 & s2
    difference = s1 & ~s2

    # 枚举集合s的所有子集
    subset = s
    while subset:
        # 处理子集subset
        process(subset)
        subset = (subset - 1) & s
    # 别忘了空集
    process(0)

    # 计算集合大小(元素个数)
    size = bin(s).count('1')
    # 或使用Brian Kernighan算法
    count = 0
    temp = s
    while temp:
        temp &= temp - 1
        count += 1

我能赢吗 (#464)

两个玩家轮流从1到maxChoosableInteger中选择一个数(不能重复选择),谁先使得累计和达到desiredTotal谁就赢。判断先手是否必胜。

这是典型的博弈论+状态压缩问题:

def canIWin(maxChoosableInteger, desiredTotal):
    # 特殊情况处理
    if desiredTotal <= 0:
        return True

    total_sum = (1 + maxChoosableInteger) * maxChoosableInteger // 2
    if total_sum < desiredTotal:
        return False

    memo = {}

    def dfs(used, current_total):
        """
        used: 位掩码表示已使用的数字
        current_total: 当前累计和
        返回: 当前玩家是否必胜
        """
        if used in memo:
            return memo[used]

        # 尝试选择每个未使用的数字
        for i in range(maxChoosableInteger):
            if used & (1 << i):  # 数字i+1已使用
                continue

            # 选择数字i+1
            if current_total + i + 1 >= desiredTotal:
                # 当前玩家直接获胜
                memo[used] = True
                return True

        # 检查是否存在一种选择使得对手必败
        for i in range(maxChoosableInteger):
            if used & (1 << i):
                continue

            # 选择数字i+1后,对手的状态
            new_used = used | (1 << i)
            new_total = current_total + i + 1

            if not dfs(new_used, new_total):
                # 对手必败,则当前玩家必胜
                memo[used] = True
                return True

        # 所有选择都导致对手必胜,当前玩家必败
        memo[used] = False
        return False

    return dfs(0, 0)

访问所有节点的最短路径 (#847)

给定一个无向连通图,求访问所有节点的最短路径长度(可以重复访问节点和边)。

这是旅行商问题(TSP)的变体,使用状态压缩DP+BFS:

def shortestPathLength(graph):
    n = len(graph)

    # 特殊情况
    if n == 1:
        return 0

    # 状态:(当前节点, 已访问节点的位掩码)
    # dp[node][mask] = 到达node且已访问mask中节点的最短路径

    # 使用BFS找最短路径
    from collections import deque

    # 初始状态:从每个节点开始
    queue = deque()
    visited = set()

    for i in range(n):
        mask = 1 << i
        queue.append((i, mask, 0))  # (节点, 访问状态, 步数)
        visited.add((i, mask))

    target_mask = (1 << n) - 1  # 所有节点都访问

    while queue:
        node, mask, dist = queue.popleft()

        if mask == target_mask:
            return dist

        # 访问所有邻居
        for neighbor in graph[node]:
            new_mask = mask | (1 << neighbor)

            if (neighbor, new_mask) not in visited:
                visited.add((neighbor, new_mask))
                queue.append((neighbor, new_mask, dist + 1))

    return -1  # 不应该到达这里

也可以用Floyd-Warshall预处理+状态压缩DP:

def shortestPathLength_dp(graph):
    n = len(graph)

    # Floyd-Warshall计算任意两点最短路
    dist = [[float('inf')] * n for _ in range(n)]

    for i in range(n):
        dist[i][i] = 0
        for j in graph[i]:
            dist[i][j] = 1

    for k in range(n):
        for i in range(n):
            for j in range(n):
                dist[i][j] = min(dist[i][j], dist[i][k] + dist[k][j])

    # dp[mask][i] = 访问mask中的节点,最后停在i的最短路径
    dp = [[float('inf')] * n for _ in range(1 << n)]

    # 初始化:从每个节点开始
    for i in range(n):
        dp[1 << i][i] = 0

    # 状态转移
    for mask in range(1 << n):
        for last in range(n):
            if not (mask & (1 << last)):
                continue
            if dp[mask][last] == float('inf'):
                continue

            # 尝试访问下一个节点
            for next_node in range(n):
                if mask & (1 << next_node):
                    continue

                new_mask = mask | (1 << next_node)
                dp[new_mask][next_node] = min(
                    dp[new_mask][next_node],
                    dp[mask][last] + dist[last][next_node]
                )

    # 找到访问所有节点的最短路径
    full_mask = (1 << n) - 1
    return min(dp[full_mask])

状态压缩的空间与时间权衡

状态压缩DP的复杂度分析:

  • 时间复杂度:通常为 O(2^n × n × f),其中f是转移的复杂度
  • 空间复杂度:O(2^n × s),其中s是每个状态需要记录的信息

优化技巧:

  1. 子集枚举优化:枚举集合s的所有子集的复杂度是O(3^n)而不是O(4^n)

  2. 滚动数组:如果只依赖上一层状态,可以用滚动数组降低空间复杂度

  3. 剪枝:很多状态实际上不可达,可以通过BFS或其他方式只计算可达状态

11.4 概率DP与期望DP

概率DP处理包含随机因素的问题,计算某个事件发生的概率或某个随机变量的期望值。关键是正确定义状态和理解概率的转移关系。

概率DP的状态定义与转移

概率DP的核心原则:

  1. 状态通常包含"位置"和"步数"等信息
  2. 转移基于概率的加法和乘法原理
  3. 期望值满足线性性质:E[X + Y] = E[X] + E[Y]

骑士在棋盘上的概率 (#688)

在n×n的棋盘上,骑士从(row, column)开始,走K步后仍在棋盘上的概率。

def knightProbability(n, k, row, column):
    # 骑士的8个移动方向
    moves = [(2, 1), (2, -1), (-2, 1), (-2, -1),
             (1, 2), (1, -2), (-1, 2), (-1, -2)]

    # dp[step][i][j] = 走step步后在位置(i,j)的概率
    dp = [[[0] * n for _ in range(n)] for _ in range(k + 1)]
    dp[0][row][column] = 1  # 初始位置概率为1

    for step in range(k):
        for i in range(n):
            for j in range(n):
                if dp[step][i][j] == 0:
                    continue

                # 从(i,j)出发,等概率地走向8个方向
                for di, dj in moves:
                    ni, nj = i + di, j + dj
                    if 0 <= ni < n and 0 <= nj < n:
                        # 每个方向的概率是1/8
                        dp[step + 1][ni][nj] += dp[step][i][j] / 8

    # 第K步后仍在棋盘上的总概率
    result = 0
    for i in range(n):
        for j in range(n):
            result += dp[k][i][j]

    return result

优化:使用滚动数组降低空间复杂度:

def knightProbability_optimized(n, k, row, column):
    moves = [(2, 1), (2, -1), (-2, 1), (-2, -1),
             (1, 2), (1, -2), (-1, 2), (-1, -2)]

    # 只需要保存当前步和下一步的概率
    dp = [[0] * n for _ in range(n)]
    dp[row][column] = 1

    for _ in range(k):
        dp_next = [[0] * n for _ in range(n)]
        for i in range(n):
            for j in range(n):
                if dp[i][j] == 0:
                    continue
                for di, dj in moves:
                    ni, nj = i + di, j + dj
                    if 0 <= ni < n and 0 <= nj < n:
                        dp_next[ni][nj] += dp[i][j] / 8
        dp = dp_next

    return sum(sum(row) for row in dp)

新21点 (#837)

Alice从0分开始,每次随机获得[1, maxPts]范围内的分数。当分数≥k时停止抽取。求最终分数≤n的概率。

这题的关键是理解游戏规则和状态转移:

def new21Game(n, k, maxPts):
    if k == 0 or n >= k + maxPts:
        return 1.0

    # dp[i] = 达到分数i的概率
    dp = [0.0] * (n + 1)
    dp[0] = 1.0

    # window_sum = 当前可以转移到i的所有概率之和
    window_sum = 1.0  # 初始只有dp[0] = 1
    result = 0.0

    for i in range(1, n + 1):
        # 从[max(0, i-maxPts), min(i-1, k-1)]转移而来
        dp[i] = window_sum / maxPts

        if i < k:
            # 还没到k,可以继续抽
            window_sum += dp[i]
        else:
            # 到达或超过k,停止抽取,计入结果
            result += dp[i]

        # 维护滑动窗口
        if i >= maxPts:
            window_sum -= dp[i - maxPts]

    return result

期望DP的计算技巧

期望值的计算通常有两种方式:

  1. 正向计算:从初始状态出发,计算到达各个状态的期望步数
def expected_steps_forward():
    # E[steps to state] = sum(P(prev) * (E[steps to prev] + 1))
    # 其中P(prev)是从prev转移到当前状态的概率
    pass
  1. 逆向计算:从目标状态反推,计算从各个状态到达目标的期望步数
def expected_steps_backward():
    # E[steps from state] = 1 + sum(P(next) * E[steps from next])
    # 其中P(next)是从当前状态转移到next的概率

    # 例:掷骰子到达目标的期望步数
    def dice_expected_steps(target):
        # E[i] = 从位置i到达target的期望步数
        E = [0] * (target + 1)

        for i in range(target - 1, -1, -1):
            E[i] = 1  # 至少需要掷一次
            for dice in range(1, 7):  # 骰子1-6
                if i + dice <= target:
                    E[i] += E[i + dice] / 6
                else:
                    # 超过目标,需要重来
                    E[i] += E[i] / 6

        return E[0]

概率DP的常见模式

  1. 马尔可夫链:状态转移只依赖当前状态,不依赖历史

  2. 吸收态:一旦到达就不再离开的状态(如游戏结束)

  3. 条件概率:P(A|B) = P(A∩B) / P(B)

  4. 全概率公式:P(A) = Σ P(A|Bi) × P(Bi)

精度处理注意事项:

  • 使用浮点数时注意精度损失
  • 某些题目可以用分数或保持分子分母分别计算
  • 避免除以很小的数导致数值不稳定

本章小结

本章介绍了四种高级动态规划技巧,每种都有其独特的应用场景和思维模式:

关键概念总结

  1. 区间DP - 核心思想:枚举分割点,将大区间分解为子区间 - 状态定义:dp[i][j] 表示区间 [i, j] 的最优解 - 转移方程:dp[i][j] = opt(dp[i][k] + dp[k][j] + cost(i,k,j)) - 实现要点:按区间长度从小到大计算

  2. 数位DP - 核心思想:按位构造数字,记忆化搜索避免重复 - 关键参数:位置pos、状态state、上界限制tight、前导零lead - 通用框架:DFS + 记忆化 - 优化技巧:利用数学规律直接计算

  3. 状态压缩DP - 核心思想:用二进制表示集合,压缩状态空间 - 位运算技巧:集合的增删查、子集枚举 - 复杂度:时间 O(2^n × n × f),空间 O(2^n × s) - 适用范围:元素数量少但需要枚举所有子集

  4. 概率/期望DP - 核心思想:状态包含概率信息,转移遵循概率规则 - 期望计算:正向累加或逆向递推 - 关键公式:全概率公式、条件概率、期望的线性性 - 注意事项:精度控制、吸收态处理

关键公式汇总

区间DP
dp[i][j] = max/min{dp[i][k] + dp[k][j] + merge_cost(i,k,j)}
         for k in (i, j)

数位DP
count(n) = dfs(0, init_state, tight=True, lead=True)
其中 dfs(pos, state, tight, lead) 递归构造每一位

状态压缩
对于集合S子集枚举
for subset = S; subset > 0; subset = (subset - 1) & S

概率DP
P(state) = Σ P(prev_state) × P(transition)
E[state] = 1 + Σ P(next) × E[next]

选择策略

  • 数据规模提示
  • n ≤ 20:考虑状态压缩
  • n ≤ 300且涉及区间:考虑区间DP
  • 数字范围问题:考虑数位DP
  • 包含随机因素:考虑概率DP

  • 问题特征识别

  • "最后一个操作"思维 → 区间DP
  • "统计满足条件的数" → 数位DP
  • "选择子集"且规模小 → 状态压缩
  • "概率"、"期望" → 概率DP

常见陷阱与错误

区间DP常见错误

  1. 边界处理错误
# 错误:忘记处理长度为1或2的base case
for length in range(3, n + 1):  # 应该从2开始
    ...

# 正确:明确base case
for length in range(2, n + 1):
    if length == 2:
        # 特殊处理
  1. 枚举顺序错误
# 错误:随意枚举可能导致依赖的子问题未计算
for i in range(n):
    for j in range(i + 1, n):  # 错误!

# 正确:按长度枚举确保子问题已解决
for length in range(2, n + 1):
    for i in range(n - length + 1):
        j = i + length - 1
  1. 分割点范围错误
# 错误:包含端点
for k in range(i, j + 1):  # 错误!

# 正确:不包含端点(开区间)
for k in range(i + 1, j):

数位DP常见错误

  1. 前导零处理不当
# 错误:忽略前导零导致计数错误
if digit == 0:
    count += 1  # 000也被计数了

# 正确:使用lead标记
if lead and digit == 0:
    continue  # 跳过前导零
  1. tight标记更新错误
# 错误:tight一旦为False就永远False
new_tight = False  # 错误!

# 正确:只有当前位达到上界且之前也tight才继续tight
new_tight = tight and (digit == limit)
  1. 记忆化条件不完整
# 错误:tight状态下也记忆化
memo[(pos, state)] = result  # 错误!

# 正确:只在非tight状态下记忆化
if not tight and not lead:
    memo[(pos, state)] = result

状态压缩常见错误

  1. 位运算优先级错误
# 错误:忘记位运算优先级低于比较运算符
if s & 1 << i:  # 实际是 s & (1 << i),可能不是预期

# 正确:使用括号明确优先级
if s & (1 << i):
    ...
  1. 子集枚举遗漏空集
# 错误:漏掉空集
subset = s
while subset:
    process(subset)
    subset = (subset - 1) & s  # 漏掉了0

# 正确:单独处理空集
subset = s
while subset:
    process(subset)
    subset = (subset - 1) & s
process(0)  # 处理空集
  1. 状态空间爆炸
# 错误:n太大导致内存溢出
dp = [[0] * m for _ in range(1 << 25)]  # 2^25太大!

# 正确:评估状态空间,必要时用字典
if n <= 20:
    dp = {}  # 使用字典只存储访问过的状态

概率DP常见错误

  1. 概率计算错误
# 错误:概率相加超过1
for next_state in possible_states:
    prob[next_state] += prob[current]  # 可能超过1

# 正确:确保转移概率和为1
for next_state in possible_states:
    prob[next_state] += prob[current] * transition_prob
  1. 精度损失
# 错误:连续除法导致精度损失
result = 1
for i in range(n):
    result = result * p / q  # 精度逐渐损失

# 正确:延迟除法或使用分数
from fractions import Fraction
result = Fraction(1, 1)
  1. 期望计算方向错误
# 错误:正向计算期望时忘记累加
E[state] = 1 + P(next) * E[next]  # 应该是累加

# 正确:累加所有可能的转移
E[state] = 1 + sum(P(next) * E[next] for next in next_states)

调试技巧

  1. 打印中间状态:对于DP问题,打印dp数组帮助理解状态转移
  2. 验证边界条件:手动计算小规模输入,验证base case
  3. 检查转移完整性:确保所有可能的转移都被考虑
  4. 使用断言:在关键位置添加断言验证不变量
  5. 递归改迭代:记忆化搜索更容易调试,调通后再改为迭代