Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jan 28, 2024
1 parent 56de7fa commit 162859d
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 38 deletions.
8 changes: 3 additions & 5 deletions examples/mistral/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ pub struct CLIArgs {

fn main() {
let cli_args = CLIArgs::parse();
let prompt = cli_args.prompt.as_str();
let tokens_to_generate = cli_args.gen_tokens;

let tokenizer = SentencePieceBpeTokenizer::from_file(
"./examples/mistral/setup/mistral-7b-hf/tokenizer.model",
Expand Down Expand Up @@ -137,7 +135,7 @@ fn main() {
delete_inputs(&model_weights, &mut cx1);

// Run inference first pass
let mut input_ids = encode(&tokenizer, prompt);
let mut input_ids = encode(&tokenizer, &cli_args.prompt);

input.set_dyn(
input_ids.iter().map(|i| *i as f32).collect::<Vec<_>>(),
Expand All @@ -158,7 +156,7 @@ fn main() {
// Decode token
print!(
"{}{}",
prompt.white().bold(),
cli_args.prompt.white().bold(),
decode(&tokenizer, &[output_id]).bright_green()
);
io::stdout().flush().unwrap();
Expand All @@ -170,7 +168,7 @@ fn main() {

// Decode loop
let mut token_decode_times = vec![];
for _ in 0..tokens_to_generate {
for _ in 0..cli_args.gen_tokens {
single_input.set(vec![*input_ids.last().unwrap() as f32]);
cx2.set_dyn_dim('p', input_ids.len() - 1);
cx2.set_dyn_dim('t', input_ids.len());
Expand Down
4 changes: 2 additions & 2 deletions examples/mistral/model.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{marker::PhantomData, ops::Mul};
use std::{marker::PhantomData, ops::Div};

use luminal::{
nn::{embedding::Embedding, norm::RMSNorm},
Expand Down Expand Up @@ -144,7 +144,7 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
let mut attention_weights = queries
.reshape::<(_, Const<N_KV_HEADS>, Const<N_ATTENTION_GROUPS>, _, _)>() // Split query heads into groups
.matmul(repeated_keys.permute())
.mul((HEAD_DIM as f32).sqrt().recip());
.div((HEAD_DIM as f32).sqrt());

// We only mask on a non-kv cache pass
if cache.is_none() {
Expand Down
38 changes: 21 additions & 17 deletions src/compilers/metal/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,30 @@ impl<T> MetalKernel for Matmul<T> {
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let (a_shape, b_shape) = (inputs[0].1.shape(), inputs[1].1.shape());
let (a_shape, b_shape) = (
inputs[0]
.1
.shape()
.into_iter()
.map(|i| i.to_usize().unwrap())
.collect::<Vec<_>>(),
inputs[1]
.1
.shape()
.into_iter()
.map(|i| i.to_usize().unwrap())
.collect::<Vec<_>>(),
);
let a_dims = a_shape.len();
let m = a_shape[a_dims - 2].to_usize().unwrap();
let batch_size = a_shape
.iter()
.take(a_dims - 2)
.map(|i| i.to_usize().unwrap())
.product::<usize>()
.max(1);
let m = a_shape[a_dims - 2];
let batch_size = a_shape.iter().take(a_dims - 2).product::<usize>().max(1);
// if m == 1 && a_shape.len() > 2 {
// m *= a_shape[a_shape.len() - 3].to_usize().unwrap();
// m *= a_shape[a_shape.len() - 3];
// batch_size /= m;
// }
let b_dims = b_shape.len();
let k = b_shape[b_dims - 2].to_usize().unwrap();
let n = b_shape[b_dims - 1].to_usize().unwrap();
let k = b_shape[b_dims - 2];
let n = b_shape[b_dims - 1];

let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
Expand Down Expand Up @@ -106,12 +114,8 @@ impl<T> MetalKernel for Matmul<T> {
.any(|i| !inputs[1].1.fake[*i])
// At least one non-fake dimension before 3rd to last
{
encoder.set_i32(
8,
inputs[1].1.shape()[inputs[1].1.len() - 3]
.to_usize()
.unwrap() as i32,
); // B batch size 2
encoder.set_i32(8, b_shape[inputs[1].1.len() - 3] as i32);
// B batch size 2
} else {
encoder.set_i32(8, 1 as i32); // B batch size
}
Expand Down
6 changes: 4 additions & 2 deletions src/compilers/metal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,12 @@ pub trait MetalKernel: Debug {
.output_buffer_sizes(&inp_shapes)
.into_iter()
.map(|n| {
dev.new_buffer(
let b = dev.new_buffer(
n.exec(dyn_map).unwrap() as u64,
MTLResourceOptions::StorageModeShared,
)
);
// println!("Allocated {} bytes", n.exec(dyn_map).unwrap());
b
})
.collect::<Vec<_>>();
let output_buffers_ref = output_buffers.iter().collect::<Vec<_>>();
Expand Down
6 changes: 0 additions & 6 deletions src/compilers/metal/prim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,8 @@ impl<T: MetalFloat> Operator for MetalCopyToDevice<T> {
data_ptr,
(data_len * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModeShared,
// This causes a double free, so I guess metal frees it?
// Some(&ConcreteBlock::new(|_, _| {
// let data = unsafe { Vec::from_raw_parts(data_ptr, data_len, data_len) };
// drop(data);
// })),
None,
);
data.leak(); // Is this ok? I don't know if metal frees the data once the buffer is discarded
vec![Tensor {
data: Box::new(buffer),
}]
Expand Down
16 changes: 10 additions & 6 deletions src/compilers/metal/storage_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ impl Compiler for StorageBufferCompiler {
if let Some((buffer_index, source_node, _)) = first_pass[&node]
.1
.iter()
.filter(|i| !graph.no_delete.contains(i))
.filter(|i| !used.contains(i))
.filter(|i| available_buffers.contains_key(i))
.flat_map(|i| {
Expand Down Expand Up @@ -171,6 +172,7 @@ impl Compiler for StorageBufferCompiler {
if let Some((buffer_index, source_node, _)) = first_pass[&node]
.1
.iter()
.filter(|i| !graph.no_delete.contains(i))
.filter(|i| !used.contains(i))
.filter(|i| available_buffers.contains_key(i))
.flat_map(|i| {
Expand Down Expand Up @@ -273,26 +275,28 @@ impl Operator for AllocateMetalBuffers {
let dyn_map = unsafe { self.dyn_map.as_ref().unwrap() };
// Allocate all buffers
if buffers.is_empty() {
let mut dyn_map = dyn_map.clone();
dyn_map.insert('t', 1000); // Shouldn't be here, just for debug
*buffers = self
.buffer_sizes
.iter()
.map(|e| {
self.dev.new_buffer(
e.exec(&dyn_map).unwrap() as u64,
e.exec(dyn_map).unwrap() as u64,
MTLResourceOptions::StorageModeShared,
)
})
.collect();
} else if !dyn_map.contains_key(&'t') {
} else {
for (size, buffer) in self.buffer_sizes.iter().zip(buffers) {
let size = size.exec(dyn_map).unwrap() as u64;
if buffer.length() < size {
buffer.set_purgeable_state(metal_rs::MTLPurgeableState::Empty);
// Similar allocation strategy to Rust's Vec
let mut length = buffer.length();
while length < size {
length *= 2;
}
*buffer = self
.dev
.new_buffer(size, MTLResourceOptions::StorageModeShared);
.new_buffer(length, MTLResourceOptions::StorageModeShared);
}
}
}
Expand Down

0 comments on commit 162859d

Please sign in to comment.