Skip to content

Commit

Permalink
Add tokenize_batch method
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-p committed Nov 28, 2023
1 parent 448c229 commit 1435509
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ required-features = ["openai-vocabulary-file"]

[dependencies]
ahash = "0.8.6"
ndarray = "0.15.6"
regex = "1.10.2"

[dev-dependencies]
Expand Down
21 changes: 20 additions & 1 deletion benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion};

use instant_clip_tokenizer::Tokenizer;

fn tokenize_batch_small(c: &mut Criterion) {
let tokenizer = Tokenizer::new();
c.bench_function("tokenize_batch small", |b| {
b.iter(|| {
tokenizer.tokenize_batch(
black_box(["Hi", "How are you?", "I'm fine, thanks!"]),
black_box(6),
)
})
});
}

fn short(c: &mut Criterion) {
let tokenizer = Tokenizer::new();
let mut tokens = Vec::with_capacity(100);
Expand Down Expand Up @@ -53,5 +65,12 @@ fn long_sentence(c: &mut Criterion) {
});
}

criterion_group!(benches, short, realistic, long_word, long_sentence);
criterion_group!(
benches,
tokenize_batch_small,
short,
realistic,
long_word,
long_sentence,
);
criterion_main!(benches);
68 changes: 68 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@
//! tokenizer using [`Tokenizer::with_vocabulary`].
use std::io::{self, BufRead};
use std::iter::ExactSizeIterator;

use ahash::AHashMap;
use ndarray::Array2;
use regex::Regex;

/// A text tokenizer for the CLIP neural network.
Expand Down Expand Up @@ -189,6 +191,58 @@ impl Tokenizer {
})
}

/// Tokenize a batch of multiple input strings.
///
/// Each given input string is tokenized and the numeric representation
/// written to a row in the resulting two-dimensional matrix of shape
/// `(texts.len(), context_length)`, with the special `<start_of_text>`
/// token prepended, and `<end_of_text>` appended to each text.
///
/// The individual input strings are lowercased before being tokenized, but
/// otherwise no pre-processing is performed.
///
/// `context_length` is the maximum number of tokens per each text and
/// should be `77` for all current CLIP models. If tokenizing an input text
/// results in too many tokens, the token sequence will be truncated to fit
/// within the resulting row of length `context_length`, always including
/// the `<start_of_text>` and `<end_of_text>` marker tokens.
///
/// # Examples
///
/// ```
/// # use ndarray::array;
/// # use instant_clip_tokenizer::{Token, Tokenizer};
/// let tokenizer = Tokenizer::new();
/// let encoded = tokenizer.tokenize_batch(["Hi", "How are you?"], 5);
/// assert_eq!(encoded, array![
/// [49406, 1883, 49407, 0, 0],
/// [49406, 829, 631, 592, 49407],
/// ]);
/// ```
pub fn tokenize_batch<'a, I>(&self, texts: I, context_length: usize) -> Array2<u16>
where
I: IntoIterator<Item = &'a str>,
I::IntoIter: ExactSizeIterator,
{
if context_length < 3 {
panic!("context length must be at least 3");
}
let texts = texts.into_iter();
let mut result = Array2::zeros((texts.len(), context_length));
let mut tokens = Vec::with_capacity(context_length);
for (text, mut result_row) in texts.zip(result.rows_mut()) {
tokens.clear();
tokens.push(self.start_of_text());
self.encode(text, &mut tokens);
tokens.truncate(context_length - 1);
tokens.push(self.end_of_text());
for (token, result_element) in tokens.iter().zip(&mut result_row) {
*result_element = token.to_u16();
}
}
result
}

/// Encode a `text` input as a sequence of tokens.
///
/// The resulting tokens are appended to `out`. `text` is lowercased before
Expand Down Expand Up @@ -354,8 +408,22 @@ impl Token {

#[cfg(test)]
mod tests {
use ndarray::array;

use super::*;

#[test]
fn tokenize_batch() {
let tokenizer = Tokenizer::new();
let encoded = tokenizer.tokenize_batch(["Hi", "How are you?", "I'm fine, thanks!"], 6);
let expected = array![
[49406, 1883, 49407, 0, 0, 0],
[49406, 829, 631, 592, 286, 49407],
[49406, 328, 880, 3797, 267, 49407],
];
assert_eq!(encoded, expected);
}

#[test]
fn encode_special_chars() {
let tokens = encode("hello world!!!");
Expand Down

0 comments on commit 1435509

Please sign in to comment.