Skip to content

Commit

Permalink
Add write_into_mask method on Guide. Add kernels for torch and numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
unaidedelf8777 committed Feb 25, 2025
1 parent 33d6deb commit c01e1b8
Show file tree
Hide file tree
Showing 11 changed files with 253 additions and 3 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ hf-hub = "=0.3.2"
tokenizers = { version = "=0.20.3", features = ["http"] }
rustc-hash = "2.1.0"
regex-automata = "0.4.9"

[features]
python-bindings = ["pyo3", "serde-pyobject"]

Expand Down
2 changes: 2 additions & 0 deletions python/outlines_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from .outlines_core_rs import Guide, Index, Vocabulary

from .kernels import torch

try:
__version__ = version("outlines_core")
except PackageNotFoundError:
Expand Down
Empty file.
78 changes: 78 additions & 0 deletions python/outlines_core/kernels/numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from outlines_core import Guide

try:
import numpy as np
import numba
except ImportError as e:
missing_dep = "numba" if "numba" in str(e) else "numpy"
raise ImportError(
f"To use the kernels in `outlines_core.kernels.numpy`, `{missing_dep}` must be installed."
) from e

def allocate_token_bitmask(vocab_size: int) -> np.ndarray:
return np.full(
(1, (vocab_size + 31) // 32),
-1,
dtype=np.int32,
)

@numba.njit
def _apply_token_bitmask_kernel(logits, mask):
mask_len = mask.shape[1]
cutoff = 32 * mask_len

if logits.shape[1] > cutoff:
logits[:, cutoff:] = -np.inf
logits = logits[:, :cutoff]

n_rows, n_cols = logits.shape

for i in range(n_rows):
for mi in range(mask_len):
mval = mask[i, mi]
base = mi * 32
for bit in range(32):
j = base + bit

if j >= n_cols:
break

if ((mval >> bit) & 1) == 0:
logits[i, j] = -np.inf

def apply_token_bitmask_inplace(logits: np.ndarray, mask: np.ndarray) -> None:
if logits.ndim == 1:
logits = np.expand_dims(logits, axis=0)
if mask.ndim == 1:
mask = np.expand_dims(mask, axis=0)

if mask.dtype != np.int32:
raise ValueError(f"Invalid mask dtype: Expected `np.int32`, but got `{mask.dtype}`.")
elif mask.ndim != 2:
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.ndim}D.")
elif logits.ndim != 2:
raise ValueError(f"Invalid logits dimensions: Expected a 2D array, but got {mask.ndim}D.")
elif mask.shape[0] != logits.shape[0]:
raise ValueError(
f"Invalid batch size: Expected `mask.shape[0]` ({mask.shape[0]}) to match `logits.shape[0]` ({logits.shape[0]})."
)
_apply_token_bitmask_kernel(logits, mask)

def fill_next_token_bitmask(
guide: Guide, mask: np.ndarray
) -> None:
# timing: all checks take roughly 0.5 microseconds.
if mask.dtype != np.int32:
raise ValueError(f"Invalid mask dtype: Expected `np.int32`, but got `{mask.dtype}`.")
elif mask.ndim != 2:
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.ndim}D.")
elif mask.shape[0] != 1:
raise ValueError(f"Batch mask writes are not supported. Expected shape[0] == 1, but got shape {mask.shape}.")
elif not mask.flags["C_CONTIGUOUS"]:
raise ValueError("Mask array must be contiguous in memory. Use `np.ascontiguousarray(mask)`.")

return guide.write_mask_into(
mask.ctypes.data,
mask.size,
mask.itemsize
)
106 changes: 106 additions & 0 deletions python/outlines_core/kernels/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Provides kernels for masking a logits tensor,
# using the write_into_mask method on the `Guide` object and the bitmask
# which it writes into a tensor.
#
# Kernels inspired by https://github.com/guidance-ai/llguidance/blob/main/python/llguidance/torch.py
from outlines_core import Guide

try:
import torch
except Exception as e:
raise ModuleNotFoundError(
"`torch` is required to use the kernels from"
"`outlines_core.kernels.torch. You can install "
"`torch` using the official guide at https://pytorch.org/get-started/locally/"
)

def allocate_token_bitmask(vocab_size: int) -> torch.Tensor:
"""
Allocate a token bitmask for use with the `Guide.write_into_mask` API and logits masking,
based on the vocab_size.
Arguments:
- vocab_size: int
Returns:
- torch.Tensor
"""
return torch.full(
(1, (vocab_size + 31) // 32),
-1,
dtype=torch.int32,
pin_memory=torch.cuda.is_available(),
)

# This takes roughly 23 microseconds per run, with a bitmask of
# 1k allowed tokens, and 128k logits tensor.
# Also compiles to one graph with no graph breaks
# Performance characteristics are:
# - Larger the logits array ( length ), the longer the kernel takes
# - Constant time for mask i.e. number of allowed tokens does not effect execution
# time
@torch.compile(dynamic=True)
def _apply_token_bitmask_kernel(logits, mask):
# This should not modify, so long as the mask
# is allocated at the correct size
logits = torch.where(
torch.ge(
torch.arange(
logits.shape[1],
device=logits.device
),
32 * mask.shape[1]
),
-torch.inf,
logits
)

# Unpack each 32-bit mask value into 32 individual bits (as booleans)
bit_masks = (
(torch.bitwise_right_shift(
mask.unsqueeze(-1),
torch.arange(
32,
device=mask.device,
dtype=torch.int32
)) & 1
)
.bool()
.view(mask.shape[0], -1)
.narrow(1, 0, logits.shape[1])
)

# Possibly trim mask to match the logits width
bit_masks = bit_masks[:, :logits.shape[1]]
logits.masked_fill_(~bit_masks, -torch.inf)


def apply_token_bitmask_inplace(logits: torch.Tensor, mask: torch.Tensor) -> None:
if mask.dtype != torch.int32:
raise ValueError(f"Invalid mask dtype: Expected `torch.int32`, but got `{mask.dtype}`.")
elif mask.dim() != 2:
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.dim()}D.")
elif logits.dim() != 2:
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.dim()}D.")
elif mask.shape[0] != logits.shape[0]:
raise ValueError(
f"Invalid batch size: Expected `mask.shape[0]` ({mask.shape[0]}) to match `logits.shape[0]` ({logits.shape[0]})."
)
_apply_token_bitmask_kernel(logits, mask)

def fill_next_token_bitmask(guide: Guide, mask: torch.Tensor) -> None:
if mask.dtype != torch.int32:
raise ValueError(f"Invalid mask dtype: Expected `torch.int32`, but got `{mask.dtype}`.")
elif mask.dim() != 2:
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.dim()}D.")
elif mask.shape[0] != 1:
raise ValueError(f"Batch mask writes are not supported. Expected shape[0] == 1, but got shape {mask.shape}.")
elif not mask.is_contiguous():
raise ValueError("Mask array must be contiguous in memory. Use `mask.contiguous()` to fix it.")
elif mask.device != torch.device("cpu"):
raise ValueError(f"Invalid device: Expected `mask` tensor to be on device `cpu`, but found it on `{mask.device}`.")

guide.write_mask_into(
mask.data_ptr(),
mask.numel(),
mask.element_size()
)
9 changes: 9 additions & 0 deletions python/outlines_core/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ class Guide:
def is_finished(self) -> bool:
"""Checks if the automaton is in a final state."""
...
def write_mask_into(self, data_ptr: int, numel: int, element_size: int) -> None:
"""Write the mask of allowed tokens into the memory specified by data_ptr.
Size of the memory to be written to is indicated by `numel`, and `element_size`.
`element_size` must be 4.
`data_ptr` should be the data ptr to a `torch.tensor`, or `np.ndarray`, or other
continuous memory array"""
...

def __repr__(self) -> str:
"""Gets the debug string representation of the guide."""
...
Expand Down
12 changes: 12 additions & 0 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ pub struct Index {
transitions: HashMap<StateId, HashMap<TokenId, StateId>>,
/// The token ID reserved for the "end-of-sequence" token.
eos_token_id: TokenId,
/// The size of th vocabulary used to build the index
vocab_size: usize
}
/// The `Index` structure is designed to efficiently map tokens from a given vocabulary
/// to state transitions within a finite-state automaton.
Expand Down Expand Up @@ -99,6 +101,7 @@ pub struct Index {
impl Index {
/// Builds an `Index` from regular expression and vocabulary tokens.
pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
let vocab_size = vocabulary.len();
let eos_token_id = vocabulary.eos_token_id();
let dfa = DFA::new(regex).map_err(Box::new)?;
let start_state = match dfa.universal_start_state(Anchored::Yes) {
Expand Down Expand Up @@ -160,6 +163,7 @@ impl Index {
final_states,
transitions,
eos_token_id,
vocab_size
})
}

Expand Down Expand Up @@ -190,13 +194,21 @@ impl Index {
.map(|res| res.keys().cloned().collect())
}

pub fn allowed_tokens_iter(&self, state: &StateId) -> Option<impl Iterator<Item = &TokenId>> {
self.transitions.get(state).map(|map| map.keys())
}

/// Returns transition state for a given state and token id or `None` otherwise.
pub fn next_state(&self, state: &StateId, token_id: &TokenId) -> Option<StateId> {
if token_id == &self.eos_token_id {
return None;
}
Some(*self.transitions.get(state)?.get(token_id)?)
}

pub fn vocab_size(&self) -> usize {
self.vocab_size
}
}

impl std::fmt::Display for Index {
Expand Down
39 changes: 39 additions & 0 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,45 @@ impl PyGuide {
self.index.is_final_state(self.state)
}

fn write_mask_into(
&self,
data_ptr: usize,
numel: usize,
element_size: usize
) -> PyResult<()> {

if element_size != 4 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"The data type of the Tensor must be `torch.int32`",
));
} else if data_ptr == 0 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"data_ptr cannot be null or nullptr",
));
} else if data_ptr % 4 != 0 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"data_ptr is not aligned",
));
} else if ((self.index.0.vocab_size() +31) / 32) != numel * 4 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Invalid buffer size. Please ensure that the length of the mask tensor is equal to ((vocab_size + 31) / 32), and in `torch.int32` precision.",
));
}
unsafe {
std::ptr::write_bytes(data_ptr as *mut u8, 0, numel * 4);
}
if let Some(tokens) = self.index.0.allowed_tokens_iter(&self.state) {
let slice = unsafe { std::slice::from_raw_parts_mut(data_ptr as *mut u32, numel) };
for &token in tokens {
let bucket = (token as usize) / 32;
if bucket < slice.len() {
slice[bucket] |= 1 << ((token as usize) % 32);
}
}
}
Ok(())
}

fn __repr__(&self) -> String {
format!(
"Guide object with the state={:#?} and {:#?}",
Expand Down
5 changes: 5 additions & 0 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ impl Vocabulary {
self.tokens.remove(&token);
}

pub fn len(&self) -> usize {
// +1 for `eos_token_id`
self.tokens.len() + 1
}

/// Filters out `Prepend` kind of tokenizer's normalizers.
fn filter_prepend_normalizers(tokenizer: &mut Tokenizer) {
// Main concern is prepend normalizers, for example https://github.com/google/sentencepiece
Expand Down
2 changes: 1 addition & 1 deletion tests/test_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,4 @@ def test_equality(index):
# progress one of the guides, confirm different state == different guide
guide1.advance(guide1.get_tokens()[-1])
assert guide1 != guide2
assert guide3 == guide2
assert guide3 == guide2
2 changes: 1 addition & 1 deletion tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def test_deepcopy(index):
is_deleted = not any(id(o) == index2_id for o in gc.get_objects())
assert is_deleted

assert copy_index2 == index
assert copy_index2 == index

0 comments on commit c01e1b8

Please sign in to comment.