diff --git a/src/relaxit/distributions/InvertibleGaussian.py b/src/relaxit/distributions/InvertibleGaussian.py index 6452c63..63e0460 100644 --- a/src/relaxit/distributions/InvertibleGaussian.py +++ b/src/relaxit/distributions/InvertibleGaussian.py @@ -122,7 +122,7 @@ def log_prob(self, value): log_det_jacobian = - (K - 1) * torch.log(self.temperature).item() + torch.sum(torch.log(g), dim=-1, keepdim=True) + torch.log(residual) # Adjust the log probability by the Jacobian determinant - log_prob = log_prob_normal + log_det_jacobian + log_prob = log_prob_normal - log_det_jacobian return log_prob