今天在网上看题目时,发现一个十分有趣的算法,叫蓄水池算法(Reservoir Sampling),牵扯到一点概率论问题。
题目:给出一个数据流,这个数据流的长度很大或者未知。并且对该数据流中数据只能访问一次。请写出一个随机选择算法,使得数据流中所有数据被选中的概率相等。
抽象:从n中取出k个数,n未知大小,保证最后n中每个元素被抽取的概率一样为k/n
。
做法:假设我们从3个数{1,2,3}
中取一个数,那么就要求每个数被抽取的概率为1/3
,我们先读取前2个数{1,2}
,我们以1/2
的概率选取其中一个数,加入选择{1}
,接下来读取数字{3}
,因为要求每个数被选取的概率为1/3
,因此我们以1/3
的概率选取数字{3}
,2/3
的概率选取数字{1}
,那么最终数字1
被选取的概率是2/3 * 1/2 = 1/3
,同理数字{2}
被选取的概率也是1/3
。
将上述做法n
个数选择1
个数,每次要读取第n
个数的时候,以1/n
的概率保留该数,以(n-1)/n
的概率保留前面n-1
个数选取出来的1
个数。
再推广到从n
个数中选取k
个数,假设读取到第n
个数时(n>=k),以k/n
的概率保留概述,以(n-k)/n
的概率保留前n-1
个中选取出来的k
个数。
证明:使用上述步骤,从n
中读取k
个数,在读取n
个元素后,n中每一个被保留下来的概率都是k/n
。假设n = k + i
,那么算法证明的就是第i
轮选取中,前k+i
个数每个数被保留的概率为k/(k+i)
,其中( 0 <= i <= n - k)
。
用数学归纳法来证明:
- 当
i=0
是,每个数字被选取的概率为k/(k+0)=1
,正确; - 假设当
i-1
轮时,结论成立,即前k+i-1
中每个数被保留的概率为k/(k+i-1)
; - 第
i轮
,因为我们以k/(k+i)
的概率选取第k+i
个数,因此其概率为k/(k+i)
,正确;对于前k+i-1
个数中的x
,其被保留的概率由两部分组成:
①:第k+i
个数没有被选取到,则x
被选取的概率是:i/(k+i) * k/(k+i-1)
,其中k/(k+i-1)
是2中
假设的条件,即x
在前k+i-1
中被保留的概率;
②:第k+i
个数被选取到,要替换前k+i-1
中的数,那么x
不被替换的概率是:k/(k+i) * k/(k+i-1) * (k-1)/k
,k/(k+i-1)
同①,k-1/k
是指,在被选取为k
个数之后,不被替换的概率。
因此,总概率为(i/(k+i) * k/(k+i-1)) + (k/(k+i) * k/(k+i-1) * (k-1)/k) = (k/(k+i-1)) * (i/(k+i) + (k-1)/(k+i)) = k/k+i
;
得证。
具体算法步骤:
- 从
n
个数读取前k
个数,保存在集合A
中; - 从第
i
个数开始(k<=i<=n)
,每次以k/i
的概率选择是否保留概述,若保留,将随机替换k中任一数; - 重复2,直到结束。
Leetcode有关于蓄水池算法
的两道题,来看看怎么应用:
- Linked List Random Node
题目要求去链表中的任一节点,使得节点被选取的概率相等,因此我们利用蓄水池法
,遍历节点;
# Definition for singly-linked list.
# class ListNode:
# def __init__(self, x):
# self.val = x
# self.next = None
import random
class Solution:
def __init__(self, head):
"""
@param head The linked list's head.
Note that the head is guaranteed to be not null, so it contains at least one node.
:type head: ListNode
"""
self.head = head
def getRandom(self):
"""
Returns a random node's value.
:rtype: int
"""
node = self.head
select_node = node
i = 1
while node.next:
i += 1
node = node.next
if random.random() <= 1.0/i: # 保证被选取的概率为`1/i`
select_node = node
return select_node.val
- Random Pick Index
题目要求获取指定数字序号,每个数字在该序列中的序号被获取的概率相等,题目有个硬性条件,n
很大,刚好蓄水池法
可以用来解决。
import random
class Solution:
def __init__(self, nums):
"""
:type nums: List[int]
"""
self.nums = nums
def pick(self, target):
"""
:type target: int
:rtype: int
"""
count = 0
select_index = 0
for index,num in enumerate(self.nums):
if num == target: # 遍历列表,当该值与目标值相同时,才进入蓄水池算法
count += 1
if random.random() <= 1.0/count: #满足概率为1/i
select_index = index
return select_index