Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/jafioti/luminal
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Apr 28, 2024
2 parents fa2b7ac + 868e1c6 commit 1e5d639
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 29 deletions.
4 changes: 2 additions & 2 deletions crates/luminal_metal/src/elementwise_fusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1090,9 +1090,9 @@ mod tests {
let unopt_out = out.data();
out.drop();

cx.compile(<(GenericCompiler, MetalCompiler<f32>)>::default(), &mut out);
cx.compile(<(GenericCompiler, MetalCompiler<f16>)>::default(), &mut out);
cx.execute();

assert_close_precision(&out.data(), &unopt_out, 1e-3);
assert_close_precision(&out.data(), &unopt_out, 1e-2);
}
}
3 changes: 3 additions & 0 deletions crates/luminal_metal/src/prim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1600,6 +1600,9 @@ impl<T: MetalFloat + 'static> Compiler for PrimitiveCompiler<T> {
graph.remove_edge(edge_id);
}

if graph.no_delete.contains(&function_node) {
graph.no_delete.insert(copy_node);
}
if let Some(w) = graph.to_retrieve.get(&function_node) {
graph.to_retrieve.insert(copy_node, *w);
}
Expand Down
66 changes: 44 additions & 22 deletions examples/llama/src/loader.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use itertools::Itertools;
use std::fs::File;
use std::io::{Read, Seek};
use std::path::Path;

use luminal::{op::Function, prelude::*};
Expand All @@ -8,11 +10,6 @@ use {luminal_cuda::CudaData, luminal_cudarc::driver::CudaDevice};

use crate::gguf::*;

#[cfg(not(feature = "metal"))]
use {
itertools::Itertools,
std::io::{Read, Seek},
};
#[cfg(feature = "metal")]
use {
luminal_metal::MetalBuffer,
Expand Down Expand Up @@ -53,23 +50,48 @@ pub fn q8_load<P: AsRef<Path>, M: SerializeModule>(
}
_ => panic!("Unsupported dtype: {data_type:?}"),
};
loading_node.1 = Box::new(move |_| {
let mmap_buffer = unsafe { Mmap::map(&File::open(&file_path).unwrap()).unwrap() };
let buffer = Device::system_default()
.unwrap()
.new_buffer_with_bytes_no_copy(
unsafe {
mmap_buffer
.as_ptr()
.add(buffer_offset + tensor_data_offset as usize)
as *const _
},
n_bytes as u64,
MTLResourceOptions::StorageModeShared,
None,
);
vec![Tensor::new(MetalBuffer(buffer))]
});
if let GgmlDType::F32 = data_type {
loading_node.1 = Box::new(move |_| {
// Read bytes
let mut bytes = vec![0; n_bytes];
let mut file = File::open(&file_path).unwrap();
file.seek(std::io::SeekFrom::Start(
buffer_offset as u64 + tensor_data_offset,
))
.unwrap();
file.read_exact(&mut bytes).unwrap();
vec![Tensor::new(
bytes
.into_iter()
.chunks(4)
.into_iter()
.map(|c| {
let c = c.collect::<Vec<_>>();
f32::from_le_bytes([c[0], c[1], c[2], c[3]])
})
.collect::<Vec<_>>(),
)]
});
} else {
loading_node.1 = Box::new(move |_| {
let mmap_buffer =
unsafe { Mmap::map(&File::open(&file_path).unwrap()).unwrap() };
let buffer = Device::system_default()
.unwrap()
.new_buffer_with_bytes_no_copy(
unsafe {
mmap_buffer
.as_ptr()
.add(buffer_offset + tensor_data_offset as usize)
as *const _
},
n_bytes as u64,
MTLResourceOptions::StorageModeShared,
None,
);
vec![Tensor::new(MetalBuffer(buffer))]
});
}
}
}
q8_weights
Expand Down
9 changes: 5 additions & 4 deletions examples/llama/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ fn main() {
.collect();
cache_src.set_dyn(vec![], &[1, model::N_KV_HEADS, 0, model::HEAD_DIM]);
let model = model::MistralLM::initialize(&mut cx);
let mut model_weights = downstream(params(&model), &cx);
let mut model_weights = params(&model);
cx.keep_tensors(&model_weights);
let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::<Dyn<'t'>>));
let mut logits = logits
Expand All @@ -67,9 +67,9 @@ fn main() {
(
GenericCompiler::default(),
#[cfg(feature = "metal")]
luminal_metal::quantized::MetalQuantizedCompiler::<f32>::new(q_weights),
luminal_metal::quantized::MetalQuantizedCompiler::<f16>::new(q_weights),
#[cfg(feature = "cuda")]
luminal_cuda::CudaQuantizedCompiler::<f32>::new(q_weights),
luminal_cuda::CudaQuantizedCompiler::<f16>::new(q_weights),
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
luminal_cpu::CPUCompiler::default(),
),
Expand Down Expand Up @@ -97,7 +97,8 @@ fn main() {
println!("\t\t - {}ms", now.elapsed().as_millis());

// Now that weights are loaded, delete the loading nodes so they don't run again
delete_inputs(&model_weights, &mut cx);
delete_inputs(&downstream(model_weights, &cx), &mut cx);

// Run prompt processing pass
let mut input_ids = tokenizer
.encode(&cli_args.prompt as &str, false)
Expand Down
2 changes: 1 addition & 1 deletion examples/llama/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use luminal_nn::{Embedding, PermutedLinear, RMSNorm};
// Llama3 8B Config
pub const VOCAB_SIZE: usize = 128256;
pub const HIDDEN_DIM: usize = 4096;
pub const NUM_LAYERS: usize = 1;
pub const NUM_LAYERS: usize = 32;
pub const N_HEADS: usize = 32;
pub const N_KV_HEADS: usize = 8;
pub const MLP_DIM: usize = 14336;
Expand Down

0 comments on commit 1e5d639

Please sign in to comment.