diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index 45dddf0..0cb2bc9 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -37,7 +37,7 @@ jobs:
- name: Build Docs
run: |
- sphinx-build -b html docs public
+ sphinx-build -b html docs/source public
touch public/.nojekyll
- name: Deploy đ
diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml
index 60c5975..8fe1ceb 100644
--- a/.github/workflows/pypi.yml
+++ b/.github/workflows/pypi.yml
@@ -1,4 +1,4 @@
-name: Publish Python đ distribution đĻ to PyPI and TestPyPI
+name: Publish Python đ distribution đĻ to PyPI
on: push
@@ -90,28 +90,4 @@ jobs:
run: >-
gh release upload
'${{ github.ref_name }}' dist/**
- --repo '${{ github.repository }}'
-
- publish-to-testpypi:
- name: Publish Python đ distribution đĻ to TestPyPI
- needs:
- - build
- runs-on: ubuntu-latest
-
- environment:
- name: testpypi
- url: https://test.pypi.org/p/relaxit
-
- permissions:
- id-token: write # IMPORTANT: mandatory for trusted publishing
-
- steps:
- - name: Download all the dists
- uses: actions/download-artifact@v4
- with:
- name: python-package-distributions
- path: dist/
- - name: Publish distribution đĻ to TestPyPI
- uses: pypa/gh-action-pypi-publish@release/v1
- with:
- repository-url: https://test.pypi.org/legacy/
\ No newline at end of file
+ --repo '${{ github.repository }}'
\ No newline at end of file
diff --git a/README.md b/README.md
index 7b8f1e4..e9c38a1 100644
--- a/README.md
+++ b/README.md
@@ -52,9 +52,9 @@
6. [Tests](https://github.com/intsystems/discrete-variables-relaxation/tree/main/tests)
## đĄ Motivation
-For lots of mathematical problems we need an ability to sample discrete random variables.
-The problem is that due to continuos nature of deep learning optimization, the usage of truely discrete random variables is infeasible.
-Thus we use different relaxation method.
+For lots of mathematical problems we need an ability to sample discrete random variables.
+The problem is that due to continuous nature of deep learning optimization, the usage of truly discrete random variables is infeasible.
+Thus we use different relaxation methods.
One of them, [Concrete distribution](https://arxiv.org/abs/1611.00712) or [Gumbel-Softmax](https://arxiv.org/abs/1611.01144) (this is one distribution proposed in parallel by two research groups) is implemented in different DL packages.
In this project we implement different alternatives to it.
@@ -65,11 +65,11 @@ In this project we implement different alternatives to it.
- [x] [Relaxed Bernoulli](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/GaussianRelaxedBernoulli.py), also see [đ paper](http://proceedings.mlr.press/v119/yamada20a/yamada20a.pdf)
- [x] [Correlated relaxed Bernoulli](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/CorrelatedRelaxedBernoulli.py), also see [đ paper](https://openreview.net/pdf?id=oDFvtxzPOx)
- [x] [Gumbel-softmax TOP-K](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/GumbelSoftmaxTopK.py), also see [đ paper](https://arxiv.org/pdf/1903.06059)
-- [x] [Straight-Through Bernoulli](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/StraightThroughBernoulli.py), also see [đ paper](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235)
+- [x] [Straight-Through Bernoulli](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/StraightThroughBernoulli.py), also see [đ paper](https://arxiv.org/abs/1910.02176)
+- [x] [Stochastic Times Smooth](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/StochasticTimesSmooth.py), also see [đ paper](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235)
- [x] [Invertible Gaussian](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/InvertibleGaussian.py) with [KL implemented](https://github.com/intsystems/discrete-variables-relaxation/blob/f398ebbbac703582de392bc33d89b55c6c99ea68/src/relaxit/distributions/kl.py#L7), also see [đ paper](https://arxiv.org/abs/1912.09588)
- [x] [Hard Concrete](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/HardConcrete.py), also see [đ paper](https://arxiv.org/pdf/1712.01312)
-- [x] [REINFORCE](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/CorrelatedRelaxedBernoulli.py), also see [đē slides](http://www.cs.toronto.edu/~tingwuwang/REINFORCE.pdf)
-- [x] [Logit-Normal](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/LogisticNormalSoftmax.py) and [Laplace-form approximation of Dirichlet](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/approx.py), also see [âšī¸ wiki](https://en.wikipedia.org/wiki/Logit-normal_distribution) and [đģ stackexchange](https://stats.stackexchange.com/questions/535560/approximating-the-logit-normal-by-dirichlet)
+- [x] [Logistic-Normal](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/LogisticNormalSoftmax.py) and [Laplace-form approximation of Dirichlet](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/approx.py), also see [âšī¸ wiki](https://en.wikipedia.org/wiki/Logit-normal_distribution) and [đģ stackexchange](https://stats.stackexchange.com/questions/535560/approximating-the-logit-normal-by-dirichlet)
## đ ī¸ Install
@@ -116,13 +116,15 @@ print('sample.requires_grad:', sample.requires_grad)
| data:image/s3,"s3://crabby-images/2c66b/2c66be57a59b24901472bdbbf5420503d7b5cbf2" alt="Laplace Bridge" | data:image/s3,"s3://crabby-images/601d4/601d41d86a8969d850a9c31c6424a48457cc4065" alt="REINFORCE" | data:image/s3,"s3://crabby-images/099cb/099cb473a8f1846efaa5259b6459068f3b733465" alt="VAE" |
| [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Open In Colab"](https://colab.research.google.com/github/intsystems/discrete-variables-relaxation/blob/main/demo/laplace-bridge.ipynb) | [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Open In Colab"](https://colab.research.google.com/github/intsystems/discrete-variables-relaxation/blob/main/demo/reinforce.ipynb) | [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Open In Colab"](https://colab.research.google.com/github/intsystems/discrete-variables-relaxation/blob/main/demo/demo.ipynb) |
-For demonstration purposes, we divide our algorithms in three different groups. Each group relates to the particular demo code:
+For demonstration purposes, we divide our algorithms in three[^*] different groups. Each group relates to the particular demo code:
- [Laplace bridge between Dirichlet and LogisticNormal distributions](https://github.com/intsystems/discrete-variables-relaxation/blob/main/demo/laplace-bridge.ipynb)
- [REINFORCE](https://github.com/intsystems/discrete-variables-relaxation/blob/main/demo/reinforce.ipynb)
- [Other relaxation methods](https://github.com/intsystems/discrete-variables-relaxation/blob/main/demo/demo.ipynb)
We describe our demo experiments [here](https://github.com/intsystems/discrete-variables-relaxation/tree/main/demo).
+[^*]: We also implement REINFORCE algorithm as a *score function* estimator alternative for our relaxation methods that are inherently *pathwise derivative* estimators. This one is implemented only for demo experiments and is not included into the source code of package.
+
## đ Stack
Some of the alternatives for GS were implemented in [pyro](https://docs.pyro.ai/en/dev/distributions.html), so we base our library on their codebase.
diff --git a/assets/demo_results/vae_stochastic_times_smooth/vae_stochastic_times_smooth_reconstruction.png b/assets/demo_results/vae_stochastic_times_smooth/vae_stochastic_times_smooth_reconstruction.png
new file mode 100644
index 0000000..df4f0d3
Binary files /dev/null and b/assets/demo_results/vae_stochastic_times_smooth/vae_stochastic_times_smooth_reconstruction.png differ
diff --git a/assets/demo_results/vae_stochastic_times_smooth/vae_stochastic_times_smooth_sample.png b/assets/demo_results/vae_stochastic_times_smooth/vae_stochastic_times_smooth_sample.png
new file mode 100644
index 0000000..f3211f0
Binary files /dev/null and b/assets/demo_results/vae_stochastic_times_smooth/vae_stochastic_times_smooth_sample.png differ
diff --git a/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_reconstruction.png b/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_reconstruction.png
index 69eee89..00ea80e 100644
Binary files a/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_reconstruction.png and b/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_reconstruction.png differ
diff --git a/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_sample.png b/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_sample.png
index a44c7ae..94b8bf0 100644
Binary files a/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_sample.png and b/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_sample.png differ
diff --git a/demo/README.md b/demo/README.md
index 85845e1..94f8bdb 100644
--- a/demo/README.md
+++ b/demo/README.md
@@ -1,5 +1,5 @@
# Demo experiments code
-This repository contains our demo code for various experiments. The main demo code can be found in the notebook `demo/demo.ipynb`. Open the notebook and run the cells to see the demonstration in action. For additional experiments, refer to the section [Additional experiments](#experiments). Before starting any experiments, ensure you follow all the installation steps outlined in the [Installation](#installation) section.
+This repository contains our demo code for various experiments. The main demo code can be found in the notebook `demo.ipynb`. Open the notebook and run the cells to see the demonstration in action. For additional experiments, refer to the section [Additional experiments](#experiments). Before starting any experiments, ensure you follow all the installation steps outlined in the [Installation](#installation) section.
## Installation
@@ -27,33 +27,36 @@ For additional demo experiments, we have implemented Variational Autoencoders (V
1. **Train and save the models:**
To run the additional demo code with VAEs, you need to train all the models and save their results. Execute the following commands:
```bash
- # VAE with Gaussian Bernoulli latent space
- python vae_gaussian_bernoulli.py
-
# VAE with Correlated Bernoulli latent space
python vae_correlated_bernoulli.py
-
+
+ # VAE with Gaussian Bernoulli latent space
+ python vae_gaussian_bernoulli.py
+
+ # VAE with Gumbel-Softmax top-K latent space
+ python vae_gumbel_softmax_topk.py
+
# VAE with Hard Concrete latent space
python vae_hard_concrete.py
- # VAE with Straight Through Bernoullii latent space
- python vae_straight_through_bernoulli.py
-
# VAE with Invertible Gaussian latent space
python vae_invertible_gaussian.py
- # VAE with Gumbel Softmax TopK latent space
- python vae_gumbel_softmax_topk.py
+ # VAE with Stochastic Times Smooth latent space
+ python vae_stochastic_times_smooth.py
+
+ # VAE with Straight Through Bernoullii latent space
+ python vae_straight_through_bernoulli.py
```
2. **View the results:**
- After completing the training and testing of all the models, you can find the results of sampling and reconstruction methods in the directory `demo/results`.
+ After completing the training and testing of all the models, you can find the results of sampling and reconstruction methods in the directory `results`.
Moreover, we conducted experiments with Laplace Bridge between LogisticNormal and Dirichlet distributions. We use two-side Laplace bridge to approximate:
-- Dirichlet using logisticNormal
-- LogisticNormal using Dirichlet
+- Dirichlet using Logistic-Normal
+- Logistic-Normal using Dirichlet
-These experiments aim to find the best parameters to make the distributions nearly identical on the simplex. The experiments can be found in the notebook `demo/laplace-bridge.ipynb`.
+These experiments aim to find the best parameters to make the distributions nearly identical on the simplex. The experiments can be found in the notebook `laplace-bridge.ipynb`.
-Furthermore, the Reinforce algorithm is applied in the [Acrobot environment](https://www.gymlibrary.dev/environments/classic_control/acrobot/). Detailed experiments can be viewed in the notebook `demo/reinforce.ipynb`. A script `demo/reinforce.py` can also be used for training.
+Furthermore, the Reinforce algorithm is applied in the [Acrobot environment](https://www.gymlibrary.dev/environments/classic_control/acrobot/). Detailed experiments can be viewed in the notebook `reinforce.ipynb`. A script `reinforce.py` can also be used for training.
diff --git a/demo/requirements.txt b/demo/requirements.txt
index 546ee61..f468961 100644
--- a/demo/requirements.txt
+++ b/demo/requirements.txt
@@ -1,11 +1,10 @@
-gym==0.26.2
-numpy==1.24.1
-pyro_ppl==1.9.1
-setuptools==65.5.0
-torch==2.5.1
-torchvision==0.20.1
-matplotlib==3.9.2
-networkx==3.3
-tqdm==4.66.5
-pillow==10.4.0
-relaxit==0.1.2
\ No newline at end of file
+gym>=0.26.2
+numpy>=1.24.1
+torchvision>=0.20.1
+matplotlib>=3.9.2
+networkx>=3.3
+tqdm>=4.66.5
+pillow>=10.4.0
+relaxit==1.0.1
+ftfy
+regex
\ No newline at end of file
diff --git a/demo/vae_stochastic_times_smooth.py b/demo/vae_stochastic_times_smooth.py
new file mode 100644
index 0000000..deb10e3
--- /dev/null
+++ b/demo/vae_stochastic_times_smooth.py
@@ -0,0 +1,252 @@
+import os
+import argparse
+import numpy as np
+import torch
+import sys
+import torch.utils.data
+from torch import nn, optim
+from torch.nn import functional as F
+from torchvision import datasets, transforms
+from torchvision.utils import save_image
+
+from relaxit.distributions import StochasticTimesSmooth
+
+
+def parse_arguments() -> argparse.Namespace:
+ """
+ Parse command line arguments.
+
+ Returns:
+ argparse.Namespace: Parsed command line arguments.
+ """
+ parser = argparse.ArgumentParser(description="VAE MNIST Example")
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=128,
+ metavar="N",
+ help="input batch size for training (default: 128)",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=10,
+ metavar="N",
+ help="number of epochs to train (default: 10)",
+ )
+ parser.add_argument(
+ "--no-cuda", action="store_true", default=False, help="enables CUDA training"
+ )
+ parser.add_argument(
+ "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
+ )
+ parser.add_argument(
+ "--log_interval",
+ type=int,
+ default=10,
+ metavar="N",
+ help="how many batches to wait before logging training status",
+ )
+ return parser.parse_args()
+
+
+args = parse_arguments()
+args.cuda = not args.no_cuda and torch.cuda.is_available()
+
+torch.manual_seed(args.seed)
+
+device = torch.device("cuda" if args.cuda else "cpu")
+
+os.makedirs("./results/vae_stochastic_times_smooth", exist_ok=True)
+
+kwargs = {"num_workers": 1, "pin_memory": True} if args.cuda else {}
+train_loader = torch.utils.data.DataLoader(
+ datasets.MNIST(
+ "./data", train=True, download=True, transform=transforms.ToTensor()
+ ),
+ batch_size=args.batch_size,
+ shuffle=True,
+ **kwargs
+)
+test_loader = torch.utils.data.DataLoader(
+ datasets.MNIST("./data", train=False, transform=transforms.ToTensor()),
+ batch_size=args.batch_size,
+ shuffle=True,
+ **kwargs
+)
+
+steps = 0
+
+
+class VAE(nn.Module):
+ """
+ Variational Autoencoder (VAE) with StochasticTimesSmooth distribution.
+ """
+
+ def __init__(self) -> None:
+ super(VAE, self).__init__()
+
+ self.fc1 = nn.Linear(784, 400)
+ self.fc2 = nn.Linear(400, 20)
+ self.fc3 = nn.Linear(20, 400)
+ self.fc4 = nn.Linear(400, 784)
+
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Encode the input by passing through the encoder network
+ and return the latent code.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor: Latent code.
+ """
+ h1 = F.relu(self.fc1(x))
+ return self.fc2(h1)
+
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
+ """
+ Decode the latent code by passing through the decoder network
+ and return the reconstructed input.
+
+ Args:
+ z (torch.Tensor): Latent code.
+
+ Returns:
+ torch.Tensor: Reconstructed input.
+ """
+ h3 = F.relu(self.fc3(z))
+ return torch.sigmoid(self.fc4(h3))
+
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Forward pass through the VAE.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: Reconstructed input and latent code.
+ """
+ logits = self.encode(x.view(-1, 784))
+ q_z = StochasticTimesSmooth(logits=logits)
+ probs = q_z.probs
+ z = q_z.rsample()
+ return self.decode(z), probs
+
+
+model = VAE().to(device)
+optimizer = optim.Adam(model.parameters(), lr=1e-3)
+
+
+def loss_function(
+ recon_x: torch.Tensor,
+ x: torch.Tensor,
+ probs: torch.Tensor,
+ prior: float = 0.5,
+ eps: float = 1e-10,
+) -> torch.Tensor:
+ """
+ Compute the loss function for the VAE.
+
+ Args:
+ recon_x (torch.Tensor): Reconstructed input.
+ x (torch.Tensor): Original input.
+ probs (torch.Tensor): Probabilities for Bernoulli distribution in latent space.
+ prior (float): Prior probability.
+ eps (float): Small value to avoid log(0).
+
+ Returns:
+ torch.Tensor: Loss value.
+ """
+ BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction="sum")
+ t1 = probs * ((probs + eps) / prior).log()
+ t2 = (1 - probs) * ((1 - probs + eps) / (1 - prior)).log()
+ KLD = torch.sum(t1 + t2, dim=-1).sum()
+
+ return BCE + KLD
+
+
+def train(epoch: int) -> None:
+ """
+ Train the VAE for one epoch.
+
+ Args:
+ epoch (int): Current epoch number.
+ """
+ global steps
+ model.train()
+ train_loss = 0
+ for batch_idx, (data, _) in enumerate(train_loader):
+ data = data.to(device)
+ optimizer.zero_grad()
+ recon_batch, probs = model(data)
+ loss = loss_function(recon_batch, data, probs)
+ loss.backward()
+ train_loss += loss.item()
+ optimizer.step()
+
+ if batch_idx % args.log_interval == 0:
+ print(
+ "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
+ epoch,
+ batch_idx * len(data),
+ len(train_loader.dataset),
+ 100.0 * batch_idx / len(train_loader),
+ loss.item() / len(data),
+ )
+ )
+
+ steps += 1
+
+ print(
+ "====> Epoch: {} Average loss: {:.4f}".format(
+ epoch, train_loss / len(train_loader.dataset)
+ )
+ )
+
+
+def test(epoch: int) -> None:
+ """
+ Test the VAE for one epoch.
+
+ Args:
+ epoch (int): Current epoch number.
+ """
+ model.eval()
+ test_loss = 0
+ with torch.no_grad():
+ for i, (data, _) in enumerate(test_loader):
+ data = data.to(device)
+ recon_batch, probs = model(data)
+ test_loss += loss_function(recon_batch, data, probs).item()
+ if i == 0:
+ n = min(data.size(0), 8)
+ comparison = torch.cat(
+ [data[:n], recon_batch.view(args.batch_size, 1, 28, 28)[:n]]
+ )
+ save_image(
+ comparison.cpu(),
+ "results/vae_stochastic_times_smooth/reconstruction_"
+ + str(epoch)
+ + ".png",
+ nrow=n,
+ )
+
+ test_loss /= len(test_loader.dataset)
+ print("====> Test set loss: {:.4f}".format(test_loss))
+
+
+if __name__ == "__main__":
+ for epoch in range(1, args.epochs + 1):
+ train(epoch)
+ test(epoch)
+ with torch.no_grad():
+ sample = np.random.binomial(1, 0.5, size=(64, 20))
+ sample = torch.from_numpy(np.float32(sample)).to(device)
+ sample = model.decode(sample).cpu()
+ save_image(
+ sample.view(64, 1, 28, 28),
+ "results/vae_stochastic_times_smooth/sample_" + str(epoch) + ".png",
+ )
diff --git a/demo/vae_straight_through_bernoulli.py b/demo/vae_straight_through_bernoulli.py
index 6d5cf69..335fe11 100644
--- a/demo/vae_straight_through_bernoulli.py
+++ b/demo/vae_straight_through_bernoulli.py
@@ -129,14 +129,11 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
Returns:
tuple[torch.Tensor, torch.Tensor]: Reconstructed input and latent code.
"""
- a = self.encode(x.view(-1, 784))
- a = a.float()
- q_z = StraightThroughBernoulli(a)
-
+ logits = self.encode(x.view(-1, 784))
+ q_z = StraightThroughBernoulli(logits=logits)
+ probs = q_z.probs
z = q_z.rsample()
- z = z.float()
-
- return self.decode(z), a
+ return self.decode(z), probs
model = VAE().to(device)
@@ -146,7 +143,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def loss_function(
recon_x: torch.Tensor,
x: torch.Tensor,
- a: torch.Tensor,
+ probs: torch.Tensor,
prior: float = 0.5,
eps: float = 1e-10,
) -> torch.Tensor:
@@ -156,7 +153,7 @@ def loss_function(
Args:
recon_x (torch.Tensor): Reconstructed input.
x (torch.Tensor): Original input.
- a (torch.Tensor): Latent code.
+ probs (torch.Tensor): Probabilities for Bernoulli distribution in latent space.
prior (float): Prior probability.
eps (float): Small value to avoid log(0).
@@ -164,9 +161,8 @@ def loss_function(
torch.Tensor: Loss value.
"""
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction="sum")
- q_z = torch.sigmoid(a)
- t1 = q_z * ((q_z + eps) / prior).log()
- t2 = (1 - q_z) * ((1 - q_z + eps) / (1 - prior)).log()
+ t1 = probs * ((probs + eps) / prior).log()
+ t2 = (1 - probs) * ((1 - probs + eps) / (1 - prior)).log()
KLD = torch.sum(t1 + t2, dim=-1).sum()
return BCE + KLD
@@ -185,8 +181,8 @@ def train(epoch: int) -> None:
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
- recon_batch, q_z = model(data)
- loss = loss_function(recon_batch, data, q_z)
+ recon_batch, probs = model(data)
+ loss = loss_function(recon_batch, data, probs)
loss.backward()
train_loss += loss.item()
optimizer.step()
@@ -223,8 +219,8 @@ def test(epoch: int) -> None:
with torch.no_grad():
for i, (data, _) in enumerate(test_loader):
data = data.to(device)
- recon_batch, q_z = model(data)
- test_loss += loss_function(recon_batch, data, q_z).item()
+ recon_batch, probs = model(data)
+ test_loss += loss_function(recon_batch, data, probs).item()
if i == 0:
n = min(data.size(0), 8)
comparison = torch.cat(
diff --git a/docs/conf.py b/docs/conf.py
deleted file mode 100644
index a07925a..0000000
--- a/docs/conf.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Configuration file for the Sphinx documentation builder.
-#
-# This file only contains a selection of the most common options. For a full
-# list see the documentation:
-# https://www.sphinx-doc.org/en/master/usage/configuration.html
-
-# -- Path setup --------------------------------------------------------------
-
-# If extensions (or modules to document with autodoc) are in another directory,
-# add these directories to sys.path here. If the directory is relative to the
-# documentation root, use os.path.abspath to make it absolute, like shown here.
-#
-import os
-import sys
-
-from relaxit import __version__
-
-
-# -- Project information -----------------------------------------------------
-
-project = "Just Relax It"
-copyright = "2024, Daniil Dorin, Igor Ignashin, Nikita Kiselev, Andrey Veprikov"
-author = "Daniil Dorin, Igor Ignashin, Nikita Kiselev, Andrey Veprikov"
-
-version = __version__
-master_doc = "index"
-
-# -- General configuration ---------------------------------------------------
-
-# Add any Sphinx extension module names here, as strings. They can be
-# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
-# ones.
-extensions = [
- 'sphinx.ext.napoleon',
- 'sphinx.ext.duration',
- 'sphinx.ext.doctest',
- 'sphinx.ext.autodoc',
- 'myst_parser'
-]
-highlight_language = 'python'
-
-autodoc_mock_imports = ["numpy", "scipy", "sklearn"]
-
-# Add any paths that contain templates here, relative to this directory.
-templates_path = ["_templates"]
-
-# List of patterns, relative to source directory, that match files and
-# directories to ignore when looking for source files.
-# This pattern also affects html_static_path and html_extra_path.
-exclude_patterns = []
-
-html_extra_path = []
-
-html_context = {
- "display_github": True, # Integrate GitHub
- "github_user": "Intelligent-Systems-Phystech", # Username
- "github_repo": "discrete-variables-relaxation", # Repo name
- "github_version": "main", # Version
- "conf_py_path": "./doc/", # Path in the checkout to the docs root
-}
-
-
-# -- Options for HTML output -------------------------------------------------
-
-# The theme to use for HTML and HTML Help pages. See the documentation for
-# a list of builtin themes.
-#
-html_theme = "sphinx_rtd_theme"
-
-# Add any paths that contain custom static files (such as style sheets) here,
-# relative to this directory. They are copied after the builtin static files,
-# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ["_static"]
diff --git a/docs/make_docs.sh b/docs/make_docs.sh
index a88d3f6..4fd437a 100644
--- a/docs/make_docs.sh
+++ b/docs/make_docs.sh
@@ -1,2 +1,7 @@
-rm -rf docs/source
-sphinx-apidoc -o docs/source/ src/relaxit
\ No newline at end of file
+#!/bin/bash
+
+# Generate API documentation
+# sphinx-apidoc -o source ../src/relaxit
+
+# Build the documentation
+sphinx-build -b html source build/html
\ No newline at end of file
diff --git a/docs/source/_static/css/relaxit.css b/docs/source/_static/css/relaxit.css
new file mode 100644
index 0000000..18802fe
--- /dev/null
+++ b/docs/source/_static/css/relaxit.css
@@ -0,0 +1,29 @@
+@import url("theme.css");
+
+.wy-side-nav-search {
+ background-color: #565656;
+}
+
+.wy-side-nav-search a {
+ margin: 0
+}
+
+.wy-side-nav-search > div.version {
+ color: #f26822;
+}
+
+.wy-nav-top {
+ background: #404040;
+}
+
+.wy-menu-vertical li.on a, .wy-menu-vertical li.current>a {
+ background: #ccc;
+}
+
+.wy-side-nav-search input[type=text] {
+ border-color: #313131;
+}
+
+.wy-side-nav-search>a img.logo, .wy-side-nav-search .wy-dropdown>a img.logo {
+ max-width: 40%;
+}
\ No newline at end of file
diff --git a/docs/source/_static/img/logo-small.png b/docs/source/_static/img/logo-small.png
new file mode 100644
index 0000000..10e24a5
Binary files /dev/null and b/docs/source/_static/img/logo-small.png differ
diff --git a/docs/source/_static/img/logo.png b/docs/source/_static/img/logo.png
new file mode 100644
index 0000000..ca4401e
Binary files /dev/null and b/docs/source/_static/img/logo.png differ
diff --git a/docs/source/_static/img/overview.png b/docs/source/_static/img/overview.png
new file mode 100644
index 0000000..a82bab5
Binary files /dev/null and b/docs/source/_static/img/overview.png differ
diff --git a/docs/source/conf.py b/docs/source/conf.py
new file mode 100644
index 0000000..2a67f96
--- /dev/null
+++ b/docs/source/conf.py
@@ -0,0 +1,71 @@
+# Configuration file for the Sphinx documentation builder.
+
+import os
+import sys
+
+from relaxit import __version__
+
+# -- Path setup --------------------------------------------------------------
+
+sys.path.insert(0, os.path.abspath('../../src'))
+
+# -- Project information -----------------------------------------------------
+
+project = "Just Relax It"
+copyright = "2024, Daniil Dorin, Igor Ignashin, Nikita Kiselev, Andrey Veprikov"
+author = "Daniil Dorin, Igor Ignashin, Nikita Kiselev, Andrey Veprikov"
+
+version = __version__
+master_doc = "index"
+
+# -- General configuration ---------------------------------------------------
+
+extensions = [
+ 'sphinx.ext.napoleon',
+ 'sphinx.ext.duration',
+ 'sphinx.ext.doctest',
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.viewcode',
+ 'sphinx.ext.githubpages',
+ 'myst_parser'
+]
+
+highlight_language = 'python'
+
+autodoc_mock_imports = ["numpy", "scipy", "sklearn"]
+
+templates_path = ["_templates"]
+exclude_patterns = []
+html_extra_path = []
+
+html_context = {
+ "display_github": True,
+ "github_user": "Intelligent-Systems-Phystech",
+ "github_repo": "discrete-variables-relaxation",
+ "github_version": "main",
+ "conf_py_path": "./doc/",
+}
+
+# -- Options for HTML output -------------------------------------------------
+
+html_logo = "_static/img/logo-small.png"
+
+html_theme = "sphinx_rtd_theme"
+
+html_theme_options = {
+ "navigation_depth": 3,
+ "logo_only": True,
+}
+
+html_static_path = ["_static"]
+html_css_files = [
+ 'css/relaxit.css',
+]
+
+# -- Options for intersphinx extension ---------------------------------------
+
+intersphinx_mapping = {
+ 'python': ('https://docs.python.org/3/', None),
+ 'numpy': ('https://numpy.org/doc/stable/', None),
+ 'torch': ('https://pytorch.org/docs/stable/', None),
+}
diff --git a/docs/index.rst b/docs/source/index.rst
similarity index 60%
rename from docs/index.rst
rename to docs/source/index.rst
index 114f005..e49515f 100644
--- a/docs/index.rst
+++ b/docs/source/index.rst
@@ -3,29 +3,37 @@
Just Relax It
=============
-.. image:: ../assets/logo.png
+.. image:: _static/img/logo.png
:width: 200
:align: center
+.. raw:: html
+
+
+
"Just Relax It" is a cutting-edge Python library designed to streamline the optimization of discrete probability distributions in neural networks, offering a suite of advanced relaxation techniques compatible with PyTorch.
Motivation
----------
-For lots of mathematical problems we need an ability to sample discrete random variables.
-The problem is that due to continuos nature of deep learning optimization, the usage of truely discrete random variables is infeasible.
-Thus we use different relaxation method.
-One of them, `Concrete distribution
`_ or `Gumbel-Softmax `_ (this is one distribution proposed in parallel by two research groups) is implemented in different DL packages.
-In this project we implement different alternatives to it.
+For lots of mathematical problems we need an ability to sample discrete random variables.
+The problem is that due to continuous nature of deep learning optimization, the usage of truly discrete random variables is infeasible.
+Thus we use different relaxation methods.
+One of them, `Concrete distribution `_ or `Gumbel-Softmax `_ (this is one distribution proposed in parallel by two research groups) is implemented in different DL packages.
+In this project we implement different alternatives to it.
-.. image:: ../assets/overview.png
+.. image:: _static/img/overview.png
:width: 600
:align: center
+.. raw:: html
+
+
+
.. toctree::
:maxdepth: 1
:caption: Guidelines
-
+
install
quickstart
@@ -33,6 +41,4 @@ In this project we implement different alternatives to it.
:maxdepth: 1
:caption: Code
- source/modules
- source/relaxit.distributions
-
+ relaxit.distributions
diff --git a/docs/install.md b/docs/source/install.md
similarity index 100%
rename from docs/install.md
rename to docs/source/install.md
diff --git a/docs/source/modules.rst b/docs/source/modules.rst
index 41f9d2c..4fec1e2 100644
--- a/docs/source/modules.rst
+++ b/docs/source/modules.rst
@@ -2,6 +2,6 @@ relaxit
=======
.. toctree::
- :maxdepth: 4
+ :maxdepth: 1
relaxit
diff --git a/docs/quickstart.md b/docs/source/quickstart.md
similarity index 99%
rename from docs/quickstart.md
rename to docs/source/quickstart.md
index 74caf23..ec67fea 100644
--- a/docs/quickstart.md
+++ b/docs/source/quickstart.md
@@ -16,4 +16,4 @@ distribution = InvertibleGaussian(loc, scale, temperature)
sample = distribution.rsample()
print('sample.shape:', sample.shape)
print('sample.requires_grad:', sample.requires_grad)
-```
\ No newline at end of file
+```
diff --git a/docs/source/relaxit.distributions.rst b/docs/source/relaxit.distributions.rst
index 4b41e2e..33ce68d 100644
--- a/docs/source/relaxit.distributions.rst
+++ b/docs/source/relaxit.distributions.rst
@@ -1,85 +1,63 @@
-relaxit.distributions package
-=============================
+relaxit.distributions
+=====================
-Submodules
-----------
-
-relaxit.distributions.CorrelatedRelaxedBernoulli module
--------------------------------------------------------
-
-.. automodule:: relaxit.distributions.CorrelatedRelaxedBernoulli
+.. automodule:: relaxit.distributions
:members:
:undoc-members:
:show-inheritance:
-relaxit.distributions.GaussianRelaxedBernoulli module
------------------------------------------------------
+Distribution Classes
+--------------------
-.. automodule:: relaxit.distributions.GaussianRelaxedBernoulli
+.. autoclass:: relaxit.distributions.CorrelatedRelaxedBernoulli.CorrelatedRelaxedBernoulli
:members:
:undoc-members:
:show-inheritance:
-relaxit.distributions.GumbelSoftmaxTopK module
-----------------------------------------------
-
-.. automodule:: relaxit.distributions.GumbelSoftmaxTopK
+.. autoclass:: relaxit.distributions.GaussianRelaxedBernoulli.GaussianRelaxedBernoulli
:members:
:undoc-members:
:show-inheritance:
-relaxit.distributions.HardConcrete module
------------------------------------------
-
-.. automodule:: relaxit.distributions.HardConcrete
+.. autoclass:: relaxit.distributions.GumbelSoftmaxTopK.GumbelSoftmaxTopK
:members:
:undoc-members:
:show-inheritance:
-relaxit.distributions.InvertibleGaussian module
------------------------------------------------
-
-.. automodule:: relaxit.distributions.InvertibleGaussian
+.. autoclass:: relaxit.distributions.HardConcrete.HardConcrete
:members:
:undoc-members:
:show-inheritance:
-relaxit.distributions.LogisticNormalSoftmax module
---------------------------------------------------
-
-.. automodule:: relaxit.distributions.LogisticNormalSoftmax
+.. autoclass:: relaxit.distributions.InvertibleGaussian.InvertibleGaussian
:members:
:undoc-members:
:show-inheritance:
-relaxit.distributions.StraightThroughBernoulli module
------------------------------------------------------
-
-.. automodule:: relaxit.distributions.StraightThroughBernoulli
+.. autoclass:: relaxit.distributions.LogisticNormalSoftmax.LogisticNormalSoftmax
:members:
:undoc-members:
:show-inheritance:
-relaxit.distributions.approx module
------------------------------------
-
-.. automodule:: relaxit.distributions.approx
+.. autoclass:: relaxit.distributions.StochasticTimesSmooth.StochasticTimesSmooth
:members:
:undoc-members:
:show-inheritance:
-relaxit.distributions.kl module
--------------------------------
-
-.. automodule:: relaxit.distributions.kl
+.. autoclass:: relaxit.distributions.StraightThroughBernoulli.StraightThroughBernoulli
:members:
:undoc-members:
:show-inheritance:
-Module contents
+Utility Modules
---------------
-.. automodule:: relaxit.distributions
+.. automodule:: relaxit.distributions.approx
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. automodule:: relaxit.distributions.kl
:members:
:undoc-members:
:show-inheritance:
diff --git a/docs/source/relaxit.rst b/docs/source/relaxit.rst
index 4ce6ca4..06b6fdb 100644
--- a/docs/source/relaxit.rst
+++ b/docs/source/relaxit.rst
@@ -1,11 +1,11 @@
-relaxit package
-===============
+relaxit
+=======
Subpackages
-----------
.. toctree::
- :maxdepth: 4
+ :maxdepth: 1
relaxit.distributions
diff --git a/setup.py b/setup.py
index 07c2af9..b159d33 100644
--- a/setup.py
+++ b/setup.py
@@ -26,5 +26,5 @@
url="https://github.com/intsystems/discrete-variables-relaxation",
package_dir= {"": "src"},
packages=find_packages(where="src"),
- install_requires=["pyro-ppl==1.9.1"],
+ install_requires=["pyro-ppl>=1.9.1"],
)
\ No newline at end of file
diff --git a/src/relaxit/_version.py b/src/relaxit/_version.py
index 0058b93..ff1068c 100644
--- a/src/relaxit/_version.py
+++ b/src/relaxit/_version.py
@@ -1 +1 @@
-__version__ = "1.0.1"
\ No newline at end of file
+__version__ = "1.1.0"
\ No newline at end of file
diff --git a/src/relaxit/distributions/GumbelSoftmaxTopK.py b/src/relaxit/distributions/GumbelSoftmaxTopK.py
index b0769ce..c2ac519 100644
--- a/src/relaxit/distributions/GumbelSoftmaxTopK.py
+++ b/src/relaxit/distributions/GumbelSoftmaxTopK.py
@@ -1,13 +1,13 @@
import torch
import torch.nn.functional as F
-from torch.distributions import constraints
from pyro.distributions.torch_distribution import TorchDistribution
from torch.distributions import constraints
+from torch.distributions.utils import probs_to_logits, logits_to_probs
class GumbelSoftmaxTopK(TorchDistribution):
r"""
- Implimentation of the Gaussian-soft max topK trick from https://arxiv.org/pdf/1903.06059
+ Implementation of the Gaussian-Softmax top-K trick from https://arxiv.org/pdf/1903.06059.
:param a: logits.
:type a: torch.Tensor
@@ -50,10 +50,14 @@ def __init__(
:param validate_args: Whether to validate arguments.
:type validate_args: bool
"""
- if probs is None and logits is None:
- raise ValueError("Pass `probs` or `logits`!")
- elif probs is None:
- self.probs = logits / logits.sum(dim=-1, keepdim=True)
+ if (probs is None) == (logits is None):
+ raise ValueError("Pass `probs` or `logits`, but not both of them!")
+ elif probs is not None:
+ self.probs = probs
+ self.logits = probs_to_logits(probs)
+ else:
+ self.logits = logits
+ self.probs = logits_to_probs(logits)
self.K = K.int() # Ensure K is a int tensor
self.tau = tau
self.hard = hard
diff --git a/src/relaxit/distributions/InvertibleGaussian.py b/src/relaxit/distributions/InvertibleGaussian.py
index 8b9ff04..a6cf742 100644
--- a/src/relaxit/distributions/InvertibleGaussian.py
+++ b/src/relaxit/distributions/InvertibleGaussian.py
@@ -6,7 +6,7 @@
class InvertibleGaussian(TorchDistribution):
"""
- Invertible Gaussian distribution class inheriting from Pyro's TorchDistribution.
+ Invertible Gaussian distribution, as it was presented in https://arxiv.org/abs/1912.09588.
Parameters:
- loc (Tensor): The mean (mu) of the normal distribution.
diff --git a/src/relaxit/distributions/LogisticNormalSoftmax.py b/src/relaxit/distributions/LogisticNormalSoftmax.py
index 743a8a6..fdf3b6c 100644
--- a/src/relaxit/distributions/LogisticNormalSoftmax.py
+++ b/src/relaxit/distributions/LogisticNormalSoftmax.py
@@ -3,8 +3,6 @@
from pyro.distributions.transforms import SoftmaxTransform
-# We implement LogisticNormal distribution with SoftmaxTransform instead of
-# StickBreakingTransform, which is originally applied in the PyTorch and Pyro
class LogisticNormalSoftmax(TransformedDistribution):
r"""
Creates a logistic-normal distribution parameterized by :attr:`loc` and :attr:`scale`
diff --git a/src/relaxit/distributions/STEstimator.py b/src/relaxit/distributions/STEstimator.py
deleted file mode 100644
index 8d97a78..0000000
--- a/src/relaxit/distributions/STEstimator.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import torch
-import torch.nn.functional as F
-from torch.distributions import constraints
-from pyro.distributions.torch_distribution import TorchDistribution
-from torch.distributions import constraints
-from torch.distributions.utils import probs_to_logits
-
-
-class StraightThroughEstimator(TorchDistribution):
- r"""
- Implimentation of the Straight Through Estimator from https://arxiv.org/abs/1910.02176
-
- :param a: logits.
- :type a: torch.Tensor
- :param validate_args: Whether to validate arguments.
- :type validate_args: bool
- """
-
- arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
- has_rsample = True
-
- def __init__(
- self,
- probs: torch.Tensor = None,
- logits: torch.Tensor = None,
- validate_args: bool = None,
- ):
- r"""Initializes the ST Estimator.
-
- :param probs: TODO
- :param logits: the log-odds of sampling `1`.
- :type logits: torch.Tensor
- :param validate_args: Whether to validate arguments.
- :type validate_args: bool
- """
- if probs is None and logits is None:
- raise ValueError("Pass `probs` or `logits`!")
- elif probs is None:
- self.probs = logits / logits.sum(dim=-1, keepdim=True)
- self.logits=logits
- super().__init__(validate_args=validate_args)
-
- @property
- def batch_shape(self) -> torch.Size:
- """
- Returns the batch shape of the distribution.
-
- The batch shape represents the shape of independent distributions.
- For example, if `loc` is vector of length 3,
- the batch shape will be `[3]`, indicating 3 independent distributions.
- """
- return self.probs.shape
-
- @property
- def event_shape(self) -> torch.Size:
- """
- Returns the event shape of the distribution.
-
- The event shape represents the shape of each individual event.
- """
- return torch.Size()
-
- def rsample(self) -> torch.Tensor:
- """
- Generates a sample from the distribution using the Gaussian-soft max topK trick.
-
- :return: A sample from the distribution.
- :rtype: torch.Tensor
- """
-
- binary_sample = torch.bernoulli(self.probs).detach()
-
- return binary_sample + (self.probs - self.probs.detach())
-
- def sample(self) -> torch.Tensor:
- """
- Generates a sample from the distribution with no grad.
-
- :return: A sample from the distribution.
- :rtype: torch.Tensor
- """
- with torch.no_grad():
- return self.rsample()
-
- def log_prob(self, value: torch.Tensor) -> torch.Tensor:
- """
- Computes the log probability of the given value.
- :param value: The value for which to compute the log probability.
- :type value: torch.Tensor
- :return: The log probability of the given value.
- :rtype: torch.Tensor
- """
- if self._validate_args:
- self._validate_sample(value)
-
- return -F.binary_cross_entropy(self.probs, value, reduction="none")
-
- def _validate_sample(self, value: torch.Tensor):
- """
- Validates the given sample value.
-
- Args:
- - value (Tensor): The sample value to validate.
- """
- if self._validate_args:
- if ((value != 1.0) & (value != 0.0)).any():
- ValueError(
- f"All coordinates in `value` must be 0 or 1 and you have {value}"
- )
\ No newline at end of file
diff --git a/src/relaxit/distributions/StochasticTimesSmooth.py b/src/relaxit/distributions/StochasticTimesSmooth.py
new file mode 100644
index 0000000..6731c8d
--- /dev/null
+++ b/src/relaxit/distributions/StochasticTimesSmooth.py
@@ -0,0 +1,27 @@
+import torch
+from pyro.distributions import Bernoulli
+
+
+class StochasticTimesSmooth(Bernoulli):
+ r"""
+ Implementation of the Stochastic Times Smooth from https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235
+
+ Creates a Bernoulli distribution parameterized by :attr:`probs`
+ or :attr:`logits` (but not both).
+
+ Samples are binary (0 or 1). They take the value `1` with probability `p`
+ and `0` with probability `1 - p`.
+
+ However, supports gradient flow through parameters due to the
+ stochastic times smooth gradient estimator.
+ """
+ has_rsample = True
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def rsample(self, sample_shape: torch.Size = torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ sqrt_probs = self.probs.expand(shape).sqrt()
+ sample = sqrt_probs * torch.bernoulli(sqrt_probs)
+ return sample
\ No newline at end of file
diff --git a/src/relaxit/distributions/StraightThroughBernoulli.py b/src/relaxit/distributions/StraightThroughBernoulli.py
index bad2db3..a222158 100644
--- a/src/relaxit/distributions/StraightThroughBernoulli.py
+++ b/src/relaxit/distributions/StraightThroughBernoulli.py
@@ -1,108 +1,27 @@
import torch
-from pyro.distributions.torch_distribution import TorchDistribution
-from torch.distributions import constraints
+from pyro.distributions import Bernoulli
-class StraightThroughBernoulli(TorchDistribution):
- """
+class StraightThroughBernoulli(Bernoulli):
+ r"""
+ Implementation of the Straight Through Bernoulli from https://arxiv.org/abs/1910.02176
+
+ Creates a Bernoulli distribution parameterized by :attr:`probs`
+ or :attr:`logits` (but not both).
- Parameters:
- - a (Tensor): logits
+ Samples are binary (0 or 1). They take the value `1` with probability `p`
+ and `0` with probability `1 - p`.
+
+ However, supports gradient flow through parameters due to the
+ straight through gradient estimator.
"""
-
- arg_constraints = {"a": constraints.real}
- support = constraints.real
has_rsample = True
-
- def __init__(self, a: torch.Tensor, validate_args: bool = None):
- """
-
- Args:
- - a (Tensor): logits
- - validate_args (bool): Whether to validate arguments.
- """
-
- self.a = a.float() # Ensure a is a float tensor
- self.uniform = torch.distributions.Uniform(
- torch.tensor([0.0], device=self.a.device),
- torch.tensor([1.0], device=self.a.device),
- )
- super().__init__(validate_args=validate_args)
-
- @property
- def batch_shape(self) -> torch.Size:
- """
- Returns the batch shape of the distribution.
-
- The batch shape represents the shape of independent distributions.
- For example, if `loc` is vector of length 3,
- the batch shape will be `[3]`, indicating 3 independent Bernoulli distributions.
- """
- return self.a.shape
-
- @property
- def event_shape(self) -> torch.Size:
- """
- Returns the event shape of the distribution.
-
- The event shape represents the shape of each individual event.
- """
- return torch.Size()
-
- def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
- """
- Generates a sample from the distribution using the reparameterization trick.
-
- Args:
- - sample_shape (torch.Size): The shape of the sample.
-
- Returns:
- - torch.Tensor: A sample from the distribution.
- """
- p = torch.nn.functional.sigmoid(self.a)
- b = torch.distributions.Bernoulli(torch.sqrt(p)).sample(
- sample_shape=sample_shape
- )
- z = torch.sqrt(p) * b
- return z
-
- def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
- """
- Generates a sample from the distribution.
-
- Args:
- - sample_shape (torch.Size): The shape of the sample.
-
- Returns:
- - torch.Tensor: A sample from the distribution.
- """
- with torch.no_grad():
- return self.rsample(sample_shape)
-
- def log_prob(self, value: torch.Tensor) -> torch.Tensor:
- """
- Computes the log probability of the given value.
-
- Args:
- - value (Tensor): The value for which to compute the log probability.
-
- Returns:
- - torch.Tensor: The log probability of the given value.
- """
- if self._validate_args:
- self._validate_sample(value)
- p = torch.nn.functional.sigmoid(self.a)
-
- log_prob = torch.where(value == 0, torch.log(1 - p), torch.log(p))
- return log_prob
-
- def _validate_sample(self, value: torch.Tensor):
- """
- Validates the given sample value.
-
- Args:
- - value (Tensor): The sample value to validate.
- """
- if self._validate_args:
- if (value < 0).any():
- raise ValueError("Sample value must be non negative")
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def rsample(self, sample_shape: torch.Size = torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ probs = self.probs.expand(shape)
+ sample = torch.bernoulli(probs).detach()
+ return sample + probs - probs.detach()
\ No newline at end of file
diff --git a/src/relaxit/distributions/__init__.py b/src/relaxit/distributions/__init__.py
index b6e18f3..68e1512 100644
--- a/src/relaxit/distributions/__init__.py
+++ b/src/relaxit/distributions/__init__.py
@@ -1,17 +1,20 @@
-from .GaussianRelaxedBernoulli import GaussianRelaxedBernoulli
from .CorrelatedRelaxedBernoulli import CorrelatedRelaxedBernoulli
-from .StraightThroughBernoulli import StraightThroughBernoulli
+from .GaussianRelaxedBernoulli import GaussianRelaxedBernoulli
+from .GumbelSoftmaxTopK import GumbelSoftmaxTopK
from .HardConcrete import HardConcrete
from .InvertibleGaussian import InvertibleGaussian
from .LogisticNormalSoftmax import LogisticNormalSoftmax
-from .GumbelSoftmaxTopK import GumbelSoftmaxTopK
+from .StochasticTimesSmooth import StochasticTimesSmooth
+from .StraightThroughBernoulli import StraightThroughBernoulli
+
__all__ = [
- "GaussianRelaxedBernoulli",
"CorrelatedRelaxedBernoulli",
- "StraightThroughBernoulli",
+ "GaussianRelaxedBernoulli",
+ "GumbelSoftmaxTopK",
"HardConcrete",
"InvertibleGaussian",
"LogisticNormalSoftmax",
- "GumbelSoftmaxTopK",
+ "StochasticTimesSmooth",
+ "StraightThroughBernoulli"
]
diff --git a/src/relaxit/distributions/kl.py b/src/relaxit/distributions/kl.py
index 5d37dc6..84da13a 100644
--- a/src/relaxit/distributions/kl.py
+++ b/src/relaxit/distributions/kl.py
@@ -4,7 +4,23 @@
@register_kl(InvertibleGaussian, InvertibleGaussian)
-def _kl_igr_igr(p, q):
+def _kl_igr_igr(p: InvertibleGaussian, q: InvertibleGaussian):
+ r"""
+ Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
+
+ Based on the paper https://arxiv.org/abs/1912.09588.
+
+ .. math::
+
+ KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx
+
+ Args:
+ p (InvertibleGaussian): A :class:`~relaxit.distributions.InvertibleGaussian` object.
+ q (InvertibleGaussian): A :class:`~relaxit.distributions.InvertibleGaussian` object.
+
+ Returns:
+ Tensor: A batch of KL divergences of shape `batch_shape`.
+ """
p_normal = Normal(p.loc, p.scale)
q_normal = Normal(q.loc, q.scale)
return kl_divergence(p_normal, q_normal)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/distributions/test_CorrelatedRelaxedBernoulli.py b/tests/distributions/test_CorrelatedRelaxedBernoulli.py
index 33eff73..3ea3771 100644
--- a/tests/distributions/test_CorrelatedRelaxedBernoulli.py
+++ b/tests/distributions/test_CorrelatedRelaxedBernoulli.py
@@ -1,10 +1,7 @@
import torch
import sys, os
-sys.path.append(
- os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src"))
-)
-from relaxit.distributions.CorrelatedRelaxedBernoulli import CorrelatedRelaxedBernoulli
+from relaxit.distributions import CorrelatedRelaxedBernoulli
# Testing reparameterized sampling from the GaussianRelaxedBernoulli distribution
@@ -13,9 +10,8 @@ def test_sample_shape():
pi = torch.tensor([0.1, 0.2, 0.3])
R = torch.tensor([[1.0]])
tau = torch.tensor([2.0])
-
- distr = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau)
- samples = distr.rsample()
+ distribution = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau)
+ samples = distribution.rsample()
assert samples.shape == torch.Size([3])
@@ -23,19 +19,26 @@ def test_sample_grad():
pi = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
R = torch.tensor([[1.0]])
tau = torch.tensor([2.0])
-
- distr = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau)
- samples = distr.rsample()
+ distribution = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau)
+ samples = distribution.rsample()
assert samples.requires_grad == True
-def test_log_prob():
- pi = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
+def test_log_prob_shape():
+ pi = torch.tensor([0.1, 0.2, 0.3])
R = torch.tensor([[1.0]])
tau = torch.tensor([2.0])
-
- distr = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau)
+ distribution = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau)
value = torch.tensor([1.0])
- log_prob = distr.log_prob(value)
+ log_prob = distribution.log_prob(value)
assert log_prob.shape == torch.Size([3])
- assert log_prob.requires_grad == True
+
+
+def test_log_prob_grad():
+ pi = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
+ R = torch.tensor([[1.0]])
+ tau = torch.tensor([2.0])
+ distribution = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau)
+ value = torch.tensor([1.0])
+ log_prob = distribution.log_prob(value)
+ assert log_prob.requires_grad == True
\ No newline at end of file
diff --git a/tests/distributions/test_GaussianRelaxedBernoulli.py b/tests/distributions/test_GaussianRelaxedBernoulli.py
index 07a1f15..f348066 100644
--- a/tests/distributions/test_GaussianRelaxedBernoulli.py
+++ b/tests/distributions/test_GaussianRelaxedBernoulli.py
@@ -1,10 +1,7 @@
import torch
import sys, os
-sys.path.append(
- os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src"))
-)
-from relaxit.distributions.GaussianRelaxedBernoulli import GaussianRelaxedBernoulli
+from relaxit.distributions import GaussianRelaxedBernoulli
# Testing reparameterized sampling from the GaussianRelaxedBernoulli distribution
@@ -12,32 +9,32 @@
def test_sample_shape():
loc = torch.tensor([0.0])
scale = torch.tensor([1.0])
-
- distr = GaussianRelaxedBernoulli(loc=loc, scale=scale)
- samples = distr.rsample(sample_shape=torch.Size([3]))
+ distribution = GaussianRelaxedBernoulli(loc=loc, scale=scale)
+ samples = distribution.rsample(sample_shape=torch.Size([3]))
assert samples.shape == torch.Size([3, 1])
def test_sample_grad():
loc = torch.tensor([0.0], requires_grad=True)
scale = torch.tensor([1.0], requires_grad=True)
- distr = GaussianRelaxedBernoulli(loc=loc, scale=scale)
- samples = distr.rsample()
+ distribution = GaussianRelaxedBernoulli(loc=loc, scale=scale)
+ samples = distribution.rsample()
assert samples.requires_grad == True
-def test_log_prob():
- loc = torch.tensor([0.0], requires_grad=True)
- scale = torch.tensor([1.0], requires_grad=True)
- distr = GaussianRelaxedBernoulli(loc=loc, scale=scale)
-
+def test_log_prob_shape():
+ loc = torch.tensor([0.0])
+ scale = torch.tensor([1.0])
+ distribution = GaussianRelaxedBernoulli(loc=loc, scale=scale)
value = torch.tensor([1.0])
- log_prob = distr.log_prob(value)
+ log_prob = distribution.log_prob(value)
assert log_prob.shape == torch.Size([1])
- assert log_prob.requires_grad == True
-if __name__ == "__main__":
- test_sample_shape()
- test_sample_grad()
- test_log_prob()
+def test_log_prob_grad():
+ loc = torch.tensor([0.0], requires_grad=True)
+ scale = torch.tensor([1.0], requires_grad=True)
+ distribution = GaussianRelaxedBernoulli(loc=loc, scale=scale)
+ value = torch.tensor([1.0])
+ log_prob = distribution.log_prob(value)
+ assert log_prob.requires_grad == True
\ No newline at end of file
diff --git a/tests/distributions/test_GumbelSoftmaxTopK.py b/tests/distributions/test_GumbelSoftmaxTopK.py
index b5ac731..9bf9362 100644
--- a/tests/distributions/test_GumbelSoftmaxTopK.py
+++ b/tests/distributions/test_GumbelSoftmaxTopK.py
@@ -1,44 +1,46 @@
import torch
import sys, os
-sys.path.append(
- os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src"))
-)
-from relaxit.distributions.GumbelSoftmaxTopK import GumbelSoftmaxTopK
+from relaxit.distributions import GumbelSoftmaxTopK
# Testing reparameterized sampling and log prob from the GumbelSoftmaxTopK distribution
def test_sample_shape():
- a = torch.tensor([[1.0, 2.0, 3.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]])
+ logits = torch.tensor([[1.0, 2.0, 3.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]])
K = torch.tensor(1)
tau = torch.tensor(0.1)
- distribution = GumbelSoftmaxTopK(a, K=K, tau=tau)
+ distribution = GumbelSoftmaxTopK(logits=logits, K=K, tau=tau)
sample = distribution.rsample()
- assert sample.shape == a.shape
+ assert sample.shape == logits.shape
def test_sample_grad():
- a = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
+ logits = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
K = torch.tensor(2)
tau = torch.tensor(0.1)
- distribution = GumbelSoftmaxTopK(a, K=K, tau=tau)
+ distribution = GumbelSoftmaxTopK(logits=logits, K=K, tau=tau)
sample = distribution.rsample()
assert sample.requires_grad == True
-def test_log_prob():
- a = torch.tensor([1.0, 2.0, 3.0])
+def test_log_prob_shape():
+ logits = torch.tensor([1.0, 2.0, 3.0])
K = torch.tensor(3)
tau = torch.tensor(0.1)
- distribution = GumbelSoftmaxTopK(a, K=K, tau=tau)
+ distribution = GumbelSoftmaxTopK(logits=logits, K=K, tau=tau)
sample = distribution.rsample()
value = torch.tensor([1.0, 1.0, 1.0])
log_prob = distribution.log_prob(value)
- assert log_prob - torch.tensor(0) < 1e-6
+ assert log_prob.shape == torch.Size([3])
+
-
-if __name__ == "__main__":
- test_sample_shape()
- test_sample_grad()
- test_log_prob()
+def test_log_prob_grad():
+ logits = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
+ K = torch.tensor(3)
+ tau = torch.tensor(0.1)
+ distribution = GumbelSoftmaxTopK(logits=logits, K=K, tau=tau)
+ sample = distribution.rsample()
+ value = torch.tensor([1.0, 1.0, 1.0])
+ log_prob = distribution.log_prob(value)
+ assert log_prob.requires_grad == True
\ No newline at end of file
diff --git a/tests/distributions/test_HardConcrete.py b/tests/distributions/test_HardConcrete.py
index c907e2b..ebc0888 100644
--- a/tests/distributions/test_HardConcrete.py
+++ b/tests/distributions/test_HardConcrete.py
@@ -1,10 +1,7 @@
import torch
import sys, os
-sys.path.append(
- os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src"))
-)
-from relaxit.distributions.HardConcrete import HardConcrete
+from relaxit.distributions import HardConcrete
# Testing reparameterized sampling from the HardConcrete distribution
@@ -14,9 +11,8 @@ def test_sample_shape():
beta = torch.tensor([2.0])
gamma = torch.tensor([-3.0])
xi = torch.tensor([4.0])
- distr = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi)
- samples = distr.rsample(sample_shape=torch.Size([3]))
-
+ distribution = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi)
+ samples = distribution.rsample(sample_shape=torch.Size([3]))
assert samples.shape == torch.Size([3, 1])
@@ -25,20 +21,28 @@ def test_sample_grad():
beta = torch.tensor([2.0], requires_grad=True)
gamma = torch.tensor([-3.0], requires_grad=True)
xi = torch.tensor([4.0], requires_grad=True)
- distr = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi)
- samples = distr.rsample(sample_shape=torch.Size([3]))
-
+ distribution = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi)
+ samples = distribution.rsample(sample_shape=torch.Size([3]))
assert samples.requires_grad == True
-def test_log_prob():
+def test_log_prob_shape():
+ alpha = torch.tensor([1.0])
+ beta = torch.tensor([2.0])
+ gamma = torch.tensor([-3.0])
+ xi = torch.tensor([4.0])
+ distribution = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi)
+ value = torch.tensor([1.0])
+ log_prob = distribution.log_prob(value)
+ assert log_prob.shape == torch.Size([1])
+
+
+def test_log_prob_grad():
alpha = torch.tensor([1.0], requires_grad=True)
beta = torch.tensor([2.0], requires_grad=True)
gamma = torch.tensor([-3.0], requires_grad=True)
xi = torch.tensor([4.0], requires_grad=True)
- distr = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi)
-
+ distribution = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi)
value = torch.tensor([1.0])
- log_prob = distr.log_prob(value)
- assert log_prob.shape == torch.Size([1])
- assert log_prob.requires_grad == True
+ log_prob = distribution.log_prob(value)
+ assert log_prob.requires_grad == True
\ No newline at end of file
diff --git a/tests/distributions/test_InvertibleGaussian.py b/tests/distributions/test_InvertibleGaussian.py
index e28b2df..90216a4 100644
--- a/tests/distributions/test_InvertibleGaussian.py
+++ b/tests/distributions/test_InvertibleGaussian.py
@@ -1,15 +1,14 @@
import torch
import sys
-sys.path.append("../../src")
-from relaxit.distributions.InvertibleGaussian import InvertibleGaussian
+from relaxit.distributions import InvertibleGaussian
# Testing reparameterized sampling from the InvertibleGaussian distribution
def test_sample_shape():
- loc = torch.zeros(3, 4, 5, requires_grad=True)
- scale = torch.ones(3, 4, 5, requires_grad=True)
+ loc = torch.zeros(3, 4, 5)
+ scale = torch.ones(3, 4, 5)
temperature = torch.tensor([1e-0])
distribution = InvertibleGaussian(loc, scale, temperature)
sample = distribution.rsample()
@@ -26,8 +25,8 @@ def test_sample_grad():
def test_log_prob_shape():
- loc = torch.zeros(3, 4, 5, requires_grad=True)
- scale = torch.ones(3, 4, 5, requires_grad=True)
+ loc = torch.zeros(3, 4, 5)
+ scale = torch.ones(3, 4, 5)
temperature = torch.tensor([1e-0])
distribution = InvertibleGaussian(loc, scale, temperature)
value = 0.5 * torch.ones(3, 4, 6)
diff --git a/tests/distributions/test_LogisticNormalSoftmax.py b/tests/distributions/test_LogisticNormalSoftmax.py
index 135c4bf..7c32e96 100644
--- a/tests/distributions/test_LogisticNormalSoftmax.py
+++ b/tests/distributions/test_LogisticNormalSoftmax.py
@@ -1,15 +1,14 @@
import torch
import sys
-sys.path.append("../../src")
-from relaxit.distributions.LogisticNormalSoftmax import LogisticNormalSoftmax
+from relaxit.distributions import LogisticNormalSoftmax
# Testing reparameterized sampling from the LogisticNormalSoftmax distribution
def test_sample_shape():
- loc = torch.zeros(3, 4, 5, requires_grad=True)
- scale = torch.ones(3, 4, 5, requires_grad=True)
+ loc = torch.zeros(3, 4, 5)
+ scale = torch.ones(3, 4, 5)
distribution = LogisticNormalSoftmax(loc, scale)
sample = distribution.rsample()
assert sample.shape == torch.Size([3, 4, 5])
diff --git a/tests/distributions/test_STEstimator.py b/tests/distributions/test_STEstimator.py
deleted file mode 100644
index 524f5cf..0000000
--- a/tests/distributions/test_STEstimator.py
+++ /dev/null
@@ -1,31 +0,0 @@
-import torch
-import sys, os
-
-sys.path.append(
- os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src"))
-)
-from relaxit.distributions.STEstimator import StraightThroughEstimator
-
-# Testing reparameterized sampling and log prob from the StraightThroughEstimator distribution
-
-
-def test_sample_shape():
- a = torch.tensor([[1.0, 2.0, 3.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]])
- distribution = StraightThroughEstimator(logits=a)
- sample = distribution.rsample()
- assert sample.shape == a.shape
-
-
-def test_sample_grad():
- a = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
- distribution = StraightThroughEstimator(logits=a)
- sample = distribution.rsample()
- assert sample.requires_grad == True
-
-
-def test_log_prob():
- a = torch.tensor([1.0, 2.0, 3.0])
- distribution = StraightThroughEstimator(logits=a)
- value = torch.tensor([1.0, 1.0, 1.0])
- log_prob = distribution.log_prob(value)
- assert log_prob.shape == torch.Size([3])
diff --git a/tests/distributions/test_StochasticTimesSmooth.py b/tests/distributions/test_StochasticTimesSmooth.py
new file mode 100644
index 0000000..92e3946
--- /dev/null
+++ b/tests/distributions/test_StochasticTimesSmooth.py
@@ -0,0 +1,37 @@
+import torch
+import sys, os
+
+from relaxit.distributions import StochasticTimesSmooth
+
+# Testing reparameterized sampling from the StochasticTimesSmooth distribution
+
+
+def test_sample_shape():
+ logits = torch.tensor([1., 2., 3.])
+ distribution = StochasticTimesSmooth(logits=logits)
+ samples = distribution.rsample()
+ assert samples.shape == torch.Size([3])
+
+
+def test_sample_grad():
+ logits = torch.tensor([1., 2., 3.], requires_grad=True)
+ distribution = StochasticTimesSmooth(logits=logits)
+ samples = distribution.rsample()
+ assert samples.requires_grad == True
+
+
+def test_log_prob_shape():
+ logits = torch.tensor([1., 2., 3.])
+ distribution = StochasticTimesSmooth(logits=logits)
+ value = torch.Tensor([1., 1., 1.])
+ log_prob = distribution.log_prob(value)
+ print('log_prob.shape:', log_prob.shape)
+ assert log_prob.shape == torch.Size([3])
+
+
+def test_log_prob_grad():
+ logits = torch.tensor([1., 2., 3.], requires_grad=True)
+ distribution = StochasticTimesSmooth(logits=logits)
+ value = torch.Tensor([1., 1., 1.])
+ log_prob = distribution.log_prob(value)
+ assert log_prob.requires_grad == True
\ No newline at end of file
diff --git a/tests/distributions/test_StraightThroughBernoulli.py b/tests/distributions/test_StraightThroughBernoulli.py
index 9c75c89..8821b1c 100644
--- a/tests/distributions/test_StraightThroughBernoulli.py
+++ b/tests/distributions/test_StraightThroughBernoulli.py
@@ -1,29 +1,36 @@
import torch
import sys, os
-sys.path.append(
- os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src"))
-)
-from relaxit.distributions.StraightThroughBernoulli import StraightThroughBernoulli
+from relaxit.distributions import StraightThroughBernoulli
-# Testing reparameterized sampling from the GaussianRelaxedBernoulli distribution
+# Testing reparameterized sampling and log prob from the StraightThroughBernoulli distribution
def test_sample_shape():
- a = torch.tensor([1, 2, 3])
- distr = StraightThroughBernoulli(a=a)
- samples = distr.rsample()
- assert samples.shape == torch.Size([3])
-
-
-# def test_sample_grad():
-# a = torch.tensor([1., 2., 3.], requires_grad=True)
-# distr = StraightThroughBernoulli(a = a)
-# samples = distr.rsample()
-# assert samples.requires_grad == True
-
-# def test_log_prob():
-# a = torch.tensor([1, 2, 3])
-# distr = StraightThroughBernoulli(a = a)
-# value = torch.Tensor([1.])
-# print(distr.log_prob(value))
+ logits = torch.tensor([[1.0, 2.0, 3.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]])
+ distribution = StraightThroughBernoulli(logits=logits)
+ sample = distribution.rsample()
+ assert sample.shape == logits.shape
+
+
+def test_sample_grad():
+ logits = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
+ distribution = StraightThroughBernoulli(logits=logits)
+ sample = distribution.rsample()
+ assert sample.requires_grad == True
+
+
+def test_log_prob_shape():
+ logits = torch.tensor([1.0, 2.0, 3.0])
+ distribution = StraightThroughBernoulli(logits=logits)
+ value = torch.tensor([1.0, 1.0, 1.0])
+ log_prob = distribution.log_prob(value)
+ assert log_prob.shape == torch.Size([3])
+
+
+def test_log_prob_grad():
+ logits = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
+ distribution = StraightThroughBernoulli(logits=logits)
+ value = torch.tensor([1.0, 1.0, 1.0])
+ log_prob = distribution.log_prob(value)
+ assert log_prob.requires_grad == True
diff --git a/tests/distributions/test_approx.py b/tests/distributions/test_approx.py
index d3412b0..9a312ba 100644
--- a/tests/distributions/test_approx.py
+++ b/tests/distributions/test_approx.py
@@ -1,7 +1,6 @@
import torch
import sys
-sys.path.append("../../src")
from relaxit.distributions.LogisticNormalSoftmax import LogisticNormalSoftmax
from relaxit.distributions.approx import (
lognorm_approximation_fn,
diff --git a/tests/distributions/test_kl.py b/tests/distributions/test_kl.py
index fc74420..10710b6 100644
--- a/tests/distributions/test_kl.py
+++ b/tests/distributions/test_kl.py
@@ -1,8 +1,7 @@
import torch
import sys
-sys.path.append("../../src")
-from relaxit.distributions.InvertibleGaussian import InvertibleGaussian
+from relaxit.distributions import InvertibleGaussian
from relaxit.distributions.kl import kl_divergence
# Testing KL-divergence between two IntertibleGaussian distributions