-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
122 lines (83 loc) · 3.87 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import jax.numpy as jnp
from distrax import Normal, Bernoulli
from jax.random import split
from jax import vmap
from jax.scipy.special import logsumexp
from einops import reduce, repeat
# Typing
from jax.random import PRNGKeyArray
from jax import Array
def vae_loss(model, x: Array, K: int, key: PRNGKeyArray) -> Array:
'''Compute the VAE loss.'''
def loss_fn(x: Array, key: PRNGKeyArray):
x_rec, z, mean, logvar = model(x, K, key=key)
def log_importance_weight(x_rec, z):
# Compute importance weights
log_q_z_x = reduce(Normal(mean, jnp.exp(1 / 2 * logvar)).log_prob(z), 'l -> ', 'sum')
log_p_z = reduce(Normal(jnp.zeros_like(mean), jnp.ones_like(logvar)).log_prob(z), 'l -> ', 'sum')
log_p_x_z = reduce(Bernoulli(logits=x_rec).log_prob(x), 'c h w -> ', 'sum')
return log_p_x_z + log_p_z - log_q_z_x
log_iw = vmap(log_importance_weight)(x_rec, z)
# Marginalize log likelihood
return log_iw.mean()
keys = split(key, x.shape[0])
# Mean over the batch
return -jnp.mean(vmap(loss_fn)(x, keys))
def iwae_loss(model, x: Array, K: int, key: PRNGKeyArray) -> Array:
'''Compute the IWELBO loss.'''
def loss_fn(x: Array, key: PRNGKeyArray):
x_rec, z, mean, logvar = model(x, K, key=key)
def log_importance_weight(x_rec, z):
# Compute importance weights
log_q_z_x = reduce(Normal(mean, jnp.exp(1 / 2 * logvar)).log_prob(z), 'l -> ', 'sum')
log_p_z = reduce(Normal(jnp.zeros_like(mean), jnp.ones_like(logvar)).log_prob(z), 'l -> ', 'sum')
log_p_x_z = reduce(Bernoulli(logits=x_rec).log_prob(x), 'c h w -> ', 'sum')
return log_p_x_z + log_p_z - log_q_z_x
log_iw = vmap(log_importance_weight)(x_rec, z)
# Marginalize log likelihood
return reduce(log_iw, 'k -> ', logsumexp) - jnp.log(K)
keys = split(key, x.shape[0])
# Mean over the batch
return -jnp.mean(vmap(loss_fn)(x, keys))
def old_iwae_loss(model, x, K: int, key: PRNGKeyArray) -> float:
'''Compute the IWELBO loss.'''
def loss_fn(x: Array, key: PRNGKeyArray):
x_rec, _, mean, logvar = model(x, K, key=key)
# Posterior p_{\theta}(z|x)
post = Normal(jnp.zeros_like(mean), jnp.ones_like(logvar))
# Approximate posterior q_{\phi}(z|x)
latent = Normal(mean, jnp.exp(1 / 2 * logvar))
# Likelihood p_{\theta}(x|z)
likelihood = Bernoulli(logits=x_rec)
# KL divergence
kl_div = reduce(latent.kl_divergence(post), 'n -> ()', 'sum')
# Repeat samples for broadcasting
kl_div = repeat(kl_div, '() -> k', k=K)
xs = repeat(x, 'c h w -> k c h w', k=K)
# Log-likelihood or reconstruction loss
like = reduce(likelihood.log_prob(xs), 'k c h w -> k', 'sum')
# Importance weights
iw_loss = reduce(like - kl_div, 'k -> ()', logsumexp) - jnp.log(K)
return -iw_loss
keys = split(key, x.shape[0])
# Mean over the batch
return jnp.mean(vmap(loss_fn)(x, keys))
def old_vae_loss(model, x: Array, K: int, key: PRNGKeyArray) -> Array:
'''Compute the VAE loss.'''
def loss_fn(x: Array, key: PRNGKeyArray):
x_rec, _, mean, logvar = model(x, K, key=key)
# Posterior p_{\theta}(z|x)
post = Normal(jnp.zeros_like(mean), jnp.ones_like(logvar))
# Approximate posterior q_{\phi}(z|x)
latent = Normal(mean, jnp.exp(1 / 2 * logvar))
# Likelihood p_{\theta}(x|z)
likelihood = Bernoulli(logits=x_rec)
# KL divergence
kl_div = reduce(latent.kl_divergence(post), 'n -> ()', 'sum')
# Log-likelihood or reconstruction loss
like = reduce(likelihood.log_prob(x), 'k c h w -> k', 'sum')
# ELBO
return -(like - kl_div)
keys = split(key, x.shape[0])
# Mean over the batch
return jnp.mean(vmap(loss_fn)(x, keys))