Skip to content

Commit

Permalink
Sync codebase (#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
hauntsaninja authored Mar 9, 2025
1 parent e35ab09 commit 4560a88
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-22.04-arm]
os: [ubuntu-24.04-arm]
python-version: [39, 310, 311, 312, 313]

steps:
Expand All @@ -55,7 +55,7 @@ jobs:

- uses: actions/upload-artifact@v4
with:
name: cibw-wheelsaarch64-${{ matrix.os }}-${{ strategy.job-index }}
name: cibw-wheels-aarch64-${{ matrix.os }}-${{ strategy.job-index }}
path: ./wheelhouse/*.whl

build_sdist:
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ impl CoreBPE {
}
let end = next_special.map_or(text.len(), |m| m.start());

// Okay, here we go, compare this logic to _encode_ordinary_native
// Okay, here we go, compare this logic to encode_ordinary
for mat in regex.find_iter(&text[start..end]) {
let piece = mat.unwrap().as_str().as_bytes();
if let Some(token) = self.encoder.get(piece) {
Expand Down Expand Up @@ -398,7 +398,7 @@ impl CoreBPE {
// notice all the big holes in the previous unstable token implementation)
Err(_) => byte_pair_encode(&possibility, &self.encoder),
// Something like the following is intriguing but incorrect:
// Err(e) => self._encode_ordinary_native(unsafe {
// Err(e) => self.encode_ordinary(unsafe {
// std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
// }),
};
Expand Down
19 changes: 15 additions & 4 deletions src/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,34 @@ impl CoreBPE {
fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> {
py.allow_threads(|| {
match std::str::from_utf8(bytes) {
// Straightforward case
Ok(text) => self.encode_ordinary(text),
// Oops, don't actually have UTF-8. But we need to do the regex splitting in
// Unicode space, so we make our best guess at where we would have splits
Err(e) => {
let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
let (tokens, last_piece_token_len) = self.encode(text, &HashSet::new());
let (mut tokens, last_piece_token_len) =
self._increase_last_piece_token_len(tokens, last_piece_token_len);

let mut unstable_bytes;
if !tokens.is_empty() && last_piece_token_len > 0 {
// Lop off the tokens from the last piece and run BPE on the remaining bytes
// Somewhat niche, but this may not be correct if we'd have had a regex
// split between the valid UTF-8 and the invalid bytes, which is why this
// method is private
let mut unstable_bytes = self
// This likely matches what models see better, e.g. if you assume we're
// dealing with truncated UTF-8 bytes.
// Niche, but note this may not be correct if we'd have had a regex
// split between the valid UTF-8 and the invalid bytes.
unstable_bytes = self
.decode_bytes(&tokens[tokens.len() - last_piece_token_len..])
.unwrap();
unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);

tokens.truncate(tokens.len() - last_piece_token_len);
} else {
unstable_bytes = bytes[e.valid_up_to()..].to_vec();
}

if !unstable_bytes.is_empty() {
match self.encoder.get(&unstable_bytes) {
Some(token) => tokens.push(*token),
None => {
Expand Down
11 changes: 11 additions & 0 deletions tests/test_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ def test_encode_empty():
def test_encode_bytes():
enc = tiktoken.get_encoding("cl100k_base")
assert enc._encode_bytes(b" \xec\x8b\xa4\xed") == [62085]
for i in range(10):
bytestring = b"\x80" * i
assert enc.decode_bytes(enc._encode_bytes(bytestring)) == bytestring


@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
@hypothesis.given(bytestring=st.binary())
@hypothesis.settings(deadline=None)
def test_hyp_encode_bytes(make_enc: Callable[[], tiktoken.Encoding], bytestring: bytes):
enc = make_enc()
assert enc.decode_bytes(enc._encode_bytes(bytestring)) == bytestring


def test_encode_surrogate_pairs():
Expand Down

0 comments on commit 4560a88

Please sign in to comment.