Skip to content

Commit

Permalink
User-friendly trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
gbenson committed May 16, 2024
1 parent e599219 commit e4d4c12
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 24 deletions.
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ per-file-ignores =
src/dom_tokenizers/**/__init__.py: F401
# line too long
src/dom_tokenizers/pre_tokenizers/dom_snapshot.py: E501
# module level import not at top of file
src/dom_tokenizers/train.py: E402
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ python3 -m venv .venv
pip install --upgrade pip
pip install -e .[dev,train]
```

## Train a tokenizer
```sh
train-tokenizer gbenson/interesting-dom-snapshots -n 10000
```
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dom-tokenizers"
version = "0.0.1"
version = "0.0.2"
authors = [{ name = "Gary Benson" }]
description = "DOM-aware tokenizers for Hugging Face language models"
readme = "README.md"
Expand Down
102 changes: 79 additions & 23 deletions src/dom_tokenizers/train.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
import os
import json
import warnings

from argparse import ArgumentParser
from math import log10, floor

from datasets import load_dataset
from tokenizers.pre_tokenizers import PreTokenizer, WhitespaceSplit

os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = str(True)
from transformers import AutoTokenizer

from .pre_tokenizers import DOMSnapshotPreTokenizer

FULL_DATASET = "gbenson/webui-dom-snapshots"
TEST_DATASET = "gbenson/interesting-dom-snapshots"
DEFAULT_BASE = "bert-base-uncased"
DEFAULT_SPLIT = "train"
DEFAULT_SIZE = 1024
SEND_BUGS_TO = "https://github.com/gbenson/dom-tokenizers/issues"


def train_tokenizer(
*args,
training_dataset=None,
base_tokenizer="bert-base-uncased",
vocab_size=1024, # XXX including all tokens and alphabet
**kwargs):
"""
XXX
base_tokenizer
all other args passed to load_dataset for XXX...
"""

# Load the training data we'll train our new tokenizer with.
if training_dataset is None:
training_dataset = load_dataset(*args, **kwargs)
training_dataset,
base_tokenizer=DEFAULT_BASE,
vocab_size=DEFAULT_SIZE,
corpus_size=None):

# Create the base tokenizer we'll train our new tokenizer from.
if isinstance(base_tokenizer, str):
Expand Down Expand Up @@ -65,25 +63,83 @@ def get_training_corpus():
for row in training_dataset:
yield futz_input(json.dumps(row["dom_snapshot"]))

# Try and get a dataset length, for the progress tracker.
if corpus_size is None:
try:
corpus_size = len(training_dataset)
except TypeError:
pass

# Train the new tokenizer.
new_tokenizer = base_tokenizer.train_new_from_iterator(
text_iterator=get_training_corpus(),
vocab_size=vocab_size,
new_special_tokens=new_special_tokens,
length=len(training_dataset), # used for progress tracking
length=corpus_size,
show_progress=True,
)

return new_tokenizer


def main(save_directory="pretrained", use_full_dataset=False):
def _round_and_prefix(value):
"""314159 -> '314k'."""
whole, frac = divmod(log10(value), 1)
unit_index, whole = divmod(floor(whole), 3)
value = round(10 ** (whole + frac))
unit = ([""] + list("kMBTQ"))[unit_index]
return f"{value}{unit}"


def main():
p = ArgumentParser(
description="Train DOM-aware tokenizers.",
epilog=f"Report bugs to: <{SEND_BUGS_TO}>.")
p.add_argument(
"dataset", metavar="DATASET",
help="dataset containing the training corpus")
p.add_argument(
"--base-tokenizer", metavar="ID", default=DEFAULT_BASE,
help=f"tokenizer to train ours from [default: {DEFAULT_BASE}]")
p.add_argument(
"--split", default=DEFAULT_SPLIT, metavar="SPLIT", dest="split_name",
help=(f"split of the training dataset to use"
f" [default: {DEFAULT_SPLIT}]"))
p.add_argument(
"-N", "--num-inputs", metavar="N", dest="corpus_size",
type=int,
help=("number of sequences in the training dataset, if known;"
" this is used to provide meaningful progress tracking"))
p.add_argument(
"-n", "--num-tokens", metavar="N", dest="vocab_size", type=int,
default=DEFAULT_SIZE,
help=(f"desired vocabulary size, including all special tokens and"
f" the initial alphabet [default: {DEFAULT_SIZE} tokens]"))
p.add_argument(
"-o", "--output", metavar="DIR", dest="save_directory",
help=("directory to save the trained tokenizer into"
" [default: something based on targeted vocabulary size]"))
args = p.parse_args()

save_directory = args.save_directory
if save_directory is None:
pretty_size = _round_and_prefix(args.vocab_size)
save_directory = f"dom-tokenizer-{pretty_size}"
print(f"Output directory: {save_directory}\n")

warnings.filterwarnings("ignore", message=r".*resume_download.*")

if use_full_dataset:
dataset, kwargs = FULL_DATASET, dict(streaming=True)
else:
dataset, kwargs = TEST_DATASET, {}
tokenizer = train_tokenizer(
load_dataset(
args.dataset,
split=args.split_name,
streaming=True),
base_tokenizer=args.base_tokenizer,
vocab_size=args.vocab_size,
corpus_size=args.corpus_size)
print(f'\n{tokenizer.tokenize("training complete")}')

tokenizer = train_tokenizer(dataset, split="train", **kwargs)
tokenizer.save_pretrained(save_directory)

print(tokenizer.tokenize("tokenizer state saved"))
print(tokenizer.tokenize("see you soon") + ["!!"])

0 comments on commit e4d4c12

Please sign in to comment.