Skip to content

Commit

Permalink
Removed forward_kv from llama
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Dec 26, 2023
1 parent 58a56f9 commit bf8f3d9
Show file tree
Hide file tree
Showing 9 changed files with 315 additions and 228 deletions.
15 changes: 12 additions & 3 deletions examples/llama/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@ mod config;
mod loader;
mod model;

use std::marker::PhantomData;

use luminal::prelude::*;
use model::LlamaForCausalLM;
use rust_tokenizers::tokenizer::{
SentencePieceBpeTokenizer, Tokenizer, TruncationStrategy::LongestFirst,
};

use crate::model::KVCache;

type Model = LlamaForCausalLM<
{ config::VOCAB },
{ config::HEADS },
Expand Down Expand Up @@ -48,7 +52,11 @@ fn main() {
input.iter().map(|i| *i as f32).collect::<Vec<f32>>(),
vec![1, input.len()],
);
let (out1, cache1) = model.forward(inp);
let (out1, cache1) = model.forward((
inp,
Option::<Vec<KVCache<Const<1>, Const<0>, { config::HEADS }, { config::HEAD_DIM }>>>::None,
PhantomData::<Dyn<'s'>>,
));
out1.retrieve();
cache1.keep();
loader::DfdxDeferredLoader::new("./examples/llama/setup/llama-7b-hf").load(&model, &mut cx1);
Expand All @@ -64,7 +72,8 @@ fn main() {
// Build KV cache forward graph
let kv_model = Model::initialize(&mut cx2);
let single_inp = cx2.named_tensor::<R2<1, 1>>("Input");
let cache_src = (0..config::LAYERS)
let cache_src: Vec<KVCache<Const<1>, Dyn<'p'>, { config::HEADS }, { config::HEAD_DIM }>> = (0
..config::LAYERS)
.map(|_| {
(
cx2.named_tensor("Key Cache"),
Expand All @@ -73,7 +82,7 @@ fn main() {
})
.collect::<Vec<_>>();
let (out, cache_dest) =
kv_model.forward_kv::<_, _, Dyn<'p'>, Dyn<'t'>>((single_inp, cache_src.clone()));
kv_model.forward((single_inp, Some(cache_src.clone()), PhantomData::<Dyn<'t'>>));
out.retrieve();
cache_dest.keep();
cx2.compile(<(PreGenericCompiler, DeviceCompiler, PostGenericCompiler)>::default());
Expand Down
Loading

0 comments on commit bf8f3d9

Please sign in to comment.