Skip to content

Commit

Permalink
Fixed matmuls
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jan 21, 2024
1 parent 4219d8e commit b1c435b
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
11 changes: 8 additions & 3 deletions src/compilers/metal/fp16/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,12 @@ impl Compiler for MetalMatMulCompiler {
dims.swap(src2_shape.len() - 2, src2_shape.len() - 1);
src2_shape.permute(&dims);
// If src1 is padded or sliced, or batch dim isn't first, we need to make it contiguous
if (src1_shape
if src1_shape
.indexes
.iter()
.take(src1_shape.len() - 2)
.enumerate()
.any(|(a, b)| a != *b))
.any(|(a, b)| a != *b)
|| src1_shape.is_sliced()
|| src1_shape.is_padded()
{
Expand All @@ -320,7 +320,12 @@ impl Compiler for MetalMatMulCompiler {
src1_shape = src1_shape.contiguous();
}
// If src2 is padded or sliced, or batch dim isn't first, we need to make it contiguous
if (src2_shape.len() == 3 && src2_shape.indexes[0] != 0)
if src2_shape
.indexes
.iter()
.take(src2_shape.len() - 2)
.enumerate()
.any(|(a, b)| a != *b)
|| src2_shape.is_sliced()
|| src2_shape.is_padded()
{
Expand Down
2 changes: 1 addition & 1 deletion src/compilers/metal/fp16/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ fn test_transformer_encoder_block() {
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<32>));
let d_b = d_model.forward(d_a);

assert_close(&b.data(), &d_b.as_vec());
assert_close_precision(&b.data(), &d_b.as_vec(), 2);
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion src/nn/transformer/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ impl<
let tokens: GraphTensor<(B, S2, Const<V_DIM>)> = weights
.matmul(values)
.permute::<_, Axes4<0, 2, 1, 3>>()
.dyn_reshape(vec![B::const_size(), S2::const_size(), V_DIM.into()]);
.reshape();
self.w_o.forward(tokens)
}
}
Expand Down
5 changes: 2 additions & 3 deletions src/nn/transformer/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@ impl<const DIM: usize, const FF: usize, const HEADS: usize, S: Dimension, B: Dim
fn forward(&self, x: GraphTensor<(B, S, Const<DIM>)>) -> Self::Output {
let y = self.attention.forward(x);
let x = (x + y).layer_norm::<2>(1e-5);
// let y = self.ff.forward(x);
// (x + y).layer_norm::<2>(1e-5)
x
let y = self.ff.forward(x);
(x + y).layer_norm::<2>(1e-5)
}
}

Expand Down

0 comments on commit b1c435b

Please sign in to comment.