Skip to content

Commit

Permalink
Merge
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Dec 29, 2023
2 parents 10ee2c7 + 3e0cafb commit 51545ee
Show file tree
Hide file tree
Showing 20 changed files with 861 additions and 246 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ Cargo.lock
*.st
*.npx
*.npz
/**/llama-7b-hf
/**/llama-7b-hf
/**/mistral-7b-hf
19 changes: 12 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,26 @@ cuda = ["dep:cudarc"]
metal = ["dep:metal-rs"]

[dependencies]
luminal_macro = {path="./resources/luminal_macro"}
luminal_macro = { path = "./resources/luminal_macro" }
itertools = "0.11.0"
matrixmultiply = "0.3.8"
num-traits = "0.2.16"
petgraph = {path="./resources/petgraph"} # Literally only need this because the free_node field of the stable graph is private
petgraph = { path = "./resources/petgraph" } # Literally only need this because the free_node field of the stable graph is private
rand = "0.8.5"
strum = { version = "0.25.0", features = ["derive"] }
urlencoding = "2.1.2"
webbrowser = "0.8.10"
dyn-clone = "1.0.12"
cudarc = {path="./resources/cudarc", features=["cublas", "f16"], optional=true}
metal-rs = {version="0.26.0", package="metal", optional=true, features=["mps"]}
cudarc = { path = "./resources/cudarc", features = [
"cublas",
"f16",
], optional = true }
metal-rs = { version = "0.26.0", package = "metal", optional = true, features = [
"mps",
] }
safetensors = "0.3.1"
memmap2 = "0.7.1"
half = {version="2.3.1", features = ["num-traits", "rand_distr"]}
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
tinyvec = "1.6.0"
term_size = "0.3.2"
colored = "2.0.4"
Expand All @@ -36,5 +41,5 @@ gemm = "0.15.4"
objc = "0.2.7"

[dev-dependencies]
dfdx = {version="0.13", features=["f16"]}
dfdx = { version = "0.13", features = ["f16"] }
rust_tokenizers = "8.1.0"
4 changes: 2 additions & 2 deletions examples/llama/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn main() {
cache1.keep();
loader::DfdxDeferredLoader::new("./examples/llama/setup/llama-7b-hf").load(&model, &mut cx1);

cx1.compile(<(PreGenericCompiler, DeviceCompiler, PostGenericCompiler)>::default());
cx1.compile(GenericCompiler::<DeviceCompiler>::default());

// Cache model weights
cx1.compile(RemapDownstream(
Expand All @@ -85,7 +85,7 @@ fn main() {
kv_model.forward((single_inp, Some(cache_src.clone()), PhantomData::<Dyn<'t'>>));
out.retrieve();
cache_dest.keep();
cx2.compile(<(PreGenericCompiler, DeviceCompiler, PostGenericCompiler)>::default());
cx2.compile(GenericCompiler::<DeviceCompiler>::default());

// Cache model weights
cx2.compile(RemapDownstream(
Expand Down
20 changes: 7 additions & 13 deletions examples/llama/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,28 +65,21 @@ impl<
const HEAD_DIM_OVER_2: usize,
>
Module<(
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
BigExpression,
)> for RotaryEmbedding<HEAD_DIM, HEAD_DIM_OVER_2>
{
type Output = (
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
);
type Output = GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>;

fn forward(
&self,
(q, k, prev_seq): (
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
(inp, prev_seq): (
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
BigExpression,
),
) -> Self::Output {
let (sin, cos) = self.get_sincos::<NUM_HEADS, Seq>(prev_seq);
let q_embed = (Self::rotate_half(q) * sin.expand()) + (q * cos.expand());
let k_embed = (Self::rotate_half(k) * sin.expand()) + (k * cos.expand());
(q_embed, k_embed)
(Self::rotate_half(inp) * sin.expand()) + (inp * cos.expand())
}
}

Expand Down Expand Up @@ -195,9 +188,10 @@ impl<
.matmul(self.v_proj.permute())
.reshape::<(Batch, CurSeq, Const<NUM_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
let (q, k) =
self.rotary_embed
.forward((q.permute(), k.permute(), PrevSeq::const_size().into()));
let q = self
.rotary_embed
.forward((q.permute(), PrevSeq::const_size().into()));
let k = self.rotary_embed.forward((k, PrevSeq::const_size().into()));

let (k, v) = if let Some(cache) = cache {
// Add KV cache
Expand Down
181 changes: 181 additions & 0 deletions examples/mistral/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
use std::{marker::PhantomData, time::Instant};

use colored::Colorize;
use rust_tokenizers::tokenizer::{SentencePieceBpeTokenizer, Tokenizer, TruncationStrategy};
mod model;

use luminal::{prelude::*, shape::symbolic::Expression};

use crate::model::KVCache;

#[cfg(feature = "metal")]
type DeviceCompiler = MetalFp16Compiler;
#[cfg(feature = "cuda")]
type DeviceCompiler = CudaFp16Compiler;
#[cfg(all(not(feature = "cuda"), not(feature = "metal")))]
type DeviceCompiler = CPUCompiler;

fn main() -> Result<(), String> {
println!("Constructing graph...");
let tokenizer = SentencePieceBpeTokenizer::from_file(
"./examples/mistral/setup/mistral-7b-hf/tokenizer.model",
false,
)
.unwrap();

let mut cx1 = Graph::new();
let input = cx1.named_tensor::<(Const<1>, Dyn<'s'>)>("Input");
let model = model::MistralLM::initialize(&mut cx1);
let (logits, kv_cache) = model.forward((
input,
Option::<Vec<KVCache<Const<1>, Const<0>>>>::None,
PhantomData::<Dyn<'s'>>,
));
let logits = logits
.slice((.., (Expression::from('s') - 1).., ..))
.retrieve();
kv_cache.keep();
SafeTensorLoader::new(vec![
"./examples/mistral/setup/mistral-7b-hf/model-00001-of-00003.safetensors".to_string(),
"./examples/mistral/setup/mistral-7b-hf/model-00002-of-00003.safetensors".to_string(),
"./examples/mistral/setup/mistral-7b-hf/model-00003-of-00003.safetensors".to_string(),
])
.load(&model, &mut cx1);
let mut cx2 = Graph::new();
let single_input = cx2.named_tensor::<R2<1, 1>>("Input");
let kv_model = model::MistralLM::initialize(&mut cx2);
let cache_src: Vec<KVCache<Const<1>, Dyn<'p'>>> = (0..model::NUM_LAYERS)
.map(|_| {
(
cx2.named_tensor("Key Cache"),
cx2.named_tensor("Value Cache"),
)
})
.collect();
let (decode_logits, cache_dest) = kv_model.forward((
single_input,
Some(cache_src.clone()),
PhantomData::<Dyn<'t'>>,
));
decode_logits.retrieve();
cache_dest.keep();

println!("Compiling graph...");
cx1.compile(GenericCompiler::<DeviceCompiler>::default());
// Cache model weights
cx1.compile(RemapDownstream(
state_dict(&model).values().copied().collect(),
));
keep_weights(&model, &mut cx1);

// Compile second graph
cx2.compile(GenericCompiler::<DeviceCompiler>::default());
// Cache model weights
cx2.compile(RemapDownstream(
state_dict(&kv_model).values().copied().collect(),
));
keep_weights(&kv_model, &mut cx2);
delete_inputs(
&state_dict(&kv_model).values().copied().collect::<Vec<_>>(),
&mut cx2,
);
delete_inputs(
&cache_src
.iter()
.flat_map(|(k, v)| [k.id(), v.id()])
.collect::<Vec<_>>(),
&mut cx2,
);

// Initial forward pass to load weights
println!("Loading model...");
input.set_dyn(vec![1.], vec![1, 1]);
cx1.execute();
logits.drop();
kv_cache.drop();

// Now that weights are loaded, delete the loading nodes so they don't run again
delete_inputs(
&state_dict(&model).values().copied().collect::<Vec<_>>(),
&mut cx1,
);

// Run inference first pass
let prompt = "Santa says: Merry";
let mut input_ids = encode(&tokenizer, prompt);

let mut completion = String::new();
input.set_dyn(
input_ids.iter().map(|i| *i as f32).collect::<Vec<_>>(),
vec![1, input_ids.len()],
);
cx1.execute();

let output_id = sample_index(&logits.data());
input_ids.push(output_id);

// Decode token
completion.push_str(&decode(&tokenizer, &[output_id]));
println!("{}{}", prompt.on_black().white().bold(), completion.green());

// Transfer weights and kv cache
transfer_weights(&model, &mut cx1, &kv_model, &mut cx2);
for ((key_src, val_src), (key_dest, val_dest)) in kv_cache.into_iter().zip(cache_src.iter()) {
cx2.set_tensor(key_dest.id(), 0, cx1.get_tensor(key_src.id(), 0).unwrap());
cx2.set_tensor(val_dest.id(), 0, cx1.get_tensor(val_src.id(), 0).unwrap());
}

// Decode loop
for _ in 0..100 {
single_input.set(vec![*input_ids.last().unwrap() as f32]);
cx2.set_dyn_dim('p', input_ids.len() - 1);
cx2.set_dyn_dim('t', input_ids.len());

let now = Instant::now();
cx2.execute();
println!("Forward Pass Took {:.2}s", now.elapsed().as_secs_f32());

// Sample tokens
let output_id = sample_index(&decode_logits.data());
decode_logits.drop();
completion.push_str(&decode(&tokenizer, &[output_id]));
input_ids.push(output_id);
println!("{}{}", prompt.on_black().white().bold(), completion.green());

// Swap caches
for ((src_k, src_v), (dest_k, dest_v)) in cache_src.iter().zip(cache_dest.iter()) {
// Move dest caches to src
cx2.swap_tensors(*src_k, *dest_k);
cx2.swap_tensors(*src_v, *dest_v);
// // Drop dest caches
// dest_k.drop();
// dest_v.drop();
}
}

Ok(())
}

// Method to encode text as vector
pub fn encode(tokenizer: &SentencePieceBpeTokenizer, text: &str) -> Vec<i64> {
let mut vector = tokenizer
.encode(text, None, text.len(), &TruncationStrategy::LongestFirst, 0)
.token_ids;
vector.insert(0, 1); // Start token
vector
}

pub fn decode(tokenizer: &SentencePieceBpeTokenizer, token_ids: &[i64]) -> String {
tokenizer
.decode(token_ids, true, false)
.replace("<0x0A>", "\n")
}

// Currently just an argmax, do actual sampling here
fn sample_index(dist: &[f32]) -> i64 {
dist.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap()
.0 as i64
}
Loading

0 comments on commit 51545ee

Please sign in to comment.