Python 函数式编程:map/filter/reduce 之外

Python 不是纯函数式语言,但 functools、itertools、偏函数、闭包这些工具用好了能让代码简洁一个量级。

$1.9k 字/约 10 min👁— views

Python 函数式编程:map/filter/reduce 之外

Python 不是纯函数式语言,但 functoolsitertools、偏函数、闭包这些工具用好了能让代码简洁一个量级。


一等函数:函数是对象

# 函数可以赋值给变量
def square(x):
    return x ** 2

fn = square
print(fn(5))   # 25

# 函数可以存进列表
operations = [abs, str, type]
for op in operations:
    print(op(-42))   # 42 / -42 / <class int>

# 函数可以作为参数传递
def apply(func, value):
    return func(value)

print(apply(square, 4))    # 16

# 函数可以作为返回值
def multiplier(n):
    def multiply(x):
        return x * n
    return multiply

double = multiplier(2)
triple = multiplier(3)
print(double(5))   # 10
print(triple(5))   # 15

高阶函数

map / filter

nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

# map:惰性,返回迭代器
squares = list(map(lambda x: x**2, nums))
# 等价于推导式(更 Pythonic)
squares = [x**2 for x in nums]

# filter:惰性
evens = list(filter(lambda x: x % 2 == 0, nums))
# 等价于
evens = [x for x in nums if x % 2 == 0]

sorted / max / min 的 key 参数

words = ["banana", "apple", "cherry", "date"]

# 按长度排序
sorted(words, key=len)              # ['date', 'apple', 'banana', 'cherry']

# 按最后一个字母排序
sorted(words, key=lambda w: w[-1])

# 多条件排序:先按长度,再按字母
sorted(words, key=lambda w: (len(w), w))

# 对象列表排序
people = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]
sorted(people, key=lambda p: p["age"])

# 用 operator 模块替代 lambda(性能更好)
from operator import itemgetter, attrgetter
sorted(people, key=itemgetter("age"))

oldest = max(people, key=itemgetter("age"))

reduce:折叠操作

from functools import reduce

nums = [1, 2, 3, 4, 5]

# sum / product
total = reduce(lambda acc, x: acc + x, nums, 0)   # 15
product = reduce(lambda acc, x: acc * x, nums, 1)  # 120

# 用 reduce 展平嵌套列表
nested = [[1,2], [3,4], [5,6]]
flat = reduce(lambda acc, x: acc + x, nested, [])  # [1,2,3,4,5,6]

# 大多数 reduce 场景有更好替代
sum(nums)                            # 用内置 sum
import math; math.prod(nums)        # 用 math.prod(Python 3.8+)
[x for sub in nested for x in sub]  # 用推导式展平

reduce 适合的场景:自定义折叠逻辑,不能用简单内置函数表达时。


lambda 的边界

# 合理使用 lambda:简单的 key 函数
nums.sort(key=lambda x: -x)
pairs.sort(key=lambda p: p[1])

# 滥用 lambda:逻辑复杂时,给函数一个名字更清晰
def normalize(s: str) -> str:
    return s.strip().lower().replace("-", "_")

# lambda 无法包含语句,只能是表达式
# 用三元表达式替代 if-else:
fn = lambda x: x if x > 0 else -x

闭包

基础

def make_counter(start=0):
    count = start  # 被内层函数"捕获"的自由变量

    def counter():
        nonlocal count   # 声明修改外层变量
        count += 1
        return count

    return counter

c1 = make_counter()
c2 = make_counter(10)

print(c1())  # 1
print(c1())  # 2
print(c2())  # 11(独立的闭包)

Late Binding 陷阱

# 经典 bug:循环中的闭包
fns = [lambda: i for i in range(5)]
print([f() for f in fns])   # [4, 4, 4, 4, 4]  全是 4!

# 原因:lambda 捕获的是变量 i 的引用,不是值
# 所有 lambda 共享同一个 i,循环结束时 i=4

# 修复方法1:默认参数(在定义时绑定值)
fns = [lambda i=i: i for i in range(5)]
print([f() for f in fns])   # [0, 1, 2, 3, 4]

# 修复方法2:工厂函数
def make_fn(i):
    return lambda: i

fns = [make_fn(i) for i in range(5)]

闭包作工厂函数

def make_validator(min_val, max_val):
    def validate(x):
        if not (min_val <= x <= max_val):
            raise ValueError(f"{x} not in [{min_val}, {max_val}]")
        return x
    return validate

check_age = make_validator(0, 150)
check_score = make_validator(0, 100)

装饰器

基础装饰器

import functools

def my_decorator(func):
    @functools.wraps(func)   # 保留原函数的 __name__、__doc__ 等元数据
    def wrapper(*args, **kwargs):
        print(f"Calling {func.__name__}")
        result = func(*args, **kwargs)
        print(f"{func.__name__} returned {result!r}")
        return result
    return wrapper

@my_decorator
def add(a, b):
    return a + b

add(1, 2)
# Calling add
# add returned 3

计时器装饰器

import time, functools

def timer(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        elapsed = time.perf_counter() - start
        print(f"{func.__name__} took {elapsed:.4f}s")
        return result
    return wrapper

@timer
def slow_function():
    time.sleep(0.1)

重试装饰器(带参数)

def retry(times=3, exceptions=(Exception,), delay=0.5):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            for attempt in range(1, times + 1):
                try:
                    return func(*args, **kwargs)
                except exceptions as e:
                    if attempt == times:
                        raise
                    print(f"Attempt {attempt} failed: {e}. Retrying...")
                    time.sleep(delay)
        return wrapper
    return decorator

@retry(times=3, exceptions=(ConnectionError,), delay=1)
def fetch_data(url):
    pass

lru_cache:缓存装饰器

from functools import lru_cache

@lru_cache(maxsize=128)
def fibonacci(n):
    if n < 2: return n
    return fibonacci(n-1) + fibonacci(n-2)

fibonacci(100)   # 瞬间完成
fibonacci.cache_info()  # CacheInfo(hits=..., misses=..., maxsize=128, currsize=...)

functools 模块

from functools import partial, lru_cache, cache, cached_property

# partial:偏函数(固定部分参数)
def power(base, exp):
    return base ** exp

square = partial(power, exp=2)
cube = partial(power, exp=3)
print(square(5))  # 25
print(cube(3))    # 27

# 实际场景:固定序列化参数
import json
dumps_compact = partial(json.dumps, separators=(",", ":"), ensure_ascii=False)
dumps_pretty = partial(json.dumps, indent=2, ensure_ascii=False)

# cache(Python 3.9+):无界缓存,等同于 lru_cache(maxsize=None)
@cache
def fib(n):
    return n if n < 2 else fib(n-1) + fib(n-2)

# cached_property:类属性的惰性计算缓存
class DataSet:
    def __init__(self, data):
        self._data = data

    @cached_property
    def stats(self):
        print("Computing stats (expensive)...")
        return {"mean": sum(self._data)/len(self._data), "n": len(self._data)}

ds = DataSet(list(range(1000)))
print(ds.stats)  # 计算一次
print(ds.stats)  # 直接从缓存返回

itertools:迭代器工具箱

import itertools

# chain:连接多个迭代器
list(itertools.chain([1,2], [3,4], [5]))        # [1,2,3,4,5]
list(itertools.chain.from_iterable([[1,2],[3]])) # [1,2,3]

# islice:切片(适用于任意迭代器)
list(itertools.islice(range(100), 5, 10))  # [5, 6, 7, 8, 9]

# groupby:按 key 分组(需要先排序)
data = [("a", 1), ("b", 2), ("a", 3), ("b", 4)]
data.sort(key=lambda x: x[0])
for key, group in itertools.groupby(data, key=lambda x: x[0]):
    print(key, list(group))

# product:笛卡尔积
list(itertools.product([1,2], ["a","b"]))
# [(1,'a'), (1,'b'), (2,'a'), (2,'b')]

# combinations 和 permutations
list(itertools.combinations([1,2,3], 2))   # [(1,2),(1,3),(2,3)]
list(itertools.permutations([1,2,3], 2))   # 6个结果

# accumulate:累积运算
import operator
list(itertools.accumulate([1,2,3,4,5]))               # [1,3,6,10,15]
list(itertools.accumulate([1,2,3,4,5], operator.mul)) # [1,2,6,24,120]
list(itertools.accumulate([3,1,4,1,5,9], max))        # [3,3,4,4,5,9]

# takewhile / dropwhile
list(itertools.takewhile(lambda x: x < 5, [1,2,3,5,6]))  # [1,2,3]
list(itertools.dropwhile(lambda x: x < 5, [1,2,3,5,6]))  # [5,6]

operator 模块:替代简单 lambda

from operator import itemgetter, attrgetter, methodcaller

# itemgetter:替代 lambda x: x[key]
data = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]
sorted(data, key=itemgetter("age"))

# 多 key 排序
sorted(data, key=itemgetter("name", "age"))

# attrgetter:替代 lambda x: x.attr
from datetime import datetime
dates = [datetime(2024,1,3), datetime(2024,1,1), datetime(2024,1,2)]
sorted(dates, key=attrgetter("day"))

# methodcaller:替代 lambda x: x.method(args)
from operator import methodcaller
words = ["hello", "WORLD", "Python"]
list(map(methodcaller("lower"), words))   # ['hello', 'world', 'python']

生成器与 yield from

# 生成器函数:遇到 yield 暂停,保留状态
def fibonacci():
    a, b = 0, 1
    while True:
        yield a
        a, b = b, a + b

gen = fibonacci()
print([next(gen) for _ in range(10)])  # [0,1,1,2,3,5,8,13,21,34]

# yield from:委托给子生成器
def flatten(nested):
    for item in nested:
        if isinstance(item, list):
            yield from flatten(item)  # 递归委托
        else:
            yield item

list(flatten([1, [2, [3, 4]], [5, 6]]))  # [1, 2, 3, 4, 5, 6]

# 生成器 vs 列表:内存对比
import sys
lst = [x**2 for x in range(1000000)]
gen = (x**2 for x in range(1000000))
print(sys.getsizeof(lst))  # ~8 MB
print(sys.getsizeof(gen))  # ~104 bytes!

# 惰性管道
def read_lines(path):
    with open(path) as f:
        yield from f

def grep(pattern, lines):
    import re
    return (line for line in lines if re.search(pattern, line))

def trim(lines):
    return (line.strip() for line in lines)

# pipeline = trim(grep(r"ERROR", read_lines("/var/log/app.log")))
# 整个管道是惰性的,不会一次性读入内存

不可变数据原则

# 避免副作用:不修改传入的参数

# 有副作用(bad)
def add_item_bad(lst, item):
    lst.append(item)     # 修改了外部传入的列表!
    return lst

# 无副作用(good)
def add_item_good(lst, item):
    return lst + [item]  # 返回新列表

# 对于复杂对象,用 copy
import copy
def process(data):
    data = copy.deepcopy(data)
    data["processed"] = True
    return data

# 函数式管道
from functools import reduce

def pipeline(*funcs):
    return reduce(lambda f, g: lambda x: g(f(x)), funcs)

normalize = pipeline(
    str.strip,
    str.lower,
    lambda s: s.replace("-", "_"),
    lambda s: s.replace(" ", "_"),
)

print(normalize("  Hello World  "))  # hello_world