Skip to content

Commit

Permalink
Merge branch 'main' of github.com:intsystems/discrete-variables-relax…
Browse files Browse the repository at this point in the history
…ation
  • Loading branch information
Igoreshka committed Nov 26, 2024
2 parents 5d2cbc1 + f5bd81d commit 1430504
Show file tree
Hide file tree
Showing 7 changed files with 372 additions and 89 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ jobs:
run: |
python src/badge_generator.py
- name: Commit coverage badge
run: |
git config --global user.name 'github-actions[bot]'
git config --global user.email 'github-actions[bot]@users.noreply.github.com'
git add coverage-badge.svg
git commit -m "Update coverage badge"
git push
# - name: Commit coverage badge
# run: |
# git config --global user.name 'github-actions[bot]'
# git config --global user.email 'github-actions[bot]@users.noreply.github.com'
# git add coverage-badge.svg
# git commit -m "Update coverage badge"
# git push
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
<a href="https://docs.pyro.ai/en/dev/distributions.html">
<img alt="Inspired by Pyro" src="https://img.shields.io/badge/Inspired_by_Pyro-fecd08">
</a>
<img src="coverage-badge.svg" />
</p>

<p align="center">
<!-- <a href="https://github.com/intsystems/discrete-variables-relaxation/actions"> -->
<!-- <img alt="Tests Passing" src="https://github.com/intsystems/discrete-variables-relaxation/workflows/Test/badge.svg" /> -->
<!-- </a> -->
<!-- <a href="https://codecov.io/gh/intsystems/discrete-variables-relaxation"> -->
<!-- <img alt="Tests Coverage" src="https://codecov.io/gh/intsystems/discrete-variables-relaxation/branch/main/graph/badge.svg" /> -->
<!-- </a> -->
<a href="https://github.com/intsystems/discrete-variables-relaxation/tree/main/tests">
<img alt="Coverage" src="coverage-badge.svg" />
</a>
<a href="https://intsystems.github.io/discrete-variables-relaxation">
<img alt="Docs" src="https://github.com/intsystems/discrete-variables-relaxation/actions/workflows/docs.yml/badge.svg" />
</a>
</p>

<p align="center">
<a href="https://github.com/intsystems/discrete-variables-relaxation/blob/main/LICENSE">
<img alt="License" src="https://badgen.net/github/license/intsystems/discrete-variables-relaxation?color=green" />
</a>
Expand Down Expand Up @@ -57,11 +59,11 @@ In this project we implement different alternatives to it.
## 🗃 Algorithms to implement (from simplest to hardest)
- [x] [Relaxed Bernoulli](http://proceedings.mlr.press/v119/yamada20a/yamada20a.pdf)
- [x] [Correlated relaxed Bernoulli](https://openreview.net/pdf?id=oDFvtxzPOx)
- [ ] [Gumbel-softmax TOP-K](https://arxiv.org/pdf/1903.06059)
- [x] [Gumbel-softmax TOP-K](https://arxiv.org/pdf/1903.06059)
- [x] [Straight-Through Bernoulli, distribution (don't mix with Relaxed distribution from pyro)](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235)
- [x] [Invertible Gaussian reparametrization](https://arxiv.org/abs/1912.09588) with KL implemented
- [x] [Hard concrete](https://arxiv.org/pdf/1712.01312)
- [ ] [REINFORCE](http://www.cs.toronto.edu/~tingwuwang/REINFORCE.pdf) (not a distribution actually, think how to integrate it with other distributions)
- [x] [REINFORCE](http://www.cs.toronto.edu/~tingwuwang/REINFORCE.pdf) (not a distribution actually, think how to integrate it with other distributions)
- [x] [Logit-normal distribution](https://en.wikipedia.org/wiki/Logit-normal_distribution) and [Laplace-form approximation of Dirichlet](https://stats.stackexchange.com/questions/535560/approximating-the-logit-normal-by-dirichlet)

## 📚 Recommended stack
Expand Down
267 changes: 267 additions & 0 deletions basic/reinforce.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion coverage-badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 1 addition & 2 deletions demo/reinforce.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import argparse, os, sys
import argparse
import gym
import numpy as np
from itertools import count
Expand All @@ -21,7 +21,6 @@
help='interval between training status logs (default: 10)')
args = parser.parse_args()


env = gym.make('Acrobot-v1')
env.reset(seed=args.seed)
torch.manual_seed(args.seed)
Expand Down
121 changes: 60 additions & 61 deletions src/relaxit/distributions/GumbelSoftmaxTopK.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,49 @@
import torch
import torch.nn.functional as F
from torch.distributions import constraints
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions import Gumbel
from torch.distributions import constraints

class GumbelSoftmaxTopK(TorchDistribution):
"""
r'''
Implimentation of the Gaussian-soft max topK trick from https://arxiv.org/pdf/1903.06059
:param a: logits, if not from Simples, we project a into it.
:param a: logits.
:type a: torch.Tensor
:param K: how many samples without replacement to pick.
:type K: int
:param support: support of the discrete distribution. If None, it will be `torch.arange(a.numel()).reshape(a.shape)`. It must be the same `shape` as `a`.
:type support: torch.Tensor
"""

arg_constraints = {'a': constraints.real}
:type K: torch.Tensor
:param tau: Temperature hyper-parameter.
:type tau: torch.Tensor
:param hard: if `True`, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd
:type hard: bool
:param validate_args: Whether to validate arguments.
:type validate_args: bool
'''

arg_constraints = {'a' : constraints.real,
'K' : constraints.positive_integer,
'tau': constraints.positive}
has_rsample = True

def __init__(self, a: torch.Tensor, K: int,
support: torch.Tensor = None, validate_args: bool = None):
"""
Initializes the GumbelSoftmaxTopK distribution.
Args:
- a (Tensor): logits, if not from Simples, we project a into it
- K (int): how many samples without replacement to pick
- support (Tensor): support of the discrete distribution. If None, it will be `torch.range(len(a))`. It must be the same `len` as `a`.
- validate_args (bool): Whether to validate arguments.
"""
self.a = a.float() / a.sum() # Ensure loc is a float tensor from simplex
self.gumbel = Gumbel(loc=0, scale=1, validate_args=validate_args)
if support is None:
self.supp = torch.arange(a.numel()).reshape(a.shape)
else:
if support.shape != a.shape:
raise ValueError("support and a must have the same shape")
self.supp = support
self.K = int(K) # Ensure K is a int number
def __init__(self, a: torch.Tensor, K: torch.Tensor,
tau: torch.Tensor, hard: bool = True, validate_args: bool = None):
r''' Initializes the GumbelSoftmaxTopK distribution.
:param a: logits.
:type a: torch.Tensor
:param K: how many samples without replacement to pick.
:type K: torch.Tensor
:param tau: Temperature hyper-parameter.
:type tau: torch.Tensor
:param hard: if `True`, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd
:type hard: bool
:param validate_args: Whether to validate arguments.
:type validate_args: bool
'''
self.a = a.float() # Ensure loc is a float tensor
self.K = K.int() # Ensure K is a int tensor
self.tau = tau
self.hard = hard
super().__init__(validate_args=validate_args)

@property
Expand All @@ -61,53 +66,45 @@ def event_shape(self) -> torch.Size:
"""
return torch.Size()

def rsample(self, sample_shape: torch.Size = None) -> torch.Tensor:
def rsample(self) -> torch.Tensor:
"""
Generates a sample from the distribution using the Gaussian-soft max topK trick.
Args:
- sample_shape (torch.Size): The shape of the sample.
Returns:
- torch.Tensor: A sample from the distribution.
:return: A sample from the distribution.
:rtype: torch.Tensor
"""
if sample_shape is None:
sample_shape = torch.Size([self.K])

G = self.gumbel.rsample(sample_shape=self.a.shape)
_, idxs = torch.topk(G + torch.log(self.a), k = self.K)
return self.supp.reshape(-1)[idxs].reshape(shape=sample_shape)
top_k_logits = torch.zeros_like(self.a)
logits = torch.clone(self.a)
for _ in range(self.K):
top1_gumbel = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
top_k_logits += top1_gumbel
logits -= top1_gumbel * 1e10 # mask the selected entry

def sample(self, sample_shape: torch.Size = None) -> torch.Tensor:
"""
Generates a sample from the distribution.
return top_k_logits

Args:
- sample_shape (torch.Size): The shape of the sample.
def sample(self) -> torch.Tensor:
"""
Generates a sample from the distribution with no grad.
Returns:
- torch.Tensor: A sample from the distribution.
:return: A sample from the distribution.
:rtype: torch.Tensor
"""
with torch.no_grad():
return self.rsample(sample_shape)
return self.rsample()

def log_prob(self, value: torch.Tensor, shape: torch.Size = torch.Size([1])) -> torch.Tensor:
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Computes the log probability of the given value.
Args:
- value (Tensor): The value for which to compute the log probability.
- shape(torch.Size): The shape of the output
Returns:
- torch.Tensor: The log probability of the given value.
:param value: The value for which to compute the log probability.
:type value: torch.Tensor
:return: The log probability of the given value.
:rtype: torch.Tensor
"""
if self._validate_args:
self._validate_sample(value)

idx = (self.supp.reshape(-1) == value).nonzero().squeeze()

return torch.log(self.a.reshape(-1)[idx]).reshape(shape=shape)
return torch.log((self.a * value).sum() / self.a.sum())

def _validate_sample(self, value: torch.Tensor):
"""
Expand All @@ -117,5 +114,7 @@ def _validate_sample(self, value: torch.Tensor):
- value (Tensor): The sample value to validate.
"""
if self._validate_args:
if value not in self.supp:
raise ValueError("Sample value must be in the support")
if self.hard and ((value != 1.) & (value != 0.)).any():
ValueError(f"If `self.hard` is `True`, then all coordinates in `value` must be 0 or 1 and you have {value}")
if not self.hard and (value < 0).any():
ValueError(f"If `self.hard` is `False`, then all coordinates in `value` must be >= 0 and you have {value}")
34 changes: 25 additions & 9 deletions tests/distributions/test_GumbelSoftmaxTopK.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,32 @@
# Testing reparameterized sampling and log prob from the GumbelSoftmaxTopK distribution

def test_sample_shape():
a = torch.tensor([[1, 2, 3]])
distribution = GumbelSoftmaxTopK(a, K=2)
a = torch.tensor([1., 2., 3., 4., 5.])
K = torch.tensor(2)
tau = torch.tensor(0.1)
distribution = GumbelSoftmaxTopK(a, K=K, tau=tau)
sample = distribution.rsample()
assert sample.shape == torch.Size([2])
print("$")
assert sample.shape == a.shape

def test_sample_grad():
a = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
K = torch.tensor(2)
tau = torch.tensor(0.1)
distribution = GumbelSoftmaxTopK(a, K=K, tau=tau)
sample = distribution.rsample()
assert sample.requires_grad == True

def test_log_prob():
a = torch.tensor([1., 2., 3.])
distribution = GumbelSoftmaxTopK(a, K=1)
value = 1
log_prob_my = distribution.log_prob(value, shape=torch.Size([]))
log_prob_true = torch.log(a[value] / 6.)
assert log_prob_my - log_prob_true < 1e-6
K = torch.tensor(3)
tau = torch.tensor(0.1)
distribution = GumbelSoftmaxTopK(a, K=K, tau=tau)
sample = distribution.rsample()
value = torch.tensor([1., 1., 1.])
log_prob = distribution.log_prob(value)
assert log_prob - torch.tensor(0) < 1e-6

if __name__ == "__main__":
test_sample_shape()
test_sample_grad()
test_log_prob()

0 comments on commit 1430504

Please sign in to comment.