Skip to content

Commit

Permalink
Tokenizer creation function
Browse files Browse the repository at this point in the history
  • Loading branch information
gbenson committed May 19, 2024
1 parent b0518b4 commit 4918be0
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
exclude = .git,__pycache__,venv*,.venv*,build,dist,.local,.#*,#*,*~
per-file-ignores =
# imported but unused
src/dom_tokenizers/**/__init__.py: F401
src/**/__init__.py: F401
# line too long
src/dom_tokenizers/pre_tokenizers/dom_snapshot.py: E501
1 change: 1 addition & 0 deletions src/dom_tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tokenizers import DOMSnapshotTokenizer
8 changes: 2 additions & 6 deletions src/dom_tokenizers/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
from argparse import ArgumentParser

from datasets import load_dataset
from tokenizers.pre_tokenizers import PreTokenizer

from .internal.transformers import AutoTokenizer
from .pre_tokenizers import DOMSnapshotPreTokenizer
from .tokenizers import DOMSnapshotTokenizer

DEFAULT_DATASET = "gbenson/interesting-dom-snapshots"
DEFAULT_SPLIT = "train"
Expand All @@ -31,9 +29,7 @@ def main():

warnings.filterwarnings("ignore", message=r".*resume_download.*")

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
tokenizer.backend_tokenizer.pre_tokenizer = \
PreTokenizer.custom(DOMSnapshotPreTokenizer())
tokenizer = DOMSnapshotTokenizer.from_pretrained(args.tokenizer)

dataset = load_dataset(args.dataset, split=args.split)
rows = ((row["source_index"], row["dom_snapshot"]) for row in dataset)
Expand Down
19 changes: 19 additions & 0 deletions src/dom_tokenizers/tokenizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from tokenizers.pre_tokenizers import PreTokenizer

from .internal.transformers import AutoTokenizer
from .pre_tokenizers import DOMSnapshotPreTokenizer


class Tokenizer:
@classmethod
def from_pretrained(cls, *args, **kwargs):
tokenizer = AutoTokenizer.from_pretrained(*args, **kwargs)
impl = tokenizer.backend_tokenizer
pt_impl = cls.PRE_TOKENIZER_CLASS()
impl.pre_tokenizer = PreTokenizer.custom(pt_impl)
tokenizer.backend_pre_tokenizer = pt_impl # so we can find it
return tokenizer


class DOMSnapshotTokenizer(Tokenizer):
PRE_TOKENIZER_CLASS = DOMSnapshotPreTokenizer
28 changes: 13 additions & 15 deletions src/dom_tokenizers/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from math import log10, floor

from datasets import load_dataset
from tokenizers.pre_tokenizers import PreTokenizer, WhitespaceSplit
from tokenizers.pre_tokenizers import WhitespaceSplit

from .internal.transformers import AutoTokenizer
from .pre_tokenizers import DOMSnapshotPreTokenizer
from .tokenizers import DOMSnapshotTokenizer

DEFAULT_BASE_TOKENIZER = "bert-base-cased"
DEFAULT_SPLIT = "train"
Expand All @@ -25,17 +24,8 @@ def train_tokenizer(

# Create the base tokenizer we'll train our new tokenizer from.
if isinstance(base_tokenizer, str):
base_tokenizer = AutoTokenizer.from_pretrained(base_tokenizer)

# Create the custom pretokenizer our new tokenizer will use.
new_pretokenizer = DOMSnapshotPreTokenizer()

# List the custom special tokens that need adding to our tokenizer.
new_special_tokens = [
special_token
for special_token in new_pretokenizer.special_tokens
if base_tokenizer.tokenize(special_token) != [special_token]
]
base_tokenizer = DOMSnapshotTokenizer.from_pretrained(
base_tokenizer)

# It's not possible to train using a custom pre-tokenizer, the Rust
# code raises "Exception: Custom PreTokenizer cannot be serialized"
Expand All @@ -44,9 +34,9 @@ def train_tokenizer(
# whitespace and hope the regular pretokenizer takes it back apart
# how we need it to.

new_pretokenizer = base_tokenizer.backend_tokenizer.pre_tokenizer
base_tokenizer.backend_tokenizer.pre_tokenizer = WhitespaceSplit()
base_pretokenizer = base_tokenizer.backend_tokenizer.pre_tokenizer
new_pretokenizer = PreTokenizer.custom(new_pretokenizer)

def futz_input(real_input):
pretokenized = new_pretokenizer.pre_tokenize_str(real_input)
Expand All @@ -57,6 +47,14 @@ def futz_input(real_input):
assert got_tokens == want_tokens
return futzed_input

# List the custom special tokens that need adding to our tokenizer.
dom_snapshot_pre_tokenizer = base_tokenizer.backend_pre_tokenizer
new_special_tokens = [
special_token
for special_token in dom_snapshot_pre_tokenizer.special_tokens
if base_tokenizer.tokenize(special_token) != [special_token]
]

def get_training_corpus():
for row in training_dataset:
yield futz_input(json.dumps(row["dom_snapshot"]))
Expand Down
17 changes: 7 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
import pytest

from tokenizers.pre_tokenizers import PreTokenizer

from dom_tokenizers.internal.transformers import AutoTokenizer
from dom_tokenizers.pre_tokenizers import DOMSnapshotPreTokenizer
from dom_tokenizers import DOMSnapshotTokenizer
from dom_tokenizers.train import DEFAULT_BASE_TOKENIZER


@pytest.fixture
def base_tokenizer():
"""An instance of the default base tokenizer we train our
tokenizers from.
def dom_snapshot_tokenizer():
"""An instance of a tokenizer that consumes JSON-serialized
DOM snapshots.
"""
return AutoTokenizer.from_pretrained(DEFAULT_BASE_TOKENIZER)
return DOMSnapshotTokenizer.from_pretrained(DEFAULT_BASE_TOKENIZER)


@pytest.fixture
def dom_snapshot_pre_tokenizer():
def dom_snapshot_pre_tokenizer(dom_snapshot_tokenizer):
"""An instance of a pre-tokenizer that consumes JSON-serialized
DOM snapshots.
"""
return PreTokenizer.custom(DOMSnapshotPreTokenizer())
return dom_snapshot_tokenizer.backend_tokenizer.pre_tokenizer
2 changes: 1 addition & 1 deletion tests/test_train_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .util import load_resource


def test_base64(base_tokenizer, dom_snapshot_pre_tokenizer):
def test_base64(dom_snapshot_pre_tokenizer):
"""Test that base64 is entered successfully. Incorrectly-sequenced
lowercasing (i.e. applied prior to pre-tokenization) will cause this
test to fail.
Expand Down

0 comments on commit 4918be0

Please sign in to comment.