Skip to content

Commit

Permalink
Fast mistral loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Jan 5, 2024
1 parent a23e536 commit 9aaff41
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 30 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ Cargo.lock
*.npx
*.npz
/**/llama-7b-hf
/**/mistral-7b-hf
/**/mistral-7b-hf
/**/setup_weights/target
57 changes: 57 additions & 0 deletions examples/mistral/loader.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use std::fs::File;

use luminal::{op::Function, prelude::*};
use memmap2::MmapOptions;
use metal_rs::{Device, MTLResourceOptions};
use safetensors::SafeTensors;

/// Load the model in the same way dfdx-llama does
pub struct MetalFp16SafetensorsLoader {
paths: Vec<String>,
}

impl MetalFp16SafetensorsLoader {
pub fn new<S: ToString>(paths: &[S]) -> Self {
Self {
paths: paths.iter().map(|s| s.to_string()).collect(),
}
}
}

impl Loader for MetalFp16SafetensorsLoader {
fn load<M: SerializeModule>(self, model: &M, graph: &mut Graph) {
for (weight_name, node_index) in state_dict(model) {
if let Some(loading_node) = graph
.graph
.node_weight_mut(node_index)
.and_then(|op| op.as_any_mut().downcast_mut::<Function>())
{
let file_paths = self.paths.clone();
loading_node.1 = Box::new(move |_| {
for file_path in file_paths.iter() {
let file = File::open(file_path).unwrap();
let buffer = unsafe { MmapOptions::new().map(&file).unwrap() };
let safetensors = SafeTensors::deserialize(&buffer).unwrap();

if let Ok(tensor_view) = safetensors.tensor(&weight_name.replace('/', "."))
{
let buffer = Device::system_default()
.unwrap()
.new_buffer_with_bytes_no_copy(
tensor_view.data().as_ptr() as *const _,
tensor_view.data().len() as u64,
MTLResourceOptions::StorageModeShared,
None,
);
return vec![Tensor {
data: Box::new(buffer),
}];
}
}

panic!("Tensor \"{weight_name}\" not found in files");
});
}
}
}
}
9 changes: 5 additions & 4 deletions examples/mistral/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{io::Write, marker::PhantomData, time::Instant};

use colored::Colorize;
use rust_tokenizers::tokenizer::{SentencePieceBpeTokenizer, Tokenizer, TruncationStrategy};
mod loader;
mod model;

use luminal::{prelude::*, shape::symbolic::Expression};
Expand Down Expand Up @@ -40,10 +41,10 @@ fn main() {
kv_cache.keep();

// Set up model loading
SafeTensorLoader::new(&[
"./examples/mistral/setup/mistral-7b-hf/model-00001-of-00003.safetensors",
"./examples/mistral/setup/mistral-7b-hf/model-00002-of-00003.safetensors",
"./examples/mistral/setup/mistral-7b-hf/model-00003-of-00003.safetensors",
loader::MetalFp16SafetensorsLoader::new(&[
"./examples/mistral/setup/mistral-7b-hf/converted-model-00001-of-00003.safetensors",
"./examples/mistral/setup/mistral-7b-hf/converted-model-00002-of-00003.safetensors",
"./examples/mistral/setup/mistral-7b-hf/converted-model-00003-of-00003.safetensors",
])
.load(&model, &mut cx1);

Expand Down
28 changes: 14 additions & 14 deletions examples/mistral/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ pub type KVCache<Batch, Seq> = (
);

pub struct Mlp<const I: usize, const H: usize> {
pub gate_proj: GraphTensor<(Const<I>, Const<H>)>,
pub down_proj: GraphTensor<(Const<H>, Const<I>)>,
pub up_proj: GraphTensor<(Const<I>, Const<H>)>,
pub gate_proj: GraphTensor<(Const<H>, Const<I>)>,
pub down_proj: GraphTensor<(Const<I>, Const<H>)>,
pub up_proj: GraphTensor<(Const<H>, Const<I>)>,
}

impl<Sh: Shape, Im: Shape, const I: usize, const H: usize> Module<GraphTensor<Sh>> for Mlp<I, H>
Expand All @@ -39,9 +39,9 @@ where
type Output = GraphTensor<Sh>;

fn forward(&self, input: GraphTensor<Sh>) -> Self::Output {
let gate = input.matmul(self.gate_proj.permute()).swish();
let up = input.matmul(self.up_proj.permute()) * gate;
up.matmul(self.down_proj.permute())
let gate = input.matmul(self.gate_proj).swish();
let up = input.matmul(self.up_proj) * gate;
up.matmul(self.down_proj)
}
}

Expand Down Expand Up @@ -153,8 +153,8 @@ impl<const HEAD_DIM: usize, const HEAD_DIM_OVER_2: usize> InitModule

pub struct SelfAttention {
pub q_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
pub k_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
pub v_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
pub k_proj: GraphTensor<R2<HIDDEN_DIM, ATTN_PROJ_DIM>>,
pub v_proj: GraphTensor<R2<HIDDEN_DIM, ATTN_PROJ_DIM>>,
pub o_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
pub rotary_embeddings: RotaryEmbedding<HEAD_DIM, HEAD_DIM_OVER_2>,
}
Expand All @@ -180,15 +180,15 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
) -> Self::Output {
// Apply the Projections
let query_states = x
.matmul(self.q_proj.permute())
.matmul(self.q_proj)
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
let key_states = x
.matmul(self.k_proj.permute())
.matmul(self.k_proj)
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
let value_states = x
.matmul(self.v_proj.permute())
.matmul(self.v_proj)
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();

Expand Down Expand Up @@ -234,7 +234,7 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
.matmul(repeated_value_states)
.permute::<_, Axes4<0, 2, 1, 3>>()
.reshape::<(Batch, CurSeq, Const<HIDDEN_DIM>)>()
.matmul(self.o_proj.permute()),
.matmul(self.o_proj),
(key_states, value_states),
)
}
Expand Down Expand Up @@ -341,7 +341,7 @@ pub struct MistralLM {
// Final Norm layer
pub norm: RMSNorm<HIDDEN_DIM>,
// LM Head Layer
pub lm_head: GraphTensor<R2<VOCAB_SIZE, HIDDEN_DIM>>,
pub lm_head: GraphTensor<R2<HIDDEN_DIM, VOCAB_SIZE>>,
}

impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Expand Down Expand Up @@ -377,7 +377,7 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
}
hidden_states = self.norm.forward(hidden_states);

(hidden_states.matmul(self.lm_head.permute()), new_caches)
(hidden_states.matmul(self.lm_head), new_caches)
}
}

Expand Down
7 changes: 5 additions & 2 deletions examples/mistral/setup/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ echo "Downloading Tokenizer"
curl --location https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/tokenizer.model?download=true --output $SCRIPT_DIR/mistral-7b-hf/tokenizer.model


echo "Downloading Model Files"
echo "Downloading Model Files"

curl\
--parallel --parallel-immediate --parallel-max 3\
--location https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/model-00001-of-00003.safetensors?download=true --output $SCRIPT_DIR/mistral-7b-hf/model-00001-of-00003.safetensors\
--location https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/model-00002-of-00003.safetensors?download=true --output $SCRIPT_DIR/mistral-7b-hf/model-00002-of-00003.safetensors\
--location https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/model-00003-of-00003.safetensors?download=true --output $SCRIPT_DIR/mistral-7b-hf/model-00003-of-00003.safetensors

echo "Done Downloading Model"
echo "Done Downloading Model"

# Convert model weights
cargo run --manifest-path $SCRIPT_DIR/setup_weights/Cargo.toml --release
1 change: 0 additions & 1 deletion src/compilers/metal/prim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::{
};

use super::*;
use block::ConcreteBlock;
use metal_rs::*;
use objc::rc::autoreleasepool;
use petgraph::visit::EdgeRef;
Expand Down
12 changes: 4 additions & 8 deletions src/core/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,14 @@ impl Loader for SafeTensorLoader {
if let Ok(tensor_view) = safetensors.tensor(&weight_name.replace('/', "."))
{
// Convert to fp32
let bytes = tensor_view.data().to_vec();
let data: Vec<f32> = match tensor_view.dtype() {
Dtype::F32 => {
unsafe { std::mem::transmute::<_, &[f32]>(tensor_view.data()) }
.to_vec()
}
Dtype::F16 => tensor_view
.data()
Dtype::F32 => unsafe { std::mem::transmute::<_, Vec<f32>>(bytes) },
Dtype::F16 => bytes
.chunks_exact(2)
.map(|c| f16::from_ne_bytes([c[0], c[1]]).to_f32())
.collect(),
Dtype::BF16 => tensor_view
.data()
Dtype::BF16 => bytes
.chunks_exact(2)
.map(|c| bf16::from_ne_bytes([c[0], c[1]]).to_f32())
.collect(),
Expand Down

0 comments on commit 9aaff41

Please sign in to comment.