数学重学 - 12 递归与数学归纳法

这是 数学重学路线图 阶段二的子页面

📌 递归是编程的灵魂之一,数学归纳法是递归的数学双胞胎。理解了这对关系,分治、动态规划、树遍历都是自然延伸。


一、递归直觉:从生活到代码

1.1 三个比喻

俄罗斯套娃:打开一个娃娃,里面还有一个更小的,直到最小的那个(base case)

查字典:查一个词 → 解释里有不认识的词 → 再查 → 直到全认识为止

排队数人数:你不知道自己是第几个,就问前面的人”你是第几个?”,前面的人再问更前面的,直到第一个人说”我是第1个”,然后答案一路回传

1.2 递归三要素

Base Case(基础情况):最小的套娃,不需要再递归,直接返回结果

递归步骤:把大问题拆成更小的同类问题

收敛性:每次递归必须离 base case 更近,否则无限递归 → 栈溢出

1.3 最简单的递归

1
2
3
4
5
6
7
8
9
10
def countdown(n):
"""倒计时"""
if n <= 0: # base case
print("发射!")
return
print(n)
countdown(n - 1) # 递归步骤:n在减小,趋向base case

countdown(5)
# 输出: 5 4 3 2 1 发射!

二、数学归纳法:递归的数学双胞胎

2.1 原理

数学归纳法证明命题P(n)对所有自然数n成立,分两步:

第一步(基底):证明 P(1) 成立 ← 对应递归的 base case

第二步(归纳):假设 P(k) 成立,证明 P(k+1) 也成立 ← 对应递归步骤

直觉:多米诺骨牌。第一块倒了(base case),且每一块能推倒下一块(归纳步骤),则所有骨牌都会倒

2.2 经典例子:证明 1+2+…+n = n(n+1)/2

基底:n=1 时,左边=1,右边=1×2/2=1,成立 ✓

归纳假设:假设 n=k 时成立,即 1+2+…+k = k(k+1)/2

归纳步骤:证明 n=k+1 时也成立

左边 = 1+2+…+k+(k+1) = k(k+1)/2 + (k+1) (用了归纳假设)

= (k+1)(k/2 + 1)

= (k+1)(k+2)/2

右边 = (k+1)((k+1)+1)/2 = (k+1)(k+2)/2

左边 = 右边 ✓

证毕。

2.3 归纳法 vs 递归:对照表

数学归纳法 递归
基底 P(1) base case
假设 P(k) 成立 信任递归调用会返回正确结果
由 P(k) 推出 P(k+1) 用子问题的结果构建当前结果
从小到大证明 从大到小调用,从小到大返回

2.4 用 Python 验证

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def sum_formula(n):
"""公式法"""
return n * (n + 1) // 2

def sum_recursive(n):
"""递归法"""
if n == 1:
return 1
return n + sum_recursive(n - 1)

# 验证两者一致
for i in range(1, 101):
assert sum_formula(i) == sum_recursive(i), f"不一致: n={i}"
print("全部通过!公式和递归结果一致")

三、经典递推关系

3.1 阶乘:n! = n × (n-1)!

定义:n! = n × (n-1) × (n-2) × … × 2 × 1,且 0! = 1

递推:n! = n × (n-1)!,base case: 0! = 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def factorial_recursive(n):
"""递归阶乘"""
if n <= 1:
return 1
return n * factorial_recursive(n - 1)

def factorial_iterative(n):
"""迭代阶乘"""
result = 1
for i in range(2, n + 1):
result *= i
return result

print(factorial_recursive(10)) # 3628800
print(factorial_iterative(10)) # 3628800

# Python内置
import math
print(math.factorial(10)) # 3628800

3.2 斐波那契:F(n) = F(n-1) + F(n-2)

F(0)=0, F(1)=1, F(2)=1, F(3)=2, F(4)=3, F(5)=5, F(6)=8 …

自然界:向日葵螺旋、贝壳螺线、兔子繁殖

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# 朴素递归:O(2^n),极慢
def fib_naive(n):
if n <= 1:
return n
return fib_naive(n-1) + fib_naive(n-2)

# 记忆化递归:O(n),用空间换时间
from functools import lru_cache

@lru_cache(maxsize=None)
def fib_memo(n):
if n <= 1:
return n
return fib_memo(n-1) + fib_memo(n-2)

# 迭代法:O(n)时间,O(1)空间
def fib_iter(n):
if n <= 1:
return n
a, b = 0, 1
for _ in range(2, n+1):
a, b = b, a + b
return b

# 对比耗时
import time

start = time.time()
print(f"fib_naive(35) = {fib_naive(35)}")
print(f"朴素递归耗时: {time.time()-start:.3f}s")

start = time.time()
print(f"fib_memo(35) = {fib_memo(35)}")
print(f"记忆化耗时: {time.time()-start:.6f}s")

start = time.time()
print(f"fib_iter(35) = {fib_iter(35)}")
print(f"迭代耗时: {time.time()-start:.6f}s")

3.3 汉诺塔:T(n) = 2T(n-1) + 1

规则:n个盘子从A柱移到C柱,借助B柱,大盘不能放在小盘上面

解法:把上面n-1个移到B → 把最大的移到C → 把n-1个从B移到C

步数:T(n) = 2^n - 1(指数增长,64个盘子需要 1.8×10^19 步)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def hanoi(n, source="A", target="C", auxiliary="B"):
"""汉诺塔递归解法"""
if n == 1:
print(f" 盘子1: {source}{target}")
return 1

# 1. 把上面n-1个从source移到auxiliary
moves = hanoi(n-1, source, auxiliary, target)
# 2. 把最大的从source移到target
print(f" 盘子{n}: {source}{target}")
moves += 1
# 3. 把n-1个从auxiliary移到target
moves += hanoi(n-1, auxiliary, target, source)

return moves

print("3个盘子的汉诺塔:")
total = hanoi(3)
print(f"总步数: {total}") # 7 = 2^3 - 1

四、递归 vs 迭代

4.1 什么时候用递归

问题本身有递归结构:树遍历、目录遍历、JSON解析

分治问题:归并排序、快速排序

回溯问题:N皇后、组合枚举

代码简洁性比性能更重要时

4.2 什么时候用迭代

简单的线性递推(如斐波那契、阶乘)

性能敏感场景(递归有函数调用开销)

递归深度可能很大时(Python默认限制1000层)

4.3 栈溢出风险

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import sys
print(f"Python默认递归限制: {sys.getrecursionlimit()}") # 1000

# 可以修改,但不推荐设太大
# sys.setrecursionlimit(10000)

# 危险示例(不要运行):
# def infinite():
# return infinite() # RecursionError!

# 安全做法:改写为迭代
def safe_sum(n):
"""用迭代代替递归,避免栈溢出"""
total = 0
for i in range(1, n+1):
total += i
return total

print(safe_sum(1_000_000)) # 没问题

4.4 递归调用栈可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def factorial_trace(n, depth=0):
"""带缩进的递归追踪"""
indent = " " * depth
print(f"{indent}→ factorial({n})")

if n <= 1:
print(f"{indent}← 返回 1")
return 1

result = n * factorial_trace(n - 1, depth + 1)
print(f"{indent}← 返回 {result}")
return result

factorial_trace(5)
# → factorial(5)
# → factorial(4)
# → factorial(3)
# → factorial(2)
# → factorial(1)
# ← 返回 1
# ← 返回 2
# ← 返回 6
# ← 返回 24
# ← 返回 120

五、尾递归

5.1 概念

尾递归:递归调用是函数的最后一个操作,不需要回传再计算

普通递归:return n * f(n-1) → 需要等 f(n-1) 返回后再乘n

尾递归:return f(n-1, acc*n) → 中间结果通过参数传递,不需要回传

5.2 对比示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 普通递归:需要保存每一层的 n
def fact_normal(n):
if n <= 1:
return 1
return n * fact_normal(n - 1) # 等返回后还要乘n

# 尾递归:中间结果通过accumulator传递
def fact_tail(n, acc=1):
if n <= 1:
return acc
return fact_tail(n - 1, acc * n) # 最后一步就是递归调用

print(fact_normal(10)) # 3628800
print(fact_tail(10)) # 3628800

5.3 注意

Python 不支持尾递归优化(CPython不会自动把尾递归变成循环)

所以在Python中尾递归的写法只是思想上的优化,不能真正防止栈溢出

真正需要深递归时,应该手动改写为迭代

1
2
3
4
5
6
7
8
9
10
# 手动将尾递归转为迭代(通用方法)
def fact_loop(n):
acc = 1
while n > 1:
acc *= n
n -= 1
return acc

print(fact_loop(10)) # 3628800
# 这和尾递归的逻辑完全对应,但不会栈溢出

六、分治法:递归的重要应用

6.1 分治三步

分(Divide):把问题分成若干个子问题

治(Conquer):递归解决每个子问题

合(Combine):合并子问题的结果

6.2 归并排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def merge_sort(arr):
"""归并排序:分治经典"""
if len(arr) <= 1: # base case
return arr

# 分
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])

# 合
return merge(left, right)

def merge(left, right):
"""合并两个有序数组"""
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result

data = [38, 27, 43, 3, 9, 82, 10]
print(f"排序前: {data}")
print(f"排序后: {merge_sort(data)}")
# 时间复杂度:O(n log n),空间复杂度:O(n)

6.3 二分查找(迭代版 vs 递归版)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def binary_search_recursive(arr, target, low=0, high=None):
"""递归二分查找"""
if high is None:
high = len(arr) - 1
if low > high:
return -1 # 未找到

mid = (low + high) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
return binary_search_recursive(arr, target, mid+1, high)
else:
return binary_search_recursive(arr, target, low, mid-1)

def binary_search_iterative(arr, target):
"""迭代二分查找(推荐)"""
low, high = 0, len(arr) - 1
while low <= high:
mid = (low + high) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
low = mid + 1
else:
high = mid - 1
return -1

sorted_arr = [2, 5, 8, 12, 16, 23, 38, 56, 72, 91]
print(binary_search_recursive(sorted_arr, 23)) # 5
print(binary_search_iterative(sorted_arr, 23)) # 5

七、安全应用

7.1 递归目录遍历扫描

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os

def scan_directory(path, extensions=None, max_depth=10, depth=0):
"""递归扫描目录,查找指定类型文件"""
if depth > max_depth: # 防止符号链接循环导致无限递归
print(f" [WARN] 达到最大深度: {path}")
return []

results = []
try:
for entry in os.scandir(path):
if entry.is_file():
if extensions is None or entry.name.endswith(tuple(extensions)):
results.append(entry.path)
elif entry.is_dir(follow_symlinks=False): # 不跟踪符号链接
results.extend(
scan_directory(entry.path, extensions, max_depth, depth+1)
)
except PermissionError:
print(f" [WARN] 无权限: {path}")

return results

# 示例:查找所有 .conf 和 .yaml 配置文件
# configs = scan_directory("/etc", extensions=[".conf", ".yaml"])
# print(f"找到 {len(configs)} 个配置文件")

7.2 XML炸弹 / JSON嵌套攻击

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# XML炸弹(Billion Laughs Attack)
# 原理:用实体引用递归膨胀
# <?xml version="1.0"?>
# <!DOCTYPE lolz [
# <!ENTITY lol "lol">
# <!ENTITY lol2 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
# <!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">
# ...
# ]>
# 每层10倍膨胀,9层 = 10^9 个"lol" → 内存爆炸

# JSON嵌套攻击:超深嵌套导致解析器栈溢出
import json

def check_json_depth(obj, max_depth=100, current=0):
"""检查JSON嵌套深度,防止攻击"""
if current > max_depth:
raise ValueError(f"JSON嵌套过深: {current} > {max_depth}")

if isinstance(obj, dict):
for v in obj.values():
check_json_depth(v, max_depth, current + 1)
elif isinstance(obj, list):
for item in obj:
check_json_depth(item, max_depth, current + 1)

# 正常JSON
normal = {"a": {"b": {"c": 1}}}
check_json_depth(normal) # OK

# 构造深层嵌套(测试用)
deep = {"a": None}
current = deep
for _ in range(200):
current["a"] = {"a": None}
current = current["a"]

try:
check_json_depth(deep)
except ValueError as e:
print(f"检测到攻击: {e}")

7.3 防御建议

始终设置递归深度限制(max_depth参数)

不跟踪符号链接(follow_symlinks=False)

使用 defusedxml 库解析 XML(防止实体膨胀攻击)

对用户输入的JSON设置大小和深度限制


八、大数据应用

8.1 MapReduce 的分治本质

MapReduce 本质就是分治法:

Map(分):把大数据集分成小块,每块独立处理

Shuffle(重组):按key重新分组

Reduce(合):对每组数据进行聚合

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from collections import defaultdict
from functools import reduce

# 模拟 MapReduce 词频统计

def mapper(text):
"""Map阶段:每个词映射为 (word, 1)"""
return [(word.lower(), 1) for word in text.split()]

def shuffle(mapped_results):
"""Shuffle阶段:按key分组"""
grouped = defaultdict(list)
for results in mapped_results:
for key, value in results:
grouped[key].append(value)
return grouped

def reducer(key, values):
"""Reduce阶段:聚合"""
return (key, sum(values))

# 模拟多个数据分片
texts = [
"hello world hello python",
"python is great hello",
"world python hello great"
]

# 执行 MapReduce
mapped = [mapper(t) for t in texts] # Map
grouped = shuffle(mapped) # Shuffle
result = [reducer(k, v) for k, v in grouped.items()] # Reduce

for word, count in sorted(result, key=lambda x: -x[1]):
print(f" {word}: {count}")
# hello: 4, python: 3, world: 2, great: 2, is: 1

8.2 大数据排序

大文件排序 = 外部归并排序 = 分治法

把大文件切成能放进内存的小块 → 分别排序 → 多路归并

Spark的sortBy底层就是这个思路


九、后端应用

9.1 树形菜单渲染

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def build_tree(items, parent_id=None):
"""从平铺数据构建树形结构(数据库常见场景)"""
tree = []
for item in items:
if item["parent_id"] == parent_id:
children = build_tree(items, item["id"])
node = {**item, "children": children}
tree.append(node)
return tree

# 数据库中的菜单表(平铺)
menus = [
{"id": 1, "name": "系统管理", "parent_id": None},
{"id": 2, "name": "用户管理", "parent_id": 1},
{"id": 3, "name": "角色管理", "parent_id": 1},
{"id": 4, "name": "用户列表", "parent_id": 2},
{"id": 5, "name": "用户详情", "parent_id": 2},
{"id": 6, "name": "日志管理", "parent_id": None},
{"id": 7, "name": "操作日志", "parent_id": 6},
]

tree = build_tree(menus)

def print_tree(nodes, indent=0):
for node in nodes:
print(" " * indent + f"├─ {node['name']}")
if node["children"]:
print_tree(node["children"], indent + 1)

print_tree(tree)
# ├─ 系统管理
# ├─ 用户管理
# ├─ 用户列表
# ├─ 用户详情
# ├─ 角色管理
# ├─ 日志管理
# ├─ 操作日志

9.2 JSON 递归解析(深度遍历)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def find_all_keys(obj, target_key):
"""递归查找JSON中所有匹配的key"""
results = []

if isinstance(obj, dict):
for key, value in obj.items():
if key == target_key:
results.append(value)
results.extend(find_all_keys(value, target_key))
elif isinstance(obj, list):
for item in obj:
results.extend(find_all_keys(item, target_key))

return results

# 示例:从复杂嵌套JSON中提取所有email
data = {
"users": [
{"name": "Alice", "email": "alice@example.com", "contacts": [
{"name": "Bob", "email": "bob@example.com"}
]},
{"name": "Charlie", "email": "charlie@example.com"}
],
"admin": {"email": "admin@example.com"}
}

emails = find_all_keys(data, "email")
print(f"所有邮箱: {emails}")
# ['alice@example.com', 'bob@example.com', 'charlie@example.com', 'admin@example.com']

9.3 递归权限继承

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def get_inherited_permissions(role_tree, role_id, visited=None):
"""递归获取角色的继承权限(含父角色权限)"""
if visited is None:
visited = set()

if role_id in visited: # 防止循环继承
return set()
visited.add(role_id)

role = role_tree.get(role_id)
if not role:
return set()

# 当前角色的权限
perms = set(role.get("permissions", []))

# 加上父角色的权限(递归)
parent_id = role.get("parent_id")
if parent_id:
perms |= get_inherited_permissions(role_tree, parent_id, visited)

return perms

role_tree = {
"super_admin": {"permissions": ["manage_system"], "parent_id": None},
"admin": {"permissions": ["manage_users", "view_logs"], "parent_id": "super_admin"},
"editor": {"permissions": ["edit_content"], "parent_id": "admin"},
}

perms = get_inherited_permissions(role_tree, "editor")
print(f"editor 最终权限: {perms}")
# {'edit_content', 'manage_users', 'view_logs', 'manage_system'}

十、练习题

练习1:递归基础

写一个递归函数,计算 1^2 + 2^2 + … + n^2

提示:base case 是 n=1 返回 1

答案:

1
2
3
4
5
6
7
8
9
def sum_of_squares(n):
if n == 1:
return 1
return n**2 + sum_of_squares(n-1)

print(sum_of_squares(5)) # 1+4+9+16+25 = 55

# 公式法验证:n(n+1)(2n+1)/6
print(5*6*11//6) # 55

练习2:数学归纳法

用数学归纳法证明:1^2 + 2^2 + … + n^2 = n(n+1)(2n+1)/6

答案:

基底:n=1,左边=1,右边=1×2×3/6=1 ✓

归纳假设:假设 n=k 时成立

归纳步骤:

左边 = k(k+1)(2k+1)/6 + (k+1)^2

= (k+1)[k(2k+1)/6 + (k+1)]

= (k+1)[k(2k+1) + 6(k+1)] / 6

= (k+1)(2k^2 + 7k + 6) / 6

= (k+1)(k+2)(2k+3) / 6

右边 = (k+1)((k+1)+1)(2(k+1)+1)/6 = (k+1)(k+2)(2k+3)/6 ✓

练习3:递归 → 迭代改写

把斐波那契的递归版本改写成迭代版本,支持 n=1000000

答案:

1
2
3
4
5
6
7
8
9
10
11
12
def fib_iter(n):
if n <= 1:
return n
a, b = 0, 1
for _ in range(2, n+1):
a, b = b, a + b
return b

# 递归版 n=1000000 会栈溢出
# 迭代版可以轻松处理
result = fib_iter(1_000_000)
print(f"fib(1000000) 有 {len(str(result))} 位数")

练习4:递归扁平化

写一个递归函数,把任意嵌套的列表展平为一维列表

输入:[1, [2, [3, 4], 5], [6, 7]]

输出:[1, 2, 3, 4, 5, 6, 7]

答案:

1
2
3
4
5
6
7
8
9
10
11
def flatten(lst):
result = []
for item in lst:
if isinstance(item, list):
result.extend(flatten(item))
else:
result.append(item)
return result

print(flatten([1, [2, [3, 4], 5], [6, 7]]))
# [1, 2, 3, 4, 5, 6, 7]

练习5:安全实战 - 递归目录大小统计

写一个递归函数统计指定目录的总大小(含所有子目录),设置最大深度为5

答案:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import os

def dir_size(path, max_depth=5, depth=0):
if depth > max_depth:
return 0
total = 0
try:
for entry in os.scandir(path):
if entry.is_file(follow_symlinks=False):
total += entry.stat().st_size
elif entry.is_dir(follow_symlinks=False):
total += dir_size(entry.path, max_depth, depth+1)
except PermissionError:
pass
return total

# size = dir_size("/tmp")
# print(f"/tmp 大小: {size / 1024 / 1024:.2f} MB")

十一、本页小结

递归三要素:base case + 递归步骤 + 收敛性

数学归纳法 = 递归的数学版:基底对应base case,归纳步骤对应递归调用

经典递推:阶乘O(n)、斐波那契(朴素O(2^n)→记忆化O(n))、汉诺塔O(2^n)

实战中优先用迭代,递归用于树/图/嵌套结构

安全注意:限制递归深度、不跟踪符号链接、防范嵌套攻击

下一篇:13-排列组合与计数


上一章 目录 下一章
11-集合论与数据操作 数学重学路线图 13-排列组合与计数