Python heapq 模块的实现

无意间看了一下 Python heapq 模块的源代码,发现里面的源代码写得很棒,忍不住拿出来和大家分享一下。:)

heapq 的意思是 heap queue,也就是基于 heap 的 priority queue。说到 priority queue,忍不住吐槽几句。我上学的时候学优先级队列的时候就没有碰到像 wikipedia 上那样透彻的解释,priority queue 并不是具体的某一个数据结构,而是对一类数据结构的概括!比如栈就可以看作是后进入的优先级总是大于先进入的,而队列就可以看成是先进入的优先级一定高于后进来的!这还没完!如果我们是用一个无序的数组实现一个priority queue,对它进行出队操作尼玛不就是选择排序么!尼玛用有序数组实现入队操作尼玛不就是插入排序么!尼玛用堆实现(即本文要介绍的)就是堆排序啊!用二叉树实现就是二叉树排序啊!数据结构课学的这些东西基本上都出来了啊!T_T

heapq 的代码也是用 Python 写的,用到了一些其它 Python 模块,如果你对 itertools 不熟悉,在阅读下面的代码之前请先读文档。heapq 模块主要有5个函数:heappush(),把一个元素放入堆中;heappop(),从堆中取出一个元素;heapify(),把一个列表变成一个堆;nlargest() 和 nsmallest() 分别提供列表中最大或最小的N个元素。

先从简单的看起:

[python]
def heappush(heap, item):
“””Push item onto heap, maintaining the heap invariant.”””
heap.append(item)
_siftdown(heap, 0, len(heap)-1)

def heappop(heap):
“””Pop the smallest item off the heap, maintaining the heap invariant.”””
lastelt = heap.pop() # raises appropriate IndexError if heap is empty
if heap:
returnitem = heap[0]
heap[0] = lastelt
_siftup(heap, 0)
else:
returnitem = lastelt
return returnitem
[/python]

从源代码我们不难看出,这个堆也是用数组实现的,而且是最小堆,即 a[k] <= a[2*k+1] and a[k] startpos:
parentpos = (pos - 1) >> 1
parent = heap[parentpos]
if cmp_lt(newitem, parent):
heap[pos] = parent
pos = parentpos
continue
break
heap[pos] = newitem

def _siftup(heap, pos):
endpos = len(heap)
startpos = pos
newitem = heap[pos]

# Bubble up the smaller child until hitting a leaf.
childpos = 2*pos + 1    # leftmost child position
while childpos &lt; endpos:
    # Set childpos to index of smaller child.
    rightpos = childpos + 1
    if rightpos &lt; endpos and not cmp_lt(heap[childpos], heap[rightpos]):
        childpos = rightpos
    # Move the smaller child up.
    heap[pos] = heap[childpos]
    pos = childpos
    childpos = 2*pos + 1
# The leaf at pos is empty now.  Put newitem there, and bubble it up
# to its final resting place (by sifting its parents down).
heap[pos] = newitem
_siftdown(heap, startpos, pos)

[/python]

上面的代码加上注释很容易理解,不是吗?在此基础上实现 heapify() 就很容易了:

[python]
def heapify(x):
"""Transform list into a heap, in-place, in O(len(x)) time."""
n = len(x)

# Transform bottom-up.  The largest index there&#039;s any point to looking at
# is the largest with a child index in-range, so must have 2*i + 1 &lt; n,
# or i &lt; (n-1)/2\.  If n is even = 2*j, this is (2*j-1)/2 = j-1/2 so
# j-1 is the largest, which is n//2 - 1\.  If n is odd = 2*j+1, this is
# (2*j+1-1)/2 = j so j-1 is the largest, and that&#039;s again n//2-1.
for i in reversed(xrange(n//2)):
    _siftup(x, i)

[/python]
这里用了一个技巧,正如注释中所说,其实只要 siftup 后面一半就可以了,前面的一半自然就是一个heap了。heappushpop() 也可以用它来实现了,而且用不着调用 heappush()+heappop():

[python]
def heappushpop(heap, item):
"""Fast version of a heappush followed by a heappop."""
if heap and cmp_lt(heap[0], item):
item, heap[0] = heap[0], item
_siftup(heap, 0)
return item
[/python]

第一眼看到这个函数可能觉得它放进去一个再取出来一个有什么意思嘛!仔细想想它很有用,尤其是在后面实现 nlargest() 的时候:

[python]
def nlargest(n, iterable):
"""Find the n largest elements in a dataset.

Equivalent to:  sorted(iterable, reverse=True)[:n]
&quot;&quot;&quot;
if n &lt; 0:
    return []
it = iter(iterable)
result = list(islice(it, n))
if not result:
    return result
heapify(result)
_heappushpop = heappushpop
for elem in it:
    _heappushpop(result, elem)
result.sort(reverse=True)
return result

[/python]

先从 list 中取出 N 个元素来,然后把这个 list 转化成 heap,把原先的 list 中的所有元素在此 heap 上进行 heappushpop() 操作,最后剩下的一定是最大的!因为你每次 push 进去的不一定是最大的,但你 pop 出来的一定是最小的啊!

但 nsmallest() 的实现就截然不同了:

[python]
def nsmallest(n, iterable):
"""Find the n smallest elements in a dataset.

Equivalent to:  sorted(iterable)[:n]
&quot;&quot;&quot;
if n &lt; 0:
    return []
if hasattr(iterable, &#039;__len__&#039;) and n * 10  Largest of the nsmallest
    for elem in it:
        if cmp_lt(elem, los):
            insort(result, elem)
            pop()
            los = result[-1]
    return result
# An alternative approach manifests the whole iterable in memory but
# saves comparisons by heapifying all at once.  Also, saves time
# over bisect.insort() which has O(n) data movement time for every
# insertion.  Finding the n smallest of an m length iterable requires
#    O(m) + O(n log m) comparisons.
h = list(iterable)
heapify(h)
return map(heappop, repeat(h, min(n, len(h))))

[/python]

这里做了一个优化,如果 N 小于 list 长度的1/10的话,就直接用插入排序(因为之前就是有序的,所以很快)。如果 N 比较大的话,就用 heap 来实现了,把整个 list 作成一个堆,然后 heappop() 出来的前 N 个就是最后的结果了!不得不说最后一行代码写得太精炼了!