Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invertible Gaussian Reparameterization and Logistic-Normal approximation of Dirichlet #2

Merged
merged 3 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions basic/approx.ipynb

Large diffs are not rendered by default.

94 changes: 94 additions & 0 deletions basic/invertible_gaussian.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import sys\n",
"sys.path.append('../src')\n",
"from relaxit.distributions.InvertibleGaussian import InvertibleGaussian\n",
"from relaxit.distributions.kl import kl_divergence"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Testing reparameterized sampling from the InvertibleGaussian distribution"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"loc = torch.zeros(3, 4, 5, requires_grad=True)\n",
"scale = torch.ones(3, 4, 5, requires_grad=True)\n",
"temperature = torch.tensor([1e-0])\n",
"\n",
"distribution = InvertibleGaussian(loc, scale, temperature)\n",
"sample = distribution.rsample()\n",
"\n",
"assert sample.shape == torch.Size([3, 4, 6])\n",
"assert sample.requires_grad == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2. Testing KL-divergence between two IntertibleGaussian distributions"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"loc_1 = torch.zeros(3, 4, 5, requires_grad=True)\n",
"scale_1 = torch.ones(3, 4, 5, requires_grad=True)\n",
"temperature_1 = torch.tensor([1e-0])\n",
"\n",
"dist_1 = InvertibleGaussian(loc_1, scale_1, temperature_1)\n",
"\n",
"loc_2 = torch.ones(3, 4, 5, requires_grad=True) # ones, not zeros\n",
"scale_2 = torch.ones(3, 4, 5, requires_grad=True)\n",
"temperature_2 = torch.tensor([1e-2])\n",
"\n",
"dist_2 = InvertibleGaussian(loc_2, scale_2, temperature_2)\n",
"\n",
"div = kl_divergence(dist_1, dist_2)\n",
"\n",
"assert div.shape == torch.Size([3, 4, 5])\n",
"assert div.requires_grad == True\n",
"assert not torch.allclose(div, torch.zeros_like(div))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "nkiselev_relaxit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
119 changes: 119 additions & 0 deletions src/relaxit/distributions/InvertibleGaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import torch
from pyro.distributions.torch_distribution import TorchDistribution
from torch.distributions import constraints
from torch.distributions.utils import _standard_normal


class InvertibleGaussian(TorchDistribution):
"""
Invertible Gaussian distribution class inheriting from Pyro's TorchDistribution.

Parameters:
- loc (Tensor): The mean (mu) of the normal distribution.
- scale (Tensor): The standard deviation (sigma) of the normal distribution.
"""

arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
support = constraints.real
has_rsample = True

def __init__(self, loc, scale, temperature, validate_args: bool = None):
"""
Initializes the Invertible Gaussian distribution.

Args:
- loc (Tensor): Mean of the normal distribution.
- scale (Tensor): Standard deviation of the normal distribution.
- validate_args (bool): Whether to validate arguments.

The batch shape is inferred from the shape of the parameters (loc and scale),
meaning it defines how many independent distributions are parameterized.
"""
self.loc = loc
self.scale = scale
self.temperature = temperature
batch_shape = torch.Size() if loc.dim() == 0 else loc.shape
super().__init__(batch_shape, validate_args=validate_args)

@property
def batch_shape(self):
"""
Returns the batch shape of the distribution.

The batch shape represents the shape of independent distributions.
"""
return self.loc.shape

@property
def event_shape(self):
"""
Returns the event shape of the distribution.

The event shape represents the shape of each individual event.
"""
return torch.Size()

def softmax_plus_plus(self, y, delta=1):
"""
Compute the softmax++ function.

Args:
y (torch.Tensor): Input tensor of shape (batch_size, num_classes).
tau (float): Temperature parameter.
delta (float): Additional term delta > 0.

Returns:
torch.Tensor: Output tensor of the same shape as y.
"""
# Scale the input by the temperature
scaled_y = y / self.temperature

# Compute the exponentials
exp_y = torch.exp(scaled_y)

# Compute the denominator
denominator = torch.sum(exp_y, dim=-1, keepdim=True) + delta

# Compute the softmax++
softmax_pp = exp_y / denominator

return softmax_pp

def rsample(self, sample_shape=torch.Size()):
"""
Generates a sample from the distribution using the reparameterization trick.

Args:
- sample_shape (torch.Size): The shape of the generated samples.
"""
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
y = self.loc + self.scale * eps
g = self.softmax_plus_plus(y)
residual = 1 - torch.sum(g, dim=-1, keepdim=True)
return torch.cat([g, residual], dim=-1)

# def log_prob(self, value):
# """
# Computes the log likelihood of a value.

# Args:
# - value (Tensor): The value for which to compute the log probability.
# """
# var = self.scale ** 2
# log_scale = torch.log(self.scale)
# log_prob_norm = -((value - self.loc) ** 2) / (2 * var) - log_scale - 0.5 * torch.log(torch.tensor(2.0 * torch.pi, device=value.device))


def _validate_sample(self, value: torch.Tensor):
"""
Validates the given sample value.

Args:
- value (Tensor): The sample value to validate.
"""
if self._validate_args:
if not (value >= 0).all() or not (value <= 1).all():
raise ValueError("Sample value must be in the range [0, 1]")


42 changes: 42 additions & 0 deletions src/relaxit/distributions/LogisticNormalSoftmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from pyro.distributions import constraints, Normal
from pyro.distributions.torch import TransformedDistribution
from pyro.distributions.transforms import SoftmaxTransform

# We implement LogisticNormal distribution with SoftmaxTransform instead of
# StickBreakingTransform, which is originally applied in the PyTorch and Pyro
class LogisticNormalSoftmax(TransformedDistribution):
r"""
Creates a logistic-normal distribution parameterized by :attr:`loc` and :attr:`scale`
that define the base `Normal` distribution transformed with the
`SoftmaxTransform` such that::

X ~ LogisticNormal(loc, scale)
Y = Logistic(X) ~ Normal(loc, scale)

Args:
loc (float or Tensor): mean of the base distribution
scale (float or Tensor): standard deviation of the base distribution
"""
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.simplex
has_rsample = True

def __init__(self, loc, scale, validate_args=None):
base_dist = Normal(loc, scale, validate_args=validate_args)
if not base_dist.batch_shape:
base_dist = base_dist.expand([1])
super().__init__(
base_dist, SoftmaxTransform(), validate_args=validate_args
)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LogisticNormal, _instance)
return super().expand(batch_shape, _instance=new)

@property
def loc(self):
return self.base_dist.base_dist.loc

@property
def scale(self):
return self.base_dist.base_dist.scale
11 changes: 10 additions & 1 deletion src/relaxit/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,14 @@
from .CorrelatedRelaxedBernoulli import CorrelatedRelaxedBernoulli
from .StraightThroughBernoulli import StraightThroughBernoulli
from .HardConcrete import HardConcrete
from .InvertibleGaussian import InvertibleGaussian
from .LogisticNormalSoftmax import LogisticNormalSoftmax

__all__ = ["GaussianRelaxedBernoulli", "CorrelatedRelaxedBernoulli", "StraightThroughBernoulli", "HardConcrete"]
__all__ = [
"GaussianRelaxedBernoulli",
"CorrelatedRelaxedBernoulli",
"StraightThroughBernoulli",
"HardConcrete",
"InvertibleGaussian",
"LogisticNormalSoftmax",
]
50 changes: 50 additions & 0 deletions src/relaxit/distributions/approx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
from .LogisticNormalSoftmax import LogisticNormalSoftmax
from pyro.distributions import Dirichlet


def lognorm_approximation_fn(dirichlet_distribution: Dirichlet) -> LogisticNormalSoftmax:
"""
Approximates a Dirichlet distribution with a LogisticNormalSoftmax distribution.

Args:
dirichlet_distribution (Dirichlet): The Dirichlet distribution to approximate.

Returns:
LogisticNormalSoftmax: The approximated LogisticNormalSoftmax distribution.
"""
concentration = dirichlet_distribution.concentration
num_events = torch.tensor(dirichlet_distribution.event_shape, dtype=torch.float)

# Compute the location parameter (mu)
loc = concentration.log() - (1 / num_events) * concentration.log().sum(-1).unsqueeze(-1)

# Compute the scale parameter (sigma)
scale = 1 / concentration - (1 / num_events) * (2 / concentration - (1 / num_events) * (1 / concentration).sum(-1).unsqueeze(-1))

# Create the LogisticNormalSoftmax distribution
lognorm_approximation = LogisticNormalSoftmax(loc, scale)

return lognorm_approximation


def dirichlet_approximation_fn(lognorm_distribution: LogisticNormalSoftmax) -> Dirichlet:
"""
Approximates a LogisticNormalSoftmax distribution with a Dirichlet distribution.

Args:
lognorm_distribution (LogisticNormalSoftmax): The LogisticNormalSoftmax distribution to approximate.

Returns:
Dirichlet: The approximated Dirichlet distribution.
"""
num_events = torch.tensor(lognorm_distribution.event_shape, dtype=torch.float)
loc, scale = lognorm_distribution.loc, lognorm_distribution.scale

# Compute the concentration parameter (alpha)
concentration = (1 / scale) * (1 - 2 / num_events + loc.exp() / (num_events ** 2) * loc.neg().exp().sum(-1).unsqueeze(-1))

# Create the Dirichlet distribution
dirichlet_approximation = Dirichlet(concentration)

return dirichlet_approximation
13 changes: 13 additions & 0 deletions src/relaxit/distributions/kl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torch.distributions import (
kl_divergence,
register_kl,
Normal
)

from .InvertibleGaussian import InvertibleGaussian

@register_kl(InvertibleGaussian, InvertibleGaussian)
def _kl_igr_igr(p, q):
p_normal = Normal(p.loc, p.scale)
q_normal = Normal(q.loc, q.scale)
return kl_divergence(p_normal, q_normal)
Loading