Skip to content

Commit

Permalink
Added option toload database from turso
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidSamuell committed Dec 13, 2024
1 parent 7e1492e commit 4111e6d
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 6 deletions.
53 changes: 47 additions & 6 deletions g2p_id/g2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from builtins import str as unicode
from itertools import permutations
from typing import Dict, List, Tuple, Union
import asyncio
import libsql_client

import nltk
from nltk.tag.perceptron import PerceptronTagger
Expand All @@ -29,19 +31,40 @@
from g2p_id.bert import BERT
from g2p_id.lstm import LSTM
from g2p_id.text_processor import TextProcessor
from g2p_id.turso_db import _fetch_lexicon_from_turso, _fetch_homographs_from_turso

import logging

# Configure logging at the top of the file after imports
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
_LOGGER = logging.getLogger(__name__)


nltk.download("wordnet")
resources_path = os.path.join(os.path.dirname(__file__), "resources")


def construct_homographs_dictionary() -> Dict[str, Tuple[str, str, str, str]]:
"""Creates a dictionary of homographs
def construct_homographs_dictionary(turso_config=None) -> Dict[str, Tuple[str, str, str, str]]:
"""Creates a dictionary of homographs.
If turso_config is provided, fetches from Turso database.
Otherwise, loads from local TSV file.
Args:
turso_config: Optional dictionary containing Turso configuration
Returns:
Dict[str, Tuple[str, str, str, str]]:
Key: WORD
Value: (PH1, PH2, POS1, POS2)
"""
if turso_config:
_LOGGER.info("Loading homographs from Turso database...")
return asyncio.run(_fetch_homographs_from_turso(turso_config))

_LOGGER.info("Loading homographs from local TSV file...")
homograph_path = os.path.join(resources_path, "homographs_id.tsv")
homograph2features = {}
with open(homograph_path, encoding="utf-8") as file:
Expand All @@ -53,14 +76,27 @@ def construct_homographs_dictionary() -> Dict[str, Tuple[str, str, str, str]]:
return homograph2features


def construct_lexicon_dictionary() -> Dict[str, str]:
def construct_lexicon_dictionary(turso_config=None) -> Dict[str, str]:
"""Creates a lexicon dictionary.
If turso_config is provided, fetches from Turso database.
Otherwise, loads from local TSV file.
Args:
turso_config: Optional dictionary containing Turso configuration
- url: Turso database URL
- auth_token: Authentication token
- table: Table name
Returns:
Dict[str, str]:
Key: WORD
Value: Phoneme (IPA)
"""
if turso_config:
_LOGGER.info("Loading lexicon from Turso database...")
return asyncio.run(_fetch_lexicon_from_turso(turso_config))

_LOGGER.info("Loading lexicon from local TSV file...")
lexicon_path = os.path.join(resources_path, "lexicon_id.tsv")
lexicon2features = {}
with open(lexicon_path, encoding="utf-8") as file:
Expand All @@ -71,6 +107,10 @@ def construct_lexicon_dictionary() -> Dict[str, str]:
return lexicon2features


def construct_online_lexicon(turso_config):
return


class G2p:
"""Grapheme-to-phoneme (g2p) main class for phonemization.
This class provides a high-level API for grapheme-to-phoneme conversion.
Expand All @@ -84,16 +124,17 @@ class G2p:
7. Otherwise, predict with a neural network
"""

def __init__(self, model_type="BERT"):
def __init__(self, model_type="BERT", turso_config=None):
"""Constructor for G2p.
Args:
model_type (str, optional):
Type of neural network to use for prediction.
Choices are "LSTM" or "BERT". Defaults to "BERT".
"""
self.homograph2features = construct_homographs_dictionary()
self.lexicon2features = construct_lexicon_dictionary()
self.homograph2features = construct_homographs_dictionary(turso_config)
self.lexicon2features = construct_lexicon_dictionary(turso_config)

self.normalizer = TextProcessor()
self.tagger = PerceptronTagger(load=False)
tagger_path = os.path.join(resources_path, "id_posp_tagger.pickle")
Expand Down
179 changes: 179 additions & 0 deletions g2p_id/turso_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import typing
import asyncio #
import logging
import sqlite3
import libsql_client
import asyncio
import time # Add this import at the top
from typing import Dict, Tuple

from pathlib import Path
from gruut.const import PHONEMES_TYPE
from gruut.phonemize import SqlitePhonemizer

# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

_LOGGER = logging.getLogger(__name__)

async def _fetch_lexicon_from_turso(turso_config) -> Dict[str, str]:
"""Fetches lexicon from Turso database using batching and parallel queries.
Only fetches non-homographs (where role is empty string).
Args:
turso_config: Dictionary containing Turso configuration
- url: Turso database URL
- auth_token: Authentication token
- table: Table name
Returns:
Dict[str, str]: Dictionary mapping words to phonemes
"""
start_time = time.time()
_LOGGER.info("Starting to fetch lexicon from Turso...")

client = libsql_client.create_client(
url=turso_config["url"],
auth_token=turso_config["auth_token"]
)

try:
# Get total count first
count_result = await client.execute(
f"SELECT COUNT(*) FROM {turso_config['table']} WHERE role = ''"
)
total_count = count_result.rows[0][0]

# Initialize batch size
batch_size = 5000

# Create concurrent tasks for fetching data
tasks = []
for offset in range(0, total_count, batch_size):
query = f"""
SELECT word, phonemes
FROM {turso_config['table']}
WHERE role = ''
LIMIT {batch_size}
OFFSET {offset}
"""
tasks.append(client.execute(query))

# Execute all queries concurrently
results = await asyncio.gather(*tasks)

# Process results into dictionary
lexicon = {}
total_rows = 0
for result in results:
rows = result.rows
if rows:
total_rows += len(rows)
for row in rows:
lexicon[row[0].lower()] = row[1]

total_time = time.time() - start_time
_LOGGER.info(f"Loaded {total_rows} lexicon entries in {total_time:.2f} seconds")

return lexicon

finally:
await client.close()

async def _fetch_homographs_from_turso(turso_config) -> Dict[str, Tuple[str, str, str, str]]:
"""Fetches homographs from Turso database using batching and parallel queries.
Only fetches rows where role is not empty.
Args:
turso_config: Dictionary containing Turso configuration
- url: Turso database URL
- auth_token: Authentication token
- table: Table name
Returns:
Dict[str, Tuple[str, str, str, str]]:
Key: WORD
Value: (PH1, PH2, POS1, POS2)
"""
start_time = time.time()
_LOGGER.info("Starting to fetch homographs from Turso...")

client = libsql_client.create_client(
url=turso_config["url"],
auth_token=turso_config["auth_token"]
)
try:
# Get total count first for rows with non-empty roles
count_result = await client.execute(
f"SELECT COUNT(*) FROM {turso_config['table']} WHERE role != ''"
)
total_count = count_result.rows[0][0]
# Initialize batch size
batch_size = 5000

# Create concurrent tasks for fetching data
tasks = []
for offset in range(0, total_count, batch_size):
query = f"""
SELECT word, phonemes, role
FROM {turso_config['table']}
WHERE role != ''
ORDER BY word, role -- Ensure consistent ordering for pairs
LIMIT {batch_size}
OFFSET {offset}
"""
tasks.append(client.execute(query))

# Execute all queries concurrently
results = await asyncio.gather(*tasks)

# Process results into dictionary
homographs = {}
current_word = None
current_data = []
total_pairs = 0

# Process all results
for result in results:
for row in result.rows:
word, phonemes, role = row
word = word.lower()

if current_word != word:
# Store previous word's data if we have a complete pair
if len(current_data) == 2:
homographs[current_word] = (
current_data[0][0], # ph1
current_data[1][0], # ph2
current_data[0][1], # pos1
current_data[1][1] # pos2
)
total_pairs += 1
# Start new word
current_word = word
current_data = [(phonemes, role)]
else:
# Add second pronunciation for current word
if len(current_data) < 2:
current_data.append((phonemes, role))

# Don't forget to process the last word
if len(current_data) == 2:
homographs[current_word] = (
current_data[0][0], # ph1
current_data[1][0], # ph2
current_data[0][1], # pos1
current_data[1][1] # pos2
)
total_pairs += 1

total_time = time.time() - start_time
_LOGGER.info(f"Loaded {total_pairs} homograph pairs in {total_time:.2f} seconds")

return homographs

finally:
await client.close()
23 changes: 23 additions & 0 deletions tests/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
import libsql_client
from g2p_id import G2p

TURSO_URL = os.getenv("TURSO_URL")
TURSO_AUTH_TOKEN = os.getenv("TURSO_AUTH_TOKEN")

id_turso_config = {
"url": TURSO_URL,
"auth_token": TURSO_AUTH_TOKEN,
"table": "id_phonemes"
}


texts = [
"Apel itu berwarna merah.",
"Rahel bersekolah di S M A Jakarta 17.",
"Mereka sedang bermain bola di lapangan.",
]

g2p = G2p(turso_config=id_turso_config)
for text in texts:
print(g2p(text))
75 changes: 75 additions & 0 deletions tests/test_online_g2p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Tests the online g2p_id with the original g2p_id"""
import logging
import typing
import requests
import os
import logging
import unittest

from tqdm import tqdm
from g2p_id import G2p

logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logging.getLogger("gruut").setLevel(logging.INFO)

TURSO_URL = os.getenv("TURSO_URL")
TURSO_AUTH_TOKEN = os.getenv("TURSO_AUTH_TOKEN")

id_turso_config = {
"url": TURSO_URL,
"auth_token": TURSO_AUTH_TOKEN,
"table": "id_phonemes"
}

g2p_online = G2p(turso_config=id_turso_config)
g2p_offline = G2p()

def get_phonemes_online(text):
return g2p_online(text)

def get_phonemes_offline(text):
return g2p_offline(text)


class TestPhonemeFunctions(unittest.TestCase):
def test_phoneme_functions_english(self):
logging.info("Testing English language phonemes")
error_log_path = 'phoneme_comparison_errors.log'
file_path = '/home/s44504/3b01c699-3670-469b-801f-13880b9cac56/dataset_creation/data/indonesian_book_transcripts.txt'

with open(file_path, 'r') as file, open(error_log_path, 'w') as error_log:
lines = file.readlines()
for line_number, line in tqdm(enumerate(lines, start=1), total=len(lines), desc="Indonesia Test Progress"):
line = line.strip()
if not line:
continue # Skip empty lines

try:
online_phonemes = get_phonemes_online(line)
offline_phonemes = get_phonemes_offline(line)

if online_phonemes != offline_phonemes:
# Find the differing segments
differences = []
for i, (online_seg, offline_seg) in enumerate(zip(online_phonemes, offline_phonemes)):
if online_seg != offline_seg:
differences.append(f"Word {i}: {online_seg}{offline_seg}")

error_message = f"""Mismatch at line {line_number}:
Text: {line}
Differences: {' | '.join(differences)}
"""
error_log.write(error_message + "\n")
logging.warning(error_message)
except Exception as e:
error_message = f"Error processing line {line_number}: {line}\nError: {str(e)}\n"
error_log.write(error_message)
logging.error(error_message)

if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
unittest.main()

0 comments on commit 4111e6d

Please sign in to comment.