diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 45dddf0..0cb2bc9 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -37,7 +37,7 @@ jobs: - name: Build Docs run: | - sphinx-build -b html docs public + sphinx-build -b html docs/source public touch public/.nojekyll - name: Deploy 🚀 diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 60c5975..8fe1ceb 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -1,4 +1,4 @@ -name: Publish Python 🐍 distribution đŸ“Ļ to PyPI and TestPyPI +name: Publish Python 🐍 distribution đŸ“Ļ to PyPI on: push @@ -90,28 +90,4 @@ jobs: run: >- gh release upload '${{ github.ref_name }}' dist/** - --repo '${{ github.repository }}' - - publish-to-testpypi: - name: Publish Python 🐍 distribution đŸ“Ļ to TestPyPI - needs: - - build - runs-on: ubuntu-latest - - environment: - name: testpypi - url: https://test.pypi.org/p/relaxit - - permissions: - id-token: write # IMPORTANT: mandatory for trusted publishing - - steps: - - name: Download all the dists - uses: actions/download-artifact@v4 - with: - name: python-package-distributions - path: dist/ - - name: Publish distribution đŸ“Ļ to TestPyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - repository-url: https://test.pypi.org/legacy/ \ No newline at end of file + --repo '${{ github.repository }}' \ No newline at end of file diff --git a/README.md b/README.md index 7b8f1e4..e9c38a1 100644 --- a/README.md +++ b/README.md @@ -52,9 +52,9 @@ 6. [Tests](https://github.com/intsystems/discrete-variables-relaxation/tree/main/tests) ## 💡 Motivation -For lots of mathematical problems we need an ability to sample discrete random variables. -The problem is that due to continuos nature of deep learning optimization, the usage of truely discrete random variables is infeasible. -Thus we use different relaxation method. +For lots of mathematical problems we need an ability to sample discrete random variables. +The problem is that due to continuous nature of deep learning optimization, the usage of truly discrete random variables is infeasible. +Thus we use different relaxation methods. One of them, [Concrete distribution](https://arxiv.org/abs/1611.00712) or [Gumbel-Softmax](https://arxiv.org/abs/1611.01144) (this is one distribution proposed in parallel by two research groups) is implemented in different DL packages. In this project we implement different alternatives to it.
@@ -65,11 +65,11 @@ In this project we implement different alternatives to it. - [x] [Relaxed Bernoulli](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/GaussianRelaxedBernoulli.py), also see [📝 paper](http://proceedings.mlr.press/v119/yamada20a/yamada20a.pdf) - [x] [Correlated relaxed Bernoulli](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/CorrelatedRelaxedBernoulli.py), also see [📝 paper](https://openreview.net/pdf?id=oDFvtxzPOx) - [x] [Gumbel-softmax TOP-K](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/GumbelSoftmaxTopK.py), also see [📝 paper](https://arxiv.org/pdf/1903.06059) -- [x] [Straight-Through Bernoulli](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/StraightThroughBernoulli.py), also see [📝 paper](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235) +- [x] [Straight-Through Bernoulli](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/StraightThroughBernoulli.py), also see [📝 paper](https://arxiv.org/abs/1910.02176) +- [x] [Stochastic Times Smooth](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/StochasticTimesSmooth.py), also see [📝 paper](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235) - [x] [Invertible Gaussian](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/InvertibleGaussian.py) with [KL implemented](https://github.com/intsystems/discrete-variables-relaxation/blob/f398ebbbac703582de392bc33d89b55c6c99ea68/src/relaxit/distributions/kl.py#L7), also see [📝 paper](https://arxiv.org/abs/1912.09588) - [x] [Hard Concrete](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/HardConcrete.py), also see [📝 paper](https://arxiv.org/pdf/1712.01312) -- [x] [REINFORCE](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/CorrelatedRelaxedBernoulli.py), also see [đŸ“ē slides](http://www.cs.toronto.edu/~tingwuwang/REINFORCE.pdf) -- [x] [Logit-Normal](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/LogisticNormalSoftmax.py) and [Laplace-form approximation of Dirichlet](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/approx.py), also see [ℹī¸ wiki](https://en.wikipedia.org/wiki/Logit-normal_distribution) and [đŸ’ģ stackexchange](https://stats.stackexchange.com/questions/535560/approximating-the-logit-normal-by-dirichlet) +- [x] [Logistic-Normal](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/LogisticNormalSoftmax.py) and [Laplace-form approximation of Dirichlet](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/approx.py), also see [ℹī¸ wiki](https://en.wikipedia.org/wiki/Logit-normal_distribution) and [đŸ’ģ stackexchange](https://stats.stackexchange.com/questions/535560/approximating-the-logit-normal-by-dirichlet) ## 🛠ī¸ Install @@ -116,13 +116,15 @@ print('sample.requires_grad:', sample.requires_grad) | ![Laplace Bridge](https://github.com/user-attachments/assets/ac5d5a71-e7d7-4ec3-b9ca-9b72d958eb41) | ![REINFORCE](https://gymnasium.farama.org/_images/acrobot.gif) | ![VAE](https://github.com/user-attachments/assets/937585c4-df84-4ab0-a2b9-ea6a73997793) | | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intsystems/discrete-variables-relaxation/blob/main/demo/laplace-bridge.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intsystems/discrete-variables-relaxation/blob/main/demo/reinforce.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intsystems/discrete-variables-relaxation/blob/main/demo/demo.ipynb) | -For demonstration purposes, we divide our algorithms in three different groups. Each group relates to the particular demo code: +For demonstration purposes, we divide our algorithms in three[^*] different groups. Each group relates to the particular demo code: - [Laplace bridge between Dirichlet and LogisticNormal distributions](https://github.com/intsystems/discrete-variables-relaxation/blob/main/demo/laplace-bridge.ipynb) - [REINFORCE](https://github.com/intsystems/discrete-variables-relaxation/blob/main/demo/reinforce.ipynb) - [Other relaxation methods](https://github.com/intsystems/discrete-variables-relaxation/blob/main/demo/demo.ipynb) We describe our demo experiments [here](https://github.com/intsystems/discrete-variables-relaxation/tree/main/demo). +[^*]: We also implement REINFORCE algorithm as a *score function* estimator alternative for our relaxation methods that are inherently *pathwise derivative* estimators. This one is implemented only for demo experiments and is not included into the source code of package. + ## 📚 Stack Some of the alternatives for GS were implemented in [pyro](https://docs.pyro.ai/en/dev/distributions.html), so we base our library on their codebase. diff --git a/assets/demo_results/vae_stochastic_times_smooth/vae_stochastic_times_smooth_reconstruction.png b/assets/demo_results/vae_stochastic_times_smooth/vae_stochastic_times_smooth_reconstruction.png new file mode 100644 index 0000000..df4f0d3 Binary files /dev/null and b/assets/demo_results/vae_stochastic_times_smooth/vae_stochastic_times_smooth_reconstruction.png differ diff --git a/assets/demo_results/vae_stochastic_times_smooth/vae_stochastic_times_smooth_sample.png b/assets/demo_results/vae_stochastic_times_smooth/vae_stochastic_times_smooth_sample.png new file mode 100644 index 0000000..f3211f0 Binary files /dev/null and b/assets/demo_results/vae_stochastic_times_smooth/vae_stochastic_times_smooth_sample.png differ diff --git a/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_reconstruction.png b/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_reconstruction.png index 69eee89..00ea80e 100644 Binary files a/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_reconstruction.png and b/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_reconstruction.png differ diff --git a/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_sample.png b/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_sample.png index a44c7ae..94b8bf0 100644 Binary files a/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_sample.png and b/assets/demo_results/vae_straight_through_bernoulli/vae_straight_through_bernoulli_sample.png differ diff --git a/demo/README.md b/demo/README.md index 85845e1..94f8bdb 100644 --- a/demo/README.md +++ b/demo/README.md @@ -1,5 +1,5 @@ # Demo experiments code -This repository contains our demo code for various experiments. The main demo code can be found in the notebook `demo/demo.ipynb`. Open the notebook and run the cells to see the demonstration in action. For additional experiments, refer to the section [Additional experiments](#experiments). Before starting any experiments, ensure you follow all the installation steps outlined in the [Installation](#installation) section. +This repository contains our demo code for various experiments. The main demo code can be found in the notebook `demo.ipynb`. Open the notebook and run the cells to see the demonstration in action. For additional experiments, refer to the section [Additional experiments](#experiments). Before starting any experiments, ensure you follow all the installation steps outlined in the [Installation](#installation) section. ## Installation @@ -27,33 +27,36 @@ For additional demo experiments, we have implemented Variational Autoencoders (V 1. **Train and save the models:** To run the additional demo code with VAEs, you need to train all the models and save their results. Execute the following commands: ```bash - # VAE with Gaussian Bernoulli latent space - python vae_gaussian_bernoulli.py - # VAE with Correlated Bernoulli latent space python vae_correlated_bernoulli.py - + + # VAE with Gaussian Bernoulli latent space + python vae_gaussian_bernoulli.py + + # VAE with Gumbel-Softmax top-K latent space + python vae_gumbel_softmax_topk.py + # VAE with Hard Concrete latent space python vae_hard_concrete.py - # VAE with Straight Through Bernoullii latent space - python vae_straight_through_bernoulli.py - # VAE with Invertible Gaussian latent space python vae_invertible_gaussian.py - # VAE with Gumbel Softmax TopK latent space - python vae_gumbel_softmax_topk.py + # VAE with Stochastic Times Smooth latent space + python vae_stochastic_times_smooth.py + + # VAE with Straight Through Bernoullii latent space + python vae_straight_through_bernoulli.py ``` 2. **View the results:** - After completing the training and testing of all the models, you can find the results of sampling and reconstruction methods in the directory `demo/results`. + After completing the training and testing of all the models, you can find the results of sampling and reconstruction methods in the directory `results`. Moreover, we conducted experiments with Laplace Bridge between LogisticNormal and Dirichlet distributions. We use two-side Laplace bridge to approximate: -- Dirichlet using logisticNormal -- LogisticNormal using Dirichlet +- Dirichlet using Logistic-Normal +- Logistic-Normal using Dirichlet -These experiments aim to find the best parameters to make the distributions nearly identical on the simplex. The experiments can be found in the notebook `demo/laplace-bridge.ipynb`. +These experiments aim to find the best parameters to make the distributions nearly identical on the simplex. The experiments can be found in the notebook `laplace-bridge.ipynb`. -Furthermore, the Reinforce algorithm is applied in the [Acrobot environment](https://www.gymlibrary.dev/environments/classic_control/acrobot/). Detailed experiments can be viewed in the notebook `demo/reinforce.ipynb`. A script `demo/reinforce.py` can also be used for training. +Furthermore, the Reinforce algorithm is applied in the [Acrobot environment](https://www.gymlibrary.dev/environments/classic_control/acrobot/). Detailed experiments can be viewed in the notebook `reinforce.ipynb`. A script `reinforce.py` can also be used for training. diff --git a/demo/requirements.txt b/demo/requirements.txt index 546ee61..f468961 100644 --- a/demo/requirements.txt +++ b/demo/requirements.txt @@ -1,11 +1,10 @@ -gym==0.26.2 -numpy==1.24.1 -pyro_ppl==1.9.1 -setuptools==65.5.0 -torch==2.5.1 -torchvision==0.20.1 -matplotlib==3.9.2 -networkx==3.3 -tqdm==4.66.5 -pillow==10.4.0 -relaxit==0.1.2 \ No newline at end of file +gym>=0.26.2 +numpy>=1.24.1 +torchvision>=0.20.1 +matplotlib>=3.9.2 +networkx>=3.3 +tqdm>=4.66.5 +pillow>=10.4.0 +relaxit==1.0.1 +ftfy +regex \ No newline at end of file diff --git a/demo/vae_stochastic_times_smooth.py b/demo/vae_stochastic_times_smooth.py new file mode 100644 index 0000000..deb10e3 --- /dev/null +++ b/demo/vae_stochastic_times_smooth.py @@ -0,0 +1,252 @@ +import os +import argparse +import numpy as np +import torch +import sys +import torch.utils.data +from torch import nn, optim +from torch.nn import functional as F +from torchvision import datasets, transforms +from torchvision.utils import save_image + +from relaxit.distributions import StochasticTimesSmooth + + +def parse_arguments() -> argparse.Namespace: + """ + Parse command line arguments. + + Returns: + argparse.Namespace: Parsed command line arguments. + """ + parser = argparse.ArgumentParser(description="VAE MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=128, + metavar="N", + help="input batch size for training (default: 128)", + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + metavar="N", + help="number of epochs to train (default: 10)", + ) + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="enables CUDA training" + ) + parser.add_argument( + "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" + ) + parser.add_argument( + "--log_interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + return parser.parse_args() + + +args = parse_arguments() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +torch.manual_seed(args.seed) + +device = torch.device("cuda" if args.cuda else "cpu") + +os.makedirs("./results/vae_stochastic_times_smooth", exist_ok=True) + +kwargs = {"num_workers": 1, "pin_memory": True} if args.cuda else {} +train_loader = torch.utils.data.DataLoader( + datasets.MNIST( + "./data", train=True, download=True, transform=transforms.ToTensor() + ), + batch_size=args.batch_size, + shuffle=True, + **kwargs +) +test_loader = torch.utils.data.DataLoader( + datasets.MNIST("./data", train=False, transform=transforms.ToTensor()), + batch_size=args.batch_size, + shuffle=True, + **kwargs +) + +steps = 0 + + +class VAE(nn.Module): + """ + Variational Autoencoder (VAE) with StochasticTimesSmooth distribution. + """ + + def __init__(self) -> None: + super(VAE, self).__init__() + + self.fc1 = nn.Linear(784, 400) + self.fc2 = nn.Linear(400, 20) + self.fc3 = nn.Linear(20, 400) + self.fc4 = nn.Linear(400, 784) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """ + Encode the input by passing through the encoder network + and return the latent code. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Latent code. + """ + h1 = F.relu(self.fc1(x)) + return self.fc2(h1) + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """ + Decode the latent code by passing through the decoder network + and return the reconstructed input. + + Args: + z (torch.Tensor): Latent code. + + Returns: + torch.Tensor: Reconstructed input. + """ + h3 = F.relu(self.fc3(z)) + return torch.sigmoid(self.fc4(h3)) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass through the VAE. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Reconstructed input and latent code. + """ + logits = self.encode(x.view(-1, 784)) + q_z = StochasticTimesSmooth(logits=logits) + probs = q_z.probs + z = q_z.rsample() + return self.decode(z), probs + + +model = VAE().to(device) +optimizer = optim.Adam(model.parameters(), lr=1e-3) + + +def loss_function( + recon_x: torch.Tensor, + x: torch.Tensor, + probs: torch.Tensor, + prior: float = 0.5, + eps: float = 1e-10, +) -> torch.Tensor: + """ + Compute the loss function for the VAE. + + Args: + recon_x (torch.Tensor): Reconstructed input. + x (torch.Tensor): Original input. + probs (torch.Tensor): Probabilities for Bernoulli distribution in latent space. + prior (float): Prior probability. + eps (float): Small value to avoid log(0). + + Returns: + torch.Tensor: Loss value. + """ + BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction="sum") + t1 = probs * ((probs + eps) / prior).log() + t2 = (1 - probs) * ((1 - probs + eps) / (1 - prior)).log() + KLD = torch.sum(t1 + t2, dim=-1).sum() + + return BCE + KLD + + +def train(epoch: int) -> None: + """ + Train the VAE for one epoch. + + Args: + epoch (int): Current epoch number. + """ + global steps + model.train() + train_loss = 0 + for batch_idx, (data, _) in enumerate(train_loader): + data = data.to(device) + optimizer.zero_grad() + recon_batch, probs = model(data) + loss = loss_function(recon_batch, data, probs) + loss.backward() + train_loss += loss.item() + optimizer.step() + + if batch_idx % args.log_interval == 0: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item() / len(data), + ) + ) + + steps += 1 + + print( + "====> Epoch: {} Average loss: {:.4f}".format( + epoch, train_loss / len(train_loader.dataset) + ) + ) + + +def test(epoch: int) -> None: + """ + Test the VAE for one epoch. + + Args: + epoch (int): Current epoch number. + """ + model.eval() + test_loss = 0 + with torch.no_grad(): + for i, (data, _) in enumerate(test_loader): + data = data.to(device) + recon_batch, probs = model(data) + test_loss += loss_function(recon_batch, data, probs).item() + if i == 0: + n = min(data.size(0), 8) + comparison = torch.cat( + [data[:n], recon_batch.view(args.batch_size, 1, 28, 28)[:n]] + ) + save_image( + comparison.cpu(), + "results/vae_stochastic_times_smooth/reconstruction_" + + str(epoch) + + ".png", + nrow=n, + ) + + test_loss /= len(test_loader.dataset) + print("====> Test set loss: {:.4f}".format(test_loss)) + + +if __name__ == "__main__": + for epoch in range(1, args.epochs + 1): + train(epoch) + test(epoch) + with torch.no_grad(): + sample = np.random.binomial(1, 0.5, size=(64, 20)) + sample = torch.from_numpy(np.float32(sample)).to(device) + sample = model.decode(sample).cpu() + save_image( + sample.view(64, 1, 28, 28), + "results/vae_stochastic_times_smooth/sample_" + str(epoch) + ".png", + ) diff --git a/demo/vae_straight_through_bernoulli.py b/demo/vae_straight_through_bernoulli.py index 6d5cf69..335fe11 100644 --- a/demo/vae_straight_through_bernoulli.py +++ b/demo/vae_straight_through_bernoulli.py @@ -129,14 +129,11 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: Returns: tuple[torch.Tensor, torch.Tensor]: Reconstructed input and latent code. """ - a = self.encode(x.view(-1, 784)) - a = a.float() - q_z = StraightThroughBernoulli(a) - + logits = self.encode(x.view(-1, 784)) + q_z = StraightThroughBernoulli(logits=logits) + probs = q_z.probs z = q_z.rsample() - z = z.float() - - return self.decode(z), a + return self.decode(z), probs model = VAE().to(device) @@ -146,7 +143,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def loss_function( recon_x: torch.Tensor, x: torch.Tensor, - a: torch.Tensor, + probs: torch.Tensor, prior: float = 0.5, eps: float = 1e-10, ) -> torch.Tensor: @@ -156,7 +153,7 @@ def loss_function( Args: recon_x (torch.Tensor): Reconstructed input. x (torch.Tensor): Original input. - a (torch.Tensor): Latent code. + probs (torch.Tensor): Probabilities for Bernoulli distribution in latent space. prior (float): Prior probability. eps (float): Small value to avoid log(0). @@ -164,9 +161,8 @@ def loss_function( torch.Tensor: Loss value. """ BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction="sum") - q_z = torch.sigmoid(a) - t1 = q_z * ((q_z + eps) / prior).log() - t2 = (1 - q_z) * ((1 - q_z + eps) / (1 - prior)).log() + t1 = probs * ((probs + eps) / prior).log() + t2 = (1 - probs) * ((1 - probs + eps) / (1 - prior)).log() KLD = torch.sum(t1 + t2, dim=-1).sum() return BCE + KLD @@ -185,8 +181,8 @@ def train(epoch: int) -> None: for batch_idx, (data, _) in enumerate(train_loader): data = data.to(device) optimizer.zero_grad() - recon_batch, q_z = model(data) - loss = loss_function(recon_batch, data, q_z) + recon_batch, probs = model(data) + loss = loss_function(recon_batch, data, probs) loss.backward() train_loss += loss.item() optimizer.step() @@ -223,8 +219,8 @@ def test(epoch: int) -> None: with torch.no_grad(): for i, (data, _) in enumerate(test_loader): data = data.to(device) - recon_batch, q_z = model(data) - test_loss += loss_function(recon_batch, data, q_z).item() + recon_batch, probs = model(data) + test_loss += loss_function(recon_batch, data, probs).item() if i == 0: n = min(data.size(0), 8) comparison = torch.cat( diff --git a/docs/conf.py b/docs/conf.py deleted file mode 100644 index a07925a..0000000 --- a/docs/conf.py +++ /dev/null @@ -1,73 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import sys - -from relaxit import __version__ - - -# -- Project information ----------------------------------------------------- - -project = "Just Relax It" -copyright = "2024, Daniil Dorin, Igor Ignashin, Nikita Kiselev, Andrey Veprikov" -author = "Daniil Dorin, Igor Ignashin, Nikita Kiselev, Andrey Veprikov" - -version = __version__ -master_doc = "index" - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - 'sphinx.ext.napoleon', - 'sphinx.ext.duration', - 'sphinx.ext.doctest', - 'sphinx.ext.autodoc', - 'myst_parser' -] -highlight_language = 'python' - -autodoc_mock_imports = ["numpy", "scipy", "sklearn"] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ["_templates"] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = [] - -html_extra_path = [] - -html_context = { - "display_github": True, # Integrate GitHub - "github_user": "Intelligent-Systems-Phystech", # Username - "github_repo": "discrete-variables-relaxation", # Repo name - "github_version": "main", # Version - "conf_py_path": "./doc/", # Path in the checkout to the docs root -} - - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = "sphinx_rtd_theme" - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["_static"] diff --git a/docs/make_docs.sh b/docs/make_docs.sh index a88d3f6..4fd437a 100644 --- a/docs/make_docs.sh +++ b/docs/make_docs.sh @@ -1,2 +1,7 @@ -rm -rf docs/source -sphinx-apidoc -o docs/source/ src/relaxit \ No newline at end of file +#!/bin/bash + +# Generate API documentation +# sphinx-apidoc -o source ../src/relaxit + +# Build the documentation +sphinx-build -b html source build/html \ No newline at end of file diff --git a/docs/source/_static/css/relaxit.css b/docs/source/_static/css/relaxit.css new file mode 100644 index 0000000..18802fe --- /dev/null +++ b/docs/source/_static/css/relaxit.css @@ -0,0 +1,29 @@ +@import url("theme.css"); + +.wy-side-nav-search { + background-color: #565656; +} + +.wy-side-nav-search a { + margin: 0 +} + +.wy-side-nav-search > div.version { + color: #f26822; +} + +.wy-nav-top { + background: #404040; +} + +.wy-menu-vertical li.on a, .wy-menu-vertical li.current>a { + background: #ccc; +} + +.wy-side-nav-search input[type=text] { + border-color: #313131; +} + +.wy-side-nav-search>a img.logo, .wy-side-nav-search .wy-dropdown>a img.logo { + max-width: 40%; +} \ No newline at end of file diff --git a/docs/source/_static/img/logo-small.png b/docs/source/_static/img/logo-small.png new file mode 100644 index 0000000..10e24a5 Binary files /dev/null and b/docs/source/_static/img/logo-small.png differ diff --git a/docs/source/_static/img/logo.png b/docs/source/_static/img/logo.png new file mode 100644 index 0000000..ca4401e Binary files /dev/null and b/docs/source/_static/img/logo.png differ diff --git a/docs/source/_static/img/overview.png b/docs/source/_static/img/overview.png new file mode 100644 index 0000000..a82bab5 Binary files /dev/null and b/docs/source/_static/img/overview.png differ diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..2a67f96 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,71 @@ +# Configuration file for the Sphinx documentation builder. + +import os +import sys + +from relaxit import __version__ + +# -- Path setup -------------------------------------------------------------- + +sys.path.insert(0, os.path.abspath('../../src')) + +# -- Project information ----------------------------------------------------- + +project = "Just Relax It" +copyright = "2024, Daniil Dorin, Igor Ignashin, Nikita Kiselev, Andrey Veprikov" +author = "Daniil Dorin, Igor Ignashin, Nikita Kiselev, Andrey Veprikov" + +version = __version__ +master_doc = "index" + +# -- General configuration --------------------------------------------------- + +extensions = [ + 'sphinx.ext.napoleon', + 'sphinx.ext.duration', + 'sphinx.ext.doctest', + 'sphinx.ext.autodoc', + 'sphinx.ext.viewcode', + 'sphinx.ext.githubpages', + 'myst_parser' +] + +highlight_language = 'python' + +autodoc_mock_imports = ["numpy", "scipy", "sklearn"] + +templates_path = ["_templates"] +exclude_patterns = [] +html_extra_path = [] + +html_context = { + "display_github": True, + "github_user": "Intelligent-Systems-Phystech", + "github_repo": "discrete-variables-relaxation", + "github_version": "main", + "conf_py_path": "./doc/", +} + +# -- Options for HTML output ------------------------------------------------- + +html_logo = "_static/img/logo-small.png" + +html_theme = "sphinx_rtd_theme" + +html_theme_options = { + "navigation_depth": 3, + "logo_only": True, +} + +html_static_path = ["_static"] +html_css_files = [ + 'css/relaxit.css', +] + +# -- Options for intersphinx extension --------------------------------------- + +intersphinx_mapping = { + 'python': ('https://docs.python.org/3/', None), + 'numpy': ('https://numpy.org/doc/stable/', None), + 'torch': ('https://pytorch.org/docs/stable/', None), +} diff --git a/docs/index.rst b/docs/source/index.rst similarity index 60% rename from docs/index.rst rename to docs/source/index.rst index 114f005..e49515f 100644 --- a/docs/index.rst +++ b/docs/source/index.rst @@ -3,29 +3,37 @@ Just Relax It ============= -.. image:: ../assets/logo.png +.. image:: _static/img/logo.png :width: 200 :align: center +.. raw:: html + +
+ "Just Relax It" is a cutting-edge Python library designed to streamline the optimization of discrete probability distributions in neural networks, offering a suite of advanced relaxation techniques compatible with PyTorch. Motivation ---------- -For lots of mathematical problems we need an ability to sample discrete random variables. -The problem is that due to continuos nature of deep learning optimization, the usage of truely discrete random variables is infeasible. -Thus we use different relaxation method. -One of them, `Concrete distribution `_ or `Gumbel-Softmax `_ (this is one distribution proposed in parallel by two research groups) is implemented in different DL packages. -In this project we implement different alternatives to it. +For lots of mathematical problems we need an ability to sample discrete random variables. +The problem is that due to continuous nature of deep learning optimization, the usage of truly discrete random variables is infeasible. +Thus we use different relaxation methods. +One of them, `Concrete distribution `_ or `Gumbel-Softmax `_ (this is one distribution proposed in parallel by two research groups) is implemented in different DL packages. +In this project we implement different alternatives to it. -.. image:: ../assets/overview.png +.. image:: _static/img/overview.png :width: 600 :align: center +.. raw:: html + +
+ .. toctree:: :maxdepth: 1 :caption: Guidelines - + install quickstart @@ -33,6 +41,4 @@ In this project we implement different alternatives to it. :maxdepth: 1 :caption: Code - source/modules - source/relaxit.distributions - + relaxit.distributions diff --git a/docs/install.md b/docs/source/install.md similarity index 100% rename from docs/install.md rename to docs/source/install.md diff --git a/docs/source/modules.rst b/docs/source/modules.rst index 41f9d2c..4fec1e2 100644 --- a/docs/source/modules.rst +++ b/docs/source/modules.rst @@ -2,6 +2,6 @@ relaxit ======= .. toctree:: - :maxdepth: 4 + :maxdepth: 1 relaxit diff --git a/docs/quickstart.md b/docs/source/quickstart.md similarity index 99% rename from docs/quickstart.md rename to docs/source/quickstart.md index 74caf23..ec67fea 100644 --- a/docs/quickstart.md +++ b/docs/source/quickstart.md @@ -16,4 +16,4 @@ distribution = InvertibleGaussian(loc, scale, temperature) sample = distribution.rsample() print('sample.shape:', sample.shape) print('sample.requires_grad:', sample.requires_grad) -``` \ No newline at end of file +``` diff --git a/docs/source/relaxit.distributions.rst b/docs/source/relaxit.distributions.rst index 4b41e2e..33ce68d 100644 --- a/docs/source/relaxit.distributions.rst +++ b/docs/source/relaxit.distributions.rst @@ -1,85 +1,63 @@ -relaxit.distributions package -============================= +relaxit.distributions +===================== -Submodules ----------- - -relaxit.distributions.CorrelatedRelaxedBernoulli module -------------------------------------------------------- - -.. automodule:: relaxit.distributions.CorrelatedRelaxedBernoulli +.. automodule:: relaxit.distributions :members: :undoc-members: :show-inheritance: -relaxit.distributions.GaussianRelaxedBernoulli module ------------------------------------------------------ +Distribution Classes +-------------------- -.. automodule:: relaxit.distributions.GaussianRelaxedBernoulli +.. autoclass:: relaxit.distributions.CorrelatedRelaxedBernoulli.CorrelatedRelaxedBernoulli :members: :undoc-members: :show-inheritance: -relaxit.distributions.GumbelSoftmaxTopK module ----------------------------------------------- - -.. automodule:: relaxit.distributions.GumbelSoftmaxTopK +.. autoclass:: relaxit.distributions.GaussianRelaxedBernoulli.GaussianRelaxedBernoulli :members: :undoc-members: :show-inheritance: -relaxit.distributions.HardConcrete module ------------------------------------------ - -.. automodule:: relaxit.distributions.HardConcrete +.. autoclass:: relaxit.distributions.GumbelSoftmaxTopK.GumbelSoftmaxTopK :members: :undoc-members: :show-inheritance: -relaxit.distributions.InvertibleGaussian module ------------------------------------------------ - -.. automodule:: relaxit.distributions.InvertibleGaussian +.. autoclass:: relaxit.distributions.HardConcrete.HardConcrete :members: :undoc-members: :show-inheritance: -relaxit.distributions.LogisticNormalSoftmax module --------------------------------------------------- - -.. automodule:: relaxit.distributions.LogisticNormalSoftmax +.. autoclass:: relaxit.distributions.InvertibleGaussian.InvertibleGaussian :members: :undoc-members: :show-inheritance: -relaxit.distributions.StraightThroughBernoulli module ------------------------------------------------------ - -.. automodule:: relaxit.distributions.StraightThroughBernoulli +.. autoclass:: relaxit.distributions.LogisticNormalSoftmax.LogisticNormalSoftmax :members: :undoc-members: :show-inheritance: -relaxit.distributions.approx module ------------------------------------ - -.. automodule:: relaxit.distributions.approx +.. autoclass:: relaxit.distributions.StochasticTimesSmooth.StochasticTimesSmooth :members: :undoc-members: :show-inheritance: -relaxit.distributions.kl module -------------------------------- - -.. automodule:: relaxit.distributions.kl +.. autoclass:: relaxit.distributions.StraightThroughBernoulli.StraightThroughBernoulli :members: :undoc-members: :show-inheritance: -Module contents +Utility Modules --------------- -.. automodule:: relaxit.distributions +.. automodule:: relaxit.distributions.approx + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: relaxit.distributions.kl :members: :undoc-members: :show-inheritance: diff --git a/docs/source/relaxit.rst b/docs/source/relaxit.rst index 4ce6ca4..06b6fdb 100644 --- a/docs/source/relaxit.rst +++ b/docs/source/relaxit.rst @@ -1,11 +1,11 @@ -relaxit package -=============== +relaxit +======= Subpackages ----------- .. toctree:: - :maxdepth: 4 + :maxdepth: 1 relaxit.distributions diff --git a/setup.py b/setup.py index 07c2af9..b159d33 100644 --- a/setup.py +++ b/setup.py @@ -26,5 +26,5 @@ url="https://github.com/intsystems/discrete-variables-relaxation", package_dir= {"": "src"}, packages=find_packages(where="src"), - install_requires=["pyro-ppl==1.9.1"], + install_requires=["pyro-ppl>=1.9.1"], ) \ No newline at end of file diff --git a/src/relaxit/_version.py b/src/relaxit/_version.py index 0058b93..ff1068c 100644 --- a/src/relaxit/_version.py +++ b/src/relaxit/_version.py @@ -1 +1 @@ -__version__ = "1.0.1" \ No newline at end of file +__version__ = "1.1.0" \ No newline at end of file diff --git a/src/relaxit/distributions/GumbelSoftmaxTopK.py b/src/relaxit/distributions/GumbelSoftmaxTopK.py index b0769ce..c2ac519 100644 --- a/src/relaxit/distributions/GumbelSoftmaxTopK.py +++ b/src/relaxit/distributions/GumbelSoftmaxTopK.py @@ -1,13 +1,13 @@ import torch import torch.nn.functional as F -from torch.distributions import constraints from pyro.distributions.torch_distribution import TorchDistribution from torch.distributions import constraints +from torch.distributions.utils import probs_to_logits, logits_to_probs class GumbelSoftmaxTopK(TorchDistribution): r""" - Implimentation of the Gaussian-soft max topK trick from https://arxiv.org/pdf/1903.06059 + Implementation of the Gaussian-Softmax top-K trick from https://arxiv.org/pdf/1903.06059. :param a: logits. :type a: torch.Tensor @@ -50,10 +50,14 @@ def __init__( :param validate_args: Whether to validate arguments. :type validate_args: bool """ - if probs is None and logits is None: - raise ValueError("Pass `probs` or `logits`!") - elif probs is None: - self.probs = logits / logits.sum(dim=-1, keepdim=True) + if (probs is None) == (logits is None): + raise ValueError("Pass `probs` or `logits`, but not both of them!") + elif probs is not None: + self.probs = probs + self.logits = probs_to_logits(probs) + else: + self.logits = logits + self.probs = logits_to_probs(logits) self.K = K.int() # Ensure K is a int tensor self.tau = tau self.hard = hard diff --git a/src/relaxit/distributions/InvertibleGaussian.py b/src/relaxit/distributions/InvertibleGaussian.py index 8b9ff04..a6cf742 100644 --- a/src/relaxit/distributions/InvertibleGaussian.py +++ b/src/relaxit/distributions/InvertibleGaussian.py @@ -6,7 +6,7 @@ class InvertibleGaussian(TorchDistribution): """ - Invertible Gaussian distribution class inheriting from Pyro's TorchDistribution. + Invertible Gaussian distribution, as it was presented in https://arxiv.org/abs/1912.09588. Parameters: - loc (Tensor): The mean (mu) of the normal distribution. diff --git a/src/relaxit/distributions/LogisticNormalSoftmax.py b/src/relaxit/distributions/LogisticNormalSoftmax.py index 743a8a6..fdf3b6c 100644 --- a/src/relaxit/distributions/LogisticNormalSoftmax.py +++ b/src/relaxit/distributions/LogisticNormalSoftmax.py @@ -3,8 +3,6 @@ from pyro.distributions.transforms import SoftmaxTransform -# We implement LogisticNormal distribution with SoftmaxTransform instead of -# StickBreakingTransform, which is originally applied in the PyTorch and Pyro class LogisticNormalSoftmax(TransformedDistribution): r""" Creates a logistic-normal distribution parameterized by :attr:`loc` and :attr:`scale` diff --git a/src/relaxit/distributions/STEstimator.py b/src/relaxit/distributions/STEstimator.py deleted file mode 100644 index 8d97a78..0000000 --- a/src/relaxit/distributions/STEstimator.py +++ /dev/null @@ -1,109 +0,0 @@ -import torch -import torch.nn.functional as F -from torch.distributions import constraints -from pyro.distributions.torch_distribution import TorchDistribution -from torch.distributions import constraints -from torch.distributions.utils import probs_to_logits - - -class StraightThroughEstimator(TorchDistribution): - r""" - Implimentation of the Straight Through Estimator from https://arxiv.org/abs/1910.02176 - - :param a: logits. - :type a: torch.Tensor - :param validate_args: Whether to validate arguments. - :type validate_args: bool - """ - - arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} - has_rsample = True - - def __init__( - self, - probs: torch.Tensor = None, - logits: torch.Tensor = None, - validate_args: bool = None, - ): - r"""Initializes the ST Estimator. - - :param probs: TODO - :param logits: the log-odds of sampling `1`. - :type logits: torch.Tensor - :param validate_args: Whether to validate arguments. - :type validate_args: bool - """ - if probs is None and logits is None: - raise ValueError("Pass `probs` or `logits`!") - elif probs is None: - self.probs = logits / logits.sum(dim=-1, keepdim=True) - self.logits=logits - super().__init__(validate_args=validate_args) - - @property - def batch_shape(self) -> torch.Size: - """ - Returns the batch shape of the distribution. - - The batch shape represents the shape of independent distributions. - For example, if `loc` is vector of length 3, - the batch shape will be `[3]`, indicating 3 independent distributions. - """ - return self.probs.shape - - @property - def event_shape(self) -> torch.Size: - """ - Returns the event shape of the distribution. - - The event shape represents the shape of each individual event. - """ - return torch.Size() - - def rsample(self) -> torch.Tensor: - """ - Generates a sample from the distribution using the Gaussian-soft max topK trick. - - :return: A sample from the distribution. - :rtype: torch.Tensor - """ - - binary_sample = torch.bernoulli(self.probs).detach() - - return binary_sample + (self.probs - self.probs.detach()) - - def sample(self) -> torch.Tensor: - """ - Generates a sample from the distribution with no grad. - - :return: A sample from the distribution. - :rtype: torch.Tensor - """ - with torch.no_grad(): - return self.rsample() - - def log_prob(self, value: torch.Tensor) -> torch.Tensor: - """ - Computes the log probability of the given value. - :param value: The value for which to compute the log probability. - :type value: torch.Tensor - :return: The log probability of the given value. - :rtype: torch.Tensor - """ - if self._validate_args: - self._validate_sample(value) - - return -F.binary_cross_entropy(self.probs, value, reduction="none") - - def _validate_sample(self, value: torch.Tensor): - """ - Validates the given sample value. - - Args: - - value (Tensor): The sample value to validate. - """ - if self._validate_args: - if ((value != 1.0) & (value != 0.0)).any(): - ValueError( - f"All coordinates in `value` must be 0 or 1 and you have {value}" - ) \ No newline at end of file diff --git a/src/relaxit/distributions/StochasticTimesSmooth.py b/src/relaxit/distributions/StochasticTimesSmooth.py new file mode 100644 index 0000000..6731c8d --- /dev/null +++ b/src/relaxit/distributions/StochasticTimesSmooth.py @@ -0,0 +1,27 @@ +import torch +from pyro.distributions import Bernoulli + + +class StochasticTimesSmooth(Bernoulli): + r""" + Implementation of the Stochastic Times Smooth from https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235 + + Creates a Bernoulli distribution parameterized by :attr:`probs` + or :attr:`logits` (but not both). + + Samples are binary (0 or 1). They take the value `1` with probability `p` + and `0` with probability `1 - p`. + + However, supports gradient flow through parameters due to the + stochastic times smooth gradient estimator. + """ + has_rsample = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def rsample(self, sample_shape: torch.Size = torch.Size()): + shape = self._extended_shape(sample_shape) + sqrt_probs = self.probs.expand(shape).sqrt() + sample = sqrt_probs * torch.bernoulli(sqrt_probs) + return sample \ No newline at end of file diff --git a/src/relaxit/distributions/StraightThroughBernoulli.py b/src/relaxit/distributions/StraightThroughBernoulli.py index bad2db3..a222158 100644 --- a/src/relaxit/distributions/StraightThroughBernoulli.py +++ b/src/relaxit/distributions/StraightThroughBernoulli.py @@ -1,108 +1,27 @@ import torch -from pyro.distributions.torch_distribution import TorchDistribution -from torch.distributions import constraints +from pyro.distributions import Bernoulli -class StraightThroughBernoulli(TorchDistribution): - """ +class StraightThroughBernoulli(Bernoulli): + r""" + Implementation of the Straight Through Bernoulli from https://arxiv.org/abs/1910.02176 + + Creates a Bernoulli distribution parameterized by :attr:`probs` + or :attr:`logits` (but not both). - Parameters: - - a (Tensor): logits + Samples are binary (0 or 1). They take the value `1` with probability `p` + and `0` with probability `1 - p`. + + However, supports gradient flow through parameters due to the + straight through gradient estimator. """ - - arg_constraints = {"a": constraints.real} - support = constraints.real has_rsample = True - - def __init__(self, a: torch.Tensor, validate_args: bool = None): - """ - - Args: - - a (Tensor): logits - - validate_args (bool): Whether to validate arguments. - """ - - self.a = a.float() # Ensure a is a float tensor - self.uniform = torch.distributions.Uniform( - torch.tensor([0.0], device=self.a.device), - torch.tensor([1.0], device=self.a.device), - ) - super().__init__(validate_args=validate_args) - - @property - def batch_shape(self) -> torch.Size: - """ - Returns the batch shape of the distribution. - - The batch shape represents the shape of independent distributions. - For example, if `loc` is vector of length 3, - the batch shape will be `[3]`, indicating 3 independent Bernoulli distributions. - """ - return self.a.shape - - @property - def event_shape(self) -> torch.Size: - """ - Returns the event shape of the distribution. - - The event shape represents the shape of each individual event. - """ - return torch.Size() - - def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: - """ - Generates a sample from the distribution using the reparameterization trick. - - Args: - - sample_shape (torch.Size): The shape of the sample. - - Returns: - - torch.Tensor: A sample from the distribution. - """ - p = torch.nn.functional.sigmoid(self.a) - b = torch.distributions.Bernoulli(torch.sqrt(p)).sample( - sample_shape=sample_shape - ) - z = torch.sqrt(p) * b - return z - - def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: - """ - Generates a sample from the distribution. - - Args: - - sample_shape (torch.Size): The shape of the sample. - - Returns: - - torch.Tensor: A sample from the distribution. - """ - with torch.no_grad(): - return self.rsample(sample_shape) - - def log_prob(self, value: torch.Tensor) -> torch.Tensor: - """ - Computes the log probability of the given value. - - Args: - - value (Tensor): The value for which to compute the log probability. - - Returns: - - torch.Tensor: The log probability of the given value. - """ - if self._validate_args: - self._validate_sample(value) - p = torch.nn.functional.sigmoid(self.a) - - log_prob = torch.where(value == 0, torch.log(1 - p), torch.log(p)) - return log_prob - - def _validate_sample(self, value: torch.Tensor): - """ - Validates the given sample value. - - Args: - - value (Tensor): The sample value to validate. - """ - if self._validate_args: - if (value < 0).any(): - raise ValueError("Sample value must be non negative") + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def rsample(self, sample_shape: torch.Size = torch.Size()): + shape = self._extended_shape(sample_shape) + probs = self.probs.expand(shape) + sample = torch.bernoulli(probs).detach() + return sample + probs - probs.detach() \ No newline at end of file diff --git a/src/relaxit/distributions/__init__.py b/src/relaxit/distributions/__init__.py index b6e18f3..68e1512 100644 --- a/src/relaxit/distributions/__init__.py +++ b/src/relaxit/distributions/__init__.py @@ -1,17 +1,20 @@ -from .GaussianRelaxedBernoulli import GaussianRelaxedBernoulli from .CorrelatedRelaxedBernoulli import CorrelatedRelaxedBernoulli -from .StraightThroughBernoulli import StraightThroughBernoulli +from .GaussianRelaxedBernoulli import GaussianRelaxedBernoulli +from .GumbelSoftmaxTopK import GumbelSoftmaxTopK from .HardConcrete import HardConcrete from .InvertibleGaussian import InvertibleGaussian from .LogisticNormalSoftmax import LogisticNormalSoftmax -from .GumbelSoftmaxTopK import GumbelSoftmaxTopK +from .StochasticTimesSmooth import StochasticTimesSmooth +from .StraightThroughBernoulli import StraightThroughBernoulli + __all__ = [ - "GaussianRelaxedBernoulli", "CorrelatedRelaxedBernoulli", - "StraightThroughBernoulli", + "GaussianRelaxedBernoulli", + "GumbelSoftmaxTopK", "HardConcrete", "InvertibleGaussian", "LogisticNormalSoftmax", - "GumbelSoftmaxTopK", + "StochasticTimesSmooth", + "StraightThroughBernoulli" ] diff --git a/src/relaxit/distributions/kl.py b/src/relaxit/distributions/kl.py index 5d37dc6..84da13a 100644 --- a/src/relaxit/distributions/kl.py +++ b/src/relaxit/distributions/kl.py @@ -4,7 +4,23 @@ @register_kl(InvertibleGaussian, InvertibleGaussian) -def _kl_igr_igr(p, q): +def _kl_igr_igr(p: InvertibleGaussian, q: InvertibleGaussian): + r""" + Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions. + + Based on the paper https://arxiv.org/abs/1912.09588. + + .. math:: + + KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx + + Args: + p (InvertibleGaussian): A :class:`~relaxit.distributions.InvertibleGaussian` object. + q (InvertibleGaussian): A :class:`~relaxit.distributions.InvertibleGaussian` object. + + Returns: + Tensor: A batch of KL divergences of shape `batch_shape`. + """ p_normal = Normal(p.loc, p.scale) q_normal = Normal(q.loc, q.scale) return kl_divergence(p_normal, q_normal) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/distributions/test_CorrelatedRelaxedBernoulli.py b/tests/distributions/test_CorrelatedRelaxedBernoulli.py index 33eff73..3ea3771 100644 --- a/tests/distributions/test_CorrelatedRelaxedBernoulli.py +++ b/tests/distributions/test_CorrelatedRelaxedBernoulli.py @@ -1,10 +1,7 @@ import torch import sys, os -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src")) -) -from relaxit.distributions.CorrelatedRelaxedBernoulli import CorrelatedRelaxedBernoulli +from relaxit.distributions import CorrelatedRelaxedBernoulli # Testing reparameterized sampling from the GaussianRelaxedBernoulli distribution @@ -13,9 +10,8 @@ def test_sample_shape(): pi = torch.tensor([0.1, 0.2, 0.3]) R = torch.tensor([[1.0]]) tau = torch.tensor([2.0]) - - distr = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau) - samples = distr.rsample() + distribution = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau) + samples = distribution.rsample() assert samples.shape == torch.Size([3]) @@ -23,19 +19,26 @@ def test_sample_grad(): pi = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) R = torch.tensor([[1.0]]) tau = torch.tensor([2.0]) - - distr = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau) - samples = distr.rsample() + distribution = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau) + samples = distribution.rsample() assert samples.requires_grad == True -def test_log_prob(): - pi = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) +def test_log_prob_shape(): + pi = torch.tensor([0.1, 0.2, 0.3]) R = torch.tensor([[1.0]]) tau = torch.tensor([2.0]) - - distr = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau) + distribution = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau) value = torch.tensor([1.0]) - log_prob = distr.log_prob(value) + log_prob = distribution.log_prob(value) assert log_prob.shape == torch.Size([3]) - assert log_prob.requires_grad == True + + +def test_log_prob_grad(): + pi = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) + R = torch.tensor([[1.0]]) + tau = torch.tensor([2.0]) + distribution = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau) + value = torch.tensor([1.0]) + log_prob = distribution.log_prob(value) + assert log_prob.requires_grad == True \ No newline at end of file diff --git a/tests/distributions/test_GaussianRelaxedBernoulli.py b/tests/distributions/test_GaussianRelaxedBernoulli.py index 07a1f15..f348066 100644 --- a/tests/distributions/test_GaussianRelaxedBernoulli.py +++ b/tests/distributions/test_GaussianRelaxedBernoulli.py @@ -1,10 +1,7 @@ import torch import sys, os -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src")) -) -from relaxit.distributions.GaussianRelaxedBernoulli import GaussianRelaxedBernoulli +from relaxit.distributions import GaussianRelaxedBernoulli # Testing reparameterized sampling from the GaussianRelaxedBernoulli distribution @@ -12,32 +9,32 @@ def test_sample_shape(): loc = torch.tensor([0.0]) scale = torch.tensor([1.0]) - - distr = GaussianRelaxedBernoulli(loc=loc, scale=scale) - samples = distr.rsample(sample_shape=torch.Size([3])) + distribution = GaussianRelaxedBernoulli(loc=loc, scale=scale) + samples = distribution.rsample(sample_shape=torch.Size([3])) assert samples.shape == torch.Size([3, 1]) def test_sample_grad(): loc = torch.tensor([0.0], requires_grad=True) scale = torch.tensor([1.0], requires_grad=True) - distr = GaussianRelaxedBernoulli(loc=loc, scale=scale) - samples = distr.rsample() + distribution = GaussianRelaxedBernoulli(loc=loc, scale=scale) + samples = distribution.rsample() assert samples.requires_grad == True -def test_log_prob(): - loc = torch.tensor([0.0], requires_grad=True) - scale = torch.tensor([1.0], requires_grad=True) - distr = GaussianRelaxedBernoulli(loc=loc, scale=scale) - +def test_log_prob_shape(): + loc = torch.tensor([0.0]) + scale = torch.tensor([1.0]) + distribution = GaussianRelaxedBernoulli(loc=loc, scale=scale) value = torch.tensor([1.0]) - log_prob = distr.log_prob(value) + log_prob = distribution.log_prob(value) assert log_prob.shape == torch.Size([1]) - assert log_prob.requires_grad == True -if __name__ == "__main__": - test_sample_shape() - test_sample_grad() - test_log_prob() +def test_log_prob_grad(): + loc = torch.tensor([0.0], requires_grad=True) + scale = torch.tensor([1.0], requires_grad=True) + distribution = GaussianRelaxedBernoulli(loc=loc, scale=scale) + value = torch.tensor([1.0]) + log_prob = distribution.log_prob(value) + assert log_prob.requires_grad == True \ No newline at end of file diff --git a/tests/distributions/test_GumbelSoftmaxTopK.py b/tests/distributions/test_GumbelSoftmaxTopK.py index b5ac731..9bf9362 100644 --- a/tests/distributions/test_GumbelSoftmaxTopK.py +++ b/tests/distributions/test_GumbelSoftmaxTopK.py @@ -1,44 +1,46 @@ import torch import sys, os -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src")) -) -from relaxit.distributions.GumbelSoftmaxTopK import GumbelSoftmaxTopK +from relaxit.distributions import GumbelSoftmaxTopK # Testing reparameterized sampling and log prob from the GumbelSoftmaxTopK distribution def test_sample_shape(): - a = torch.tensor([[1.0, 2.0, 3.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]) + logits = torch.tensor([[1.0, 2.0, 3.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]) K = torch.tensor(1) tau = torch.tensor(0.1) - distribution = GumbelSoftmaxTopK(a, K=K, tau=tau) + distribution = GumbelSoftmaxTopK(logits=logits, K=K, tau=tau) sample = distribution.rsample() - assert sample.shape == a.shape + assert sample.shape == logits.shape def test_sample_grad(): - a = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True) + logits = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True) K = torch.tensor(2) tau = torch.tensor(0.1) - distribution = GumbelSoftmaxTopK(a, K=K, tau=tau) + distribution = GumbelSoftmaxTopK(logits=logits, K=K, tau=tau) sample = distribution.rsample() assert sample.requires_grad == True -def test_log_prob(): - a = torch.tensor([1.0, 2.0, 3.0]) +def test_log_prob_shape(): + logits = torch.tensor([1.0, 2.0, 3.0]) K = torch.tensor(3) tau = torch.tensor(0.1) - distribution = GumbelSoftmaxTopK(a, K=K, tau=tau) + distribution = GumbelSoftmaxTopK(logits=logits, K=K, tau=tau) sample = distribution.rsample() value = torch.tensor([1.0, 1.0, 1.0]) log_prob = distribution.log_prob(value) - assert log_prob - torch.tensor(0) < 1e-6 + assert log_prob.shape == torch.Size([3]) + - -if __name__ == "__main__": - test_sample_shape() - test_sample_grad() - test_log_prob() +def test_log_prob_grad(): + logits = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + K = torch.tensor(3) + tau = torch.tensor(0.1) + distribution = GumbelSoftmaxTopK(logits=logits, K=K, tau=tau) + sample = distribution.rsample() + value = torch.tensor([1.0, 1.0, 1.0]) + log_prob = distribution.log_prob(value) + assert log_prob.requires_grad == True \ No newline at end of file diff --git a/tests/distributions/test_HardConcrete.py b/tests/distributions/test_HardConcrete.py index c907e2b..ebc0888 100644 --- a/tests/distributions/test_HardConcrete.py +++ b/tests/distributions/test_HardConcrete.py @@ -1,10 +1,7 @@ import torch import sys, os -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src")) -) -from relaxit.distributions.HardConcrete import HardConcrete +from relaxit.distributions import HardConcrete # Testing reparameterized sampling from the HardConcrete distribution @@ -14,9 +11,8 @@ def test_sample_shape(): beta = torch.tensor([2.0]) gamma = torch.tensor([-3.0]) xi = torch.tensor([4.0]) - distr = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi) - samples = distr.rsample(sample_shape=torch.Size([3])) - + distribution = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi) + samples = distribution.rsample(sample_shape=torch.Size([3])) assert samples.shape == torch.Size([3, 1]) @@ -25,20 +21,28 @@ def test_sample_grad(): beta = torch.tensor([2.0], requires_grad=True) gamma = torch.tensor([-3.0], requires_grad=True) xi = torch.tensor([4.0], requires_grad=True) - distr = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi) - samples = distr.rsample(sample_shape=torch.Size([3])) - + distribution = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi) + samples = distribution.rsample(sample_shape=torch.Size([3])) assert samples.requires_grad == True -def test_log_prob(): +def test_log_prob_shape(): + alpha = torch.tensor([1.0]) + beta = torch.tensor([2.0]) + gamma = torch.tensor([-3.0]) + xi = torch.tensor([4.0]) + distribution = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi) + value = torch.tensor([1.0]) + log_prob = distribution.log_prob(value) + assert log_prob.shape == torch.Size([1]) + + +def test_log_prob_grad(): alpha = torch.tensor([1.0], requires_grad=True) beta = torch.tensor([2.0], requires_grad=True) gamma = torch.tensor([-3.0], requires_grad=True) xi = torch.tensor([4.0], requires_grad=True) - distr = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi) - + distribution = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi) value = torch.tensor([1.0]) - log_prob = distr.log_prob(value) - assert log_prob.shape == torch.Size([1]) - assert log_prob.requires_grad == True + log_prob = distribution.log_prob(value) + assert log_prob.requires_grad == True \ No newline at end of file diff --git a/tests/distributions/test_InvertibleGaussian.py b/tests/distributions/test_InvertibleGaussian.py index e28b2df..90216a4 100644 --- a/tests/distributions/test_InvertibleGaussian.py +++ b/tests/distributions/test_InvertibleGaussian.py @@ -1,15 +1,14 @@ import torch import sys -sys.path.append("../../src") -from relaxit.distributions.InvertibleGaussian import InvertibleGaussian +from relaxit.distributions import InvertibleGaussian # Testing reparameterized sampling from the InvertibleGaussian distribution def test_sample_shape(): - loc = torch.zeros(3, 4, 5, requires_grad=True) - scale = torch.ones(3, 4, 5, requires_grad=True) + loc = torch.zeros(3, 4, 5) + scale = torch.ones(3, 4, 5) temperature = torch.tensor([1e-0]) distribution = InvertibleGaussian(loc, scale, temperature) sample = distribution.rsample() @@ -26,8 +25,8 @@ def test_sample_grad(): def test_log_prob_shape(): - loc = torch.zeros(3, 4, 5, requires_grad=True) - scale = torch.ones(3, 4, 5, requires_grad=True) + loc = torch.zeros(3, 4, 5) + scale = torch.ones(3, 4, 5) temperature = torch.tensor([1e-0]) distribution = InvertibleGaussian(loc, scale, temperature) value = 0.5 * torch.ones(3, 4, 6) diff --git a/tests/distributions/test_LogisticNormalSoftmax.py b/tests/distributions/test_LogisticNormalSoftmax.py index 135c4bf..7c32e96 100644 --- a/tests/distributions/test_LogisticNormalSoftmax.py +++ b/tests/distributions/test_LogisticNormalSoftmax.py @@ -1,15 +1,14 @@ import torch import sys -sys.path.append("../../src") -from relaxit.distributions.LogisticNormalSoftmax import LogisticNormalSoftmax +from relaxit.distributions import LogisticNormalSoftmax # Testing reparameterized sampling from the LogisticNormalSoftmax distribution def test_sample_shape(): - loc = torch.zeros(3, 4, 5, requires_grad=True) - scale = torch.ones(3, 4, 5, requires_grad=True) + loc = torch.zeros(3, 4, 5) + scale = torch.ones(3, 4, 5) distribution = LogisticNormalSoftmax(loc, scale) sample = distribution.rsample() assert sample.shape == torch.Size([3, 4, 5]) diff --git a/tests/distributions/test_STEstimator.py b/tests/distributions/test_STEstimator.py deleted file mode 100644 index 524f5cf..0000000 --- a/tests/distributions/test_STEstimator.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -import sys, os - -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src")) -) -from relaxit.distributions.STEstimator import StraightThroughEstimator - -# Testing reparameterized sampling and log prob from the StraightThroughEstimator distribution - - -def test_sample_shape(): - a = torch.tensor([[1.0, 2.0, 3.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]) - distribution = StraightThroughEstimator(logits=a) - sample = distribution.rsample() - assert sample.shape == a.shape - - -def test_sample_grad(): - a = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True) - distribution = StraightThroughEstimator(logits=a) - sample = distribution.rsample() - assert sample.requires_grad == True - - -def test_log_prob(): - a = torch.tensor([1.0, 2.0, 3.0]) - distribution = StraightThroughEstimator(logits=a) - value = torch.tensor([1.0, 1.0, 1.0]) - log_prob = distribution.log_prob(value) - assert log_prob.shape == torch.Size([3]) diff --git a/tests/distributions/test_StochasticTimesSmooth.py b/tests/distributions/test_StochasticTimesSmooth.py new file mode 100644 index 0000000..92e3946 --- /dev/null +++ b/tests/distributions/test_StochasticTimesSmooth.py @@ -0,0 +1,37 @@ +import torch +import sys, os + +from relaxit.distributions import StochasticTimesSmooth + +# Testing reparameterized sampling from the StochasticTimesSmooth distribution + + +def test_sample_shape(): + logits = torch.tensor([1., 2., 3.]) + distribution = StochasticTimesSmooth(logits=logits) + samples = distribution.rsample() + assert samples.shape == torch.Size([3]) + + +def test_sample_grad(): + logits = torch.tensor([1., 2., 3.], requires_grad=True) + distribution = StochasticTimesSmooth(logits=logits) + samples = distribution.rsample() + assert samples.requires_grad == True + + +def test_log_prob_shape(): + logits = torch.tensor([1., 2., 3.]) + distribution = StochasticTimesSmooth(logits=logits) + value = torch.Tensor([1., 1., 1.]) + log_prob = distribution.log_prob(value) + print('log_prob.shape:', log_prob.shape) + assert log_prob.shape == torch.Size([3]) + + +def test_log_prob_grad(): + logits = torch.tensor([1., 2., 3.], requires_grad=True) + distribution = StochasticTimesSmooth(logits=logits) + value = torch.Tensor([1., 1., 1.]) + log_prob = distribution.log_prob(value) + assert log_prob.requires_grad == True \ No newline at end of file diff --git a/tests/distributions/test_StraightThroughBernoulli.py b/tests/distributions/test_StraightThroughBernoulli.py index 9c75c89..8821b1c 100644 --- a/tests/distributions/test_StraightThroughBernoulli.py +++ b/tests/distributions/test_StraightThroughBernoulli.py @@ -1,29 +1,36 @@ import torch import sys, os -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src")) -) -from relaxit.distributions.StraightThroughBernoulli import StraightThroughBernoulli +from relaxit.distributions import StraightThroughBernoulli -# Testing reparameterized sampling from the GaussianRelaxedBernoulli distribution +# Testing reparameterized sampling and log prob from the StraightThroughBernoulli distribution def test_sample_shape(): - a = torch.tensor([1, 2, 3]) - distr = StraightThroughBernoulli(a=a) - samples = distr.rsample() - assert samples.shape == torch.Size([3]) - - -# def test_sample_grad(): -# a = torch.tensor([1., 2., 3.], requires_grad=True) -# distr = StraightThroughBernoulli(a = a) -# samples = distr.rsample() -# assert samples.requires_grad == True - -# def test_log_prob(): -# a = torch.tensor([1, 2, 3]) -# distr = StraightThroughBernoulli(a = a) -# value = torch.Tensor([1.]) -# print(distr.log_prob(value)) + logits = torch.tensor([[1.0, 2.0, 3.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]) + distribution = StraightThroughBernoulli(logits=logits) + sample = distribution.rsample() + assert sample.shape == logits.shape + + +def test_sample_grad(): + logits = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True) + distribution = StraightThroughBernoulli(logits=logits) + sample = distribution.rsample() + assert sample.requires_grad == True + + +def test_log_prob_shape(): + logits = torch.tensor([1.0, 2.0, 3.0]) + distribution = StraightThroughBernoulli(logits=logits) + value = torch.tensor([1.0, 1.0, 1.0]) + log_prob = distribution.log_prob(value) + assert log_prob.shape == torch.Size([3]) + + +def test_log_prob_grad(): + logits = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + distribution = StraightThroughBernoulli(logits=logits) + value = torch.tensor([1.0, 1.0, 1.0]) + log_prob = distribution.log_prob(value) + assert log_prob.requires_grad == True diff --git a/tests/distributions/test_approx.py b/tests/distributions/test_approx.py index d3412b0..9a312ba 100644 --- a/tests/distributions/test_approx.py +++ b/tests/distributions/test_approx.py @@ -1,7 +1,6 @@ import torch import sys -sys.path.append("../../src") from relaxit.distributions.LogisticNormalSoftmax import LogisticNormalSoftmax from relaxit.distributions.approx import ( lognorm_approximation_fn, diff --git a/tests/distributions/test_kl.py b/tests/distributions/test_kl.py index fc74420..10710b6 100644 --- a/tests/distributions/test_kl.py +++ b/tests/distributions/test_kl.py @@ -1,8 +1,7 @@ import torch import sys -sys.path.append("../../src") -from relaxit.distributions.InvertibleGaussian import InvertibleGaussian +from relaxit.distributions import InvertibleGaussian from relaxit.distributions.kl import kl_divergence # Testing KL-divergence between two IntertibleGaussian distributions