Skip to content

Commit

Permalink
Small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jan 23, 2024
1 parent c64e408 commit 791f139
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 25 deletions.
22 changes: 22 additions & 0 deletions examples/mistral/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,28 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
.permute::<_, Axes4<0, 2, 1, 3>>();

// Apply the Rotary Embeddings
// Get embedding
// let freqs =
// (key_states.graph().arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
// let freqs = freqs.inv_pow(1000000.0).recip();
// let t = key_states.graph().arange::<CurSeq>()
// + key_states
// .graph()
// .constant_expr(PrevSeq::const_size().into())
// .expand();
// let freqs = t.expand::<(_, Const<1>), _>().matmul(freqs.expand());
// let emb = freqs.concat_along::<(CurSeq, Const<HEAD_DIM>), Axis<1>, _>(freqs);

// // Rotate input
// let x1 = key_states.slice((.., .., .., ..Expression::from(HEAD_DIM_OVER_2)));
// let x2 = key_states.slice((.., .., .., Expression::from(HEAD_DIM_OVER_2)..));
// let rotated_input = (-x2).concat_along::<(_, _, _, Const<HEAD_DIM>), Axis<3>, _>(x1);
// if cache.is_some() {
// rotated_input.print("");
// }

// Final calculation
// let key_states = rotated_input * emb.sin().expand() + key_states * emb.cos().expand();
let query_states = apply_rotary_embeddings(query_states, PrevSeq::const_size().into());
let key_states = apply_rotary_embeddings(key_states, PrevSeq::const_size().into());

Expand Down
30 changes: 30 additions & 0 deletions src/compilers/metal/tests/fp16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,36 @@ fn test_rotate() {
assert_close(&unopt, &rotated_a.data());
}

#[test]
fn test_rotate2() {
let mut cx = Graph::new();
const D: usize = 2;
const S: usize = 2;
const H: usize = 2;
let data = random_vec(D * S * H);
let a = cx
.tensor::<R4<1, D, S, H>>()
.set(data.clone())
.keep()
.permute::<_, LAxes4<0, 2, 1, 3>>();
let x1 = a.slice((.., .., .., ..Expression::from(H / 2)));
let x2 = a.slice((.., .., .., Expression::from(H / 2)..));
let mut rotated_a = (-x2)
.concat_along::<R4<1, S, D, H>, LAxis<3>, _>(x1)
.retrieve();
cx.execute();
let unopt = rotated_a.data();

cx.compile(MetalCompiler::<f16>::default(), &mut rotated_a);
cx.execute();

println!("data: {:?}", data);
println!("unopt: {:?}", unopt);
println!("rotated_a: {:?}", rotated_a.data());

assert_close(&unopt, &rotated_a.data());
}

#[test]
fn test_constant() {
let mut cx = Graph::new();
Expand Down
6 changes: 2 additions & 4 deletions src/compilers/metal/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ using namespace metal;
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device uint& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]) {{
if (idx < n_elements) {{
if ((idx % {axis_size}) < {half_size}) {{
out[idx] = -inp[idx + {half_size}];
out[idx] = ({type_name})-(float)inp[idx + {half_size}];
}} else {{
out[idx] = inp[idx - {half_size}];
}}
Expand All @@ -1086,9 +1086,7 @@ impl<T> MetalKernel for MetalRotate<T> {
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let mut sh = inputs[0].1;
sh.remove_dim(3);
let n_elements = sh.n_elements().to_usize().unwrap() * self.axis_size;
let n_elements = inputs[0].1.n_physical_elements().to_usize().unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_buffer(0, Some(inputs[0].0), 0);
Expand Down
42 changes: 21 additions & 21 deletions src/core/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,34 +82,34 @@ impl Operator for Print {
.as_any()
.downcast_ref::<Vec<f32>>()
.unwrap();
println!("{} Data: {:?}", i + 1, &d[..10.min(d.len())]);
println!("{} Data: {:?}", i + 1, &d[d.len().saturating_sub(10)..]);
println!("{} Shape: {:?}", i + 1, tracker);
// let mut data = vec![0.; d.len()];
// let (ind, val) = (tracker.index_expression(), tracker.valid_expression());
// #[allow(unused_mut)]
// for (i, mut r) in data.iter_mut().enumerate() {
// if val.exec_single_var(i) != 0 {
// *r = d[ind.exec_single_var(i)];
// }
// }
let mut data = vec![0.; d.len()];
let (ind, val) = (tracker.index_expression(), tracker.valid_expression());
#[allow(unused_mut)]
for (i, mut r) in data.iter_mut().enumerate() {
if val.exec_single_var(i) != 0 {
*r = d[ind.exec_single_var(i)];
}
}
// std::fs::write(
// "../../Desktop/llama-dfdx/out.bin",
// "../../Desktop/out.bin",
// data.iter()
// .flat_map(|i| i.to_ne_bytes())
// .collect::<Vec<_>>(),
// )
// .unwrap();
// let out = std::fs::read("../../Desktop/llama-dfdx/out.bin")
// .unwrap()
// .chunks(4)
// .map(|i| f32::from_ne_bytes([i[0], i[1], i[2], i[3]]))
// .collect::<Vec<_>>();
// assert_eq!(data.len(), out.len(), "Number of elements doesn't match");
// for (i, (a, b)) in data.iter().zip(out.iter()).enumerate() {
// if *a != *b {
// panic!("{} is not equal to {}, index {i}", *a, *b);
// }
// }
let out = std::fs::read("../../Desktop/out.bin")
.unwrap()
.chunks(4)
.map(|i| f32::from_ne_bytes([i[0], i[1], i[2], i[3]]))
.collect::<Vec<_>>();
assert_eq!(data.len(), out.len(), "Number of elements doesn't match");
for (i, (a, b)) in data.iter().zip(out.iter()).enumerate() {
if *a != *b {
panic!("{} is not equal to {}, index {i}", *a, *b);
}
}
}
vec![Tensor {
data: Box::<Vec<f32>>::default(),
Expand Down

0 comments on commit 791f139

Please sign in to comment.