Skip to content

Commit

Permalink
Update distributions.py
Browse files Browse the repository at this point in the history
fix Student Distribution
  • Loading branch information
kreininmv authored Dec 2, 2024
1 parent ce069e0 commit 1ad8325
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/irt/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ def __init__(
validate_args (Optional[bool]): If True, validates distribution parameters.
"""
self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
self.gamma = Gamma(self.df * 0.5, self.df * 0.5)
batch_shape = self.df.size()
super().__init__(batch_shape, validate_args=validate_args)

Expand Down Expand Up @@ -505,11 +504,11 @@ def rsample(self, sample_shape: _size = default_size) -> torch.Tensor:
"""
self.loc = self.loc.expand(self._extended_shape(sample_shape))
self.scale = self.scale.expand(self._extended_shape(sample_shape))

sigma = self.gamma.rsample()

gamma_samples = Gamma(self.df * 0.5, self.df * 0.5).rsample(sample_shape)
normal_samples = Normal(0., 1.).sample(sample_shape)
# Sample from Normal distribution (shape must match after broadcasting)
x = self.loc + self.scale * Normal(0, sigma).rsample(sample_shape)
x = self.loc.detach() + self.scale.detach() * normal_samples * torch.rsqrt(gamma_samples)

transform = self._transform(x.detach()) # Standardize the sample
surrogate_x = -transform / self._d_transform_d_z().detach() # Compute surrogate gradient
Expand Down

0 comments on commit 1ad8325

Please sign in to comment.