-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added option toload database from turso
- Loading branch information
1 parent
7e1492e
commit 4111e6d
Showing
4 changed files
with
324 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import typing | ||
import asyncio # | ||
import logging | ||
import sqlite3 | ||
import libsql_client | ||
import asyncio | ||
import time # Add this import at the top | ||
from typing import Dict, Tuple | ||
|
||
from pathlib import Path | ||
from gruut.const import PHONEMES_TYPE | ||
from gruut.phonemize import SqlitePhonemizer | ||
|
||
# Configure logging | ||
logging.basicConfig( | ||
level=logging.INFO, | ||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||
) | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
async def _fetch_lexicon_from_turso(turso_config) -> Dict[str, str]: | ||
"""Fetches lexicon from Turso database using batching and parallel queries. | ||
Only fetches non-homographs (where role is empty string). | ||
Args: | ||
turso_config: Dictionary containing Turso configuration | ||
- url: Turso database URL | ||
- auth_token: Authentication token | ||
- table: Table name | ||
Returns: | ||
Dict[str, str]: Dictionary mapping words to phonemes | ||
""" | ||
start_time = time.time() | ||
_LOGGER.info("Starting to fetch lexicon from Turso...") | ||
|
||
client = libsql_client.create_client( | ||
url=turso_config["url"], | ||
auth_token=turso_config["auth_token"] | ||
) | ||
|
||
try: | ||
# Get total count first | ||
count_result = await client.execute( | ||
f"SELECT COUNT(*) FROM {turso_config['table']} WHERE role = ''" | ||
) | ||
total_count = count_result.rows[0][0] | ||
|
||
# Initialize batch size | ||
batch_size = 5000 | ||
|
||
# Create concurrent tasks for fetching data | ||
tasks = [] | ||
for offset in range(0, total_count, batch_size): | ||
query = f""" | ||
SELECT word, phonemes | ||
FROM {turso_config['table']} | ||
WHERE role = '' | ||
LIMIT {batch_size} | ||
OFFSET {offset} | ||
""" | ||
tasks.append(client.execute(query)) | ||
|
||
# Execute all queries concurrently | ||
results = await asyncio.gather(*tasks) | ||
|
||
# Process results into dictionary | ||
lexicon = {} | ||
total_rows = 0 | ||
for result in results: | ||
rows = result.rows | ||
if rows: | ||
total_rows += len(rows) | ||
for row in rows: | ||
lexicon[row[0].lower()] = row[1] | ||
|
||
total_time = time.time() - start_time | ||
_LOGGER.info(f"Loaded {total_rows} lexicon entries in {total_time:.2f} seconds") | ||
|
||
return lexicon | ||
|
||
finally: | ||
await client.close() | ||
|
||
async def _fetch_homographs_from_turso(turso_config) -> Dict[str, Tuple[str, str, str, str]]: | ||
"""Fetches homographs from Turso database using batching and parallel queries. | ||
Only fetches rows where role is not empty. | ||
Args: | ||
turso_config: Dictionary containing Turso configuration | ||
- url: Turso database URL | ||
- auth_token: Authentication token | ||
- table: Table name | ||
Returns: | ||
Dict[str, Tuple[str, str, str, str]]: | ||
Key: WORD | ||
Value: (PH1, PH2, POS1, POS2) | ||
""" | ||
start_time = time.time() | ||
_LOGGER.info("Starting to fetch homographs from Turso...") | ||
|
||
client = libsql_client.create_client( | ||
url=turso_config["url"], | ||
auth_token=turso_config["auth_token"] | ||
) | ||
try: | ||
# Get total count first for rows with non-empty roles | ||
count_result = await client.execute( | ||
f"SELECT COUNT(*) FROM {turso_config['table']} WHERE role != ''" | ||
) | ||
total_count = count_result.rows[0][0] | ||
# Initialize batch size | ||
batch_size = 5000 | ||
|
||
# Create concurrent tasks for fetching data | ||
tasks = [] | ||
for offset in range(0, total_count, batch_size): | ||
query = f""" | ||
SELECT word, phonemes, role | ||
FROM {turso_config['table']} | ||
WHERE role != '' | ||
ORDER BY word, role -- Ensure consistent ordering for pairs | ||
LIMIT {batch_size} | ||
OFFSET {offset} | ||
""" | ||
tasks.append(client.execute(query)) | ||
|
||
# Execute all queries concurrently | ||
results = await asyncio.gather(*tasks) | ||
|
||
# Process results into dictionary | ||
homographs = {} | ||
current_word = None | ||
current_data = [] | ||
total_pairs = 0 | ||
|
||
# Process all results | ||
for result in results: | ||
for row in result.rows: | ||
word, phonemes, role = row | ||
word = word.lower() | ||
|
||
if current_word != word: | ||
# Store previous word's data if we have a complete pair | ||
if len(current_data) == 2: | ||
homographs[current_word] = ( | ||
current_data[0][0], # ph1 | ||
current_data[1][0], # ph2 | ||
current_data[0][1], # pos1 | ||
current_data[1][1] # pos2 | ||
) | ||
total_pairs += 1 | ||
# Start new word | ||
current_word = word | ||
current_data = [(phonemes, role)] | ||
else: | ||
# Add second pronunciation for current word | ||
if len(current_data) < 2: | ||
current_data.append((phonemes, role)) | ||
|
||
# Don't forget to process the last word | ||
if len(current_data) == 2: | ||
homographs[current_word] = ( | ||
current_data[0][0], # ph1 | ||
current_data[1][0], # ph2 | ||
current_data[0][1], # pos1 | ||
current_data[1][1] # pos2 | ||
) | ||
total_pairs += 1 | ||
|
||
total_time = time.time() - start_time | ||
_LOGGER.info(f"Loaded {total_pairs} homograph pairs in {total_time:.2f} seconds") | ||
|
||
return homographs | ||
|
||
finally: | ||
await client.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os | ||
import libsql_client | ||
from g2p_id import G2p | ||
|
||
TURSO_URL = os.getenv("TURSO_URL") | ||
TURSO_AUTH_TOKEN = os.getenv("TURSO_AUTH_TOKEN") | ||
|
||
id_turso_config = { | ||
"url": TURSO_URL, | ||
"auth_token": TURSO_AUTH_TOKEN, | ||
"table": "id_phonemes" | ||
} | ||
|
||
|
||
texts = [ | ||
"Apel itu berwarna merah.", | ||
"Rahel bersekolah di S M A Jakarta 17.", | ||
"Mereka sedang bermain bola di lapangan.", | ||
] | ||
|
||
g2p = G2p(turso_config=id_turso_config) | ||
for text in texts: | ||
print(g2p(text)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""Tests the online g2p_id with the original g2p_id""" | ||
import logging | ||
import typing | ||
import requests | ||
import os | ||
import logging | ||
import unittest | ||
|
||
from tqdm import tqdm | ||
from g2p_id import G2p | ||
|
||
logging.basicConfig( | ||
level=logging.DEBUG, | ||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | ||
datefmt='%Y-%m-%d %H:%M:%S' | ||
) | ||
logging.getLogger("gruut").setLevel(logging.INFO) | ||
|
||
TURSO_URL = os.getenv("TURSO_URL") | ||
TURSO_AUTH_TOKEN = os.getenv("TURSO_AUTH_TOKEN") | ||
|
||
id_turso_config = { | ||
"url": TURSO_URL, | ||
"auth_token": TURSO_AUTH_TOKEN, | ||
"table": "id_phonemes" | ||
} | ||
|
||
g2p_online = G2p(turso_config=id_turso_config) | ||
g2p_offline = G2p() | ||
|
||
def get_phonemes_online(text): | ||
return g2p_online(text) | ||
|
||
def get_phonemes_offline(text): | ||
return g2p_offline(text) | ||
|
||
|
||
class TestPhonemeFunctions(unittest.TestCase): | ||
def test_phoneme_functions_english(self): | ||
logging.info("Testing English language phonemes") | ||
error_log_path = 'phoneme_comparison_errors.log' | ||
file_path = '/home/s44504/3b01c699-3670-469b-801f-13880b9cac56/dataset_creation/data/indonesian_book_transcripts.txt' | ||
|
||
with open(file_path, 'r') as file, open(error_log_path, 'w') as error_log: | ||
lines = file.readlines() | ||
for line_number, line in tqdm(enumerate(lines, start=1), total=len(lines), desc="Indonesia Test Progress"): | ||
line = line.strip() | ||
if not line: | ||
continue # Skip empty lines | ||
|
||
try: | ||
online_phonemes = get_phonemes_online(line) | ||
offline_phonemes = get_phonemes_offline(line) | ||
|
||
if online_phonemes != offline_phonemes: | ||
# Find the differing segments | ||
differences = [] | ||
for i, (online_seg, offline_seg) in enumerate(zip(online_phonemes, offline_phonemes)): | ||
if online_seg != offline_seg: | ||
differences.append(f"Word {i}: {online_seg} ≠ {offline_seg}") | ||
|
||
error_message = f"""Mismatch at line {line_number}: | ||
Text: {line} | ||
Differences: {' | '.join(differences)} | ||
""" | ||
error_log.write(error_message + "\n") | ||
logging.warning(error_message) | ||
except Exception as e: | ||
error_message = f"Error processing line {line_number}: {line}\nError: {str(e)}\n" | ||
error_log.write(error_message) | ||
logging.error(error_message) | ||
|
||
if __name__ == '__main__': | ||
logging.basicConfig(level=logging.INFO) | ||
unittest.main() |