From 4111e6d952b248d9ff5230396f5e6ea5394a836c Mon Sep 17 00:00:00 2001 From: David Samuel Date: Fri, 13 Dec 2024 20:02:07 +1100 Subject: [PATCH] Added option toload database from turso --- g2p_id/g2p.py | 53 ++++++++++-- g2p_id/turso_db.py | 179 +++++++++++++++++++++++++++++++++++++++ tests/inference.py | 23 +++++ tests/test_online_g2p.py | 75 ++++++++++++++++ 4 files changed, 324 insertions(+), 6 deletions(-) create mode 100644 g2p_id/turso_db.py create mode 100644 tests/inference.py create mode 100644 tests/test_online_g2p.py diff --git a/g2p_id/g2p.py b/g2p_id/g2p.py index a9c565e..5604397 100644 --- a/g2p_id/g2p.py +++ b/g2p_id/g2p.py @@ -21,6 +21,8 @@ from builtins import str as unicode from itertools import permutations from typing import Dict, List, Tuple, Union +import asyncio +import libsql_client import nltk from nltk.tag.perceptron import PerceptronTagger @@ -29,19 +31,40 @@ from g2p_id.bert import BERT from g2p_id.lstm import LSTM from g2p_id.text_processor import TextProcessor +from g2p_id.turso_db import _fetch_lexicon_from_turso, _fetch_homographs_from_turso + +import logging + +# Configure logging at the top of the file after imports +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +_LOGGER = logging.getLogger(__name__) + nltk.download("wordnet") resources_path = os.path.join(os.path.dirname(__file__), "resources") -def construct_homographs_dictionary() -> Dict[str, Tuple[str, str, str, str]]: - """Creates a dictionary of homographs +def construct_homographs_dictionary(turso_config=None) -> Dict[str, Tuple[str, str, str, str]]: + """Creates a dictionary of homographs. + If turso_config is provided, fetches from Turso database. + Otherwise, loads from local TSV file. + + Args: + turso_config: Optional dictionary containing Turso configuration Returns: Dict[str, Tuple[str, str, str, str]]: Key: WORD Value: (PH1, PH2, POS1, POS2) """ + if turso_config: + _LOGGER.info("Loading homographs from Turso database...") + return asyncio.run(_fetch_homographs_from_turso(turso_config)) + + _LOGGER.info("Loading homographs from local TSV file...") homograph_path = os.path.join(resources_path, "homographs_id.tsv") homograph2features = {} with open(homograph_path, encoding="utf-8") as file: @@ -53,14 +76,27 @@ def construct_homographs_dictionary() -> Dict[str, Tuple[str, str, str, str]]: return homograph2features -def construct_lexicon_dictionary() -> Dict[str, str]: +def construct_lexicon_dictionary(turso_config=None) -> Dict[str, str]: """Creates a lexicon dictionary. + If turso_config is provided, fetches from Turso database. + Otherwise, loads from local TSV file. + + Args: + turso_config: Optional dictionary containing Turso configuration + - url: Turso database URL + - auth_token: Authentication token + - table: Table name Returns: Dict[str, str]: Key: WORD Value: Phoneme (IPA) """ + if turso_config: + _LOGGER.info("Loading lexicon from Turso database...") + return asyncio.run(_fetch_lexicon_from_turso(turso_config)) + + _LOGGER.info("Loading lexicon from local TSV file...") lexicon_path = os.path.join(resources_path, "lexicon_id.tsv") lexicon2features = {} with open(lexicon_path, encoding="utf-8") as file: @@ -71,6 +107,10 @@ def construct_lexicon_dictionary() -> Dict[str, str]: return lexicon2features +def construct_online_lexicon(turso_config): + return + + class G2p: """Grapheme-to-phoneme (g2p) main class for phonemization. This class provides a high-level API for grapheme-to-phoneme conversion. @@ -84,7 +124,7 @@ class G2p: 7. Otherwise, predict with a neural network """ - def __init__(self, model_type="BERT"): + def __init__(self, model_type="BERT", turso_config=None): """Constructor for G2p. Args: @@ -92,8 +132,9 @@ def __init__(self, model_type="BERT"): Type of neural network to use for prediction. Choices are "LSTM" or "BERT". Defaults to "BERT". """ - self.homograph2features = construct_homographs_dictionary() - self.lexicon2features = construct_lexicon_dictionary() + self.homograph2features = construct_homographs_dictionary(turso_config) + self.lexicon2features = construct_lexicon_dictionary(turso_config) + self.normalizer = TextProcessor() self.tagger = PerceptronTagger(load=False) tagger_path = os.path.join(resources_path, "id_posp_tagger.pickle") diff --git a/g2p_id/turso_db.py b/g2p_id/turso_db.py new file mode 100644 index 0000000..fe44a9a --- /dev/null +++ b/g2p_id/turso_db.py @@ -0,0 +1,179 @@ +import typing +import asyncio # +import logging +import sqlite3 +import libsql_client +import asyncio +import time # Add this import at the top +from typing import Dict, Tuple + +from pathlib import Path +from gruut.const import PHONEMES_TYPE +from gruut.phonemize import SqlitePhonemizer + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +_LOGGER = logging.getLogger(__name__) + +async def _fetch_lexicon_from_turso(turso_config) -> Dict[str, str]: + """Fetches lexicon from Turso database using batching and parallel queries. + Only fetches non-homographs (where role is empty string). + + Args: + turso_config: Dictionary containing Turso configuration + - url: Turso database URL + - auth_token: Authentication token + - table: Table name + + Returns: + Dict[str, str]: Dictionary mapping words to phonemes + """ + start_time = time.time() + _LOGGER.info("Starting to fetch lexicon from Turso...") + + client = libsql_client.create_client( + url=turso_config["url"], + auth_token=turso_config["auth_token"] + ) + + try: + # Get total count first + count_result = await client.execute( + f"SELECT COUNT(*) FROM {turso_config['table']} WHERE role = ''" + ) + total_count = count_result.rows[0][0] + + # Initialize batch size + batch_size = 5000 + + # Create concurrent tasks for fetching data + tasks = [] + for offset in range(0, total_count, batch_size): + query = f""" + SELECT word, phonemes + FROM {turso_config['table']} + WHERE role = '' + LIMIT {batch_size} + OFFSET {offset} + """ + tasks.append(client.execute(query)) + + # Execute all queries concurrently + results = await asyncio.gather(*tasks) + + # Process results into dictionary + lexicon = {} + total_rows = 0 + for result in results: + rows = result.rows + if rows: + total_rows += len(rows) + for row in rows: + lexicon[row[0].lower()] = row[1] + + total_time = time.time() - start_time + _LOGGER.info(f"Loaded {total_rows} lexicon entries in {total_time:.2f} seconds") + + return lexicon + + finally: + await client.close() + +async def _fetch_homographs_from_turso(turso_config) -> Dict[str, Tuple[str, str, str, str]]: + """Fetches homographs from Turso database using batching and parallel queries. + Only fetches rows where role is not empty. + + Args: + turso_config: Dictionary containing Turso configuration + - url: Turso database URL + - auth_token: Authentication token + - table: Table name + + Returns: + Dict[str, Tuple[str, str, str, str]]: + Key: WORD + Value: (PH1, PH2, POS1, POS2) + """ + start_time = time.time() + _LOGGER.info("Starting to fetch homographs from Turso...") + + client = libsql_client.create_client( + url=turso_config["url"], + auth_token=turso_config["auth_token"] + ) + try: + # Get total count first for rows with non-empty roles + count_result = await client.execute( + f"SELECT COUNT(*) FROM {turso_config['table']} WHERE role != ''" + ) + total_count = count_result.rows[0][0] + # Initialize batch size + batch_size = 5000 + + # Create concurrent tasks for fetching data + tasks = [] + for offset in range(0, total_count, batch_size): + query = f""" + SELECT word, phonemes, role + FROM {turso_config['table']} + WHERE role != '' + ORDER BY word, role -- Ensure consistent ordering for pairs + LIMIT {batch_size} + OFFSET {offset} + """ + tasks.append(client.execute(query)) + + # Execute all queries concurrently + results = await asyncio.gather(*tasks) + + # Process results into dictionary + homographs = {} + current_word = None + current_data = [] + total_pairs = 0 + + # Process all results + for result in results: + for row in result.rows: + word, phonemes, role = row + word = word.lower() + + if current_word != word: + # Store previous word's data if we have a complete pair + if len(current_data) == 2: + homographs[current_word] = ( + current_data[0][0], # ph1 + current_data[1][0], # ph2 + current_data[0][1], # pos1 + current_data[1][1] # pos2 + ) + total_pairs += 1 + # Start new word + current_word = word + current_data = [(phonemes, role)] + else: + # Add second pronunciation for current word + if len(current_data) < 2: + current_data.append((phonemes, role)) + + # Don't forget to process the last word + if len(current_data) == 2: + homographs[current_word] = ( + current_data[0][0], # ph1 + current_data[1][0], # ph2 + current_data[0][1], # pos1 + current_data[1][1] # pos2 + ) + total_pairs += 1 + + total_time = time.time() - start_time + _LOGGER.info(f"Loaded {total_pairs} homograph pairs in {total_time:.2f} seconds") + + return homographs + + finally: + await client.close() diff --git a/tests/inference.py b/tests/inference.py new file mode 100644 index 0000000..febf5d7 --- /dev/null +++ b/tests/inference.py @@ -0,0 +1,23 @@ +import os +import libsql_client +from g2p_id import G2p + +TURSO_URL = os.getenv("TURSO_URL") +TURSO_AUTH_TOKEN = os.getenv("TURSO_AUTH_TOKEN") + +id_turso_config = { + "url": TURSO_URL, + "auth_token": TURSO_AUTH_TOKEN, + "table": "id_phonemes" +} + + +texts = [ + "Apel itu berwarna merah.", + "Rahel bersekolah di S M A Jakarta 17.", + "Mereka sedang bermain bola di lapangan.", +] + +g2p = G2p(turso_config=id_turso_config) +for text in texts: + print(g2p(text)) \ No newline at end of file diff --git a/tests/test_online_g2p.py b/tests/test_online_g2p.py new file mode 100644 index 0000000..8900abc --- /dev/null +++ b/tests/test_online_g2p.py @@ -0,0 +1,75 @@ +"""Tests the online g2p_id with the original g2p_id""" +import logging +import typing +import requests +import os +import logging +import unittest + +from tqdm import tqdm +from g2p_id import G2p + +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logging.getLogger("gruut").setLevel(logging.INFO) + +TURSO_URL = os.getenv("TURSO_URL") +TURSO_AUTH_TOKEN = os.getenv("TURSO_AUTH_TOKEN") + +id_turso_config = { + "url": TURSO_URL, + "auth_token": TURSO_AUTH_TOKEN, + "table": "id_phonemes" +} + +g2p_online = G2p(turso_config=id_turso_config) +g2p_offline = G2p() + +def get_phonemes_online(text): + return g2p_online(text) + +def get_phonemes_offline(text): + return g2p_offline(text) + + +class TestPhonemeFunctions(unittest.TestCase): + def test_phoneme_functions_english(self): + logging.info("Testing English language phonemes") + error_log_path = 'phoneme_comparison_errors.log' + file_path = '/home/s44504/3b01c699-3670-469b-801f-13880b9cac56/dataset_creation/data/indonesian_book_transcripts.txt' + + with open(file_path, 'r') as file, open(error_log_path, 'w') as error_log: + lines = file.readlines() + for line_number, line in tqdm(enumerate(lines, start=1), total=len(lines), desc="Indonesia Test Progress"): + line = line.strip() + if not line: + continue # Skip empty lines + + try: + online_phonemes = get_phonemes_online(line) + offline_phonemes = get_phonemes_offline(line) + + if online_phonemes != offline_phonemes: + # Find the differing segments + differences = [] + for i, (online_seg, offline_seg) in enumerate(zip(online_phonemes, offline_phonemes)): + if online_seg != offline_seg: + differences.append(f"Word {i}: {online_seg} ≠ {offline_seg}") + + error_message = f"""Mismatch at line {line_number}: + Text: {line} + Differences: {' | '.join(differences)} + """ + error_log.write(error_message + "\n") + logging.warning(error_message) + except Exception as e: + error_message = f"Error processing line {line_number}: {line}\nError: {str(e)}\n" + error_log.write(error_message) + logging.error(error_message) + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + unittest.main()