Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/jafioti/luminal
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jul 30, 2024
2 parents e1279d9 + 0db8f6c commit c47b38a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use luminal::prelude::*;
// Setup graph and tensors
let mut cx = Graph::new();
let a = cx.tensor((3, 1)).set([[1.0], [2.0], [3.0]]);
let b = cx.tensor((1, 4).set([[1.0, 2.0, 3.0, 4.0]]);
let b = cx.tensor((1, 4)).set([[1.0, 2.0, 3.0, 4.0]]);

// Do math...
let mut c = a.matmul(b).retrieve();
Expand Down
25 changes: 23 additions & 2 deletions src/hl_ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ impl GraphTensor {
-(-self).max_f32(-rhs)
}

/// Clip a tensor in a range
/// Clip (clamp) a tensor into the range [`min`, `max`]
pub fn clip(self, min: f32, max: f32) -> GraphTensor {
self.min_f32(min).max_f32(max)
self.max_f32(min).min_f32(max)
}
}

Expand All @@ -297,3 +297,24 @@ impl F32Pow for f32 {
e.mul(self.abs().ln()).exp().recip()
}
}

#[cfg(test)]
mod tests {
crate::test_imports!();

#[test]
fn test_clip() {
let mut cx = Graph::new();
let a = cx
.tensor((3, 2))
.set([[[-1.0], [-2.0], [3.0]], [[-1.5], [0.0], [5.0]]]);
let result = a.clip(-1.5, 3.4).retrieve();
let expected_result = cx
.tensor((3, 2))
.set([[[-1.0], [-1.5], [3.0]], [[-1.5], [0.0], [3.4]]])
.retrieve();
cx.execute();

assert_close(&result.data(), &expected_result.data());
}
}

0 comments on commit c47b38a

Please sign in to comment.