Skip to content

Commit

Permalink
Merge branch 'main' into nikita
Browse files Browse the repository at this point in the history
  • Loading branch information
kisnikser authored Nov 16, 2024
2 parents a9e1d89 + 4f671c2 commit aaf0e4c
Show file tree
Hide file tree
Showing 26 changed files with 1,752 additions and 95 deletions.
47 changes: 47 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Build documentation

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:
build:
runs-on: ubuntu-latest

strategy:
matrix:
python-version: [3.7]

steps:
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.7

- name: Checkout 🛎️
uses: actions/checkout@v2

- name: Install Dependencies
run: |
pip install -U sphinx
pip install -U sphinx-rtd-theme
pip install torch
pip install pyro-ppl
ls ./
- name: Build Docs
run: |
sphinx-build -b html ./doc/source/ public
touch public/.nojekyll
- name: Deploy 🚀
uses: JamesIves/github-pages-deploy-action@releases/v3
with:
ACCESS_TOKEN: ${{ secrets.ACCESS_TOKEN }}
BRANCH: gh-pages
FOLDER: public
36 changes: 36 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Testing

on: [push, pull_request, workflow_dispatch]

jobs:
build:
runs-on: ubuntu-latest

strategy:
matrix:
python-version: [3.7]

steps:
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.7

- name: Checkout 🛎️
uses: actions/checkout@v2

- name: Install Dependencies
run: |
pip install torch
pip install pyro-ppl
pip install -U pytest pytest-cov
ls ./
- name: Testing
run: |
PYTHONPATH=src/ pytest tests/ --cov=relaxit --cov-report=xml
# - name: Upload to Codecov
# uses: codecov/codecov-action@v2
# with:
# files: ./coverage.xml,
# fail_ci_if_error: true
# verbose: true
20 changes: 16 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,34 @@
Discrete variables relaxation
</div>

## 📬 Assets

1. [Technical Meeting 1 - Presentation](https://github.com/intsystems/discrete-variables-relaxation/blob/main/assets/presentation_tm1.pdf)
2. [Technical Meeting 2 - Jupyter Notebook](https://github.com/intsystems/discrete-variables-relaxation/blob/main/basic/basic_code.ipynb)
3. [Blog Post](https://github.com/intsystems/discrete-variables-relaxation/blob/main/assets/blog-post.pdf)
4. [Documentation](https://intsystems.github.io/discrete-variables-relaxation/train.html)
5. [Tests](https://github.com/intsystems/discrete-variables-relaxation/tree/main/tests)

## 💡 Motivation
For lots of mathematical problems we need an ability to sample discrete random variables.
The problem is that due to continuos nature of deep learning optimization, the usage of truely discrete random variables is infeasible.
Thus we use different relaxation method.
One of them, [Concrete distribution](https://arxiv.org/abs/1611.00712) or [Gumbel-softmax](https://arxiv.org/abs/1611.01144) (this is one distribution proposed in parallel by two research groups) is implemented in different DL packages.
In this project we implement different alternatives to it.

<div align="center">
<img src="assets/overview.png"/>
</div>

## 🗃 Algorithms to implement (from simplest to hardest)
- [x] [Relaxed Bernoulli](http://proceedings.mlr.press/v119/yamada20a/yamada20a.pdf)
- [ ] [Correlated relaxed Bernoulli](https://openreview.net/pdf?id=oDFvtxzPOx)
- [ ] [Gumbel-softmax TOP-K](https://arxiv.org/pdf/1903.06059)
- [ ] [Straight-Through Bernoulli, distribution (don't mix with Relaxed distribution from pyro)](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235)
- [ ] [Invertible Gaussian reparametrization](https://arxiv.org/abs/1912.09588) with KL implemented
- [ ] [Hard concrete](https://arxiv.org/pdf/1712.01312)
- [x] [Straight-Through Bernoulli, distribution (don't mix with Relaxed distribution from pyro)](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235)
- [x] [Invertible Gaussian reparametrization](https://arxiv.org/abs/1912.09588) with KL implemented
- [x] [Hard concrete](https://arxiv.org/pdf/1712.01312)
- [ ] [REINFORCE](http://www.cs.toronto.edu/~tingwuwang/REINFORCE.pdf) (not a distribution actually, think how to integrate it with other distributions)
- [ ] [Logit-normal distribution](https://en.wikipedia.org/wiki/Logit-normal_distribution) with KL implemented and [Laplace-form approximation of Dirichlet](https://stats.stackexchange.com/questions/535560/approximating-the-logit-normal-by-dirichlet)
- [x] [Logit-normal distribution](https://en.wikipedia.org/wiki/Logit-normal_distribution) and [Laplace-form approximation of Dirichlet](https://stats.stackexchange.com/questions/535560/approximating-the-logit-normal-by-dirichlet)

## 📚 Recommended stack
Some of the alternatives for GS were implemented in [pyro](https://docs.pyro.ai/en/dev/distributions.html), so it might be useful to play with them also.
Expand Down
100 changes: 100 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
.. container::

::
.. image:: ./assets/logo.svg
:width: 200px
:align: center

Just Relax It
Discrete variables relaxation

📬 Assets
---------

1. `Technichal Meeting 1 -
Presentation <https://github.com/intsystems/discrete-variables-relaxation/blob/main/assets/presentation_tm1.pdf>`__
2. `Technichal Meeting 2 - Jupyter
Notebook <https://github.com/intsystems/discrete-variables-relaxation/blob/main/basic/basic_code.ipynb>`__
3. `Blog
Post <https://github.com/intsystems/discrete-variables-relaxation/blob/main/assets/blog-post.pdf>`__
4. `Documentation <https://intsystems.github.io/discrete-variables-relaxation/>`__

💡 Motivation
-------------

For lots of mathematical problems we need an ability to sample discrete
random variables. The problem is that due to continuos nature of deep
learning optimization, the usage of truely discrete random variables is
infeasible. Thus we use different relaxation method. One of them,
`Concrete distribution <https://arxiv.org/abs/1611.00712>`__ or
`Gumbel-softmax <https://arxiv.org/abs/1611.01144>`__ (this is one
distribution proposed in parallel by two research groups) is implemented
in different DL packages. In this project we implement different
alternatives to it.

.. container::

::

<img src="assets/overview.png"/>

🗃 Algorithms to implement (from simplest to hardest)
----------------------------------------------------

- ☒ `Relaxed
Bernoulli <http://proceedings.mlr.press/v119/yamada20a/yamada20a.pdf>`__
- ☐ `Correlated relaxed
Bernoulli <https://openreview.net/pdf?id=oDFvtxzPOx>`__
- ☐ `Gumbel-softmax TOP-K <https://arxiv.org/pdf/1903.06059>`__
- ☒ `Straight-Through Bernoulli, distribution (don’t mix with Relaxed
distribution from
pyro) <https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235>`__
- ☐ `Invertible Gaussian
reparametrization <https://arxiv.org/abs/1912.09588>`__ with KL
implemented
- ☒ `Hard concrete <https://arxiv.org/pdf/1712.01312>`__
- ☐ `REINFORCE <http://www.cs.toronto.edu/~tingwuwang/REINFORCE.pdf>`__
(not a distribution actually, think how to integrate it with other
distributions)
- ☐ `Logit-normal
distribution <https://en.wikipedia.org/wiki/Logit-normal_distribution>`__
with KL implemented and `Laplace-form approximation of
Dirichlet <https://stats.stackexchange.com/questions/535560/approximating-the-logit-normal-by-dirichlet>`__

📚 Recommended stack
--------------------

Some of the alternatives for GS were implemented in
`pyro <https://docs.pyro.ai/en/dev/distributions.html>`__, so it might
be useful to play with them also.

🧩 Problem details
------------------

To make to library constistent, we integrate imports of distributions
from pyro and pytorch into the library, so that all the categorical
distributions can be imported from one entrypoint.

👥 Contributors
---------------

- `Daniil Dorin <https://github.com/DorinDaniil>`__ (Basic code
writing, Final demo, Algorithms)
- `Igor Ignashin <https://github.com/ThunderstormXX>`__ (Project
wrapping, Documentation writing, Algorithms)
- `Nikita Kiselev <https://github.com/kisnikser>`__ (Project planning,
Blog post, Algorithms)
- `Andrey Veprikov <https://github.com/Vepricov>`__ (Tests writing,
Documentation writing, Algorithms)

🔗 Useful links
---------------

- `About top-k
GS <https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/DL2/sampling/subsets.html>`__
- `VAE implementation with different latent
distributions <https://github.com/kampta/pytorch-distributions>`__
- `KL divergence between Dirichlet and Logistic-Normal implemented in
R <https://rdrr.io/cran/Compositional/src/R/kl.diri.normal.R>`__
- `About score function (SF) and pathwise derivate (PD) estimators, VAE
and REINFORCE <https://arxiv.org/abs/1506.05254>`__
Binary file added assets/blog-post.pdf
Binary file not shown.
Binary file added assets/overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
313 changes: 284 additions & 29 deletions basic/basic_code.ipynb

Large diffs are not rendered by default.

151 changes: 151 additions & 0 deletions demo/vae_correlated_bernoulli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import os
import argparse
import numpy as np
import torch
import sys
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src')))
from relaxit.distributions import CorrelatedRelaxedBernoulli

parser = argparse.ArgumentParser(description='VAE MNIST Example')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if args.cuda else "cpu")

os.makedirs('./results/vae_correlated_bernoulli', exist_ok=True)

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=True,
transform=transforms.ToTensor()),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=False, transform=transforms.ToTensor()),
batch_size=args.batch_size, shuffle=True, **kwargs)

steps = 0


class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()

self.fc1 = nn.Linear(784, 400)
self.fc2 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)

# Initialize R as an identity matrix
self.R = torch.eye(20, device=device)
self.tau = torch.tensor(0.1, device=device)

def encode(self, x):
h1 = F.relu(self.fc1(x))
return torch.sigmoid(self.fc2(h1))

def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))

def forward(self, x, hard=False):
pi = self.encode(x.view(-1, 784))
pi = torch.clamp(pi, min=1e-6, max=1-1e-6)
q_z = CorrelatedRelaxedBernoulli(pi, self.R, self.tau)
z = q_z.rsample() # sample with reparameterization

if hard:
# No step function in torch, so using sign instead
z_hard = 0.5 * (torch.sign(z) + 1)
z = z + (z_hard - z).detach()

return self.decode(z), pi


model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, pi, prior=0.5, eps=1e-10):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
# You can also compute p(x|z) as below, for binary output it reduces
# to binary cross entropy error, for gaussian output it reduces to
t1 = pi * ((pi + eps) / prior).log()
t2 = (1 - pi) * ((1 - pi + eps) / (1 - prior)).log()
KLD = torch.sum(t1 + t2, dim=-1).sum()

return BCE + KLD


def train(epoch):
global steps
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, pi = model(data)
loss = loss_function(recon_batch, data, pi)
loss.backward()
train_loss += loss.item()
optimizer.step()

if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))

steps += 1

print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
model.eval()
test_loss = 0
with torch.no_grad():
for i, (data, _) in enumerate(test_loader):
data = data.to(device)
recon_batch, pi = model(data)
test_loss += loss_function(recon_batch, data, pi).item()
if i == 0:
n = min(data.size(0), 8)
comparison = torch.cat([data[:n],
recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
save_image(comparison.cpu(),
'results/vae_correlated_bernoulli/reconstruction_' + str(epoch) + '.png', nrow=n)

test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))


if __name__ == "__main__":
for epoch in range(1, args.epochs + 1):
train(epoch)
test(epoch)
with torch.no_grad():
sample = np.random.binomial(1, 0.5, size=(64, 20))
sample = torch.from_numpy(np.float32(sample)).to(device)
sample = model.decode(sample).cpu()
save_image(sample.view(64, 1, 28, 28),
'results/vae_correlated_bernoulli/sample_' + str(epoch) + '.png')
Loading

0 comments on commit aaf0e4c

Please sign in to comment.