Python 函数式编程:map/filter/reduce 之外
Python 不是纯函数式语言,但 functools、itertools、偏函数、闭包这些工具用好了能让代码简洁一个量级。
一等函数:函数是对象
# 函数可以赋值给变量
def square(x):
return x ** 2
fn = square
print(fn(5)) # 25
# 函数可以存进列表
operations = [abs, str, type]
for op in operations:
print(op(-42)) # 42 / -42 / <class int>
# 函数可以作为参数传递
def apply(func, value):
return func(value)
print(apply(square, 4)) # 16
# 函数可以作为返回值
def multiplier(n):
def multiply(x):
return x * n
return multiply
double = multiplier(2)
triple = multiplier(3)
print(double(5)) # 10
print(triple(5)) # 15
高阶函数
map / filter
nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# map:惰性,返回迭代器
squares = list(map(lambda x: x**2, nums))
# 等价于推导式(更 Pythonic)
squares = [x**2 for x in nums]
# filter:惰性
evens = list(filter(lambda x: x % 2 == 0, nums))
# 等价于
evens = [x for x in nums if x % 2 == 0]
sorted / max / min 的 key 参数
words = ["banana", "apple", "cherry", "date"]
# 按长度排序
sorted(words, key=len) # ['date', 'apple', 'banana', 'cherry']
# 按最后一个字母排序
sorted(words, key=lambda w: w[-1])
# 多条件排序:先按长度,再按字母
sorted(words, key=lambda w: (len(w), w))
# 对象列表排序
people = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]
sorted(people, key=lambda p: p["age"])
# 用 operator 模块替代 lambda(性能更好)
from operator import itemgetter, attrgetter
sorted(people, key=itemgetter("age"))
oldest = max(people, key=itemgetter("age"))
reduce:折叠操作
from functools import reduce
nums = [1, 2, 3, 4, 5]
# sum / product
total = reduce(lambda acc, x: acc + x, nums, 0) # 15
product = reduce(lambda acc, x: acc * x, nums, 1) # 120
# 用 reduce 展平嵌套列表
nested = [[1,2], [3,4], [5,6]]
flat = reduce(lambda acc, x: acc + x, nested, []) # [1,2,3,4,5,6]
# 大多数 reduce 场景有更好替代
sum(nums) # 用内置 sum
import math; math.prod(nums) # 用 math.prod(Python 3.8+)
[x for sub in nested for x in sub] # 用推导式展平
reduce 适合的场景:自定义折叠逻辑,不能用简单内置函数表达时。
lambda 的边界
# 合理使用 lambda:简单的 key 函数
nums.sort(key=lambda x: -x)
pairs.sort(key=lambda p: p[1])
# 滥用 lambda:逻辑复杂时,给函数一个名字更清晰
def normalize(s: str) -> str:
return s.strip().lower().replace("-", "_")
# lambda 无法包含语句,只能是表达式
# 用三元表达式替代 if-else:
fn = lambda x: x if x > 0 else -x
闭包
基础
def make_counter(start=0):
count = start # 被内层函数"捕获"的自由变量
def counter():
nonlocal count # 声明修改外层变量
count += 1
return count
return counter
c1 = make_counter()
c2 = make_counter(10)
print(c1()) # 1
print(c1()) # 2
print(c2()) # 11(独立的闭包)
Late Binding 陷阱
# 经典 bug:循环中的闭包
fns = [lambda: i for i in range(5)]
print([f() for f in fns]) # [4, 4, 4, 4, 4] 全是 4!
# 原因:lambda 捕获的是变量 i 的引用,不是值
# 所有 lambda 共享同一个 i,循环结束时 i=4
# 修复方法1:默认参数(在定义时绑定值)
fns = [lambda i=i: i for i in range(5)]
print([f() for f in fns]) # [0, 1, 2, 3, 4]
# 修复方法2:工厂函数
def make_fn(i):
return lambda: i
fns = [make_fn(i) for i in range(5)]
闭包作工厂函数
def make_validator(min_val, max_val):
def validate(x):
if not (min_val <= x <= max_val):
raise ValueError(f"{x} not in [{min_val}, {max_val}]")
return x
return validate
check_age = make_validator(0, 150)
check_score = make_validator(0, 100)
装饰器
基础装饰器
import functools
def my_decorator(func):
@functools.wraps(func) # 保留原函数的 __name__、__doc__ 等元数据
def wrapper(*args, **kwargs):
print(f"Calling {func.__name__}")
result = func(*args, **kwargs)
print(f"{func.__name__} returned {result!r}")
return result
return wrapper
@my_decorator
def add(a, b):
return a + b
add(1, 2)
# Calling add
# add returned 3
计时器装饰器
import time, functools
def timer(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = func(*args, **kwargs)
elapsed = time.perf_counter() - start
print(f"{func.__name__} took {elapsed:.4f}s")
return result
return wrapper
@timer
def slow_function():
time.sleep(0.1)
重试装饰器(带参数)
def retry(times=3, exceptions=(Exception,), delay=0.5):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
for attempt in range(1, times + 1):
try:
return func(*args, **kwargs)
except exceptions as e:
if attempt == times:
raise
print(f"Attempt {attempt} failed: {e}. Retrying...")
time.sleep(delay)
return wrapper
return decorator
@retry(times=3, exceptions=(ConnectionError,), delay=1)
def fetch_data(url):
pass
lru_cache:缓存装饰器
from functools import lru_cache
@lru_cache(maxsize=128)
def fibonacci(n):
if n < 2: return n
return fibonacci(n-1) + fibonacci(n-2)
fibonacci(100) # 瞬间完成
fibonacci.cache_info() # CacheInfo(hits=..., misses=..., maxsize=128, currsize=...)
functools 模块
from functools import partial, lru_cache, cache, cached_property
# partial:偏函数(固定部分参数)
def power(base, exp):
return base ** exp
square = partial(power, exp=2)
cube = partial(power, exp=3)
print(square(5)) # 25
print(cube(3)) # 27
# 实际场景:固定序列化参数
import json
dumps_compact = partial(json.dumps, separators=(",", ":"), ensure_ascii=False)
dumps_pretty = partial(json.dumps, indent=2, ensure_ascii=False)
# cache(Python 3.9+):无界缓存,等同于 lru_cache(maxsize=None)
@cache
def fib(n):
return n if n < 2 else fib(n-1) + fib(n-2)
# cached_property:类属性的惰性计算缓存
class DataSet:
def __init__(self, data):
self._data = data
@cached_property
def stats(self):
print("Computing stats (expensive)...")
return {"mean": sum(self._data)/len(self._data), "n": len(self._data)}
ds = DataSet(list(range(1000)))
print(ds.stats) # 计算一次
print(ds.stats) # 直接从缓存返回
itertools:迭代器工具箱
import itertools
# chain:连接多个迭代器
list(itertools.chain([1,2], [3,4], [5])) # [1,2,3,4,5]
list(itertools.chain.from_iterable([[1,2],[3]])) # [1,2,3]
# islice:切片(适用于任意迭代器)
list(itertools.islice(range(100), 5, 10)) # [5, 6, 7, 8, 9]
# groupby:按 key 分组(需要先排序)
data = [("a", 1), ("b", 2), ("a", 3), ("b", 4)]
data.sort(key=lambda x: x[0])
for key, group in itertools.groupby(data, key=lambda x: x[0]):
print(key, list(group))
# product:笛卡尔积
list(itertools.product([1,2], ["a","b"]))
# [(1,'a'), (1,'b'), (2,'a'), (2,'b')]
# combinations 和 permutations
list(itertools.combinations([1,2,3], 2)) # [(1,2),(1,3),(2,3)]
list(itertools.permutations([1,2,3], 2)) # 6个结果
# accumulate:累积运算
import operator
list(itertools.accumulate([1,2,3,4,5])) # [1,3,6,10,15]
list(itertools.accumulate([1,2,3,4,5], operator.mul)) # [1,2,6,24,120]
list(itertools.accumulate([3,1,4,1,5,9], max)) # [3,3,4,4,5,9]
# takewhile / dropwhile
list(itertools.takewhile(lambda x: x < 5, [1,2,3,5,6])) # [1,2,3]
list(itertools.dropwhile(lambda x: x < 5, [1,2,3,5,6])) # [5,6]
operator 模块:替代简单 lambda
from operator import itemgetter, attrgetter, methodcaller
# itemgetter:替代 lambda x: x[key]
data = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]
sorted(data, key=itemgetter("age"))
# 多 key 排序
sorted(data, key=itemgetter("name", "age"))
# attrgetter:替代 lambda x: x.attr
from datetime import datetime
dates = [datetime(2024,1,3), datetime(2024,1,1), datetime(2024,1,2)]
sorted(dates, key=attrgetter("day"))
# methodcaller:替代 lambda x: x.method(args)
from operator import methodcaller
words = ["hello", "WORLD", "Python"]
list(map(methodcaller("lower"), words)) # ['hello', 'world', 'python']
生成器与 yield from
# 生成器函数:遇到 yield 暂停,保留状态
def fibonacci():
a, b = 0, 1
while True:
yield a
a, b = b, a + b
gen = fibonacci()
print([next(gen) for _ in range(10)]) # [0,1,1,2,3,5,8,13,21,34]
# yield from:委托给子生成器
def flatten(nested):
for item in nested:
if isinstance(item, list):
yield from flatten(item) # 递归委托
else:
yield item
list(flatten([1, [2, [3, 4]], [5, 6]])) # [1, 2, 3, 4, 5, 6]
# 生成器 vs 列表:内存对比
import sys
lst = [x**2 for x in range(1000000)]
gen = (x**2 for x in range(1000000))
print(sys.getsizeof(lst)) # ~8 MB
print(sys.getsizeof(gen)) # ~104 bytes!
# 惰性管道
def read_lines(path):
with open(path) as f:
yield from f
def grep(pattern, lines):
import re
return (line for line in lines if re.search(pattern, line))
def trim(lines):
return (line.strip() for line in lines)
# pipeline = trim(grep(r"ERROR", read_lines("/var/log/app.log")))
# 整个管道是惰性的,不会一次性读入内存
不可变数据原则
# 避免副作用:不修改传入的参数
# 有副作用(bad)
def add_item_bad(lst, item):
lst.append(item) # 修改了外部传入的列表!
return lst
# 无副作用(good)
def add_item_good(lst, item):
return lst + [item] # 返回新列表
# 对于复杂对象,用 copy
import copy
def process(data):
data = copy.deepcopy(data)
data["processed"] = True
return data
# 函数式管道
from functools import reduce
def pipeline(*funcs):
return reduce(lambda f, g: lambda x: g(f(x)), funcs)
normalize = pipeline(
str.strip,
str.lower,
lambda s: s.replace("-", "_"),
lambda s: s.replace(" ", "_"),
)
print(normalize(" Hello World ")) # hello_world