Skip to content

Commit

Permalink
Simplified llama
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Dec 26, 2023
1 parent bf8f3d9 commit eaa4ad8
Showing 1 changed file with 20 additions and 100 deletions.
120 changes: 20 additions & 100 deletions examples/llama/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ impl<
};

let mut w = q
.matmul(k.permute::<_, Axes4<0, 1, 3, 2>>())
.matmul(k.permute())
.mul((HEAD_DIM as f64).sqrt().recip() as f32);
// We don't need to mask on a kv cached pass
if cache.is_none() {
Expand Down Expand Up @@ -372,7 +372,7 @@ impl<
}
}

pub struct Llama<
pub struct LlamaForCausalLM<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
Expand All @@ -385,6 +385,7 @@ pub struct Llama<
pub layers: Vec<DecoderLayer<NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2>>,
pub norm: RMSNorm<HIDDEN>,
pub graph_ref: *mut Graph,
pub lm_head: GraphTensor<(Const<VOCAB>, Const<HIDDEN>)>,
}

impl<
Expand All @@ -404,32 +405,36 @@ impl<
GraphTensor<(Batch, CurSeq)>,
Option<Vec<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>>,
PhantomData<TotSeq>,
)> for Llama<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
)>
for LlamaForCausalLM<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
GraphTensor<(Batch, CurSeq, Const<VOCAB>)>,
Vec<KVCache<Batch, TotSeq, NUM_HEADS, HEAD_DIM>>,
);
fn forward(
&self,
(input, cache, _): (
(input, caches, _): (
GraphTensor<(Batch, CurSeq)>,
Option<Vec<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>>,
PhantomData<TotSeq>,
),
) -> Self::Output {
let mut hidden_states = self.embed_tokens.forward(input);
let mut caches = vec![];
let mut new_caches = vec![];
for (i, layer_i) in self.layers.iter().enumerate() {
let (new_hidden_states, (k_cache, v_cache)) = layer_i.forward((
hidden_states,
cache.as_ref().map(|v| v[i]),
caches.as_ref().map(|v| v[i]),
PhantomData::<TotSeq>,
));
hidden_states = new_hidden_states;
caches.push((k_cache.contiguous(), v_cache.contiguous()));
new_caches.push((k_cache.contiguous(), v_cache.contiguous()));
}
(self.norm.forward(hidden_states), caches)
let hidden_states = self.norm.forward(hidden_states);
// let (hidden_states, caches) = self.llama.forward((input, caches, PhantomData::<TotSeq>));
let o = hidden_states.matmul(self.lm_head.permute());
(o, new_caches)
}
}

Expand All @@ -442,103 +447,14 @@ impl<
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
> InitModule
for Llama<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
for LlamaForCausalLM<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
{
fn initialize(cx: &mut Graph) -> Self {
Self {
norm: InitModule::initialize(cx),
embed_tokens: InitModule::initialize(cx),
layers: (0..LAYERS).map(|_| InitModule::initialize(cx)).collect(),
graph_ref: cx,
}
}
}

impl<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
> SerializeModule
for Llama<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
{
fn serialize(&self, s: &mut Serializer) {
s.module("norm", &self.norm);
s.module("embed_tokens", &self.embed_tokens);
for (i, l) in self.layers.iter().enumerate() {
s.module(&format!("layers/{i}"), l);
}
}
}

pub struct LlamaForCausalLM<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
> {
pub llama: Llama<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>,
pub lm_head: GraphTensor<(Const<VOCAB>, Const<HIDDEN>)>,
}

impl<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
Batch: Dimension,
CurSeq: Dimension,
PrevSeq: Dimension,
TotSeq: Dimension,
>
Module<(
GraphTensor<(Batch, CurSeq)>,
Option<Vec<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>>,
PhantomData<TotSeq>,
)>
for LlamaForCausalLM<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<VOCAB>)>,
Vec<KVCache<Batch, TotSeq, NUM_HEADS, HEAD_DIM>>,
);
fn forward(
&self,
(input, caches, _): (
GraphTensor<(Batch, CurSeq)>,
Option<Vec<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>>,
PhantomData<TotSeq>,
),
) -> Self::Output {
let (hidden_states, caches) = self.llama.forward((input, caches, PhantomData::<TotSeq>));
let o = hidden_states.matmul(self.lm_head.permute());
(o, caches)
}
}

impl<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
> InitModule
for LlamaForCausalLM<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
{
fn initialize(cx: &mut Graph) -> Self {
Self {
llama: InitModule::initialize(cx),
lm_head: cx.named_tensor("LM Head"),
}
}
Expand All @@ -556,7 +472,11 @@ impl<
for LlamaForCausalLM<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
{
fn serialize(&self, s: &mut Serializer) {
s.module("model", &self.llama);
s.module("model/norm", &self.norm);
s.module("model/embed_tokens", &self.embed_tokens);
for (i, l) in self.layers.iter().enumerate() {
s.module(&format!("model/layers/{i}"), l);
}
s.tensor("lm_head/weight", self.lm_head);
}
}

0 comments on commit eaa4ad8

Please sign in to comment.