Skip to content

Commit

Permalink
upd demo and readme
Browse files Browse the repository at this point in the history
  • Loading branch information
kisnikser committed Dec 6, 2024
1 parent dd2a400 commit 0c497d8
Show file tree
Hide file tree
Showing 10 changed files with 302 additions and 50 deletions.
16 changes: 9 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
6. [Tests](https://github.com/intsystems/discrete-variables-relaxation/tree/main/tests)

## 💡 Motivation
For lots of mathematical problems we need an ability to sample discrete random variables.
The problem is that due to continuos nature of deep learning optimization, the usage of truely discrete random variables is infeasible.
Thus we use different relaxation method.
For lots of mathematical problems we need an ability to sample discrete random variables.
The problem is that due to continuous nature of deep learning optimization, the usage of truly discrete random variables is infeasible.
Thus we use different relaxation methods.
One of them, [Concrete distribution](https://arxiv.org/abs/1611.00712) or [Gumbel-Softmax](https://arxiv.org/abs/1611.01144) (this is one distribution proposed in parallel by two research groups) is implemented in different DL packages.
In this project we implement different alternatives to it.
<div align="center">
Expand All @@ -65,11 +65,11 @@ In this project we implement different alternatives to it.
- [x] [Relaxed Bernoulli](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/GaussianRelaxedBernoulli.py), also see [📝 paper](http://proceedings.mlr.press/v119/yamada20a/yamada20a.pdf)
- [x] [Correlated relaxed Bernoulli](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/CorrelatedRelaxedBernoulli.py), also see [📝 paper](https://openreview.net/pdf?id=oDFvtxzPOx)
- [x] [Gumbel-softmax TOP-K](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/GumbelSoftmaxTopK.py), also see [📝 paper](https://arxiv.org/pdf/1903.06059)
- [x] [Straight-Through Bernoulli](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/StraightThroughBernoulli.py), also see [📝 paper](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235)
- [x] [Straight-Through Bernoulli](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/StraightThroughBernoulli.py), also see [📝 paper](https://arxiv.org/abs/1910.02176)
- [x] [Stochastic Times Smooth](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/StochasticTimesSmooth.py), also see [📝 paper](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235)
- [x] [Invertible Gaussian](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/InvertibleGaussian.py) with [KL implemented](https://github.com/intsystems/discrete-variables-relaxation/blob/f398ebbbac703582de392bc33d89b55c6c99ea68/src/relaxit/distributions/kl.py#L7), also see [📝 paper](https://arxiv.org/abs/1912.09588)
- [x] [Hard Concrete](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/HardConcrete.py), also see [📝 paper](https://arxiv.org/pdf/1712.01312)
- [x] [REINFORCE](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/CorrelatedRelaxedBernoulli.py), also see [📺 slides](http://www.cs.toronto.edu/~tingwuwang/REINFORCE.pdf)
- [x] [Logit-Normal](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/LogisticNormalSoftmax.py) and [Laplace-form approximation of Dirichlet](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/approx.py), also see [ℹ️ wiki](https://en.wikipedia.org/wiki/Logit-normal_distribution) and [💻 stackexchange](https://stats.stackexchange.com/questions/535560/approximating-the-logit-normal-by-dirichlet)
- [x] [Logistic-Normal](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/LogisticNormalSoftmax.py) and [Laplace-form approximation of Dirichlet](https://github.com/intsystems/discrete-variables-relaxation/blob/main/src/relaxit/distributions/approx.py), also see [ℹ️ wiki](https://en.wikipedia.org/wiki/Logit-normal_distribution) and [💻 stackexchange](https://stats.stackexchange.com/questions/535560/approximating-the-logit-normal-by-dirichlet)

## 🛠️ Install

Expand Down Expand Up @@ -116,13 +116,15 @@ print('sample.requires_grad:', sample.requires_grad)
| ![Laplace Bridge](https://github.com/user-attachments/assets/ac5d5a71-e7d7-4ec3-b9ca-9b72d958eb41) | ![REINFORCE](https://gymnasium.farama.org/_images/acrobot.gif) | ![VAE](https://github.com/user-attachments/assets/937585c4-df84-4ab0-a2b9-ea6a73997793) |
| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intsystems/discrete-variables-relaxation/blob/main/demo/laplace-bridge.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intsystems/discrete-variables-relaxation/blob/main/demo/reinforce.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intsystems/discrete-variables-relaxation/blob/main/demo/demo.ipynb) |

For demonstration purposes, we divide our algorithms in three different groups. Each group relates to the particular demo code:
For demonstration purposes, we divide our algorithms in three[^*] different groups. Each group relates to the particular demo code:
- [Laplace bridge between Dirichlet and LogisticNormal distributions](https://github.com/intsystems/discrete-variables-relaxation/blob/main/demo/laplace-bridge.ipynb)
- [REINFORCE](https://github.com/intsystems/discrete-variables-relaxation/blob/main/demo/reinforce.ipynb)
- [Other relaxation methods](https://github.com/intsystems/discrete-variables-relaxation/blob/main/demo/demo.ipynb)

We describe our demo experiments [here](https://github.com/intsystems/discrete-variables-relaxation/tree/main/demo).

[^*]: We also implement REINFORCE algorithm as a *score function* estimator alternative for our relaxation methods that are inherently *pathwise derivative* estimators. This one is implemented only for demo experiments and is not included into the source code of package.

## 📚 Stack
Some of the alternatives for GS were implemented in [pyro](https://docs.pyro.ai/en/dev/distributions.html), so we base our library on their codebase.

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
33 changes: 18 additions & 15 deletions demo/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Demo experiments code
This repository contains our demo code for various experiments. The main demo code can be found in the notebook `demo/demo.ipynb`. Open the notebook and run the cells to see the demonstration in action. For additional experiments, refer to the section [Additional experiments](#experiments). Before starting any experiments, ensure you follow all the installation steps outlined in the [Installation](#installation) section.
This repository contains our demo code for various experiments. The main demo code can be found in the notebook `demo.ipynb`. Open the notebook and run the cells to see the demonstration in action. For additional experiments, refer to the section [Additional experiments](#experiments). Before starting any experiments, ensure you follow all the installation steps outlined in the [Installation](#installation) section.

## Installation <a name="installation"></a>

Expand Down Expand Up @@ -27,33 +27,36 @@ For additional demo experiments, we have implemented Variational Autoencoders (V
1. **Train and save the models:**
To run the additional demo code with VAEs, you need to train all the models and save their results. Execute the following commands:
```bash
# VAE with Gaussian Bernoulli latent space
python vae_gaussian_bernoulli.py

# VAE with Correlated Bernoulli latent space
python vae_correlated_bernoulli.py


# VAE with Gaussian Bernoulli latent space
python vae_gaussian_bernoulli.py

# VAE with Gumbel-Softmax top-K latent space
python vae_gumbel_softmax_topk.py

# VAE with Hard Concrete latent space
python vae_hard_concrete.py

# VAE with Straight Through Bernoullii latent space
python vae_straight_through_bernoulli.py

# VAE with Invertible Gaussian latent space
python vae_invertible_gaussian.py

# VAE with Gumbel Softmax TopK latent space
python vae_gumbel_softmax_topk.py
# VAE with Stochastic Times Smooth latent space
python vae_stochastic_times_smooth.py

# VAE with Straight Through Bernoullii latent space
python vae_straight_through_bernoulli.py
```
2. **View the results:**
After completing the training and testing of all the models, you can find the results of sampling and reconstruction methods in the directory `demo/results`.
After completing the training and testing of all the models, you can find the results of sampling and reconstruction methods in the directory `results`.

Moreover, we conducted experiments with Laplace Bridge between LogisticNormal and Dirichlet distributions. We use two-side Laplace bridge to approximate:
- Dirichlet using logisticNormal
- LogisticNormal using Dirichlet
- Dirichlet using Logistic-Normal
- Logistic-Normal using Dirichlet

These experiments aim to find the best parameters to make the distributions nearly identical on the simplex. The experiments can be found in the notebook `demo/laplace-bridge.ipynb`.
These experiments aim to find the best parameters to make the distributions nearly identical on the simplex. The experiments can be found in the notebook `laplace-bridge.ipynb`.

Furthermore, the Reinforce algorithm is applied in the [Acrobot environment](https://www.gymlibrary.dev/environments/classic_control/acrobot/). Detailed experiments can be viewed in the notebook `demo/reinforce.ipynb`. A script `demo/reinforce.py` can also be used for training.
Furthermore, the Reinforce algorithm is applied in the [Acrobot environment](https://www.gymlibrary.dev/environments/classic_control/acrobot/). Detailed experiments can be viewed in the notebook `reinforce.ipynb`. A script `reinforce.py` can also be used for training.


21 changes: 10 additions & 11 deletions demo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
gym==0.26.2
numpy==1.24.1
pyro_ppl==1.9.1
setuptools==65.5.0
torch==2.5.1
torchvision==0.20.1
matplotlib==3.9.2
networkx==3.3
tqdm==4.66.5
pillow==10.4.0
relaxit==0.1.2
gym>=0.26.2
numpy>=1.24.1
torchvision>=0.20.1
matplotlib>=3.9.2
networkx>=3.3
tqdm>=4.66.5
pillow>=10.4.0
relaxit==1.0.1
ftfy
regex
252 changes: 252 additions & 0 deletions demo/vae_stochastic_times_smooth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import os
import argparse
import numpy as np
import torch
import sys
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

from relaxit.distributions import StochasticTimesSmooth


def parse_arguments() -> argparse.Namespace:
"""
Parse command line arguments.
Returns:
argparse.Namespace: Parsed command line arguments.
"""
parser = argparse.ArgumentParser(description="VAE MNIST Example")
parser.add_argument(
"--batch-size",
type=int,
default=128,
metavar="N",
help="input batch size for training (default: 128)",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
metavar="N",
help="number of epochs to train (default: 10)",
)
parser.add_argument(
"--no-cuda", action="store_true", default=False, help="enables CUDA training"
)
parser.add_argument(
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument(
"--log_interval",
type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status",
)
return parser.parse_args()


args = parse_arguments()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if args.cuda else "cpu")

os.makedirs("./results/vae_stochastic_times_smooth", exist_ok=True)

kwargs = {"num_workers": 1, "pin_memory": True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"./data", train=True, download=True, transform=transforms.ToTensor()
),
batch_size=args.batch_size,
shuffle=True,
**kwargs
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST("./data", train=False, transform=transforms.ToTensor()),
batch_size=args.batch_size,
shuffle=True,
**kwargs
)

steps = 0


class VAE(nn.Module):
"""
Variational Autoencoder (VAE) with StochasticTimesSmooth distribution.
"""

def __init__(self) -> None:
super(VAE, self).__init__()

self.fc1 = nn.Linear(784, 400)
self.fc2 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)

def encode(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode the input by passing through the encoder network
and return the latent code.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Latent code.
"""
h1 = F.relu(self.fc1(x))
return self.fc2(h1)

def decode(self, z: torch.Tensor) -> torch.Tensor:
"""
Decode the latent code by passing through the decoder network
and return the reconstructed input.
Args:
z (torch.Tensor): Latent code.
Returns:
torch.Tensor: Reconstructed input.
"""
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass through the VAE.
Args:
x (torch.Tensor): Input tensor.
Returns:
tuple[torch.Tensor, torch.Tensor]: Reconstructed input and latent code.
"""
logits = self.encode(x.view(-1, 784))
q_z = StochasticTimesSmooth(logits=logits)
probs = q_z.probs
z = q_z.rsample()
return self.decode(z), probs


model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


def loss_function(
recon_x: torch.Tensor,
x: torch.Tensor,
probs: torch.Tensor,
prior: float = 0.5,
eps: float = 1e-10,
) -> torch.Tensor:
"""
Compute the loss function for the VAE.
Args:
recon_x (torch.Tensor): Reconstructed input.
x (torch.Tensor): Original input.
probs (torch.Tensor): Probabilities for Bernoulli distribution in latent space.
prior (float): Prior probability.
eps (float): Small value to avoid log(0).
Returns:
torch.Tensor: Loss value.
"""
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction="sum")
t1 = probs * ((probs + eps) / prior).log()
t2 = (1 - probs) * ((1 - probs + eps) / (1 - prior)).log()
KLD = torch.sum(t1 + t2, dim=-1).sum()

return BCE + KLD


def train(epoch: int) -> None:
"""
Train the VAE for one epoch.
Args:
epoch (int): Current epoch number.
"""
global steps
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, probs = model(data)
loss = loss_function(recon_batch, data, probs)
loss.backward()
train_loss += loss.item()
optimizer.step()

if batch_idx % args.log_interval == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.item() / len(data),
)
)

steps += 1

print(
"====> Epoch: {} Average loss: {:.4f}".format(
epoch, train_loss / len(train_loader.dataset)
)
)


def test(epoch: int) -> None:
"""
Test the VAE for one epoch.
Args:
epoch (int): Current epoch number.
"""
model.eval()
test_loss = 0
with torch.no_grad():
for i, (data, _) in enumerate(test_loader):
data = data.to(device)
recon_batch, probs = model(data)
test_loss += loss_function(recon_batch, data, probs).item()
if i == 0:
n = min(data.size(0), 8)
comparison = torch.cat(
[data[:n], recon_batch.view(args.batch_size, 1, 28, 28)[:n]]
)
save_image(
comparison.cpu(),
"results/vae_stochastic_times_smooth/reconstruction_"
+ str(epoch)
+ ".png",
nrow=n,
)

test_loss /= len(test_loader.dataset)
print("====> Test set loss: {:.4f}".format(test_loss))


if __name__ == "__main__":
for epoch in range(1, args.epochs + 1):
train(epoch)
test(epoch)
with torch.no_grad():
sample = np.random.binomial(1, 0.5, size=(64, 20))
sample = torch.from_numpy(np.float32(sample)).to(device)
sample = model.decode(sample).cpu()
save_image(
sample.view(64, 1, 28, 28),
"results/vae_stochastic_times_smooth/sample_" + str(epoch) + ".png",
)
Loading

0 comments on commit 0c497d8

Please sign in to comment.