diff --git a/Cargo.toml b/Cargo.toml index 2eed0c1..d63fabe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,11 +5,20 @@ edition = "2021" rust-version = "1.57.0" [lib] -name = "_tiktoken" -crate-type = ["cdylib"] +name = "tiktoken" +crate-type = ["cdylib", "rlib"] + +[features] +default = [] +python = [ + "pyo3", +] [dependencies] -pyo3 = { version = "0.22.2", default-features = false, features = ["extension-module", "macros"] } +pyo3 = { version = "0.22.2", default-features = false, features = [ + "extension-module", + "macros", +], optional = true } # tiktoken dependencies fancy-regex = "0.13.0" diff --git a/pyproject.toml b/pyproject.toml index f2a29eb..9d861d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,10 +3,10 @@ name = "tiktoken" version = "0.8.0" description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" readme = "README.md" -license = {file = "LICENSE"} -authors = [{name = "Shantanu Jain"}, {email = "shantanu@openai.com"}] +license = { file = "LICENSE" } +authors = [{ name = "Shantanu Jain" }, { email = "shantanu@openai.com" }] dependencies = ["regex>=2022.1.18", "requests>=2.26.0"] -optional-dependencies = {blobfile = ["blobfile>=2"]} +optional-dependencies = { blobfile = ["blobfile>=2"] } requires-python = ">=3.9" [project.urls] @@ -43,3 +43,4 @@ test-command = "pytest {project}/tests --import-mode=append" [[tool.cibuildwheel.overrides]] select = "*linux_aarch64" test-command = """python -c 'import tiktoken; enc = tiktoken.get_encoding("gpt2"); assert enc.encode("hello world") == [31373, 995]'""" + diff --git a/setup.py b/setup.py index a22e8e5..2a3ebbf 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,7 @@ # Between our use of editable installs and wanting to use Rust for performance sensitive # code, it makes sense to just always use --release debug=False, + features=["python"], ) ], package_data={"tiktoken": ["py.typed"]}, diff --git a/src/lib.rs b/src/lib.rs index 0203acf..a64c3de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,19 +1,18 @@ -// This check is new and seems buggy (possibly with PyO3 interaction) -#![allow(clippy::borrow_deref_ref)] - +use std::borrow::Borrow; +use std::borrow::Cow; use std::collections::HashSet; use std::num::NonZeroU64; use std::thread; use fancy_regex::Regex; -use pyo3::exceptions; +#[cfg(feature = "python")] use pyo3::prelude::*; -use pyo3::pybacked::PyBackedStr; -use pyo3::types::{PyBytes, PyList, PyTuple}; -use pyo3::PyResult; use rustc_hash::FxHashMap as HashMap; -type Rank = u32; +#[cfg(feature = "python")] +mod py; + +pub type Rank = u32; fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { // This is a vector of (start, rank). @@ -132,7 +131,7 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, Rank>) -> V // The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made // to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster. -pub struct FakeThreadId(NonZeroU64); +struct FakeThreadId(NonZeroU64); fn hash_current_thread() -> usize { // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter @@ -148,8 +147,8 @@ fn hash_current_thread() -> usize { } #[derive(Debug, Clone)] -struct DecodeKeyError { - token: Rank, +pub struct DecodeKeyError { + pub token: Rank, } impl std::fmt::Display for DecodeKeyError { @@ -158,10 +157,26 @@ impl std::fmt::Display for DecodeKeyError { } } +impl std::error::Error for DecodeKeyError {} + +#[derive(Debug, Clone)] +pub struct DecodeError { + pub message: String, +} + +impl std::fmt::Display for DecodeError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Could not decode tokens: {}", self.message) + } +} + +impl std::error::Error for DecodeError {} + const MAX_NUM_THREADS: usize = 128; -#[pyclass] -struct CoreBPE { +#[cfg_attr(feature = "python", pyclass)] +#[derive(Clone)] +pub struct CoreBPE { encoder: HashMap, Rank>, special_tokens_encoder: HashMap, decoder: HashMap>, @@ -183,7 +198,10 @@ impl CoreBPE { &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] } - fn _decode_native(&self, tokens: &[Rank]) -> Result, DecodeKeyError> { + /// Decodes tokens into a list of bytes. + /// + /// The bytes are not gauranteed to be a valid utf-8 string. + fn decode_bytes(&self, tokens: &[Rank]) -> Result, DecodeKeyError> { let mut ret = Vec::with_capacity(tokens.len() * 2); for &token in tokens { let token_bytes = match self.decoder.get(&token) { @@ -198,7 +216,7 @@ impl CoreBPE { Ok(ret) } - fn _encode_ordinary_native(&self, text: &str) -> Vec { + pub fn encode_ordinary(&self, text: &str) -> Vec { // This is the core of the encoding logic; the other functions in here // just make things complicated :-) let regex = self._get_tl_regex(); @@ -213,7 +231,7 @@ impl CoreBPE { ret } - fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec, usize) { + pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec, usize) { let special_regex = self._get_tl_special_regex(); let regex = self._get_tl_regex(); let mut ret = vec![]; @@ -308,12 +326,12 @@ impl CoreBPE { (tokens, last_piece_token_len) } - fn _encode_unstable_native( + pub fn _encode_unstable_native( &self, text: &str, allowed_special: &HashSet<&str>, ) -> (Vec, HashSet>) { - let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special); + let (tokens, last_piece_token_len) = self.encode(text, allowed_special); if last_piece_token_len == 0 { // If last_piece_token_len is zero, the last token was a special token and we have // no unstable bytes @@ -323,7 +341,7 @@ impl CoreBPE { self._increase_last_piece_token_len(tokens, last_piece_token_len); let unstable_bytes = self - ._decode_native(&tokens[tokens.len() - last_piece_token_len..]) + .decode_bytes(&tokens[tokens.len() - last_piece_token_len..]) .unwrap(); tokens.truncate(tokens.len() - last_piece_token_len); @@ -372,7 +390,7 @@ impl CoreBPE { // So convert to UTF-8 and do regex splitting. // E.g. with cl100k_base " !" gets split to " " + " !", // but byte_pair_encode(" !") != byte_pair_encode(" ") - Ok(s) => self._encode_ordinary_native(s), + Ok(s) => self.encode_ordinary(s), // Technically, whether or not this arm is correct depends on whether there // would be a regex split before the UTF-8 truncation point. @@ -425,26 +443,37 @@ impl CoreBPE { (tokens, completions) } -} -#[pymethods] -impl CoreBPE { - #[new] - fn new( + pub fn new( + encoder: E, + special_tokens_encoder: SE, + pattern: &str, + ) -> Result> + where + E: IntoIterator, Rank)>, + SE: IntoIterator, + NSE: IntoIterator, + { + Self::new_internal( + HashMap::from_iter(encoder), + HashMap::from_iter(special_tokens_encoder), + pattern, + ) + } + + fn new_internal( encoder: HashMap, Rank>, special_tokens_encoder: HashMap, pattern: &str, - ) -> PyResult { - let regex = Regex::new(pattern) - .map_err(|e| PyErr::new::(e.to_string()))?; + ) -> Result> { + let regex = Regex::new(pattern)?; let special_regex = { - let _parts = special_tokens_encoder + let parts = special_tokens_encoder .keys() .map(|s| fancy_regex::escape(s)) .collect::>(); - Regex::new(&_parts.join("|")) - .map_err(|e| PyErr::new::(e.to_string()))? + Regex::new(&parts.join("|"))? }; let decoder: HashMap> = @@ -464,7 +493,7 @@ impl CoreBPE { let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); sorted_token_bytes.sort(); - Ok(CoreBPE { + Ok(Self { encoder, special_tokens_encoder, decoder, @@ -477,208 +506,22 @@ impl CoreBPE { }) } - // ==================== - // Encoding - // ==================== - - fn encode_ordinary(&self, py: Python, text: &str) -> Vec { - py.allow_threads(|| self._encode_ordinary_native(text)) - } - - fn encode(&self, py: Python, text: &str, allowed_special: HashSet) -> Vec { - py.allow_threads(|| { - let allowed_special: HashSet<&str> = - allowed_special.iter().map(|s| s.as_ref()).collect(); - self._encode_native(text, &allowed_special).0 - }) - } - - fn encode_to_tiktoken_buffer( - &self, - py: Python, - text: &str, - allowed_special: HashSet, - ) -> Py { - let tokens = py.allow_threads(|| { - let allowed_special: HashSet<&str> = - allowed_special.iter().map(|s| s.as_ref()).collect(); - self._encode_native(text, &allowed_special).0 - }); - let buffer = TiktokenBuffer { tokens }; - buffer.into_py(py) - } - - fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec { - py.allow_threads(|| { - match std::str::from_utf8(bytes) { - Ok(text) => self._encode_ordinary_native(text), - Err(e) => { - let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; - let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new()); - let (mut tokens, last_piece_token_len) = - self._increase_last_piece_token_len(tokens, last_piece_token_len); - if !tokens.is_empty() && last_piece_token_len > 0 { - // Lop off the tokens from the last piece and run BPE on the remaining bytes - // Somewhat niche, but this may not be correct if we'd have had a regex - // split between the valid UTF-8 and the invalid bytes, which is why this - // method is private - let mut unstable_bytes = self - ._decode_native(&tokens[tokens.len() - last_piece_token_len..]) - .unwrap(); - unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); - - tokens.truncate(tokens.len() - last_piece_token_len); - match self.encoder.get(&unstable_bytes) { - Some(token) => tokens.push(*token), - None => { - tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder)) - } - } - } - tokens - } - } - }) - } - - fn encode_with_unstable( - &self, - py: Python, - text: &str, - allowed_special: HashSet, - ) -> Py { - let (tokens, completions) = py.allow_threads(|| { - let allowed_special: HashSet<&str> = - allowed_special.iter().map(|s| s.as_ref()).collect(); - self._encode_unstable_native(text, &allowed_special) - }); - let py_completions = PyList::new_bound( - py, - completions - .iter() - .map(|seq| PyList::new_bound(py, &seq[..])), - ); - (tokens, py_completions).into_py(py) - } - - fn encode_single_token(&self, piece: &[u8]) -> PyResult { - if let Some(token) = self.encoder.get(piece).copied() { - return Ok(token); - } - if let Ok(piece_str) = std::str::from_utf8(piece) { - if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() { - return Ok(token); - } - } - Err(PyErr::new::(piece.to_owned())) - } - - fn encode_single_piece(&self, piece: &[u8]) -> Vec { - if let Some(token) = self.encoder.get(piece) { - return vec![*token]; - } - byte_pair_encode(piece, &self.encoder) - } - - // ==================== - // Decoding - // ==================== - - fn decode_bytes(&self, py: Python, tokens: Vec) -> Result, PyErr> { - match py.allow_threads(|| self._decode_native(&tokens)) { - Ok(bytes) => Ok(PyBytes::new_bound(py, &bytes).into()), - Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))), - } - } - - fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult> { - if let Some(bytes) = self.decoder.get(&token) { - return Ok(PyBytes::new_bound(py, bytes).into()); - } - if let Some(bytes) = self.special_tokens_decoder.get(&token) { - return Ok(PyBytes::new_bound(py, bytes).into()); - } - Err(PyErr::new::(token.to_string())) - } - - // ==================== - // Miscellaneous - // ==================== - - fn token_byte_values(&self, py: Python) -> Vec> { - self.sorted_token_bytes - .iter() - .map(|x| PyBytes::new_bound(py, x).into()) + pub fn special_tokens(&self) -> HashSet<&str> { + self.special_tokens_encoder + .keys() + .map(|s| s.as_str()) .collect() } -} - -#[pyclass] -struct TiktokenBuffer { - tokens: Vec, -} -#[pymethods] -impl TiktokenBuffer { - // Based on https://github.com/PyO3/pyo3/blob/v0.22.2/tests/test_buffer_protocol.rs#L25 - unsafe fn __getbuffer__( - slf: Bound<'_, Self>, - view: *mut pyo3::ffi::Py_buffer, - flags: std::os::raw::c_int, - ) -> PyResult<()> { - if view.is_null() { - return Err(pyo3::exceptions::PyBufferError::new_err("View is null")); - } - if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE { - return Err(pyo3::exceptions::PyBufferError::new_err( - "Object is not writable", - )); - } - - (*view).obj = slf.clone().into_any().into_ptr(); - - let data = &slf.borrow().tokens; - (*view).buf = data.as_ptr() as *mut std::os::raw::c_void; - (*view).len = (data.len() * std::mem::size_of::()) as isize; - (*view).readonly = 1; - (*view).itemsize = std::mem::size_of::() as isize; - (*view).format = if (flags & pyo3::ffi::PyBUF_FORMAT) == pyo3::ffi::PyBUF_FORMAT { - let msg = std::ffi::CString::new("I").unwrap(); - msg.into_raw() - } else { - std::ptr::null_mut() - }; - (*view).ndim = 1; - (*view).shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND { - &mut (*view).len - } else { - std::ptr::null_mut() - }; - (*view).strides = if (flags & pyo3::ffi::PyBUF_STRIDES) == pyo3::ffi::PyBUF_STRIDES { - &mut (*view).itemsize - } else { - std::ptr::null_mut() - }; - (*view).suboffsets = std::ptr::null_mut(); - (*view).internal = std::ptr::null_mut(); - - Ok(()) - } - - unsafe fn __releasebuffer__(&self, view: *mut pyo3::ffi::Py_buffer) { - std::mem::drop(std::ffi::CString::from_raw((*view).format)); + pub fn encode_with_special_tokens(&self, text: &str) -> Vec { + let allowed_special = self.special_tokens(); + self.encode(text, &allowed_special).0 } } -#[pymodule] -fn _tiktoken(_py: Python, m: &Bound) -> PyResult<()> { - m.add_class::()?; - Ok(()) -} - #[cfg(test)] mod tests { - + use fancy_regex::Regex; use rustc_hash::FxHashMap as HashMap; use crate::{byte_pair_split, Rank}; diff --git a/src/py.rs b/src/py.rs new file mode 100644 index 0000000..8485462 --- /dev/null +++ b/src/py.rs @@ -0,0 +1,236 @@ +use std::collections::HashSet; + +use pyo3::{ + exceptions, + prelude::*, + pybacked::PyBackedStr, + types::{PyBytes, PyList, PyTuple}, + PyResult, +}; +use rustc_hash::FxHashMap as HashMap; + +use crate::{byte_pair_encode, CoreBPE, Rank}; + +#[pymethods] +impl CoreBPE { + #[new] + fn py_new( + encoder: HashMap, Rank>, + special_tokens_encoder: HashMap, + pattern: &str, + ) -> PyResult { + Self::new_internal( + encoder, + special_tokens_encoder, + pattern, + ) + .map_err(|e| PyErr::new::(e.to_string())) + } + + // ==================== + // Encoding + // ==================== + + #[pyo3(name = "encode_ordinary")] + fn py_encode_ordinary(&self, py: Python, text: &str) -> Vec { + py.allow_threads(|| self.encode_ordinary(text)) + } + + #[pyo3(name = "encode")] + fn py_encode( + &self, + py: Python, + text: &str, + allowed_special: HashSet, + ) -> Vec { + py.allow_threads(|| { + let allowed_special: HashSet<&str> = + allowed_special.iter().map(|s| s.as_ref()).collect(); + self.encode(text, &allowed_special).0 + }) + } + + fn encode_to_tiktoken_buffer( + &self, + py: Python, + text: &str, + allowed_special: HashSet, + ) -> Py { + let tokens = py.allow_threads(|| { + let allowed_special: HashSet<&str> = + allowed_special.iter().map(|s| s.as_ref()).collect(); + self.encode(text, &allowed_special).0 + }); + let buffer = TiktokenBuffer { tokens }; + buffer.into_py(py) + } + + fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec { + py.allow_threads(|| { + match std::str::from_utf8(bytes) { + Ok(text) => self.encode_ordinary(text), + Err(e) => { + let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; + let (tokens, last_piece_token_len) = self.encode(text, &HashSet::new()); + let (mut tokens, last_piece_token_len) = + self._increase_last_piece_token_len(tokens, last_piece_token_len); + if !tokens.is_empty() && last_piece_token_len > 0 { + // Lop off the tokens from the last piece and run BPE on the remaining bytes + // Somewhat niche, but this may not be correct if we'd have had a regex + // split between the valid UTF-8 and the invalid bytes, which is why this + // method is private + let mut unstable_bytes = self + .decode_bytes(&tokens[tokens.len() - last_piece_token_len..]) + .unwrap(); + unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); + + tokens.truncate(tokens.len() - last_piece_token_len); + match self.encoder.get(&unstable_bytes) { + Some(token) => tokens.push(*token), + None => { + tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder)) + } + } + } + tokens + } + } + }) + } + + #[pyo3(name = "encode_with_unstable")] + fn py_encode_with_unstable( + &self, + py: Python, + text: &str, + allowed_special: HashSet, + ) -> Py { + let (tokens, completions) = py.allow_threads(|| { + let allowed_special: HashSet<&str> = + allowed_special.iter().map(|s| s.as_ref()).collect(); + self._encode_unstable_native(text, &allowed_special) + }); + let py_completions = PyList::new_bound( + py, + completions + .iter() + .map(|seq| PyList::new_bound(py, &seq[..])), + ); + (tokens, py_completions).into_py(py) + } + + fn encode_single_token(&self, piece: &[u8]) -> PyResult { + if let Some(token) = self.encoder.get(piece).copied() { + return Ok(token); + } + if let Ok(piece_str) = std::str::from_utf8(piece) { + if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() { + return Ok(token); + } + } + Err(PyErr::new::(piece.to_owned())) + } + + fn encode_single_piece(&self, piece: &[u8]) -> Vec { + if let Some(token) = self.encoder.get(piece) { + return vec![*token]; + } + byte_pair_encode(piece, &self.encoder) + } + + // ==================== + // Decoding + // ==================== + + #[pyo3(name = "decode_bytes")] + fn py_decode_bytes(&self, py: Python, tokens: Vec) -> Result, PyErr> { + match py.allow_threads(|| self.decode_bytes(&tokens)) { + Ok(bytes) => Ok(PyBytes::new_bound(py, &bytes).into()), + Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))), + } + } + + fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult> { + if let Some(bytes) = self.decoder.get(&token) { + return Ok(PyBytes::new_bound(py, bytes).into()); + } + if let Some(bytes) = self.special_tokens_decoder.get(&token) { + return Ok(PyBytes::new_bound(py, bytes).into()); + } + Err(PyErr::new::(token.to_string())) + } + + // ==================== + // Miscellaneous + // ==================== + + fn token_byte_values(&self, py: Python) -> Vec> { + self.sorted_token_bytes + .iter() + .map(|x| PyBytes::new_bound(py, x).into()) + .collect() + } +} + +#[pyclass] +struct TiktokenBuffer { + tokens: Vec, +} + +#[pymethods] +impl TiktokenBuffer { + // Based on https://github.com/PyO3/pyo3/blob/v0.22.2/tests/test_buffer_protocol.rs#L25 + unsafe fn __getbuffer__( + slf: Bound<'_, Self>, + view: *mut pyo3::ffi::Py_buffer, + flags: std::os::raw::c_int, + ) -> PyResult<()> { + if view.is_null() { + return Err(pyo3::exceptions::PyBufferError::new_err("View is null")); + } + if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE { + return Err(pyo3::exceptions::PyBufferError::new_err( + "Object is not writable", + )); + } + + (*view).obj = slf.clone().into_any().into_ptr(); + + let data = &slf.borrow().tokens; + (*view).buf = data.as_ptr() as *mut std::os::raw::c_void; + (*view).len = (data.len() * std::mem::size_of::()) as isize; + (*view).readonly = 1; + (*view).itemsize = std::mem::size_of::() as isize; + (*view).format = if (flags & pyo3::ffi::PyBUF_FORMAT) == pyo3::ffi::PyBUF_FORMAT { + let msg = std::ffi::CString::new("I").unwrap(); + msg.into_raw() + } else { + std::ptr::null_mut() + }; + (*view).ndim = 1; + (*view).shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND { + &mut (*view).len + } else { + std::ptr::null_mut() + }; + (*view).strides = if (flags & pyo3::ffi::PyBUF_STRIDES) == pyo3::ffi::PyBUF_STRIDES { + &mut (*view).itemsize + } else { + std::ptr::null_mut() + }; + (*view).suboffsets = std::ptr::null_mut(); + (*view).internal = std::ptr::null_mut(); + + Ok(()) + } + + unsafe fn __releasebuffer__(&self, view: *mut pyo3::ffi::Py_buffer) { + std::mem::drop(std::ffi::CString::from_raw((*view).format)); + } +} + +#[pymodule] +fn _tiktoken(_py: Python, m: &Bound) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} diff --git a/tiktoken/load.py b/tiktoken/load.py index 8434c23..5686528 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -2,12 +2,7 @@ import base64 import hashlib -import json import os -import tempfile -import uuid - -import requests def read_file(blobpath: str) -> bytes: @@ -20,7 +15,10 @@ def read_file(blobpath: str) -> bytes: ) from e with blobfile.BlobFile(blobpath, "rb") as f: return f.read() + # avoiding blobfile for public files helps avoid auth issues, like MFA prompts + import requests + resp = requests.get(blobpath) resp.raise_for_status() return resp.content @@ -38,6 +36,8 @@ def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes: elif "DATA_GYM_CACHE_DIR" in os.environ: cache_dir = os.environ["DATA_GYM_CACHE_DIR"] else: + import tempfile + cache_dir = os.path.join(tempfile.gettempdir(), "data-gym-cache") user_specified_cache = False @@ -67,6 +67,8 @@ def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes: f"This may indicate a corrupted download. Please try again." ) + import uuid + try: os.makedirs(cache_dir, exist_ok=True) tmp_filename = cache_path + "." + str(uuid.uuid4()) + ".tmp" @@ -114,6 +116,8 @@ def decode_data_gym(value: str) -> bytes: bpe_ranks[decode_data_gym(first) + decode_data_gym(second)] = n n += 1 + import json + # check that the encoder file matches the merges file # this sanity check is important since tiktoken assumes that ranks are ordered the same # as merge priority @@ -142,7 +146,13 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None) -> dict[bytes, int]: # NB: do not add caching to this function contents = read_file_cached(tiktoken_bpe_file, expected_hash) - return { - base64.b64decode(token): int(rank) - for token, rank in (line.split() for line in contents.splitlines() if line) - } + ret = {} + for line in contents.splitlines(): + if not line: + continue + try: + token, rank = line.split() + ret[base64.b64decode(token)] = int(rank) + except Exception as e: + raise ValueError(f"Error parsing line {line} in {tiktoken_bpe_file}") from e + return ret