-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathalias_multinomial.py
74 lines (58 loc) · 2.24 KB
/
alias_multinomial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import numpy as np
import pdb
class AliasMethod(object):
'''Alias sampling method to speedup multinomial sampling
The alias method treats multinomial sampling as a combination of uniform sampling and
bernoulli sampling. It achieves significant acceleration when repeatedly sampling from
the save multinomial distribution.
Attributes:
- probs: the probability density of desired multinomial distribution
Refs:
- https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
'''
def __init__(self, probs):
cpu_probs = probs.cpu()
K = len(probs)
self.prob = [0] * K
self.alias = [0] * K
# Sort the data into the outcomes with probabilities
# that are larger and smaller than 1/K.
smaller = []
larger = []
for idx, prob in enumerate(cpu_probs):
self.prob[idx] = K*prob
if self.prob[idx] < 1.0:
smaller.append(idx)
else:
larger.append(idx)
# Loop though and create little binary mixtures that
# appropriately allocate the larger outcomes over the
# overall uniform mixture.
while len(smaller) > 0 and len(larger) > 0:
small = smaller.pop()
large = larger.pop()
self.alias[small] = large
self.prob[large] = (self.prob[large] - 1.0) + self.prob[small]
if self.prob[large] < 1.0:
smaller.append(large)
else:
larger.append(large)
for last_one in smaller+larger:
self.prob[last_one] = 1
self.prob = probs.new(self.prob)
self.alias = probs.new(self.alias).long()
def draw(self, *size):
"""Draw N samples from multinomial
Args:
- size: the output size of samples
"""
max_value = self.alias.size(0)
kk = self.alias.new(*size).random_(0, max_value).long().view(-1)
prob = self.prob[kk]
alias = self.alias[kk]
# b is whether a random number is greater than q
b = torch.bernoulli(prob).long()
oq = kk.mul(b)
oj = alias.mul(1 - b)
return (oq + oj).view(size)