Skip to content

Commit

Permalink
phi changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed May 3, 2024
1 parent dfd40ed commit 70bbd8d
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions examples/phi/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,32 +121,29 @@ fn main() {
input_ids.len()
);
delete_inputs(&cache_src, &mut cx);
let mut output_ids = vec![sample_index(&logits.data())];
let mut output_ids = vec![argmax(&logits.data())];
logits.drop();

// Decode token
print!("{}", cli_args.prompt.white().bold());
print!(
"{}",
tokenizer.decode(&output_ids, false).unwrap().bright_green()
);
let out_str = tokenizer.decode(&output_ids, false).unwrap();
let mut prev_output_len = out_str.len();
print!("{}", out_str.bright_green());
io::stdout().flush().unwrap();

// Swap caches
transfer_data_same_graph(&cache_dest, &cache_src, &mut cx);

// Decode loop
let start_decode = std::time::Instant::now();
let mut prev_output_len = 0;
for _ in 0..cli_args.gen_tokens {
input.set_dyn(vec![*output_ids.last().unwrap() as f32], &[1, 1]);
cx.set_dyn_dim('p', input_ids.len() + output_ids.len() - 1);
cx.set_dyn_dim('t', input_ids.len() + output_ids.len());
cx.execute();

// Sample tokens
let output_id = sample_index(&logits.data());
// println!("{:?}", &logits.data()[..10]);
let output_id = argmax(&logits.data());
logits.drop();
output_ids.push(output_id);

Expand Down Expand Up @@ -176,8 +173,7 @@ fn main() {
);
}

// Currently just an argmax, do actual sampling here
fn sample_index(dist: &[f32]) -> u32 {
fn argmax(dist: &[f32]) -> u32 {
dist.iter()
.position_max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap() as u32
Expand Down

0 comments on commit 70bbd8d

Please sign in to comment.