diff --git a/.gitignore b/.gitignore index b071db4e..e7de84a3 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ Cargo.lock *.npx *.npz /**/llama-7b-hf -/**/mistral-7b-hf \ No newline at end of file +/**/mistral-7b-hf +/**/setup_weights/target diff --git a/examples/mistral/loader.rs b/examples/mistral/loader.rs new file mode 100644 index 00000000..5d4b1510 --- /dev/null +++ b/examples/mistral/loader.rs @@ -0,0 +1,57 @@ +use std::fs::File; + +use luminal::{op::Function, prelude::*}; +use memmap2::MmapOptions; +use metal_rs::{Device, MTLResourceOptions}; +use safetensors::SafeTensors; + +/// Load the model in the same way dfdx-llama does +pub struct MetalFp16SafetensorsLoader { + paths: Vec, +} + +impl MetalFp16SafetensorsLoader { + pub fn new(paths: &[S]) -> Self { + Self { + paths: paths.iter().map(|s| s.to_string()).collect(), + } + } +} + +impl Loader for MetalFp16SafetensorsLoader { + fn load(self, model: &M, graph: &mut Graph) { + for (weight_name, node_index) in state_dict(model) { + if let Some(loading_node) = graph + .graph + .node_weight_mut(node_index) + .and_then(|op| op.as_any_mut().downcast_mut::()) + { + let file_paths = self.paths.clone(); + loading_node.1 = Box::new(move |_| { + for file_path in file_paths.iter() { + let file = File::open(file_path).unwrap(); + let buffer = unsafe { MmapOptions::new().map(&file).unwrap() }; + let safetensors = SafeTensors::deserialize(&buffer).unwrap(); + + if let Ok(tensor_view) = safetensors.tensor(&weight_name.replace('/', ".")) + { + let buffer = Device::system_default() + .unwrap() + .new_buffer_with_bytes_no_copy( + tensor_view.data().as_ptr() as *const _, + tensor_view.data().len() as u64, + MTLResourceOptions::StorageModeShared, + None, + ); + return vec![Tensor { + data: Box::new(buffer), + }]; + } + } + + panic!("Tensor \"{weight_name}\" not found in files"); + }); + } + } + } +} diff --git a/examples/mistral/main.rs b/examples/mistral/main.rs index 035cf22d..b6bcfd4e 100644 --- a/examples/mistral/main.rs +++ b/examples/mistral/main.rs @@ -2,6 +2,7 @@ use std::{io::Write, marker::PhantomData, time::Instant}; use colored::Colorize; use rust_tokenizers::tokenizer::{SentencePieceBpeTokenizer, Tokenizer, TruncationStrategy}; +mod loader; mod model; use luminal::{prelude::*, shape::symbolic::Expression}; @@ -40,10 +41,10 @@ fn main() { kv_cache.keep(); // Set up model loading - SafeTensorLoader::new(&[ - "./examples/mistral/setup/mistral-7b-hf/model-00001-of-00003.safetensors", - "./examples/mistral/setup/mistral-7b-hf/model-00002-of-00003.safetensors", - "./examples/mistral/setup/mistral-7b-hf/model-00003-of-00003.safetensors", + loader::MetalFp16SafetensorsLoader::new(&[ + "./examples/mistral/setup/mistral-7b-hf/converted-model-00001-of-00003.safetensors", + "./examples/mistral/setup/mistral-7b-hf/converted-model-00002-of-00003.safetensors", + "./examples/mistral/setup/mistral-7b-hf/converted-model-00003-of-00003.safetensors", ]) .load(&model, &mut cx1); diff --git a/examples/mistral/model.rs b/examples/mistral/model.rs index 30fed0ce..89148045 100644 --- a/examples/mistral/model.rs +++ b/examples/mistral/model.rs @@ -26,9 +26,9 @@ pub type KVCache = ( ); pub struct Mlp { - pub gate_proj: GraphTensor<(Const, Const)>, - pub down_proj: GraphTensor<(Const, Const)>, - pub up_proj: GraphTensor<(Const, Const)>, + pub gate_proj: GraphTensor<(Const, Const)>, + pub down_proj: GraphTensor<(Const, Const)>, + pub up_proj: GraphTensor<(Const, Const)>, } impl Module> for Mlp @@ -39,9 +39,9 @@ where type Output = GraphTensor; fn forward(&self, input: GraphTensor) -> Self::Output { - let gate = input.matmul(self.gate_proj.permute()).swish(); - let up = input.matmul(self.up_proj.permute()) * gate; - up.matmul(self.down_proj.permute()) + let gate = input.matmul(self.gate_proj).swish(); + let up = input.matmul(self.up_proj) * gate; + up.matmul(self.down_proj) } } @@ -153,8 +153,8 @@ impl InitModule pub struct SelfAttention { pub q_proj: GraphTensor>, - pub k_proj: GraphTensor>, - pub v_proj: GraphTensor>, + pub k_proj: GraphTensor>, + pub v_proj: GraphTensor>, pub o_proj: GraphTensor>, pub rotary_embeddings: RotaryEmbedding, } @@ -180,15 +180,15 @@ impl ) -> Self::Output { // Apply the Projections let query_states = x - .matmul(self.q_proj.permute()) + .matmul(self.q_proj) .reshape::<(Batch, CurSeq, Const, Const)>() .permute::<_, Axes4<0, 2, 1, 3>>(); let key_states = x - .matmul(self.k_proj.permute()) + .matmul(self.k_proj) .reshape::<(Batch, CurSeq, Const, Const)>() .permute::<_, Axes4<0, 2, 1, 3>>(); let value_states = x - .matmul(self.v_proj.permute()) + .matmul(self.v_proj) .reshape::<(Batch, CurSeq, Const, Const)>() .permute::<_, Axes4<0, 2, 1, 3>>(); @@ -234,7 +234,7 @@ impl .matmul(repeated_value_states) .permute::<_, Axes4<0, 2, 1, 3>>() .reshape::<(Batch, CurSeq, Const)>() - .matmul(self.o_proj.permute()), + .matmul(self.o_proj), (key_states, value_states), ) } @@ -341,7 +341,7 @@ pub struct MistralLM { // Final Norm layer pub norm: RMSNorm, // LM Head Layer - pub lm_head: GraphTensor>, + pub lm_head: GraphTensor>, } impl @@ -377,7 +377,7 @@ impl } hidden_states = self.norm.forward(hidden_states); - (hidden_states.matmul(self.lm_head.permute()), new_caches) + (hidden_states.matmul(self.lm_head), new_caches) } } diff --git a/examples/mistral/setup/setup.sh b/examples/mistral/setup/setup.sh index 3d00f6c3..efe77f2b 100644 --- a/examples/mistral/setup/setup.sh +++ b/examples/mistral/setup/setup.sh @@ -30,7 +30,7 @@ echo "Downloading Tokenizer" curl --location https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/tokenizer.model?download=true --output $SCRIPT_DIR/mistral-7b-hf/tokenizer.model -echo "Downloading Model Files" +echo "Downloading Model Files" curl\ --parallel --parallel-immediate --parallel-max 3\ @@ -38,4 +38,7 @@ curl\ --location https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/model-00002-of-00003.safetensors?download=true --output $SCRIPT_DIR/mistral-7b-hf/model-00002-of-00003.safetensors\ --location https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/model-00003-of-00003.safetensors?download=true --output $SCRIPT_DIR/mistral-7b-hf/model-00003-of-00003.safetensors -echo "Done Downloading Model" \ No newline at end of file +echo "Done Downloading Model" + +# Convert model weights +cargo run --manifest-path $SCRIPT_DIR/setup_weights/Cargo.toml --release diff --git a/src/compilers/metal/prim.rs b/src/compilers/metal/prim.rs index a41419a4..99321f95 100644 --- a/src/compilers/metal/prim.rs +++ b/src/compilers/metal/prim.rs @@ -3,7 +3,6 @@ use std::{ }; use super::*; -use block::ConcreteBlock; use metal_rs::*; use objc::rc::autoreleasepool; use petgraph::visit::EdgeRef; diff --git a/src/core/serialization.rs b/src/core/serialization.rs index 619daba6..641779ef 100644 --- a/src/core/serialization.rs +++ b/src/core/serialization.rs @@ -119,18 +119,14 @@ impl Loader for SafeTensorLoader { if let Ok(tensor_view) = safetensors.tensor(&weight_name.replace('/', ".")) { // Convert to fp32 + let bytes = tensor_view.data().to_vec(); let data: Vec = match tensor_view.dtype() { - Dtype::F32 => { - unsafe { std::mem::transmute::<_, &[f32]>(tensor_view.data()) } - .to_vec() - } - Dtype::F16 => tensor_view - .data() + Dtype::F32 => unsafe { std::mem::transmute::<_, Vec>(bytes) }, + Dtype::F16 => bytes .chunks_exact(2) .map(|c| f16::from_ne_bytes([c[0], c[1]]).to_f32()) .collect(), - Dtype::BF16 => tensor_view - .data() + Dtype::BF16 => bytes .chunks_exact(2) .map(|c| bf16::from_ne_bytes([c[0], c[1]]).to_f32()) .collect(),