diff --git a/README.md b/README.md index 897c7d2c..d89e50d9 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ use luminal::prelude::*; // Setup graph and tensors let mut cx = Graph::new(); let a = cx.tensor((3, 1)).set([[1.0], [2.0], [3.0]]); -let b = cx.tensor((1, 4).set([[1.0, 2.0, 3.0, 4.0]]); +let b = cx.tensor((1, 4)).set([[1.0, 2.0, 3.0, 4.0]]); // Do math... let mut c = a.matmul(b).retrieve(); diff --git a/src/hl_ops/binary.rs b/src/hl_ops/binary.rs index bd719613..88424464 100644 --- a/src/hl_ops/binary.rs +++ b/src/hl_ops/binary.rs @@ -282,9 +282,9 @@ impl GraphTensor { -(-self).max_f32(-rhs) } - /// Clip a tensor in a range + /// Clip (clamp) a tensor into the range [`min`, `max`] pub fn clip(self, min: f32, max: f32) -> GraphTensor { - self.min_f32(min).max_f32(max) + self.max_f32(min).min_f32(max) } } @@ -297,3 +297,24 @@ impl F32Pow for f32 { e.mul(self.abs().ln()).exp().recip() } } + +#[cfg(test)] +mod tests { + crate::test_imports!(); + + #[test] + fn test_clip() { + let mut cx = Graph::new(); + let a = cx + .tensor((3, 2)) + .set([[[-1.0], [-2.0], [3.0]], [[-1.5], [0.0], [5.0]]]); + let result = a.clip(-1.5, 3.4).retrieve(); + let expected_result = cx + .tensor((3, 2)) + .set([[[-1.0], [-1.5], [3.0]], [[-1.5], [0.0], [3.4]]]) + .retrieve(); + cx.execute(); + + assert_close(&result.data(), &expected_result.data()); + } +}