Skip to content

Commit

Permalink
add tests for igr and approx
Browse files Browse the repository at this point in the history
  • Loading branch information
kisnikser committed Nov 17, 2024
1 parent 0cd18a9 commit 15868e7
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 0 deletions.
Empty file added tests/distributions/__init__.py
Empty file.
22 changes: 22 additions & 0 deletions tests/distributions/test_InvertibleGaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
import sys
sys.path.append('../../src')
from relaxit.distributions.InvertibleGaussian import InvertibleGaussian

# Testing reparameterized sampling from the InvertibleGaussian distribution

def test_sample_shape():
loc = torch.zeros(3, 4, 5, requires_grad=True)
scale = torch.ones(3, 4, 5, requires_grad=True)
temperature = torch.tensor([1e-0])
distribution = InvertibleGaussian(loc, scale, temperature)
sample = distribution.rsample()
assert sample.shape == torch.Size([3, 4, 6])

def test_sample_grad():
loc = torch.zeros(3, 4, 5, requires_grad=True)
scale = torch.ones(3, 4, 5, requires_grad=True)
temperature = torch.tensor([1e-0])
distribution = InvertibleGaussian(loc, scale, temperature)
sample = distribution.rsample()
assert sample.requires_grad == True
20 changes: 20 additions & 0 deletions tests/distributions/test_LogisticNormalSoftmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import sys
sys.path.append('../../src')
from relaxit.distributions.LogisticNormalSoftmax import LogisticNormalSoftmax

# Testing reparameterized sampling from the LogisticNormalSoftmax distribution

def test_sample_shape():
loc = torch.zeros(3, 4, 5, requires_grad=True)
scale = torch.ones(3, 4, 5, requires_grad=True)
distribution = LogisticNormalSoftmax(loc, scale)
sample = distribution.rsample()
assert sample.shape == torch.Size([3, 4, 5])

def test_sample_grad():
loc = torch.zeros(3, 4, 5, requires_grad=True)
scale = torch.ones(3, 4, 5, requires_grad=True)
distribution = LogisticNormalSoftmax(loc, scale)
sample = distribution.rsample()
assert sample.requires_grad == True
28 changes: 28 additions & 0 deletions tests/distributions/test_approx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
import sys
sys.path.append('../../src')
from relaxit.distributions.LogisticNormalSoftmax import LogisticNormalSoftmax
from relaxit.distributions.approx import lognorm_approximation_fn, dirichlet_approximation_fn
from pyro.distributions import Dirichlet

# Testing two-side closed-form Laplace bridge approximation between
# LogisticNormal and Dirichlet distributions

def test_approx():
# Generate a random concentration parameter
concentration = torch.randint(1, 10, (3,), dtype=torch.float)

# Create the Dirichlet distribution
dirichlet_distribution = Dirichlet(concentration)

# Approximate the Dirichlet distribution with a LogisticNormal distribution
lognorm_approximation = lognorm_approximation_fn(dirichlet_distribution)
loc = lognorm_approximation.loc
scale = lognorm_approximation.scale

# Approximate the LogisticNormal distribution with a Dirichlet distribution
dirichlet_approximation = dirichlet_approximation_fn(lognorm_approximation)
concentration_approx = dirichlet_approximation.concentration

# Assert that the original and approximated concentration parameters are close
assert torch.allclose(concentration, concentration_approx)
49 changes: 49 additions & 0 deletions tests/distributions/test_kl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
import sys
sys.path.append('../../src')
from relaxit.distributions.InvertibleGaussian import InvertibleGaussian
from relaxit.distributions.kl import kl_divergence

# Testing KL-divergence between two IntertibleGaussian distributions

def test_igr_kl_shape():
loc_1 = torch.zeros(3, 4, 5, requires_grad=True)
scale_1 = torch.ones(3, 4, 5, requires_grad=True)
temperature_1 = torch.tensor([1e-0])
dist_1 = InvertibleGaussian(loc_1, scale_1, temperature_1)

loc_2 = torch.ones(3, 4, 5, requires_grad=True) # ones, not zeros
scale_2 = torch.ones(3, 4, 5, requires_grad=True)
temperature_2 = torch.tensor([1e-2])
dist_2 = InvertibleGaussian(loc_2, scale_2, temperature_2)

div = kl_divergence(dist_1, dist_2)
assert div.shape == torch.Size([3, 4, 5])

def test_igr_kl_grad():
loc_1 = torch.zeros(3, 4, 5, requires_grad=True)
scale_1 = torch.ones(3, 4, 5, requires_grad=True)
temperature_1 = torch.tensor([1e-0])
dist_1 = InvertibleGaussian(loc_1, scale_1, temperature_1)

loc_2 = torch.ones(3, 4, 5, requires_grad=True) # ones, not zeros
scale_2 = torch.ones(3, 4, 5, requires_grad=True)
temperature_2 = torch.tensor([1e-2])
dist_2 = InvertibleGaussian(loc_2, scale_2, temperature_2)

div = kl_divergence(dist_1, dist_2)
assert div.requires_grad == True

def test_igr_kl_value():
loc_1 = torch.ones(3, 4, 5, requires_grad=True)
scale_1 = torch.ones(3, 4, 5, requires_grad=True)
temperature_1 = torch.tensor([1e-2])
dist_1 = InvertibleGaussian(loc_1, scale_1, temperature_1)

loc_2 = torch.ones(3, 4, 5, requires_grad=True) # ones, not zeros
scale_2 = torch.ones(3, 4, 5, requires_grad=True)
temperature_2 = torch.tensor([1e-2])
dist_2 = InvertibleGaussian(loc_2, scale_2, temperature_2)

div = kl_divergence(dist_1, dist_2)
assert torch.allclose(div, torch.zeros_like(div))

0 comments on commit 15868e7

Please sign in to comment.