From b4c687ef3625e1737fba4f6643d7bedb9d6d2b6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Mon, 15 Jan 2024 21:26:25 +0100 Subject: [PATCH] Avoid calling byte_pair_encode for existing tokens This was byte_pair_encode can be optimized further, assuming we'll always have at least 2 tokens --- src/lib.rs | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index fc97503d..6fb2e0e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -162,10 +162,10 @@ fn hash_current_thread() -> usize { // that works great for our use case of avoiding collisions in our array. Unfortunately, // it's private. However, there are only so many ways you can layout a u64, so just transmute // https://github.com/rust-lang/rust/issues/67939 - const _: [u8; 8] = [0; std::mem::size_of::()]; + const _: [u8; 8] = [0; std::mem::size_of::()]; const _: [u8; 8] = [0; std::mem::size_of::()]; let x = unsafe { - std::mem::transmute::(thread::current().id()).0 + std::mem::transmute::(thread::current().id()).0 }; u64::from(x) as usize } @@ -214,11 +214,10 @@ impl CoreBPE { let mut ret = vec![]; for mat in regex.find_iter(text) { let piece = mat.unwrap().as_str().as_bytes(); - if let Some(token) = self.encoder.get(piece) { - ret.push(*token); - continue; + match self.encoder.get(piece) { + Some(token) => ret.push(*token), + None => ret.extend(&byte_pair_encode(piece, &self.encoder)), } - ret.extend(&byte_pair_encode(piece, &self.encoder)); } ret } @@ -516,7 +515,10 @@ impl CoreBPE { unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); tokens.truncate(tokens.len() - last_piece_token_len); - tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder)); + match self.encoder.get(&unstable_bytes) { + Some(token) => tokens.push(*token), + None => tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder)), + } } tokens } @@ -597,15 +599,26 @@ fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { mod tests { use rustc_hash::FxHashMap as HashMap; - use crate::byte_pair_split; + use crate::{byte_pair_split, Rank}; - #[test] - fn very_simple_test() { - let mut ranks = HashMap::default(); - ranks.insert(b"ab".to_vec(), 1); - ranks.insert(b"cd".to_vec(), 2); + fn setup_ranks() -> HashMap, Rank> { + HashMap::from_iter([ + (b"ab".to_vec(), 0), + (b"cd".to_vec(), 1), + ]) + } + #[test] + fn test_simple_characters() { + let ranks = setup_ranks(); let res = byte_pair_split(b"abcd", &ranks); assert_eq!(res, vec![b"ab", b"cd"]); } + + #[test] + fn test_repeated_characters() { + let ranks = setup_ranks(); + let res = byte_pair_split(b"abab", &ranks); + assert_eq!(res, vec![b"ab", b"ab"]); + } }