Skip to content

Commit

Permalink
subquadratic
Browse files Browse the repository at this point in the history
  • Loading branch information
hauntsaninja committed Feb 14, 2025
1 parent bb5805d commit 0a46366
Showing 1 changed file with 137 additions and 10 deletions.
147 changes: 137 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::borrow::Borrow;
use std::borrow::Cow;
use std::collections::HashSet;
use std::num::NonZeroU64;
use std::thread;
Expand All @@ -14,6 +12,131 @@ mod py;

pub type Rank = u32;

use std::collections::BinaryHeap;

#[derive(Eq, PartialEq, Clone, Copy)]
struct Merge {
start: usize,
rank: Rank,
}

impl Ord for Merge {
#[inline]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other
.rank
.cmp(&self.rank)
.then_with(|| other.start.cmp(&self.start))
}
}

impl PartialOrd for Merge {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

struct State {
prev: usize,
end: usize,
next_end: usize,
next_rank: Rank,
cur_rank: Rank,
}

fn _byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<Rank> {
let mut state = Vec::with_capacity(piece.len());
state.push(State {
prev: usize::MAX,
end: 1,
next_end: 2,
next_rank: Rank::MAX,
cur_rank: Rank::MAX,
});

let mut heap = BinaryHeap::with_capacity(piece.len());
for i in 0..piece.len() - 1 {
if let Some(&rank) = ranks.get(&piece[i..i + 2]) {
heap.push(Merge { start: i, rank });
state[i].next_rank = rank;
}
// note this is happening offset by 1
state.push(State {
prev: i,
end: i + 2,
next_end: i + 3,
next_rank: Rank::MAX,
cur_rank: Rank::MAX,
});
}

// Repeatedly find the valid merge with smallest rank. We merge the (left) token that
// starts at `start` and ends at `state[start].end` with the (right) token that starts at
// `state[start].end` and ends at `state[start].next_end`. We invalidate the old merges
// (the ones that started at `state[start].end` and ended at `state[start]`) and add the two
// new potential merges to the heap.

let potential_merge = {
#[inline(always)]
|state: &mut Vec<State>,
heap: &mut BinaryHeap<Merge>,
start: usize,
next_end_item: usize| {
state[start].next_end = next_end_item;
state[start].next_rank = Rank::MAX; // Always invalidate the old merge
if next_end_item <= piece.len() {
if let Some(&rank) = ranks.get(&piece[start..next_end_item]) {
// We have a valid potential merge!
heap.push(Merge { start, rank });
state[start].next_rank = rank;
}
}
}
};

while let Some(left) = heap.pop() {
if left.rank == Rank::MAX {
break;
}
if left.rank != state[left.start].next_rank {
continue; // This merge was invalidated, ignore it
}

let left_start = left.start;
let right_start = state[left_start].end;
let right_end = state[left_start].next_end;
debug_assert!(right_end == state[right_start].end);
let right_next_end = state[right_start].next_end;

// Merge left and right into a single token
state[left_start].cur_rank = state[left_start].next_rank;
state[left_start].end = right_end;
potential_merge(&mut state, &mut heap, left_start, right_next_end);
if right_end < state.len() {
state[right_end].prev = left_start;
}
// Update the merge that ends at left_start
if left_start > 0 {
let prev_start = state[left_start].prev;
potential_merge(&mut state, &mut heap, prev_start, right_end);
}
// Invalidate the merge starting at right_start, so we ignore it when it comes off the heap
state[right_start].next_rank = Rank::MAX;
}

let mut result = Vec::new();
let mut i = 0;
while i < state.len() {
if state[i].cur_rank != Rank::MAX {
result.push(state[i].cur_rank);
} else {
result.push(ranks[&piece[i..state[i].end]]);
}
i = state[i].end;
}
result
}

fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
// This is a vector of (start, rank).
// The rank is of the pair starting at position start.
Expand Down Expand Up @@ -73,21 +196,25 @@ fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize,
}

pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
if piece.len() == 1 {
return vec![ranks[piece]];
}
_byte_pair_merge(ranks, piece)
assert!(piece.len() > 1);
_byte_pair_merge_large(ranks, piece)
/*
.windows(2)
.map(|part| ranks[&piece[part[0].0..part[1].0]])
// .map(|part| ranks[&piece[dbg!(part[0].0..part[1].0)]])
.map(|part| ranks[&piece[part[0]..part[1]]])
.collect()
*/
}

pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
pub fn byte_pair_split<'a>(piece: &'a [u8], _ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
assert!(piece.len() > 1);
_byte_pair_merge(ranks, piece)
panic!("Not implemented");
/*
_byte_pair_merge_large(&ranks, &piece)
.windows(2)
.map(|part| &piece[part[0].0..part[1].0])
.collect()
*/
}

// Various performance notes:
Expand Down Expand Up @@ -521,7 +648,7 @@ impl CoreBPE {

#[cfg(test)]
mod tests {
use fancy_regex::Regex;

use rustc_hash::FxHashMap as HashMap;

use crate::{byte_pair_split, Rank};
Expand Down

0 comments on commit 0a46366

Please sign in to comment.