-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
64 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import torch | ||
import sys, os | ||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', 'src'))) | ||
from relaxit.distributions.GaussianRelaxedBernoulli import GaussianRelaxedBernoulli | ||
|
||
# Testing reparameterized sampling from the GaussianRelaxedBernoulli distribution | ||
|
||
def test_sample_shape(): | ||
loc = torch.tensor([0.]) | ||
scale = torch.tensor([1.]) | ||
|
||
distr = GaussianRelaxedBernoulli(loc = loc, scale=scale) | ||
samples = distr.rsample(sample_shape = torch.Size([3])) | ||
assert samples.shape == torch.Size([3, 1]) | ||
|
||
def test_sample_grad(): | ||
loc = torch.tensor([0.], requires_grad=True) | ||
scale = torch.tensor([1.], requires_grad=True) | ||
distr = GaussianRelaxedBernoulli(loc = loc, scale=scale) | ||
samples = distr.rsample(sample_shape = torch.Size([3])) | ||
|
||
assert samples.requires_grad == True | ||
|
||
if __name__ == "__main__": | ||
test_sample_shape() | ||
test_sample_grad() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
import sys, os | ||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', 'src'))) | ||
from relaxit.distributions.GumbelSoftmaxTopK import GumbelSoftmaxTopK | ||
|
||
# Testing reparameterized sampling and log prob from the GumbelSoftmaxTopK distribution | ||
|
||
def test_sample_shape(): | ||
a = torch.tensor([[1, 2, 3]]) | ||
distribution = GumbelSoftmaxTopK(a, K=2) | ||
sample = distribution.rsample() | ||
assert sample.shape == torch.Size([2]) | ||
print("$") | ||
|
||
def test_log_prob(): | ||
a = torch.tensor([1., 2., 3.]) | ||
distribution = GumbelSoftmaxTopK(a, K=1) | ||
value = 1 | ||
log_prob_my = distribution.log_prob(value, shape=torch.Size([])) | ||
log_prob_true = torch.log(a[value] / 6.) | ||
assert log_prob_my - log_prob_true < 1e-6 |
34 changes: 17 additions & 17 deletions
34
tests/test_simple.py → tests/distributions/test_HardConcrete.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,30 @@ | ||
import sys, os | ||
import torch | ||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) | ||
from relaxit.distributions import ( | ||
HardConcrete, | ||
GaussianRelaxedBernoulli | ||
) | ||
import sys, os | ||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', 'src'))) | ||
from relaxit.distributions.HardConcrete import HardConcrete | ||
|
||
def test_GaussianRelaxedBernoulli(): | ||
loc = torch.tensor([0.], requires_grad=True) | ||
scale = torch.tensor([1.], requires_grad=True) | ||
# Testing reparameterized sampling from the HardConcrete distribution | ||
|
||
### rsample test ### | ||
distr = GaussianRelaxedBernoulli(loc = loc, scale=scale) | ||
def test_sample_shape(): | ||
alpha = torch.tensor([1.]) | ||
beta = torch.tensor([2.]) | ||
gamma = torch.tensor([-3.]) | ||
xi = torch.tensor([4.]) | ||
distr = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi) | ||
samples = distr.rsample(sample_shape = torch.Size([3])) | ||
|
||
assert samples.shape == torch.Size([3, 1]) | ||
assert samples.requires_grad == True | ||
print("GaussianRelaxedBernoulli is OK") | ||
|
||
def test_HardConcrete(): | ||
def test_sample_grad(): | ||
alpha = torch.tensor([1.], requires_grad=True) | ||
beta = torch.tensor([2.], requires_grad=True) | ||
gamma = torch.tensor([-3.], requires_grad=True) | ||
xi = torch.tensor([4.], requires_grad=True) | ||
|
||
distr = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi) | ||
samples = distr.rsample(sample_shape = torch.Size([3])) | ||
assert samples.shape == torch.Size([3, 1]) | ||
|
||
assert samples.requires_grad == True | ||
print("HardConcrete is OK") | ||
|
||
if __name__ == "__main__": | ||
test_sample_shape() | ||
test_sample_grad() |