Skip to content

Commit

Permalink
reversed mistral weight transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Jan 12, 2024
1 parent a240e2a commit ec09c02
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 67 deletions.
28 changes: 14 additions & 14 deletions examples/mistral/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ pub type KVCache<Batch, Seq> = (
);

pub struct Mlp<const I: usize, const H: usize> {
pub gate_proj: GraphTensor<(Const<H>, Const<I>)>,
pub down_proj: GraphTensor<(Const<I>, Const<H>)>,
pub up_proj: GraphTensor<(Const<H>, Const<I>)>,
pub gate_proj: GraphTensor<(Const<I>, Const<H>)>,
pub down_proj: GraphTensor<(Const<H>, Const<I>)>,
pub up_proj: GraphTensor<(Const<I>, Const<H>)>,
}

impl<Sh: Shape, Im: Shape, const I: usize, const H: usize> Module<GraphTensor<Sh>> for Mlp<I, H>
Expand All @@ -39,9 +39,9 @@ where
type Output = GraphTensor<Sh>;

fn forward(&self, input: GraphTensor<Sh>) -> Self::Output {
let gate = input.matmul(self.gate_proj).swish();
let up = input.matmul(self.up_proj) * gate;
up.matmul(self.down_proj)
let gate = input.matmul(self.gate_proj.permute()).swish();
let up = input.matmul(self.up_proj.permute()) * gate;
up.matmul(self.down_proj.permute())
}
}

Expand Down Expand Up @@ -153,8 +153,8 @@ impl<const HEAD_DIM: usize, const HEAD_DIM_OVER_2: usize> InitModule

pub struct SelfAttention {
pub q_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
pub k_proj: GraphTensor<R2<HIDDEN_DIM, ATTN_PROJ_DIM>>,
pub v_proj: GraphTensor<R2<HIDDEN_DIM, ATTN_PROJ_DIM>>,
pub k_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
pub v_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
pub o_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
pub rotary_embeddings: RotaryEmbedding<HEAD_DIM, HEAD_DIM_OVER_2>,
}
Expand All @@ -180,15 +180,15 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
) -> Self::Output {
// Apply the Projections
let query_states = x
.matmul(self.q_proj)
.matmul(self.q_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
let key_states = x
.matmul(self.k_proj)
.matmul(self.k_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
let value_states = x
.matmul(self.v_proj)
.matmul(self.v_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();

Expand Down Expand Up @@ -234,7 +234,7 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
.matmul(repeated_value_states)
.permute::<_, Axes4<0, 2, 1, 3>>()
.reshape::<(Batch, CurSeq, Const<HIDDEN_DIM>)>()
.matmul(self.o_proj),
.matmul(self.o_proj.permute()),
(key_states, value_states),
)
}
Expand Down Expand Up @@ -341,7 +341,7 @@ pub struct MistralLM {
// Final Norm layer
pub norm: RMSNorm<HIDDEN_DIM>,
// LM Head Layer
pub lm_head: GraphTensor<R2<HIDDEN_DIM, VOCAB_SIZE>>,
pub lm_head: GraphTensor<R2<VOCAB_SIZE, HIDDEN_DIM>>,
}

impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Expand Down Expand Up @@ -377,7 +377,7 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
}
hidden_states = self.norm.forward(hidden_states);

(hidden_states.matmul(self.lm_head), new_caches)
(hidden_states.matmul(self.lm_head.permute()), new_caches)
}
}

Expand Down
25 changes: 0 additions & 25 deletions examples/mistral/setup/setup_weights/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,6 @@ fn main() {
.collect(),
_ => panic!("{:?} is not a supported dtype", tensor.dtype()),
};
if weight_name.contains("q_proj") || weight_name.contains("o_proj") {
data = transpose(&data, 4096, 4096);
} else if weight_name.contains("k_proj") || weight_name.contains("v_proj") {
data = transpose(&data, 1024, 4096);
} else if weight_name.contains("gate_proj") || weight_name.contains("up_proj") {
data = transpose(&data, 14336, 4096);
} else if weight_name.contains("down_proj") {
data = transpose(&data, 4096, 14336);
} else if weight_name.contains("lm_head") {
data = transpose(&data, 32000, 4096);
}
println!("Converted {weight_name}");
let len = data.len();
weights.insert(weight_name, Fp16Vec(data, vec![len]));
Expand All @@ -56,20 +45,6 @@ fn main() {
}
}

fn transpose(matrix: &Vec<f16>, rows: usize, cols: usize) -> Vec<f16> {
let mut transposed = vec![f16::ZERO; rows * cols];

for i in 0..rows {
for j in 0..cols {
let original_index = i * cols + j; // original index for row-major order
let transposed_index = j * rows + i; // transposed index for a row-major order matrix
transposed[transposed_index] = matrix[original_index];
}
}

transposed
}

struct Fp16Vec(Vec<f16>, Vec<usize>);

impl View for Fp16Vec {
Expand Down
48 changes: 20 additions & 28 deletions src/compilers/metal/fp16/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,20 @@ impl MetalKernel for MatVec {
encoder.set_i32(4, n as i32);
encoder.set_i32(5, 0 as i32);
encoder.set_i32(6, 0 as i32);
encoder.set_threadgroup_memory_length(0, BN * BM * 8);
encoder.set_threadgroup_memory_length(
0,
if inputs[1].1.is_contiguous() {
BN * BM * 4
} else {
BN * 8
},
);

encoder.set_compute_pipeline_state(&self.pipeline);
let b = if inputs[1].1.is_contiguous() { BN } else { BM };
encoder.dispatch_thread_groups(
MTLSize {
width: if inputs[1].1.is_contiguous() {
(n as u64 + BN * 4 - 1).div_ceil(BN * 4)
} else {
(n as u64 + BM * 4 - 1).div_ceil(BM * 4)
},
height: 1,
depth: 1,
},
MTLSize {
width: BN,
height: BM,
depth: 1,
},
MTLSize::new((n as u64 + b * 4 - 1).div_ceil(b * 4), 1, 1),
MTLSize::new(BN, BM, 1),
);
encoder.end_encoding();
}
Expand Down Expand Up @@ -179,22 +175,18 @@ impl MetalKernel for Matmul {
encoder.set_i32(3, m as i32);
encoder.set_i32(4, n as i32);
encoder.set_i32(5, k as i32);
encoder.set_i32(6, (m * k) as i32);
encoder.set_i32(7, 0);
encoder.set_i32(8, (m * n) as i32);
encoder.set_i32(6, (m * k) as i32); // A batch stride
encoder.set_i32(7, 0); // B batch stride
encoder.set_i32(8, (m * n) as i32); // C batch stride

// Execute
encoder.dispatch_thread_groups(
MTLSize {
width: (n + 32 - 1).div_ceil(32) as u64,
height: (m + 32 - 1).div_ceil(32) as u64,
depth: batch_size as u64,
},
MTLSize {
width: 32,
height: 2,
depth: 2,
},
MTLSize::new(
(n + 32 - 1).div_ceil(32) as u64,
(m + 32 - 1).div_ceil(32) as u64,
batch_size as u64,
),
MTLSize::new(32, 2, 2),
);
encoder.end_encoding();
}
Expand Down

0 comments on commit ec09c02

Please sign in to comment.