From 043d883cd4376e0733bcb336bc747d507b086a99 Mon Sep 17 00:00:00 2001 From: Kazumasa Omura Date: Sat, 20 Apr 2024 09:32:29 +0900 Subject: [PATCH] tweak --- scripts/preprocessors/preprocess_reading.py | 4 +- scripts/preprocessors/preprocess_typo.py | 75 +++++++++++---------- src/kwja/datamodule/datasets/char.py | 8 +-- src/kwja/utils/normalization.py | 7 ++ 4 files changed, 52 insertions(+), 42 deletions(-) create mode 100644 src/kwja/utils/normalization.py diff --git a/scripts/preprocessors/preprocess_reading.py b/scripts/preprocessors/preprocess_reading.py index d3f81972..fbd3ac6a 100644 --- a/scripts/preprocessors/preprocess_reading.py +++ b/scripts/preprocessors/preprocess_reading.py @@ -17,7 +17,7 @@ def main(): parser = ArgumentParser() parser.add_argument("-m", "--model-name-or-path", type=str, help="model_name_or_path") parser.add_argument("-k", "--kanji-dic", type=str, help="path to kanji dic file") - parser.add_argument("-i", "--input", type=str, help="path to input directory") + parser.add_argument("-i", "--in-dir", type=Path, help="path to input directory") args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) @@ -25,7 +25,7 @@ def main(): reading_aligner = ReadingAligner(tokenizer, kanji_dic) reading_counter: Dict[str, int] = Counter() - for path in Path(args.input).glob("**/*.knp"): + for path in args.in_dir.glob("**/*.knp"): logger.info(f"processing {path}") with path.open() as f: document = Document.from_knp(f.read()) diff --git a/scripts/preprocessors/preprocess_typo.py b/scripts/preprocessors/preprocess_typo.py index 0a8e3cb3..8be5b6e8 100644 --- a/scripts/preprocessors/preprocess_typo.py +++ b/scripts/preprocessors/preprocess_typo.py @@ -9,11 +9,10 @@ from pathlib import Path from textwrap import dedent from typing import Dict, List, Optional, Tuple -from unicodedata import normalize from Levenshtein import opcodes -from kwja.utils.constants import TRANSLATION_TABLE +from kwja.utils.normalization import normalize_text logger = logging.getLogger(__name__) logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s: %(message)s", level=logging.DEBUG) @@ -33,10 +32,6 @@ class OpType(Enum): REPLACE = "replace" -def normalize_text(text: str) -> str: - return normalize("NFKC", text).translate(TRANSLATION_TABLE) - - def normalize_example(example: dict) -> None: example["pre_text"] = normalize_text(example["pre_text"]) example["post_text"] = normalize_text(example["post_text"]) @@ -161,9 +156,42 @@ def load_examples(in_dir: Path, split: str) -> Tuple[Dict[str, List[Dict[str, st return category2examples, other_examples -def build_multi_char_vocab(train_file: Path, out_dir: Path) -> None: +def save_examples( + category2examples: Dict[str, List[Dict[str, str]]], + other_examples: List[Dict[str, str]], + out_dir: Path, + split: str, + num_valid_examples_per_category: int, +) -> None: + if split == "train": + train_examples: List[Dict[str, str]] = other_examples + valid_examples: List[Dict[str, str]] = [] + for category, examples in category2examples.items(): + train_examples.extend(examples[num_valid_examples_per_category:]) + valid_examples.extend(examples[:num_valid_examples_per_category]) + + random.shuffle(train_examples) + train_dir: Path = out_dir / split + train_dir.mkdir(parents=True, exist_ok=True) + with (train_dir / f"{split}.jsonl").open(mode="w") as f: + f.write("\n".join(json.dumps(e, ensure_ascii=False) for e in train_examples) + "\n") + + valid_dir: Path = out_dir / "valid" + valid_dir.mkdir(parents=True, exist_ok=True) + with (valid_dir / "valid.jsonl").open(mode="w") as f: + f.write("\n".join(json.dumps(e, ensure_ascii=False) for e in valid_examples) + "\n") + elif split == "test": + test_dir: Path = out_dir / split + test_dir.mkdir(parents=True, exist_ok=True) + with (test_dir / f"{split}.jsonl").open(mode="w") as f: + f.write("\n".join(json.dumps(e, ensure_ascii=False) for e in other_examples) + "\n") + else: + raise ValueError("invalid split") + + +def build_multi_char_vocab(out_dir: Path) -> None: multi_char_vocab: List[str] = [] - with train_file.open() as f: + with (out_dir / "train" / "train.jsonl").open() as f: for line in f: train_example: dict = json.loads(line) for ins_tag in train_example["ins_tags"]: @@ -214,35 +242,10 @@ def main(): random.seed(0) - for split in ["test", "train"]: + for split in ["train", "test"]: category2examples, other_examples = load_examples(args.in_dir, split) - if split == "test": - test_dir: Path = args.out_dir / split - test_dir.mkdir(parents=True, exist_ok=True) - with (test_dir / f"{split}.jsonl").open(mode="w") as f: - f.write("\n".join(json.dumps(e, ensure_ascii=False) for e in other_examples) + "\n") - else: - train_examples: List[Dict[str, str]] = other_examples - valid_examples: List[Dict[str, str]] = [] - for category, examples in category2examples.items(): - train_examples.extend(examples[args.num_valid_examples_per_category :]) - valid_examples.extend(examples[: args.num_valid_examples_per_category]) - - random.shuffle(train_examples) - train_dir: Path = args.out_dir / split - train_dir.mkdir(parents=True, exist_ok=True) - with (train_dir / "train.jsonl").open(mode="w") as f: - f.write("\n".join(json.dumps(e, ensure_ascii=False) for e in train_examples) + "\n") - - valid_dir: Path = args.out_dir / "valid" - valid_dir.mkdir(parents=True, exist_ok=True) - with (valid_dir / "valid.jsonl").open(mode="w") as f: - f.write("\n".join(json.dumps(e, ensure_ascii=False) for e in valid_examples) + "\n") - - build_multi_char_vocab( - train_file=args.out_dir / "train" / "train.jsonl", - out_dir=args.out_dir, - ) + save_examples(category2examples, other_examples, args.out_dir, split, args.num_valid_examples_per_category) + build_multi_char_vocab(args.out_dir) if __name__ == "__main__": diff --git a/src/kwja/datamodule/datasets/char.py b/src/kwja/datamodule/datasets/char.py index c23bece7..d05c5e19 100644 --- a/src/kwja/datamodule/datasets/char.py +++ b/src/kwja/datamodule/datasets/char.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from pathlib import Path from typing import Dict, List -from unicodedata import normalize from rhoknp import Document from transformers import BatchEncoding, PreTrainedTokenizerBase @@ -14,11 +13,11 @@ IGNORE_INDEX, IGNORE_WORD_NORM_OP_TAG, SENT_SEGMENTATION_TAGS, - TRANSLATION_TABLE, WORD_NORM_OP_TAGS, WORD_SEGMENTATION_TAGS, ) from kwja.utils.logging_util import track +from kwja.utils.normalization import normalize_text from kwja.utils.word_normalization import SentenceDenormalizer logger = logging.getLogger(__name__) @@ -116,10 +115,11 @@ def _postprocess_document(self, document: Document) -> Document: # e.g. です -> でーす self.denormalizer.denormalize(sentence, self.denormalize_probability) for morpheme in sentence.morphemes: - normalized = normalize("NFKC", morpheme.text).translate(TRANSLATION_TABLE) + normalized = normalize_text(morpheme.text) if normalized != morpheme.text: logger.warning(f"apply normalization ({morpheme.text} -> {normalized})") morpheme.text = normalized - morpheme.lemma = normalize("NFKC", morpheme.lemma).translate(TRANSLATION_TABLE) + morpheme.reading = normalize_text(morpheme.reading) + morpheme.lemma = normalize_text(morpheme.lemma) # propagate updates of morpheme.text to sentence.text and document.text return document.reparse() diff --git a/src/kwja/utils/normalization.py b/src/kwja/utils/normalization.py new file mode 100644 index 00000000..f3272feb --- /dev/null +++ b/src/kwja/utils/normalization.py @@ -0,0 +1,7 @@ +from unicodedata import normalize + +from kwja.utils.constants import TRANSLATION_TABLE + + +def normalize_text(text: str) -> str: + return normalize("NFKC", text).translate(TRANSLATION_TABLE)