From 4918be067ecd91a13366232c7f83f1ff234f500b Mon Sep 17 00:00:00 2001 From: Gary Benson Date: Sun, 19 May 2024 00:33:21 +0100 Subject: [PATCH] Tokenizer creation function --- .flake8 | 2 +- src/dom_tokenizers/__init__.py | 1 + src/dom_tokenizers/dump.py | 8 ++------ src/dom_tokenizers/tokenizers.py | 19 +++++++++++++++++++ src/dom_tokenizers/train.py | 28 +++++++++++++--------------- tests/conftest.py | 17 +++++++---------- tests/test_train_tokenizer.py | 2 +- 7 files changed, 44 insertions(+), 33 deletions(-) create mode 100644 src/dom_tokenizers/tokenizers.py diff --git a/.flake8 b/.flake8 index 998689f..53559fb 100644 --- a/.flake8 +++ b/.flake8 @@ -2,6 +2,6 @@ exclude = .git,__pycache__,venv*,.venv*,build,dist,.local,.#*,#*,*~ per-file-ignores = # imported but unused - src/dom_tokenizers/**/__init__.py: F401 + src/**/__init__.py: F401 # line too long src/dom_tokenizers/pre_tokenizers/dom_snapshot.py: E501 diff --git a/src/dom_tokenizers/__init__.py b/src/dom_tokenizers/__init__.py index e69de29..0023a73 100644 --- a/src/dom_tokenizers/__init__.py +++ b/src/dom_tokenizers/__init__.py @@ -0,0 +1 @@ +from .tokenizers import DOMSnapshotTokenizer diff --git a/src/dom_tokenizers/dump.py b/src/dom_tokenizers/dump.py index 759dc26..1c51588 100644 --- a/src/dom_tokenizers/dump.py +++ b/src/dom_tokenizers/dump.py @@ -4,10 +4,8 @@ from argparse import ArgumentParser from datasets import load_dataset -from tokenizers.pre_tokenizers import PreTokenizer -from .internal.transformers import AutoTokenizer -from .pre_tokenizers import DOMSnapshotPreTokenizer +from .tokenizers import DOMSnapshotTokenizer DEFAULT_DATASET = "gbenson/interesting-dom-snapshots" DEFAULT_SPLIT = "train" @@ -31,9 +29,7 @@ def main(): warnings.filterwarnings("ignore", message=r".*resume_download.*") - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) - tokenizer.backend_tokenizer.pre_tokenizer = \ - PreTokenizer.custom(DOMSnapshotPreTokenizer()) + tokenizer = DOMSnapshotTokenizer.from_pretrained(args.tokenizer) dataset = load_dataset(args.dataset, split=args.split) rows = ((row["source_index"], row["dom_snapshot"]) for row in dataset) diff --git a/src/dom_tokenizers/tokenizers.py b/src/dom_tokenizers/tokenizers.py new file mode 100644 index 0000000..907bb29 --- /dev/null +++ b/src/dom_tokenizers/tokenizers.py @@ -0,0 +1,19 @@ +from tokenizers.pre_tokenizers import PreTokenizer + +from .internal.transformers import AutoTokenizer +from .pre_tokenizers import DOMSnapshotPreTokenizer + + +class Tokenizer: + @classmethod + def from_pretrained(cls, *args, **kwargs): + tokenizer = AutoTokenizer.from_pretrained(*args, **kwargs) + impl = tokenizer.backend_tokenizer + pt_impl = cls.PRE_TOKENIZER_CLASS() + impl.pre_tokenizer = PreTokenizer.custom(pt_impl) + tokenizer.backend_pre_tokenizer = pt_impl # so we can find it + return tokenizer + + +class DOMSnapshotTokenizer(Tokenizer): + PRE_TOKENIZER_CLASS = DOMSnapshotPreTokenizer diff --git a/src/dom_tokenizers/train.py b/src/dom_tokenizers/train.py index 69f3220..b32db20 100644 --- a/src/dom_tokenizers/train.py +++ b/src/dom_tokenizers/train.py @@ -6,10 +6,9 @@ from math import log10, floor from datasets import load_dataset -from tokenizers.pre_tokenizers import PreTokenizer, WhitespaceSplit +from tokenizers.pre_tokenizers import WhitespaceSplit -from .internal.transformers import AutoTokenizer -from .pre_tokenizers import DOMSnapshotPreTokenizer +from .tokenizers import DOMSnapshotTokenizer DEFAULT_BASE_TOKENIZER = "bert-base-cased" DEFAULT_SPLIT = "train" @@ -25,17 +24,8 @@ def train_tokenizer( # Create the base tokenizer we'll train our new tokenizer from. if isinstance(base_tokenizer, str): - base_tokenizer = AutoTokenizer.from_pretrained(base_tokenizer) - - # Create the custom pretokenizer our new tokenizer will use. - new_pretokenizer = DOMSnapshotPreTokenizer() - - # List the custom special tokens that need adding to our tokenizer. - new_special_tokens = [ - special_token - for special_token in new_pretokenizer.special_tokens - if base_tokenizer.tokenize(special_token) != [special_token] - ] + base_tokenizer = DOMSnapshotTokenizer.from_pretrained( + base_tokenizer) # It's not possible to train using a custom pre-tokenizer, the Rust # code raises "Exception: Custom PreTokenizer cannot be serialized" @@ -44,9 +34,9 @@ def train_tokenizer( # whitespace and hope the regular pretokenizer takes it back apart # how we need it to. + new_pretokenizer = base_tokenizer.backend_tokenizer.pre_tokenizer base_tokenizer.backend_tokenizer.pre_tokenizer = WhitespaceSplit() base_pretokenizer = base_tokenizer.backend_tokenizer.pre_tokenizer - new_pretokenizer = PreTokenizer.custom(new_pretokenizer) def futz_input(real_input): pretokenized = new_pretokenizer.pre_tokenize_str(real_input) @@ -57,6 +47,14 @@ def futz_input(real_input): assert got_tokens == want_tokens return futzed_input + # List the custom special tokens that need adding to our tokenizer. + dom_snapshot_pre_tokenizer = base_tokenizer.backend_pre_tokenizer + new_special_tokens = [ + special_token + for special_token in dom_snapshot_pre_tokenizer.special_tokens + if base_tokenizer.tokenize(special_token) != [special_token] + ] + def get_training_corpus(): for row in training_dataset: yield futz_input(json.dumps(row["dom_snapshot"])) diff --git a/tests/conftest.py b/tests/conftest.py index 7e13c41..edf4fec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,23 +1,20 @@ import pytest -from tokenizers.pre_tokenizers import PreTokenizer - -from dom_tokenizers.internal.transformers import AutoTokenizer -from dom_tokenizers.pre_tokenizers import DOMSnapshotPreTokenizer +from dom_tokenizers import DOMSnapshotTokenizer from dom_tokenizers.train import DEFAULT_BASE_TOKENIZER @pytest.fixture -def base_tokenizer(): - """An instance of the default base tokenizer we train our - tokenizers from. +def dom_snapshot_tokenizer(): + """An instance of a tokenizer that consumes JSON-serialized + DOM snapshots. """ - return AutoTokenizer.from_pretrained(DEFAULT_BASE_TOKENIZER) + return DOMSnapshotTokenizer.from_pretrained(DEFAULT_BASE_TOKENIZER) @pytest.fixture -def dom_snapshot_pre_tokenizer(): +def dom_snapshot_pre_tokenizer(dom_snapshot_tokenizer): """An instance of a pre-tokenizer that consumes JSON-serialized DOM snapshots. """ - return PreTokenizer.custom(DOMSnapshotPreTokenizer()) + return dom_snapshot_tokenizer.backend_tokenizer.pre_tokenizer diff --git a/tests/test_train_tokenizer.py b/tests/test_train_tokenizer.py index 974d1e8..af7bc7c 100644 --- a/tests/test_train_tokenizer.py +++ b/tests/test_train_tokenizer.py @@ -7,7 +7,7 @@ from .util import load_resource -def test_base64(base_tokenizer, dom_snapshot_pre_tokenizer): +def test_base64(dom_snapshot_pre_tokenizer): """Test that base64 is entered successfully. Incorrectly-sequenced lowercasing (i.e. applied prior to pre-tokenization) will cause this test to fail.