From 831002ddb15ce1ae1897d35e815d0498f93c9256 Mon Sep 17 00:00:00 2001 From: kisnikser Date: Tue, 26 Nov 2024 12:58:42 +0300 Subject: [PATCH] fix sign before logdet in logprob --- src/relaxit/distributions/InvertibleGaussian.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relaxit/distributions/InvertibleGaussian.py b/src/relaxit/distributions/InvertibleGaussian.py index 6452c63..63e0460 100644 --- a/src/relaxit/distributions/InvertibleGaussian.py +++ b/src/relaxit/distributions/InvertibleGaussian.py @@ -122,7 +122,7 @@ def log_prob(self, value): log_det_jacobian = - (K - 1) * torch.log(self.temperature).item() + torch.sum(torch.log(g), dim=-1, keepdim=True) + torch.log(residual) # Adjust the log probability by the Jacobian determinant - log_prob = log_prob_normal + log_det_jacobian + log_prob = log_prob_normal - log_det_jacobian return log_prob