From 5f730aef1f6fae73a5be3ef1ece3afcb1b1dcb99 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Fri, 3 May 2024 19:00:35 -0500 Subject: [PATCH] Changed arcmax --- crates/luminal_metal/src/tests/fp16.rs | 1 - src/hl_ops/unary.rs | 13 ++++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/crates/luminal_metal/src/tests/fp16.rs b/crates/luminal_metal/src/tests/fp16.rs index 91c2fd57..de1eb5df 100644 --- a/crates/luminal_metal/src/tests/fp16.rs +++ b/crates/luminal_metal/src/tests/fp16.rs @@ -800,7 +800,6 @@ fn test_slice_add() { cx.compile(MetalCompiler::::default(), &mut b); cx.execute(); - cx.display(); } #[test] diff --git a/src/hl_ops/unary.rs b/src/hl_ops/unary.rs index 05c51f2c..17b0c04c 100644 --- a/src/hl_ops/unary.rs +++ b/src/hl_ops/unary.rs @@ -154,15 +154,22 @@ impl GraphTensor { /// Get the indicies of the max elements along the last axis pub fn argmax(self) -> GraphTensor<::LastAxis>>::Reduced> { + // Get one-hot along last dimension let x_equal = self.equals(self.max_reduce::<_, S::LastAxis>().expand_to(self.shape)); - // ARange to shape + // Create index arange for last dimension let r = self .graph() .constant(1.) - .expand_to(self.shape) + .expand_to::<(Dyn<'-'>,)>(ShapeTracker::new(&[self + .shape + .shape() + .last() + .unwrap() + .small()])) .cumsum_last_dim() - 1.; - (x_equal * r).max_reduce::<_, S::LastAxis>() + // Multiply one-hot by expanded index arange + (x_equal * r.expand_to(self.shape)).max_reduce() } /// Take the absolute value