diff --git a/src/hl_ops/binary.rs b/src/hl_ops/binary.rs index eff44f13..7070c315 100644 --- a/src/hl_ops/binary.rs +++ b/src/hl_ops/binary.rs @@ -278,9 +278,9 @@ impl GraphTensor { -(-self).max_f32(-rhs) } - /// Clip a tensor in a range + /// Clip (clamp) a tensor into the range [`min`, `max`] pub fn clip(self, min: f32, max: f32) -> GraphTensor { - self.min_f32(min).max_f32(max) + self.max_f32(min).min_f32(max) } } @@ -293,3 +293,24 @@ impl F32Pow for f32 { e.mul(self.abs().ln()).exp().recip() } } + +#[cfg(test)] +mod tests { + crate::test_imports!(); + + #[test] + fn test_clip() { + let mut cx = Graph::new(); + let a = cx + .tensor((3, 2)) + .set([[[-1.0], [-2.0], [3.0]], [[-1.5], [0.0], [5.0]]]); + let result = a.clip(-1.5, 3.4).retrieve(); + let expected_result = cx + .tensor((3, 2)) + .set([[[-1.0], [-1.5], [3.0]], [[-1.5], [0.0], [3.4]]]) + .retrieve(); + cx.execute(); + + assert_close(&result.data(), &expected_result.data()); + } +}