diff --git a/tests/distributions/__init__.py b/tests/distributions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/distributions/test_InvertibleGaussian.py b/tests/distributions/test_InvertibleGaussian.py new file mode 100644 index 0000000..5c37611 --- /dev/null +++ b/tests/distributions/test_InvertibleGaussian.py @@ -0,0 +1,22 @@ +import torch +import sys +sys.path.append('../../src') +from relaxit.distributions.InvertibleGaussian import InvertibleGaussian + +# Testing reparameterized sampling from the InvertibleGaussian distribution + +def test_sample_shape(): + loc = torch.zeros(3, 4, 5, requires_grad=True) + scale = torch.ones(3, 4, 5, requires_grad=True) + temperature = torch.tensor([1e-0]) + distribution = InvertibleGaussian(loc, scale, temperature) + sample = distribution.rsample() + assert sample.shape == torch.Size([3, 4, 6]) + +def test_sample_grad(): + loc = torch.zeros(3, 4, 5, requires_grad=True) + scale = torch.ones(3, 4, 5, requires_grad=True) + temperature = torch.tensor([1e-0]) + distribution = InvertibleGaussian(loc, scale, temperature) + sample = distribution.rsample() + assert sample.requires_grad == True \ No newline at end of file diff --git a/tests/distributions/test_LogisticNormalSoftmax.py b/tests/distributions/test_LogisticNormalSoftmax.py new file mode 100644 index 0000000..f213c46 --- /dev/null +++ b/tests/distributions/test_LogisticNormalSoftmax.py @@ -0,0 +1,20 @@ +import torch +import sys +sys.path.append('../../src') +from relaxit.distributions.LogisticNormalSoftmax import LogisticNormalSoftmax + +# Testing reparameterized sampling from the LogisticNormalSoftmax distribution + +def test_sample_shape(): + loc = torch.zeros(3, 4, 5, requires_grad=True) + scale = torch.ones(3, 4, 5, requires_grad=True) + distribution = LogisticNormalSoftmax(loc, scale) + sample = distribution.rsample() + assert sample.shape == torch.Size([3, 4, 5]) + +def test_sample_grad(): + loc = torch.zeros(3, 4, 5, requires_grad=True) + scale = torch.ones(3, 4, 5, requires_grad=True) + distribution = LogisticNormalSoftmax(loc, scale) + sample = distribution.rsample() + assert sample.requires_grad == True \ No newline at end of file diff --git a/tests/distributions/test_approx.py b/tests/distributions/test_approx.py new file mode 100644 index 0000000..3f85662 --- /dev/null +++ b/tests/distributions/test_approx.py @@ -0,0 +1,28 @@ +import torch +import sys +sys.path.append('../../src') +from relaxit.distributions.LogisticNormalSoftmax import LogisticNormalSoftmax +from relaxit.distributions.approx import lognorm_approximation_fn, dirichlet_approximation_fn +from pyro.distributions import Dirichlet + +# Testing two-side closed-form Laplace bridge approximation between +# LogisticNormal and Dirichlet distributions + +def test_approx(): + # Generate a random concentration parameter + concentration = torch.randint(1, 10, (3,), dtype=torch.float) + + # Create the Dirichlet distribution + dirichlet_distribution = Dirichlet(concentration) + + # Approximate the Dirichlet distribution with a LogisticNormal distribution + lognorm_approximation = lognorm_approximation_fn(dirichlet_distribution) + loc = lognorm_approximation.loc + scale = lognorm_approximation.scale + + # Approximate the LogisticNormal distribution with a Dirichlet distribution + dirichlet_approximation = dirichlet_approximation_fn(lognorm_approximation) + concentration_approx = dirichlet_approximation.concentration + + # Assert that the original and approximated concentration parameters are close + assert torch.allclose(concentration, concentration_approx) \ No newline at end of file diff --git a/tests/distributions/test_kl.py b/tests/distributions/test_kl.py new file mode 100644 index 0000000..76b7673 --- /dev/null +++ b/tests/distributions/test_kl.py @@ -0,0 +1,49 @@ +import torch +import sys +sys.path.append('../../src') +from relaxit.distributions.InvertibleGaussian import InvertibleGaussian +from relaxit.distributions.kl import kl_divergence + +# Testing KL-divergence between two IntertibleGaussian distributions + +def test_igr_kl_shape(): + loc_1 = torch.zeros(3, 4, 5, requires_grad=True) + scale_1 = torch.ones(3, 4, 5, requires_grad=True) + temperature_1 = torch.tensor([1e-0]) + dist_1 = InvertibleGaussian(loc_1, scale_1, temperature_1) + + loc_2 = torch.ones(3, 4, 5, requires_grad=True) # ones, not zeros + scale_2 = torch.ones(3, 4, 5, requires_grad=True) + temperature_2 = torch.tensor([1e-2]) + dist_2 = InvertibleGaussian(loc_2, scale_2, temperature_2) + + div = kl_divergence(dist_1, dist_2) + assert div.shape == torch.Size([3, 4, 5]) + +def test_igr_kl_grad(): + loc_1 = torch.zeros(3, 4, 5, requires_grad=True) + scale_1 = torch.ones(3, 4, 5, requires_grad=True) + temperature_1 = torch.tensor([1e-0]) + dist_1 = InvertibleGaussian(loc_1, scale_1, temperature_1) + + loc_2 = torch.ones(3, 4, 5, requires_grad=True) # ones, not zeros + scale_2 = torch.ones(3, 4, 5, requires_grad=True) + temperature_2 = torch.tensor([1e-2]) + dist_2 = InvertibleGaussian(loc_2, scale_2, temperature_2) + + div = kl_divergence(dist_1, dist_2) + assert div.requires_grad == True + +def test_igr_kl_value(): + loc_1 = torch.ones(3, 4, 5, requires_grad=True) + scale_1 = torch.ones(3, 4, 5, requires_grad=True) + temperature_1 = torch.tensor([1e-2]) + dist_1 = InvertibleGaussian(loc_1, scale_1, temperature_1) + + loc_2 = torch.ones(3, 4, 5, requires_grad=True) # ones, not zeros + scale_2 = torch.ones(3, 4, 5, requires_grad=True) + temperature_2 = torch.tensor([1e-2]) + dist_2 = InvertibleGaussian(loc_2, scale_2, temperature_2) + + div = kl_divergence(dist_1, dist_2) + assert torch.allclose(div, torch.zeros_like(div)) \ No newline at end of file