-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathsimulate.py
35 lines (28 loc) · 1.14 KB
/
simulate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from torch.distributions import Bernoulli
def simulate_data(model, batch_size=10, n_batch=1, device=None):
"""Simulate data from the VAE model. Sample from the
joint distribution p(z)p(x|z). This is equivalent to
sampling from p(x)p(z|x), i.e. z is from the posterior.
Bidirectional Monte Carlo only works on simulated data,
where we could obtain exact posterior samples.
Args:
model: VAE model for simulation
batch_size: batch size for simulated data
n_batch: number of batches
device (torch.device): device to run all computation on
Returns:
iterator that loops over batches of torch Tensor pair x, z
"""
batches = []
for i in range(n_batch):
# assume prior for VAE is unit Gaussian
z = torch.randn(size=(batch_size, model.latent_dim), device=device)
x_logits = model.decode(z)
if isinstance(x_logits, tuple):
x_logits = x_logits[0]
x_bernoulli_dist = Bernoulli(probs=x_logits.sigmoid())
x = x_bernoulli_dist.sample()
paired_batch = (x, z)
batches.append(paired_batch)
return iter(batches)