diff --git a/tests/data/modules/permitted_tokens.json b/tests/data/modules/permitted_tokens.json index 69814122..7bf20cd7 100644 --- a/tests/data/modules/permitted_tokens.json +++ b/tests/data/modules/permitted_tokens.json @@ -9,7 +9,7 @@ "input_tokens": [""], "permitted_tokens": [""] }, - "target_morpheme": "init" + "target_property": "init" }, "case002": { "surfs": ["計算", "機", "に", "よる", "言語", "理解", "を", "実現", "する"], @@ -21,7 +21,7 @@ "input_tokens": ["", ""], "permitted_tokens": ["計", "計算"] }, - "target_morpheme": "surf" + "target_property": "surf" }, "case003": { "surfs": ["計算", "機", "に", "よる", "言語", "理解", "を", "実現", "する"], @@ -33,7 +33,7 @@ "input_tokens": ["", "", "計"], "permitted_tokens": ["算"] }, - "target_morpheme": "surf" + "target_property": "surf" }, "case004": { "surfs": ["計算", "機", "に", "よる", "言語", "理解", "を", "実現", "する"], @@ -45,105 +45,101 @@ "input_tokens": ["", "", "計算"], "permitted_tokens": [""] }, - "target_morpheme": "surf" + "target_property": "surf" }, "case005": { "surfs": ["計算", "機", "に", "よる", "言語", "理解", "を", "実現", "する"], "t5": { "input_tokens": ["", "", "計算", ""], "permitted_tokens": "reading_candidates", - "banned_tokens": ["", "", "", "", ""] + "prohibited_tokens": [""] }, "mt5": { "input_tokens": ["", "", "計算", ""], "permitted_tokens": "reading_candidates", - "banned_tokens": ["", "", "", "", ""] + "prohibited_tokens": [""] }, - "target_morpheme": "reading" + "target_property": "reading" }, "case006": { "surfs": ["計算", "機", "に", "よる", "言語", "理解", "を", "実現", "する"], "t5": { "input_tokens": ["", "", "計算", "", "▁けい"], - "permitted_tokens": "reading_candidates", - "banned_tokens": [""] + "permitted_tokens": "reading_candidates" }, "mt5": { "input_tokens": ["", "", "計算", "", "けい"], - "permitted_tokens": "reading_candidates", - "banned_tokens": [""] + "permitted_tokens": "reading_candidates" }, - "target_morpheme": "reading" + "target_property": "reading" }, "case007": { "surfs": ["計算", "機", "に", "よる", "言語", "理解", "を", "実現", "する"], "t5": { "input_tokens": ["", "", "計算", "", "▁けい", "さん", ""], "permitted_tokens": [], - "banned_tokens": ["", "", "", "", ""] + "prohibited_tokens": ["", "", "", "", "", "", "", ""] }, "mt5": { "input_tokens": ["", "", "計算", "", "けい", "さん", ""], "permitted_tokens": [], - "banned_tokens": ["", "", "", "", ""] + "prohibited_tokens": ["", "", "", "", "", "", "", ""] }, - "target_morpheme": "lemma" + "target_property": "lemma" }, "case008": { "surfs": ["計算", "機", "に", "よる", "言語", "理解", "を", "実現", "する"], "t5": { "input_tokens": ["", "", "計算", "", "▁けい", "さん", "", "計算"], "permitted_tokens": [], - "banned_tokens": ["", "", "", ""] + "prohibited_tokens": ["", "", "", "", "", "", ""] }, "mt5": { "input_tokens": ["", "", "計算", "", "けい", "さん", "", "計算"], "permitted_tokens": [], - "banned_tokens": ["", "", "", ""] + "prohibited_tokens": ["", "", "", "", "", "", ""] }, - "target_morpheme": "lemma" + "target_property": "lemma" }, "case009": { "surfs": ["計算", "機", "に", "よる", "言語", "理解", "を", "実現", "する"], "t5": { "input_tokens": ["", "", "計算", "", "▁けい", "さん", "", "計算", ""], "permitted_tokens": [], - "banned_tokens": ["", "", "", "", ""] + "prohibited_tokens": ["", "", "", "", "", "", ""] }, "mt5": { "input_tokens": ["", "", "計算", "", "けい", "さん", "", "計算", ""], "permitted_tokens": [], - "banned_tokens": ["", "", "", "", ""] + "prohibited_tokens": ["", "", "", "", "", "", ""] }, - "target_morpheme": "canon" + "target_property": "canon" }, "case010": { "surfs": ["計算", "機", "に", "よる", "言語", "理解", "を", "実現", "する"], "t5": { "input_tokens": ["", "", "計算", "", "▁けい", "さん", "", "計算", "", "計算"], "permitted_tokens": [], - "banned_tokens": ["", "", "", ""] + "prohibited_tokens": ["", "", "", "", "", ""] }, "mt5": { "input_tokens": ["", "", "計算", "", "けい", "さん", "", "計算", "", "計算"], "permitted_tokens": [], - "banned_tokens": ["", "", "", ""] + "prohibited_tokens": ["", "", "", "", "", ""] }, - "target_morpheme": "canon" + "target_property": "canon" }, "case011": { "surfs": ["計算", "機", "に", "よる", "言語", "理解", "を", "実現", "する"], "t5": { "input_tokens": ["", "", "計算", "", "▁けい", "さん", "", "計算", "", "計算", "/", "けい", "さん", ""], - "permitted_tokens": ["機"], - "banned_tokens": ["", "", "", ""] + "permitted_tokens": ["機"] }, "mt5": { "input_tokens": ["", "", "計算", "", "けい", "さん", "", "計算", "", "計算", "/", "けい", "さん", ""], - "permitted_tokens": ["機"], - "banned_tokens": ["", "", "", ""] + "permitted_tokens": ["機"] }, - "target_morpheme": "surf" + "target_property": "surf" }, "case012": { "surfs": ["計算"], @@ -155,34 +151,34 @@ "input_tokens": ["", "", "計算"], "permitted_tokens": [""] }, - "target_morpheme": "surf" + "target_property": "surf" }, "case013": { "surfs": ["計算"], "t5": { "input_tokens": ["", "", "計算", "", "けい", "さん", "", "計算", ""], "permitted_tokens": [], - "banned_tokens": ["", "", "", "", ""] + "prohibited_tokens": ["", "", "", "", "", "", ""] }, "mt5": { "input_tokens": ["", "", "計算", "", "けい", "さん", "", "計算", ""], "permitted_tokens": [], - "banned_tokens": ["", "", "", "", ""] + "prohibited_tokens": ["", "", "", "", "", "", ""] }, - "target_morpheme": "canon" + "target_property": "canon" }, "case014": { "surfs": ["計算"], "t5": { "input_tokens": ["", "", "計算", "", "けい", "さん", "", "計算", "", "計算", "/", "けい", "さん"], "permitted_tokens": [], - "banned_tokens": [] + "prohibited_tokens": [] }, "mt5": { "input_tokens": ["", "", "計算", "", "けい", "さん", "", "計算", "", "計算", "/", "けい", "さん"], "permitted_tokens": [], - "banned_tokens": [] + "prohibited_tokens": [] }, - "target_morpheme": "canon" + "target_property": "canon" } } diff --git a/tests/modules/components/test_logits_processor.py b/tests/modules/components/test_logits_processor.py index c8ad4ec7..dc610a02 100644 --- a/tests/modules/components/test_logits_processor.py +++ b/tests/modules/components/test_logits_processor.py @@ -5,15 +5,13 @@ import pytest import torch -from rhoknp import Sentence from transformers import AutoTokenizer, PreTrainedTokenizerFast -from kwja.datamodule.datasets.seq2seq import Seq2SeqFormatter from kwja.modules.components.logits_processor import ( - ForcedLogitsProcessor, - TargetMorpheme, - get_char2tokens, - get_reading_candidates, + SurfForcedDecodingLogitsProcessor, + TargetProperty, + get_char2token_items, + get_reading_candidate_token_ids, ) from kwja.utils.constants import HALF_SPACE_TOKEN, LEMMA_TOKEN @@ -25,9 +23,9 @@ def test_get_char2tokens() -> None: pretrained_model_name_or_path="google/mt5-small", additional_special_tokens=SPECIAL_TOKENS, ) - mt5_char2tokens = get_char2tokens(mt5_tokenizer) - assert len(mt5_char2tokens) == 19455 - assert mt5_char2tokens["京"] == { + mt5_char2token_items = get_char2token_items(mt5_tokenizer) + assert len(mt5_char2token_items) == 19455 + assert mt5_char2token_items["京"] == { "京东": 165392, "京娱乐": 178804, "京都府": 166766, @@ -37,8 +35,8 @@ def test_get_char2tokens() -> None: "京区": 208641, "▁京公网安备": 234066, } - mt5_underscore_tokens: Set[str] = {x for x in mt5_tokenizer.get_vocab() if x.startswith("▁")} - mt5_non_underscore_tokens: Set[str] = {x for x in mt5_tokenizer.get_vocab() if not x.startswith("▁")} + mt5_underscore_tokens: Set[str] = {x for x in mt5_tokenizer.vocab if x.startswith("▁")} + mt5_non_underscore_tokens: Set[str] = {x for x in mt5_tokenizer.vocab if not x.startswith("▁")} assert len(mt5_underscore_tokens) == 56369 assert len(mt5_non_underscore_tokens) == 193831 @@ -46,9 +44,9 @@ def test_get_char2tokens() -> None: pretrained_model_name_or_path="retrieva-jp/t5-small-short", additional_special_tokens=SPECIAL_TOKENS, ) - t5_char2tokens = get_char2tokens(t5_tokenizer) - assert len(t5_char2tokens) == 4289 - assert t5_char2tokens["京"] == { + t5_char2token_items = get_char2token_items(t5_tokenizer) + assert len(t5_char2token_items) == 4289 + assert t5_char2token_items["京"] == { "京都府": 3411, "京都府出身": 26029, "京橋": 22889, @@ -69,13 +67,13 @@ def test_get_char2tokens() -> None: "京阪": 14311, "京都市立": 24756, } - t5_underscore_tokens: Set[str] = {x for x in t5_tokenizer.get_vocab() if x.startswith("▁")} - t5_non_underscore_tokens: Set[str] = {x for x in t5_tokenizer.get_vocab() if not x.startswith("▁")} + t5_underscore_tokens: Set[str] = {x for x in t5_tokenizer.vocab if x.startswith("▁")} + t5_non_underscore_tokens: Set[str] = {x for x in t5_tokenizer.vocab if not x.startswith("▁")} assert len(t5_underscore_tokens) == 531 assert len(t5_non_underscore_tokens) == 31569 -def test_get_target_morpheme(data_dir: Path) -> None: +def test_get_target_property(data_dir: Path) -> None: model2pretrained_model_name_or_path: Dict[str, str] = { "mt5": "google/mt5-small", "t5": "retrieva-jp/t5-small-short", @@ -85,29 +83,29 @@ def test_get_target_morpheme(data_dir: Path) -> None: pretrained_model_name_or_path, additional_special_tokens=SPECIAL_TOKENS, ) - reading_candidates: Set[int] = get_reading_candidates(tokenizer) - char2tokens: Dict[str, Dict[str, int]] = get_char2tokens(tokenizer) + reading_candidate_token_ids: List[int] = get_reading_candidate_token_ids(tokenizer) + char2token_items: Dict[str, Dict[str, int]] = get_char2token_items(tokenizer) test_case_path: Path = data_dir / "modules" / "permitted_tokens.json" with open(test_case_path) as f: test_cases = json.load(f) for test_case in test_cases.values(): - processor = ForcedLogitsProcessor( - surfs=[test_case["surfs"]], + processor = SurfForcedDecodingLogitsProcessor( + batch_surfs=[test_case["surfs"]], num_beams=1, tokenizer=tokenizer, - reading_candidates=reading_candidates, - char2tokens=char2tokens, + char2token_items=char2token_items, + reading_candidate_token_ids=reading_candidate_token_ids, ) input_ids: List[int] = tokenizer.convert_tokens_to_ids(test_case[model]["input_tokens"]) - target_morpheme: TargetMorpheme = processor._get_target_morpheme(input_ids) - assert target_morpheme.surf == (test_case["target_morpheme"] == "surf") - assert target_morpheme.reading == (test_case["target_morpheme"] == "reading") - assert target_morpheme.lemma == (test_case["target_morpheme"] == "lemma") - assert target_morpheme.canon == (test_case["target_morpheme"] == "canon") + target_property: TargetProperty = processor._get_target_property(input_ids) + assert target_property.surf == (test_case["target_property"] == "surf") + assert target_property.reading == (test_case["target_property"] == "reading") + assert target_property.lemma == (test_case["target_property"] == "lemma") + assert target_property.canon == (test_case["target_property"] == "canon") @pytest.mark.parametrize( - ("input_tokens", "surfs", "expected_remaining_surf"), + ("input_tokens", "surfs", "expected_ungenerated_surf"), [ (["", ""], ["研究", "する"], "研究"), (["", "", "研"], ["研究", "する"], "究"), @@ -151,80 +149,51 @@ def test_get_target_morpheme(data_dir: Path) -> None: ), ], ) -def test_get_remaining_surf(input_tokens: List[str], surfs: List[str], expected_remaining_surf: str) -> None: +def test_get_ungenerated_surf(input_tokens: List[str], surfs: List[str], expected_ungenerated_surf: str) -> None: for pretrained_model_name_or_path in ["google/mt5-small", "retrieva-jp/t5-small-short"]: tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, additional_special_tokens=SPECIAL_TOKENS, ) - char2tokens = get_char2tokens(tokenizer) - reading_candidates = get_reading_candidates(tokenizer) + char2token_items = get_char2token_items(tokenizer) + reading_candidate_token_ids = get_reading_candidate_token_ids(tokenizer) - processor = ForcedLogitsProcessor( - surfs=[surfs], + processor = SurfForcedDecodingLogitsProcessor( + batch_surfs=[surfs], num_beams=1, tokenizer=tokenizer, - reading_candidates=reading_candidates, - char2tokens=char2tokens, + char2token_items=char2token_items, + reading_candidate_token_ids=reading_candidate_token_ids, ) input_ids: List[int] = tokenizer.convert_tokens_to_ids(input_tokens) - assert processor._get_remaining_surf(input_ids, surfs) == expected_remaining_surf + assert processor._get_ungenerated_surf(input_ids, surfs) == expected_ungenerated_surf @pytest.mark.parametrize( - ("surfs", "permitted_tokens"), + ("surfs", "expected_permitted_tokens"), [ (["研究", "を", "する"], ["研究", "研"]), ([HALF_SPACE_TOKEN, "研究", "を", "する"], [HALF_SPACE_TOKEN]), ], ) -def test_get_banned_token_ids(surfs: List[str], permitted_tokens: List[str]) -> None: +def test_get_permitted_token_ids(surfs: List[str], expected_permitted_tokens: List[str]) -> None: for pretrained_model_name_or_path in ["google/mt5-small", "retrieva-jp/t5-small-short"]: tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, additional_special_tokens=SPECIAL_TOKENS, ) - reading_candidates: Set[int] = get_reading_candidates(tokenizer) - char2tokens: Dict[str, Dict[str, int]] = get_char2tokens(tokenizer) - expected_banned_token_ids: Set[int] = set(tokenizer.get_vocab().values()) - set( - tokenizer.convert_tokens_to_ids(permitted_tokens) - ) + char2token_items = get_char2token_items(tokenizer) + reading_candidate_token_ids = get_reading_candidate_token_ids(tokenizer) - processor = ForcedLogitsProcessor( - surfs=[surfs], + processor = SurfForcedDecodingLogitsProcessor( + batch_surfs=[surfs], num_beams=2, tokenizer=tokenizer, - reading_candidates=reading_candidates, - char2tokens=char2tokens, + char2token_items=char2token_items, + reading_candidate_token_ids=reading_candidate_token_ids, ) - banned_token_ids: Set[int] = processor._get_banned_token_ids("".join(surfs)) - assert sorted(list(banned_token_ids)) == sorted(list(expected_banned_token_ids)) - - -def test_get_generated_surf(data_dir: Path) -> None: - for pretrained_model_name_or_path in ["google/mt5-small", "retrieva-jp/t5-small-short"]: - tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=pretrained_model_name_or_path, - additional_special_tokens=SPECIAL_TOKENS, - ) - formatter: Seq2SeqFormatter = Seq2SeqFormatter(tokenizer) - char2tokens = get_char2tokens(tokenizer) - reading_candidates = get_reading_candidates(tokenizer) - - test_case_dir: Path = data_dir / "modules" / "juman" - for path in test_case_dir.glob("*.juman"): - with path.open() as f: - sentence: Sentence = Sentence.from_jumanpp(f.read()) - processor = ForcedLogitsProcessor( - surfs=[formatter.get_surfs(sentence)], - num_beams=2, - tokenizer=tokenizer, - reading_candidates=reading_candidates, - char2tokens=char2tokens, - ) - tgt_tokens: List[str] = formatter.get_tgt_tokens(sentence) - tgt_input_ids: List[int] = [*tokenizer.convert_tokens_to_ids(tgt_tokens), tokenizer.eos_token_id] - assert processor.surfs[0] == processor._get_generated_surf(tgt_input_ids) + permitted_token_ids: List[int] = processor._get_permitted_token_ids("".join(surfs)) + assert sorted(permitted_token_ids) == sorted(tokenizer.convert_tokens_to_ids(expected_permitted_tokens)) def test_get_mask(data_dir: Path) -> None: @@ -238,22 +207,25 @@ def test_get_mask(data_dir: Path) -> None: additional_special_tokens=SPECIAL_TOKENS, ) vocab_size: int = len(tokenizer.get_vocab()) - reading_candidates: Set[int] = get_reading_candidates(tokenizer) - reading_candidate_tokens: Set[str] = {tokenizer.convert_ids_to_tokens([x])[0] for x in reading_candidates} - char2tokens = get_char2tokens(tokenizer) - all_tokens: Set[str] = set(tokenizer.get_vocab().keys()) + char2token_items = get_char2token_items(tokenizer) + reading_candidate_token_ids = get_reading_candidate_token_ids(tokenizer) + reading_candidate_tokens: Set[str] = { + tokenizer.convert_ids_to_tokens(reading_candidate_token_id) + for reading_candidate_token_id in reading_candidate_token_ids + } + all_tokens: Set[str] = set(tokenizer.vocab.keys()) test_case_path: Path = data_dir / "modules" / "permitted_tokens.json" with open(test_case_path) as f: test_cases = json.load(f) - for test_case in test_cases.values(): - assert test_case["target_morpheme"] in ["surf", "reading", "lemma", "canon", "init"] - processor = ForcedLogitsProcessor( - surfs=[test_case["surfs"]], + for k, test_case in test_cases.items(): + assert test_case["target_property"] in ["surf", "reading", "lemma", "canon", "init"] + processor = SurfForcedDecodingLogitsProcessor( + batch_surfs=[test_case["surfs"]], num_beams=1, tokenizer=tokenizer, - reading_candidates=reading_candidates, - char2tokens=char2tokens, + char2token_items=char2token_items, + reading_candidate_token_ids=reading_candidate_token_ids, ) warped_scores: Optional[torch.Tensor] = None for idx in range(1, len(test_case[model]["input_tokens"]) + 1): @@ -273,11 +245,11 @@ def test_get_mask(data_dir: Path) -> None: if test_case[model]["permitted_tokens"] == "reading_candidates": expected_permitted_tokens: Set[str] = reading_candidate_tokens expected_permitted_tokens.add(LEMMA_TOKEN) - elif not test_case[model]["permitted_tokens"]: + elif len(test_case[model]["permitted_tokens"]) == 0: expected_permitted_tokens = copy.deepcopy(all_tokens) else: expected_permitted_tokens = set(test_case[model]["permitted_tokens"]) - if "banned_tokens" in test_case[model]: - expected_permitted_tokens -= set(test_case[model]["banned_tokens"]) + if "prohibited_tokens" in test_case[model]: + expected_permitted_tokens -= set(test_case[model]["prohibited_tokens"]) assert sorted(list(permitted_tokens)) == sorted(list(expected_permitted_tokens)) diff --git a/tests/utils/test_seq2seq_format.py b/tests/utils/test_seq2seq_format.py index 1ea799b6..9a875c5a 100644 --- a/tests/utils/test_seq2seq_format.py +++ b/tests/utils/test_seq2seq_format.py @@ -10,7 +10,7 @@ CANON_TOKEN, HALF_SPACE_TOKEN, LEMMA_TOKEN, - MORPHEME_SPLIT_TOKEN, + MORPHEME_DELIMITER_TOKEN, NO_CANON_TOKEN, READING_TOKEN, SURF_TOKEN, @@ -152,7 +152,7 @@ def test_get_src_tokens(data_dir: Path, seq2seq_tokenizer: PreTrainedTokenizerFa normalize_morpheme(morpheme) expected_src_tokens: List[str] = [] for morphemes in src_tokens[idx]: - expected_src_tokens.extend([*morphemes.split(" "), MORPHEME_SPLIT_TOKEN]) + expected_src_tokens.extend([*morphemes.split(" "), MORPHEME_DELIMITER_TOKEN]) assert [ x[1:] if x.startswith("▁") else x for x in seq2seq_formatter.get_src_tokens(sentence) ] == expected_src_tokens[:-1]