diff --git a/src/dom_tokenizers/pre_tokenizers/dom_snapshot.py b/src/dom_tokenizers/pre_tokenizers/dom_snapshot.py index 53b7050..7388d7a 100644 --- a/src/dom_tokenizers/pre_tokenizers/dom_snapshot.py +++ b/src/dom_tokenizers/pre_tokenizers/dom_snapshot.py @@ -15,6 +15,7 @@ from unidecode import unidecode from .pre_tokenizer import PreTokenizer +from .token_buffer import TokenBuffer class DOMSnapshotPreTokenizer(PreTokenizer): @@ -35,7 +36,7 @@ def special_tokens(self): if attr.endswith("token") ] - def pre_tokenize_dom(self, serialized: str) -> Iterable[str]: + def pre_tokenize_dom(self, buf: TokenBuffer, serialized: str): """Transform a serialized DOM into a sequence of tokens. """ snapshot = json.loads(serialized) @@ -44,9 +45,7 @@ def pre_tokenize_dom(self, serialized: str) -> Iterable[str]: if not any(key in snapshot for key in ("documents", "strings")): snapshot = snapshot.get("result", snapshot) - return (ns.original - for ns in chain.from_iterable( - self._split_serialized(snapshot))) + buf.extend(chain.from_iterable(self._split_serialized(snapshot))) def _split_serialized(self, snapshot: dict) -> Iterable[list[NormalizedString]]: emitter = TokenEmitter(self, snapshot) diff --git a/src/dom_tokenizers/pre_tokenizers/pre_tokenizer.py b/src/dom_tokenizers/pre_tokenizers/pre_tokenizer.py index 51f8ffd..2d3c621 100644 --- a/src/dom_tokenizers/pre_tokenizers/pre_tokenizer.py +++ b/src/dom_tokenizers/pre_tokenizers/pre_tokenizer.py @@ -2,11 +2,12 @@ import weakref from abc import ABC, abstractmethod -from collections.abc import Iterable from tokenizers import NormalizedString, PreTokenizedString from tokenizers.pre_tokenizers import PreTokenizer as _PreTokenizer +from .token_buffer import TokenBuffer + logger = logging.getLogger(__name__) @@ -62,16 +63,15 @@ def _pre_tokenize_dom( split: NormalizedString, ) -> list[NormalizedString]: try: - return [ - NormalizedString(token) - for token in self.pre_tokenize_dom(split.original) - ] + buf = TokenBuffer() + self.pre_tokenize_dom(buf, split.original) + return buf.tokens except Exception as e: logger.exception(f"{type(e).__name__} in pre-tokenizer:") raise @abstractmethod - def pre_tokenize_dom(self, serialized: str) -> Iterable[str]: + def pre_tokenize_dom(self, buf: TokenBuffer, serialized: str): """Transform a serialized DOM into a sequence of tokens. """ raise NotImplementedError diff --git a/src/dom_tokenizers/pre_tokenizers/token_buffer.py b/src/dom_tokenizers/pre_tokenizers/token_buffer.py new file mode 100644 index 0000000..a91fb99 --- /dev/null +++ b/src/dom_tokenizers/pre_tokenizers/token_buffer.py @@ -0,0 +1,21 @@ +from collections.abc import Iterable + +from tokenizers import NormalizedString + + +class TokenBuffer: + def __init__(self): + self._buf = [] + + @property + def tokens(self) -> list[NormalizedString]: + return self._buf + + def append(self, token: str | NormalizedString): + if not isinstance(token, NormalizedString): + token = NormalizedString(token) + self._buf.append(token) + + def extend(self, tokens: Iterable[str | NormalizedString]): + for token in tokens: + self.append(token)