From 6d917dd579825a8f1458fb64524c1b1e010a4f98 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Fri, 3 May 2024 18:47:54 -0500 Subject: [PATCH] small changes --- crates/luminal_metal/src/tests/fp16.rs | 2 +- src/hl_ops/movement.rs | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/crates/luminal_metal/src/tests/fp16.rs b/crates/luminal_metal/src/tests/fp16.rs index 6ddf538a..91c2fd57 100644 --- a/crates/luminal_metal/src/tests/fp16.rs +++ b/crates/luminal_metal/src/tests/fp16.rs @@ -770,7 +770,7 @@ fn test_movement() { let data = random_vec(32); let mut cx = Graph::new(); let a = cx.tensor::>().set(data.clone()); - let b: GraphTensor> = a.pad(&[(0, 10)]).contiguous().retrieve(); + let b: GraphTensor> = a.pad((0, 10)).contiguous().retrieve(); let mut c: GraphTensor> = b .slice((..Expression::from(25),)) .realize() diff --git a/src/hl_ops/movement.rs b/src/hl_ops/movement.rs index 40a83e3d..75a7250b 100644 --- a/src/hl_ops/movement.rs +++ b/src/hl_ops/movement.rs @@ -145,12 +145,11 @@ impl GraphTensor { if n_dims > 1 { // View as single dimension of matrix with wider width let mat_size = (dim_size.big() + stride.big()) * number_of_windows.big(); - let actual_size = - dim_size.big() * self.shape.dims[self.shape.indexes[n_dims - 1]].big(); + let actual_size = dim_size * self.shape.dims[self.shape.indexes[n_dims - 1]]; // Reshape into single dimension to pad self.shape.remove_dim(n_dims); - self.shape.dims[self.shape.indexes[n_dims - 1]] = actual_size.small(); - self.shape.padding[self.shape.indexes[n_dims - 1]].1 = (mat_size - actual_size).small(); + self.shape.dims[self.shape.indexes[n_dims - 1]] = actual_size; + self.shape.padding[self.shape.indexes[n_dims - 1]].1 = mat_size.small() - actual_size; self = self.contiguous(); // Reshape back (mats should be full now) self.shape.add_dim(n_dims, dim_size + stride);