Skip to content

Commit

Permalink
tweak
Browse files Browse the repository at this point in the history
  • Loading branch information
omukazu committed Apr 20, 2024
1 parent 2e0fda1 commit 043d883
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 42 deletions.
4 changes: 2 additions & 2 deletions scripts/preprocessors/preprocess_reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ def main():
parser = ArgumentParser()
parser.add_argument("-m", "--model-name-or-path", type=str, help="model_name_or_path")
parser.add_argument("-k", "--kanji-dic", type=str, help="path to kanji dic file")
parser.add_argument("-i", "--input", type=str, help="path to input directory")
parser.add_argument("-i", "--in-dir", type=Path, help="path to input directory")
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
kanji_dic = KanjiDic(args.kanji_dic)
reading_aligner = ReadingAligner(tokenizer, kanji_dic)

reading_counter: Dict[str, int] = Counter()
for path in Path(args.input).glob("**/*.knp"):
for path in args.in_dir.glob("**/*.knp"):
logger.info(f"processing {path}")
with path.open() as f:
document = Document.from_knp(f.read())
Expand Down
75 changes: 39 additions & 36 deletions scripts/preprocessors/preprocess_typo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from pathlib import Path
from textwrap import dedent
from typing import Dict, List, Optional, Tuple
from unicodedata import normalize

from Levenshtein import opcodes

from kwja.utils.constants import TRANSLATION_TABLE
from kwja.utils.normalization import normalize_text

logger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s: %(message)s", level=logging.DEBUG)
Expand All @@ -33,10 +32,6 @@ class OpType(Enum):
REPLACE = "replace"


def normalize_text(text: str) -> str:
return normalize("NFKC", text).translate(TRANSLATION_TABLE)


def normalize_example(example: dict) -> None:
example["pre_text"] = normalize_text(example["pre_text"])
example["post_text"] = normalize_text(example["post_text"])
Expand Down Expand Up @@ -161,9 +156,42 @@ def load_examples(in_dir: Path, split: str) -> Tuple[Dict[str, List[Dict[str, st
return category2examples, other_examples


def build_multi_char_vocab(train_file: Path, out_dir: Path) -> None:
def save_examples(
category2examples: Dict[str, List[Dict[str, str]]],
other_examples: List[Dict[str, str]],
out_dir: Path,
split: str,
num_valid_examples_per_category: int,
) -> None:
if split == "train":
train_examples: List[Dict[str, str]] = other_examples
valid_examples: List[Dict[str, str]] = []
for category, examples in category2examples.items():
train_examples.extend(examples[num_valid_examples_per_category:])
valid_examples.extend(examples[:num_valid_examples_per_category])

random.shuffle(train_examples)
train_dir: Path = out_dir / split
train_dir.mkdir(parents=True, exist_ok=True)
with (train_dir / f"{split}.jsonl").open(mode="w") as f:
f.write("\n".join(json.dumps(e, ensure_ascii=False) for e in train_examples) + "\n")

valid_dir: Path = out_dir / "valid"
valid_dir.mkdir(parents=True, exist_ok=True)
with (valid_dir / "valid.jsonl").open(mode="w") as f:
f.write("\n".join(json.dumps(e, ensure_ascii=False) for e in valid_examples) + "\n")
elif split == "test":
test_dir: Path = out_dir / split
test_dir.mkdir(parents=True, exist_ok=True)
with (test_dir / f"{split}.jsonl").open(mode="w") as f:
f.write("\n".join(json.dumps(e, ensure_ascii=False) for e in other_examples) + "\n")
else:
raise ValueError("invalid split")


def build_multi_char_vocab(out_dir: Path) -> None:
multi_char_vocab: List[str] = []
with train_file.open() as f:
with (out_dir / "train" / "train.jsonl").open() as f:
for line in f:
train_example: dict = json.loads(line)
for ins_tag in train_example["ins_tags"]:
Expand Down Expand Up @@ -214,35 +242,10 @@ def main():

random.seed(0)

for split in ["test", "train"]:
for split in ["train", "test"]:
category2examples, other_examples = load_examples(args.in_dir, split)
if split == "test":
test_dir: Path = args.out_dir / split
test_dir.mkdir(parents=True, exist_ok=True)
with (test_dir / f"{split}.jsonl").open(mode="w") as f:
f.write("\n".join(json.dumps(e, ensure_ascii=False) for e in other_examples) + "\n")
else:
train_examples: List[Dict[str, str]] = other_examples
valid_examples: List[Dict[str, str]] = []
for category, examples in category2examples.items():
train_examples.extend(examples[args.num_valid_examples_per_category :])
valid_examples.extend(examples[: args.num_valid_examples_per_category])

random.shuffle(train_examples)
train_dir: Path = args.out_dir / split
train_dir.mkdir(parents=True, exist_ok=True)
with (train_dir / "train.jsonl").open(mode="w") as f:
f.write("\n".join(json.dumps(e, ensure_ascii=False) for e in train_examples) + "\n")

valid_dir: Path = args.out_dir / "valid"
valid_dir.mkdir(parents=True, exist_ok=True)
with (valid_dir / "valid.jsonl").open(mode="w") as f:
f.write("\n".join(json.dumps(e, ensure_ascii=False) for e in valid_examples) + "\n")

build_multi_char_vocab(
train_file=args.out_dir / "train" / "train.jsonl",
out_dir=args.out_dir,
)
save_examples(category2examples, other_examples, args.out_dir, split, args.num_valid_examples_per_category)
build_multi_char_vocab(args.out_dir)


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions src/kwja/datamodule/datasets/char.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List
from unicodedata import normalize

from rhoknp import Document
from transformers import BatchEncoding, PreTrainedTokenizerBase
Expand All @@ -14,11 +13,11 @@
IGNORE_INDEX,
IGNORE_WORD_NORM_OP_TAG,
SENT_SEGMENTATION_TAGS,
TRANSLATION_TABLE,
WORD_NORM_OP_TAGS,
WORD_SEGMENTATION_TAGS,
)
from kwja.utils.logging_util import track
from kwja.utils.normalization import normalize_text
from kwja.utils.word_normalization import SentenceDenormalizer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -116,10 +115,11 @@ def _postprocess_document(self, document: Document) -> Document:
# e.g. です -> でーす
self.denormalizer.denormalize(sentence, self.denormalize_probability)
for morpheme in sentence.morphemes:
normalized = normalize("NFKC", morpheme.text).translate(TRANSLATION_TABLE)
normalized = normalize_text(morpheme.text)
if normalized != morpheme.text:
logger.warning(f"apply normalization ({morpheme.text} -> {normalized})")
morpheme.text = normalized
morpheme.lemma = normalize("NFKC", morpheme.lemma).translate(TRANSLATION_TABLE)
morpheme.reading = normalize_text(morpheme.reading)
morpheme.lemma = normalize_text(morpheme.lemma)
# propagate updates of morpheme.text to sentence.text and document.text
return document.reparse()
7 changes: 7 additions & 0 deletions src/kwja/utils/normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from unicodedata import normalize

from kwja.utils.constants import TRANSLATION_TABLE


def normalize_text(text: str) -> str:
return normalize("NFKC", text).translate(TRANSLATION_TABLE)

0 comments on commit 043d883

Please sign in to comment.