From 7397efce7316ee9a869bbf8bf6a225b9b00c75e9 Mon Sep 17 00:00:00 2001 From: vepricov Date: Tue, 26 Nov 2024 01:27:51 +0300 Subject: [PATCH] fix topk --- demo/reinforce.py | 1 + .../distributions/GumbelSoftmaxTopK.py | 121 +++++++++--------- tests/distributions/test_GumbelSoftmaxTopK.py | 34 +++-- 3 files changed, 86 insertions(+), 70 deletions(-) diff --git a/demo/reinforce.py b/demo/reinforce.py index 2729da3..3a9f7d0 100644 --- a/demo/reinforce.py +++ b/demo/reinforce.py @@ -55,6 +55,7 @@ def select_action(state): probs = policy(state) m = Categorical(probs) action = m.sample() + print(m.log_prob(action)) policy.saved_log_probs.append(m.log_prob(action)) return action.item() diff --git a/src/relaxit/distributions/GumbelSoftmaxTopK.py b/src/relaxit/distributions/GumbelSoftmaxTopK.py index c86e253..50365c7 100644 --- a/src/relaxit/distributions/GumbelSoftmaxTopK.py +++ b/src/relaxit/distributions/GumbelSoftmaxTopK.py @@ -1,44 +1,49 @@ import torch +import torch.nn.functional as F from torch.distributions import constraints from pyro.distributions.torch_distribution import TorchDistribution -from pyro.distributions import Gumbel from torch.distributions import constraints class GumbelSoftmaxTopK(TorchDistribution): - """ + r''' Implimentation of the Gaussian-soft max topK trick from https://arxiv.org/pdf/1903.06059 - :param a: logits, if not from Simples, we project a into it. + :param a: logits. :type a: torch.Tensor :param K: how many samples without replacement to pick. - :type K: int - :param support: support of the discrete distribution. If None, it will be `torch.arange(a.numel()).reshape(a.shape)`. It must be the same `shape` as `a`. - :type support: torch.Tensor - """ - - arg_constraints = {'a': constraints.real} + :type K: torch.Tensor + :param tau: Temperature hyper-parameter. + :type tau: torch.Tensor + :param hard: if `True`, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd + :type hard: bool + :param validate_args: Whether to validate arguments. + :type validate_args: bool + ''' + + arg_constraints = {'a' : constraints.real, + 'K' : constraints.positive_integer, + 'tau': constraints.positive} has_rsample = True - def __init__(self, a: torch.Tensor, K: int, - support: torch.Tensor = None, validate_args: bool = None): - """ - Initializes the GumbelSoftmaxTopK distribution. - - Args: - - a (Tensor): logits, if not from Simples, we project a into it - - K (int): how many samples without replacement to pick - - support (Tensor): support of the discrete distribution. If None, it will be `torch.range(len(a))`. It must be the same `len` as `a`. - - validate_args (bool): Whether to validate arguments. - """ - self.a = a.float() / a.sum() # Ensure loc is a float tensor from simplex - self.gumbel = Gumbel(loc=0, scale=1, validate_args=validate_args) - if support is None: - self.supp = torch.arange(a.numel()).reshape(a.shape) - else: - if support.shape != a.shape: - raise ValueError("support and a must have the same shape") - self.supp = support - self.K = int(K) # Ensure K is a int number + def __init__(self, a: torch.Tensor, K: torch.Tensor, + tau: torch.Tensor, hard: bool = True, validate_args: bool = None): + r''' Initializes the GumbelSoftmaxTopK distribution. + + :param a: logits. + :type a: torch.Tensor + :param K: how many samples without replacement to pick. + :type K: torch.Tensor + :param tau: Temperature hyper-parameter. + :type tau: torch.Tensor + :param hard: if `True`, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd + :type hard: bool + :param validate_args: Whether to validate arguments. + :type validate_args: bool + ''' + self.a = a.float() # Ensure loc is a float tensor + self.K = K.int() # Ensure K is a int tensor + self.tau = tau + self.hard = hard super().__init__(validate_args=validate_args) @property @@ -61,53 +66,45 @@ def event_shape(self) -> torch.Size: """ return torch.Size() - def rsample(self, sample_shape: torch.Size = None) -> torch.Tensor: + def rsample(self) -> torch.Tensor: """ Generates a sample from the distribution using the Gaussian-soft max topK trick. - Args: - - sample_shape (torch.Size): The shape of the sample. - - Returns: - - torch.Tensor: A sample from the distribution. + :return: A sample from the distribution. + :rtype: torch.Tensor """ - if sample_shape is None: - sample_shape = torch.Size([self.K]) - G = self.gumbel.rsample(sample_shape=self.a.shape) - _, idxs = torch.topk(G + torch.log(self.a), k = self.K) - return self.supp.reshape(-1)[idxs].reshape(shape=sample_shape) + top_k_logits = torch.zeros_like(self.a) + logits = torch.clone(self.a) + for _ in range(self.K): + top1_gumbel = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard) + top_k_logits += top1_gumbel + logits -= top1_gumbel * 1e10 # mask the selected entry - def sample(self, sample_shape: torch.Size = None) -> torch.Tensor: - """ - Generates a sample from the distribution. + return top_k_logits - Args: - - sample_shape (torch.Size): The shape of the sample. + def sample(self) -> torch.Tensor: + """ + Generates a sample from the distribution with no grad. - Returns: - - torch.Tensor: A sample from the distribution. + :return: A sample from the distribution. + :rtype: torch.Tensor """ with torch.no_grad(): - return self.rsample(sample_shape) + return self.rsample() - def log_prob(self, value: torch.Tensor, shape: torch.Size = torch.Size([1])) -> torch.Tensor: + def log_prob(self, value: torch.Tensor) -> torch.Tensor: """ Computes the log probability of the given value. - - Args: - - value (Tensor): The value for which to compute the log probability. - - shape(torch.Size): The shape of the output - - Returns: - - torch.Tensor: The log probability of the given value. + :param value: The value for which to compute the log probability. + :type value: torch.Tensor + :return: The log probability of the given value. + :rtype: torch.Tensor """ if self._validate_args: self._validate_sample(value) - idx = (self.supp.reshape(-1) == value).nonzero().squeeze() - - return torch.log(self.a.reshape(-1)[idx]).reshape(shape=shape) + return torch.log((self.a * value).sum() / self.a.sum()) def _validate_sample(self, value: torch.Tensor): """ @@ -117,5 +114,7 @@ def _validate_sample(self, value: torch.Tensor): - value (Tensor): The sample value to validate. """ if self._validate_args: - if value not in self.supp: - raise ValueError("Sample value must be in the support") \ No newline at end of file + if self.hard and ((value != 1.) & (value != 0.)).any(): + ValueError(f"If `self.hard` is `True`, then all coordinates in `value` must be 0 or 1 and you have {value}") + if not self.hard and (value < 0).any(): + ValueError(f"If `self.hard` is `False`, then all coordinates in `value` must be >= 0 and you have {value}") \ No newline at end of file diff --git a/tests/distributions/test_GumbelSoftmaxTopK.py b/tests/distributions/test_GumbelSoftmaxTopK.py index 35ddaa6..9160bcb 100644 --- a/tests/distributions/test_GumbelSoftmaxTopK.py +++ b/tests/distributions/test_GumbelSoftmaxTopK.py @@ -6,16 +6,32 @@ # Testing reparameterized sampling and log prob from the GumbelSoftmaxTopK distribution def test_sample_shape(): - a = torch.tensor([[1, 2, 3]]) - distribution = GumbelSoftmaxTopK(a, K=2) + a = torch.tensor([1., 2., 3., 4., 5.]) + K = torch.tensor(2) + tau = torch.tensor(0.1) + distribution = GumbelSoftmaxTopK(a, K=K, tau=tau) sample = distribution.rsample() - assert sample.shape == torch.Size([2]) - print("$") + assert sample.shape == a.shape + +def test_sample_grad(): + a = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True) + K = torch.tensor(2) + tau = torch.tensor(0.1) + distribution = GumbelSoftmaxTopK(a, K=K, tau=tau) + sample = distribution.rsample() + assert sample.requires_grad == True def test_log_prob(): a = torch.tensor([1., 2., 3.]) - distribution = GumbelSoftmaxTopK(a, K=1) - value = 1 - log_prob_my = distribution.log_prob(value, shape=torch.Size([])) - log_prob_true = torch.log(a[value] / 6.) - assert log_prob_my - log_prob_true < 1e-6 \ No newline at end of file + K = torch.tensor(3) + tau = torch.tensor(0.1) + distribution = GumbelSoftmaxTopK(a, K=K, tau=tau) + sample = distribution.rsample() + value = torch.tensor([1., 1., 1.]) + log_prob = distribution.log_prob(value) + assert log_prob - torch.tensor(0) < 1e-6 + +if __name__ == "__main__": + test_sample_shape() + test_sample_grad() + test_log_prob() \ No newline at end of file