内置算法模块#

1. itertools#

itertools 内置模块包含了大量的函数,这些函数对于组织迭代器和与迭代器进行交互非常有用。

1.1. 简化循环#

from itertools import product

# 扁平化多层嵌套循环
def find_twelve(num_list1, num_list2, num_list3):
    """从 3 个数字列表中,寻找是否存在和为 12 的 3 个数
    """
    for num1, num2, num3 in product(num_list1, num_list2, num_list3):
        if num1 + num2 + num3 == 12:
            return num1, num2, num3


# 循环内隔行处理
from itertools import islice


def parse_titles(filename):
    with open(filename, 'r') as fp:
        # 设置 step=2,跳过无意义的 '---' 分隔符
        for line in islice(fp, 0, None, 2):
            yield line.strip()

1.2. 简化判断#

  • takewhile 替代 break

for user in users:
    # 当第一个不合格的用户出现后,不再进行后面的处理
    if not is_qualified(user):
        break

    # 进行处理 ... ...

from itertools import takewhile

for user in takewhile(is_qualified, users):
    # 进行处理 ... ...

1.3. 简化生成#

import itertools

# 重复
it = itertools.repeat("hello", 3)
print([*it])

# 循环重复
it = itertools.cycle([1, 2])
result = [next(it) for _ in range(10)]
print(result)

# 重复解包
it1, it2, it3 = itertools.tee(["first", "second"], 3)
print([*it1])
print([*it2])
print([*it3])

# 等差数列
gen = itertools.count(start=1, step=.5)
next(gen) # 1
next(gen) # 1.5

1.4. 合并、组合#

import itertools

# 合并
it = itertools.chain([1, 2, 3], [4, 5, 6])
print([*it])

# 组合
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
sum_reduce = itertools.accumulate(values)
print("Sum: ", list(sum_reduce))


def sum_modulo_20(first, second):
    output = first + second
    return output % 20


modulo_reduce = itertools.accumulate(values, sum_modulo_20)
print("Modulo:", list(modulo_reduce))

# returns the Cartesian product of items from one or more iterators
single = itertools.product([1, 2], repeat=2)
print("Single: ", list(single))

multiple = itertools.product([1, 2], ["a", "b"])
print("Multiple:", list(multiple))

# returns the unique ordered permutations of length N with items from an iterator
it = itertools.permutations([1, 2, 3, 4], 2)
print([*it])

# returns the unordered combinations of length N with unrepeated items from an iterator
it = itertools.combinations([1, 2, 3, 4], 2)
print([*it])

# the same as combinations, but repeated values are allowed
it = itertools.combinations_with_replacement([1, 2, 3, 4], 2)
print([*it])

1.5. 分组#

from operator import itemgetter
from itertools import groupby

rows = [
    {
        'address': '5412 N CLARK',
        'date': '07/01/2012'
    },
    {
        'address': '5148 N CLARK',
        'date': '07/04/2012'
    },
    {
        'address': '5800 E 58TH',
        'date': '07/02/2012'
    },
    {
        'address': '2122 N CLARK',
        'date': '07/03/2012'
    },
    {
        'address': '5645 N RAVENSWOOD',
        'date': '07/02/2012'
    },
    {
        'address': '1060 W ADDISON',
        'date': '07/02/2012'
    },
    {
        'address': '4801 N BROADWAY',
        'date': '07/01/2012'
    },
    {
        'address': '1039 W GRANVILLE',
        'date': '07/04/2012'
    },
]

for date, items in groupby(rows, key=itemgetter('date')):
    print(date)
    for i in items:
        print(' ', i)

1.6. 筛选#

from itertools import compress

addresses = [
    '5412 N CLARK',
    '5148 N CLARK',
    '5800 E 58TH',
    '2122 N CLARK'
    '5645 N RAVENSWOOD',
    '1060 W ADDISON',
    '4801 N BROADWAY',
    '1039 W GRANVILLE',
]

counts = [0, 3, 10, 4, 1, 7, 6, 1]
more5 = [n > 5 for n in counts]
# more5
# [False, False, True, False, False, True, True, False]
[*compress(addresses, more5)]
# ['5800 E 58TH', '4801 N BROADWAY', '1039 W GRANVILLE']

2. collections#

这个模块里的类都是 dict 的变种。

2.1. 计数#

from collections import Counter

my_list = [10, 10, 10, 5, 5, 2, 9, 9, 9, 9, 9, 9]
counter = Counter(my_list)
print(counter)  # Counter({9: 6, 10: 3, 5: 2, 2: 1})
most_common = counter.most_common(2)
print(most_common)  # [(9, 6), (10, 3)]

a, b = Counter([1, 2, 2]), Counter([2, 2, 3])
a + b  # Counter({2: 4, 1: 1, 3: 1})
a - b  # Counter({1: 1})

2.2. 合并映射#

from collections import ChainMap

a = {'x': 1, 'z': 3 }
b = {'y': 2, 'z': 4 }

c = ChainMap(a, b)
print(c['x']) # Outputs 1 (from a)
print(c['y']) # Outputs 2 (from b)
print(c['z']) # Outputs 3 (from a)

a['x'] = 42
c['x'] # Notice change to merged dicts
# 42

len(c)
# 3
list(c.keys())
# ['x', 'y', 'z']
list(c.values())
# [1, 2, 3]

# Add a new mapping
values = values.new_child()
values['x'] = 2
# Add a new mapping
values = values.new_child()
values['x'] = 3

values
# ChainMap({'x': 3}, {'x': 2}, {'x': 1})
values['x']
# 3
# Discard last mapping
values = values.parents
values['x']
# 2
# Discard last mapping
values = values.parents
values['x']
# 1
values
# ChainMap({'x': 1})

3. operator#

在函数式编程中,经常需要把算术运算符当作函数使用。

3.1. 函数式编程#

from functools import reduce
from operator import mul


def fact(n):
    return reduce(mul, range(1, n+1))

3.2. 排序#

from operator import itemgetter

rows = [
 {'fname': 'Brian', 'lname': 'Jones', 'uid': 1003},
 {'fname': 'David', 'lname': 'Beazley', 'uid': 1002},
 {'fname': 'John', 'lname': 'Cleese', 'uid': 1001},
 {'fname': 'Big', 'lname': 'Jones', 'uid': 1004}
]

rows_by_lfname = sorted(rows, key=itemgetter('lname', 'fname'))

min(rows, key=itemgetter('uid'))
# {'fname': 'John', 'lname': 'Cleese', 'uid': 1001}
max(rows, key=itemgetter('uid'))
# {'fname': 'Big', 'lname': 'Jones', 'uid': 1004}

3.3. 提取属性#

attrgetteritemgetter 作用类似,它创建的函数根据名称提取对象的属性。

from collections import namedtuple

LatLong = namedtuple('LatLong', 'lat long')
Metropolis = namedtuple('Metropolis', 'name cc pop coord')
metro_areas = [
    Metropolis(name, cc, pop, LatLong(lat, long))
    for name, cc, pop, (lat, long) in metro_data
]
metro_areas[0]
# Metropolis(name='Tokyo',
#            cc='JP',
#            pop=36.933,
#            coord=LatLong(lat=35.689722, long=139.691667))

metro_areas[0].coord.lat
# 35.689722

from operator import attrgetter

name_lat = attrgetter('name', 'coord.lat')

for city in sorted(metro_areas, key=attrgetter('coord.lat')):
    print(name_lat(city))

# ('Sao Paulo', -23.547778)
# ('Mexico City', 19.433333)
# ('Delhi NCR', 28.613889)
# ('Tokyo', 35.689722)
# ('New York-Newark', 40.808611)

4. heapq#

4.1. 查找#

import heapq

nums = [1, 8, 2, 23, 7, -4, 18, 23, 42, 37, 2]
print(heapq.nlargest(3, nums)) # Prints [42, 37, 23]
print(heapq.nsmallest(3, nums)) # Prints [-4, 1, 2]

portfolio = [{
    'name': 'IBM',
    'shares': 100,
    'price': 91.1
}, {
    'name': 'AAPL',
    'shares': 50,
    'price': 543.22
}, {
    'name': 'FB',
    'shares': 200,
    'price': 21.09
}, {
    'name': 'HPQ',
    'shares': 35,
    'price': 31.75
}, {
    'name': 'YHOO',
    'shares': 45,
    'price': 16.35
}, {
    'name': 'ACME',
    'shares': 75,
    'price': 115.65
}]


cheap = heapq.nsmallest(3, portfolio, key=lambda s: s['price'])
expensive = heapq.nlargest(3, portfolio, key=lambda s: s['price'])

堆最重要的特性就是 heap[0] 总是最小那个的元素。此外,接下来的元素可依次通过 heapq.heappop() 方法轻松找到。该方法会将第一个元素(最小的)弹出,然后以第二小的元素取而代之(这个操作的复杂度是\(O(logN)\)\(N\)代表堆的大小)。

nums = [1, 8, 2, 23, 7, -4, 18, 23, 42, 37, 2]
heapq.heapify(heap)

heapq.heappop(heap) # -4
heapq.heappop(heap) # 1
heapq.heappop(heap) # 2