Skip to content

Commit

Permalink
add new tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Vepricov committed Nov 23, 2024
1 parent dd93fb0 commit ff5de67
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 17 deletions.
26 changes: 26 additions & 0 deletions tests/distributions/test_GaussianRelaxedBernoulli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', 'src')))
from relaxit.distributions.GaussianRelaxedBernoulli import GaussianRelaxedBernoulli

# Testing reparameterized sampling from the GaussianRelaxedBernoulli distribution

def test_sample_shape():
loc = torch.tensor([0.])
scale = torch.tensor([1.])

distr = GaussianRelaxedBernoulli(loc = loc, scale=scale)
samples = distr.rsample(sample_shape = torch.Size([3]))
assert samples.shape == torch.Size([3, 1])

def test_sample_grad():
loc = torch.tensor([0.], requires_grad=True)
scale = torch.tensor([1.], requires_grad=True)
distr = GaussianRelaxedBernoulli(loc = loc, scale=scale)
samples = distr.rsample(sample_shape = torch.Size([3]))

assert samples.requires_grad == True

if __name__ == "__main__":
test_sample_shape()
test_sample_grad()
21 changes: 21 additions & 0 deletions tests/distributions/test_GumbelSoftmaxTopK.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', 'src')))
from relaxit.distributions.GumbelSoftmaxTopK import GumbelSoftmaxTopK

# Testing reparameterized sampling and log prob from the GumbelSoftmaxTopK distribution

def test_sample_shape():
a = torch.tensor([[1, 2, 3]])
distribution = GumbelSoftmaxTopK(a, K=2)
sample = distribution.rsample()
assert sample.shape == torch.Size([2])
print("$")

def test_log_prob():
a = torch.tensor([1., 2., 3.])
distribution = GumbelSoftmaxTopK(a, K=1)
value = 1
log_prob_my = distribution.log_prob(value, shape=torch.Size([]))
log_prob_true = torch.log(a[value] / 6.)
assert log_prob_my - log_prob_true < 1e-6
34 changes: 17 additions & 17 deletions tests/test_simple.py → tests/distributions/test_HardConcrete.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
import sys, os
import torch
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src')))
from relaxit.distributions import (
HardConcrete,
GaussianRelaxedBernoulli
)
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', 'src')))
from relaxit.distributions.HardConcrete import HardConcrete

def test_GaussianRelaxedBernoulli():
loc = torch.tensor([0.], requires_grad=True)
scale = torch.tensor([1.], requires_grad=True)
# Testing reparameterized sampling from the HardConcrete distribution

### rsample test ###
distr = GaussianRelaxedBernoulli(loc = loc, scale=scale)
def test_sample_shape():
alpha = torch.tensor([1.])
beta = torch.tensor([2.])
gamma = torch.tensor([-3.])
xi = torch.tensor([4.])
distr = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi)
samples = distr.rsample(sample_shape = torch.Size([3]))

assert samples.shape == torch.Size([3, 1])
assert samples.requires_grad == True
print("GaussianRelaxedBernoulli is OK")

def test_HardConcrete():
def test_sample_grad():
alpha = torch.tensor([1.], requires_grad=True)
beta = torch.tensor([2.], requires_grad=True)
gamma = torch.tensor([-3.], requires_grad=True)
xi = torch.tensor([4.], requires_grad=True)

distr = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi)
samples = distr.rsample(sample_shape = torch.Size([3]))
assert samples.shape == torch.Size([3, 1])

assert samples.requires_grad == True
print("HardConcrete is OK")

if __name__ == "__main__":
test_sample_shape()
test_sample_grad()

0 comments on commit ff5de67

Please sign in to comment.