Skip to content

Commit

Permalink
fix topk
Browse files Browse the repository at this point in the history
  • Loading branch information
Vepricov committed Nov 25, 2024
1 parent 3853a11 commit 7397efc
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 70 deletions.
1 change: 1 addition & 0 deletions demo/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def select_action(state):
probs = policy(state)
m = Categorical(probs)
action = m.sample()
print(m.log_prob(action))
policy.saved_log_probs.append(m.log_prob(action))
return action.item()

Expand Down
121 changes: 60 additions & 61 deletions src/relaxit/distributions/GumbelSoftmaxTopK.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,49 @@
import torch
import torch.nn.functional as F
from torch.distributions import constraints
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions import Gumbel
from torch.distributions import constraints

class GumbelSoftmaxTopK(TorchDistribution):
"""
r'''
Implimentation of the Gaussian-soft max topK trick from https://arxiv.org/pdf/1903.06059
:param a: logits, if not from Simples, we project a into it.
:param a: logits.
:type a: torch.Tensor
:param K: how many samples without replacement to pick.
:type K: int
:param support: support of the discrete distribution. If None, it will be `torch.arange(a.numel()).reshape(a.shape)`. It must be the same `shape` as `a`.
:type support: torch.Tensor
"""

arg_constraints = {'a': constraints.real}
:type K: torch.Tensor
:param tau: Temperature hyper-parameter.
:type tau: torch.Tensor
:param hard: if `True`, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd
:type hard: bool
:param validate_args: Whether to validate arguments.
:type validate_args: bool
'''

arg_constraints = {'a' : constraints.real,
'K' : constraints.positive_integer,
'tau': constraints.positive}
has_rsample = True

def __init__(self, a: torch.Tensor, K: int,
support: torch.Tensor = None, validate_args: bool = None):
"""
Initializes the GumbelSoftmaxTopK distribution.
Args:
- a (Tensor): logits, if not from Simples, we project a into it
- K (int): how many samples without replacement to pick
- support (Tensor): support of the discrete distribution. If None, it will be `torch.range(len(a))`. It must be the same `len` as `a`.
- validate_args (bool): Whether to validate arguments.
"""
self.a = a.float() / a.sum() # Ensure loc is a float tensor from simplex
self.gumbel = Gumbel(loc=0, scale=1, validate_args=validate_args)
if support is None:
self.supp = torch.arange(a.numel()).reshape(a.shape)
else:
if support.shape != a.shape:
raise ValueError("support and a must have the same shape")
self.supp = support
self.K = int(K) # Ensure K is a int number
def __init__(self, a: torch.Tensor, K: torch.Tensor,
tau: torch.Tensor, hard: bool = True, validate_args: bool = None):
r''' Initializes the GumbelSoftmaxTopK distribution.
:param a: logits.
:type a: torch.Tensor
:param K: how many samples without replacement to pick.
:type K: torch.Tensor
:param tau: Temperature hyper-parameter.
:type tau: torch.Tensor
:param hard: if `True`, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd
:type hard: bool
:param validate_args: Whether to validate arguments.
:type validate_args: bool
'''
self.a = a.float() # Ensure loc is a float tensor
self.K = K.int() # Ensure K is a int tensor
self.tau = tau
self.hard = hard
super().__init__(validate_args=validate_args)

@property
Expand All @@ -61,53 +66,45 @@ def event_shape(self) -> torch.Size:
"""
return torch.Size()

def rsample(self, sample_shape: torch.Size = None) -> torch.Tensor:
def rsample(self) -> torch.Tensor:
"""
Generates a sample from the distribution using the Gaussian-soft max topK trick.
Args:
- sample_shape (torch.Size): The shape of the sample.
Returns:
- torch.Tensor: A sample from the distribution.
:return: A sample from the distribution.
:rtype: torch.Tensor
"""
if sample_shape is None:
sample_shape = torch.Size([self.K])

G = self.gumbel.rsample(sample_shape=self.a.shape)
_, idxs = torch.topk(G + torch.log(self.a), k = self.K)
return self.supp.reshape(-1)[idxs].reshape(shape=sample_shape)
top_k_logits = torch.zeros_like(self.a)
logits = torch.clone(self.a)
for _ in range(self.K):
top1_gumbel = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
top_k_logits += top1_gumbel
logits -= top1_gumbel * 1e10 # mask the selected entry

def sample(self, sample_shape: torch.Size = None) -> torch.Tensor:
"""
Generates a sample from the distribution.
return top_k_logits

Args:
- sample_shape (torch.Size): The shape of the sample.
def sample(self) -> torch.Tensor:
"""
Generates a sample from the distribution with no grad.
Returns:
- torch.Tensor: A sample from the distribution.
:return: A sample from the distribution.
:rtype: torch.Tensor
"""
with torch.no_grad():
return self.rsample(sample_shape)
return self.rsample()

def log_prob(self, value: torch.Tensor, shape: torch.Size = torch.Size([1])) -> torch.Tensor:
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Computes the log probability of the given value.
Args:
- value (Tensor): The value for which to compute the log probability.
- shape(torch.Size): The shape of the output
Returns:
- torch.Tensor: The log probability of the given value.
:param value: The value for which to compute the log probability.
:type value: torch.Tensor
:return: The log probability of the given value.
:rtype: torch.Tensor
"""
if self._validate_args:
self._validate_sample(value)

idx = (self.supp.reshape(-1) == value).nonzero().squeeze()

return torch.log(self.a.reshape(-1)[idx]).reshape(shape=shape)
return torch.log((self.a * value).sum() / self.a.sum())

def _validate_sample(self, value: torch.Tensor):
"""
Expand All @@ -117,5 +114,7 @@ def _validate_sample(self, value: torch.Tensor):
- value (Tensor): The sample value to validate.
"""
if self._validate_args:
if value not in self.supp:
raise ValueError("Sample value must be in the support")
if self.hard and ((value != 1.) & (value != 0.)).any():
ValueError(f"If `self.hard` is `True`, then all coordinates in `value` must be 0 or 1 and you have {value}")
if not self.hard and (value < 0).any():
ValueError(f"If `self.hard` is `False`, then all coordinates in `value` must be >= 0 and you have {value}")
34 changes: 25 additions & 9 deletions tests/distributions/test_GumbelSoftmaxTopK.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,32 @@
# Testing reparameterized sampling and log prob from the GumbelSoftmaxTopK distribution

def test_sample_shape():
a = torch.tensor([[1, 2, 3]])
distribution = GumbelSoftmaxTopK(a, K=2)
a = torch.tensor([1., 2., 3., 4., 5.])
K = torch.tensor(2)
tau = torch.tensor(0.1)
distribution = GumbelSoftmaxTopK(a, K=K, tau=tau)
sample = distribution.rsample()
assert sample.shape == torch.Size([2])
print("$")
assert sample.shape == a.shape

def test_sample_grad():
a = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
K = torch.tensor(2)
tau = torch.tensor(0.1)
distribution = GumbelSoftmaxTopK(a, K=K, tau=tau)
sample = distribution.rsample()
assert sample.requires_grad == True

def test_log_prob():
a = torch.tensor([1., 2., 3.])
distribution = GumbelSoftmaxTopK(a, K=1)
value = 1
log_prob_my = distribution.log_prob(value, shape=torch.Size([]))
log_prob_true = torch.log(a[value] / 6.)
assert log_prob_my - log_prob_true < 1e-6
K = torch.tensor(3)
tau = torch.tensor(0.1)
distribution = GumbelSoftmaxTopK(a, K=K, tau=tau)
sample = distribution.rsample()
value = torch.tensor([1., 1., 1.])
log_prob = distribution.log_prob(value)
assert log_prob - torch.tensor(0) < 1e-6

if __name__ == "__main__":
test_sample_shape()
test_sample_grad()
test_log_prob()

0 comments on commit 7397efc

Please sign in to comment.