Skip to content

Commit

Permalink
Fixed metal
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed May 7, 2024
1 parent a5d292a commit 07ab3b3
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 16 deletions.
1 change: 1 addition & 0 deletions crates/luminal_metal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ impl DispatchNElements for ComputeCommandEncoderRef {
}
}

#[allow(dead_code)]
trait SetInt {
fn set_i32(&self, index: usize, value: i32);
fn set_u32(&self, index: usize, value: u32);
Expand Down
4 changes: 2 additions & 2 deletions crates/luminal_metal/src/storage_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl Compiler for StorageBufferCompiler {
// Assign output buffers
for required_buffer in wrapper.output_buffer_sizes(&input_shapes) {
// Find an applicable buffer
if let Some((buffer_index, source_node, _)) = first_pass[&node]
if let Some((buffer_index, source_node, _)) = first_pass[node]
.1
.iter()
.filter(|i| !graph.no_delete.contains(i))
Expand Down Expand Up @@ -171,7 +171,7 @@ impl Compiler for StorageBufferCompiler {
// Assign intermediate buffers
for required_buffer in wrapper.intermediate_buffer_sizes(&input_shapes) {
// Find an applicable buffer
if let Some((buffer_index, source_node, _)) = first_pass[&node]
if let Some((buffer_index, source_node, _)) = first_pass[node]
.1
.iter()
.filter(|i| !graph.no_delete.contains(i))
Expand Down
20 changes: 6 additions & 14 deletions crates/luminal_metal/src/tests/fp32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,7 @@ fn test_pool_1d_dims() {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16.,
]);
// Stride 1
let out1 = inp1
.pool_last_dim::<R3<4, 2, 3>>(3.into(), 1.into(), 0)
.retrieve();
let out1 = inp1.pool_last_dim::<R3<4, 2, 3>>(3, 1, 0).retrieve();

cx.execute();

Expand All @@ -425,10 +423,10 @@ fn test_pool_2d() {
let out1 = inp1
// Pool first dim first by moving it to end
.permute::<_, LAxes2<1, 0>>()
.pool_last_dim::<R3<4, 2, 3>>(3.into(), 1.into(), 0)
.pool_last_dim::<R3<4, 2, 3>>(3, 1, 0)
// Now move other dim to end
.permute::<_, LAxes3<1, 2, 0>>()
.pool_last_dim::<R4<2, 3, 2, 3>>(3.into(), 1.into(), 0)
.pool_last_dim::<R4<2, 3, 2, 3>>(3, 1, 0)
// Now swap middle two dims
.permute::<_, LAxes4<0, 2, 1, 3>>()
// Now merge both pooled dimensions
Expand All @@ -453,17 +451,11 @@ fn test_pool_1d_dilation() {

let inp1 = cx.tensor::<R1<5>>().set(vec![1., 2., 3., 4., 5.]);
// Stride 1
let out1 = inp1
.pool_last_dim::<R2<3, 2>>(2.into(), 1.into(), 1)
.retrieve();
let out1 = inp1.pool_last_dim::<R2<3, 2>>(2, 1, 1).retrieve();
// Stride 2
let out2 = inp1
.pool_last_dim::<R2<2, 2>>(2.into(), 2.into(), 1)
.retrieve();
let out2 = inp1.pool_last_dim::<R2<2, 2>>(2, 2, 1).retrieve();
// Stride 3
let out3 = inp1
.pool_last_dim::<R2<1, 2>>(2.into(), 3.into(), 1)
.retrieve();
let out3 = inp1.pool_last_dim::<R2<1, 2>>(2, 3, 1).retrieve();

cx.execute();

Expand Down

0 comments on commit 07ab3b3

Please sign in to comment.