Skip to content

Commit

Permalink
Introduce TokenBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
gbenson committed May 23, 2024
1 parent 4026e8c commit 882c259
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
7 changes: 3 additions & 4 deletions src/dom_tokenizers/pre_tokenizers/dom_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from unidecode import unidecode

from .pre_tokenizer import PreTokenizer
from .token_buffer import TokenBuffer


class DOMSnapshotPreTokenizer(PreTokenizer):
Expand All @@ -35,7 +36,7 @@ def special_tokens(self):
if attr.endswith("token")
]

def pre_tokenize_dom(self, serialized: str) -> Iterable[str]:
def pre_tokenize_dom(self, buf: TokenBuffer, serialized: str):
"""Transform a serialized DOM into a sequence of tokens.
"""
snapshot = json.loads(serialized)
Expand All @@ -44,9 +45,7 @@ def pre_tokenize_dom(self, serialized: str) -> Iterable[str]:
if not any(key in snapshot for key in ("documents", "strings")):
snapshot = snapshot.get("result", snapshot)

return (ns.original
for ns in chain.from_iterable(
self._split_serialized(snapshot)))
buf.extend(chain.from_iterable(self._split_serialized(snapshot)))

def _split_serialized(self, snapshot: dict) -> Iterable[list[NormalizedString]]:
emitter = TokenEmitter(self, snapshot)
Expand Down
12 changes: 6 additions & 6 deletions src/dom_tokenizers/pre_tokenizers/pre_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import weakref

from abc import ABC, abstractmethod
from collections.abc import Iterable

from tokenizers import NormalizedString, PreTokenizedString
from tokenizers.pre_tokenizers import PreTokenizer as _PreTokenizer

from .token_buffer import TokenBuffer

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -62,16 +63,15 @@ def _pre_tokenize_dom(
split: NormalizedString,
) -> list[NormalizedString]:
try:
return [
NormalizedString(token)
for token in self.pre_tokenize_dom(split.original)
]
buf = TokenBuffer()
self.pre_tokenize_dom(buf, split.original)
return buf.tokens
except Exception as e:
logger.exception(f"{type(e).__name__} in pre-tokenizer:")
raise

@abstractmethod
def pre_tokenize_dom(self, serialized: str) -> Iterable[str]:
def pre_tokenize_dom(self, buf: TokenBuffer, serialized: str):
"""Transform a serialized DOM into a sequence of tokens.
"""
raise NotImplementedError
21 changes: 21 additions & 0 deletions src/dom_tokenizers/pre_tokenizers/token_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from collections.abc import Iterable

from tokenizers import NormalizedString


class TokenBuffer:
def __init__(self):
self._buf = []

@property
def tokens(self) -> list[NormalizedString]:
return self._buf

def append(self, token: str | NormalizedString):
if not isinstance(token, NormalizedString):
token = NormalizedString(token)
self._buf.append(token)

def extend(self, tokens: Iterable[str | NormalizedString]):
for token in tokens:
self.append(token)

0 comments on commit 882c259

Please sign in to comment.