Skip to content

Commit

Permalink
Refactor deciphon CLI by removing unused files and enhancing logging
Browse files Browse the repository at this point in the history
- Removed `gencode.py`, `models.py`, `path_like.py`, `percent.py`, `progress.py`
from `cli/deciphon`, streamlining codebase 🎉.
- Enhanced `worker.py`:
- Added `__all__` for better module structuring.
- Integrated signal handling for SIGTERM.
- Improved logging with richer messages and log levels.
- Simplified `scan_request_key` inline.
- Introduced MQTT logging callback for detailed logs.
  • Loading branch information
horta committed Feb 10, 2025
1 parent aaf1e24 commit 5eb7670
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 181 deletions.
31 changes: 0 additions & 31 deletions cli/deciphon/gencode.py

This file was deleted.

58 changes: 0 additions & 58 deletions cli/deciphon/models.py

This file was deleted.

5 changes: 0 additions & 5 deletions cli/deciphon/path_like.py

This file was deleted.

21 changes: 0 additions & 21 deletions cli/deciphon/percent.py

This file was deleted.

43 changes: 0 additions & 43 deletions cli/deciphon/progress.py

This file was deleted.

64 changes: 41 additions & 23 deletions cli/deciphon/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
from deciphon.download import download
from deciphon.queue import ShuttableQueue, queue_loop

__all__ = ["LogLevel", "Worker", "WorkType", "setup_logger"]

info = logger.info
warn = logger.warning


class WorkType(str, Enum):
Expand Down Expand Up @@ -70,7 +73,6 @@ def __init__(
):
self._queue: ShuttableQueue[ScanRequest] = ShuttableQueue()
self._poster = poster

self._dbname = dbname
self._multi_hits = multi_hits
self._hmmer3_compat = hmmer3_compat
Expand All @@ -79,26 +81,33 @@ def __init__(

def _run(self):
signal.signal(signal.SIGINT, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal.SIG_IGN)

hmmpath = Path(self._dbname.hmmname.name)
dbpath = Path(self._dbname.name)

if not hmmpath.exists():
info(f"File <{hmmpath}> does not exist, preparing to download it")
with atomic_file_creation(hmmpath) as t:
url = self._poster.download_hmm_url(self._dbname.hmmname.name)
info(f"Downloading <{url}>")
info(f"Downloading <{url}>...")
download(url, t)

if not dbpath.exists():
info(f"File <{dbpath}> does not exist, preparing to download it")
with atomic_file_creation(dbpath) as t:
url = self._poster.download_db_url(self._dbname.name)
info(f"Downloading <{url}>")
info(f"Downloading <{url}>...")
download(url, t)

dbfile = DBFile(path=dbpath)
future = launch_scanner(
dbfile, self._multi_hits, self._hmmer3_compat, cache=True
multi_hits = self._multi_hits
hmmer3_compat = self._hmmer3_compat
info(
"Launching scanner for "
f"<{dbfile.path},multi_hits={multi_hits},hmmer3_compat={hmmer3_compat}>..."
)
future = launch_scanner(dbfile, multi_hits, hmmer3_compat, cache=True)

with future.result() as scanner:
for x in queue_loop(self._queue):
Expand All @@ -112,19 +121,18 @@ def _run(self):

task = scanner.put(snap, sequences)
for i in task.as_progress():
info(f"Progress {i}% on scan_id <{x.id}>...")
info(f"Progress {i}% on <scan_id={x.id}>...")
self._poster.job_patch(JobUpdate.run(x.job_id, i))

info(f"Finished scanning scan_id <{x.id}>")

snappath = task.result().path

self._poster.snap_post(x.id, snappath)
info(f"Finished posting {snappath}")
info(f"Finished posting <{snappath}>")

except Exception as exception:
fail = JobUpdate.fail(x.job_id, str(exception))
info(f"Failed to process {x}: {exception}")
warn(f"Failed to process <{x}>: {exception}")
self._poster.job_patch(fail)
finally:
if snappath is not None:
Expand Down Expand Up @@ -152,6 +160,7 @@ def __init__(self, poster: Poster):

def _run(self):
signal.signal(signal.SIGINT, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal.SIG_IGN)
for x in queue_loop(self._queue):
hmmpath: Path | None = None
dbfile: DBFile | None = None
Expand All @@ -164,7 +173,7 @@ def _run(self):
f"gencode{x.gencode}_epsilon{x.epsilon}_{hex}_{x.hmm.name}"
)

info(f"Downloading <{url}>")
info(f"Downloading <{url}>...")
download(url, hmmpath)
hmmfile = HMMFile(path=hmmpath)

Expand All @@ -182,7 +191,6 @@ def _run(self):

info(f"Uploading <{dbfile.path}>...")
post = self._poster.upload_db_post(x.db.name)

self._poster.upload(dbfile.path, post, name=x.db.name)
info(f"Finished uploading <{dbfile.path}>")

Expand All @@ -191,7 +199,7 @@ def _run(self):

except Exception as exception:
fail = JobUpdate.fail(x.job_id, str(exception))
info(f"Failed to process {x}: {exception}")
warn(f"Failed to process <{x}>: {exception}")
self._poster.job_patch(fail)
finally:
if hmmpath is not None:
Expand All @@ -212,10 +220,6 @@ def join(self):
self._process.join()


def scan_request_key(x: ScanRequest):
return f"{x.db.name}-{x.multi_hits}-{x.hmmer3_compat}"


class ScanConsumer:
def __init__(self, poster: Poster):
self._poster = poster
Expand All @@ -225,7 +229,8 @@ def add(self, payload: bytes):
r = ScanRequest.model_validate_json(payload)

info(f"Queuing scan request: {r}")
key = scan_request_key(r)

key = f"{r.db.name}-{r.multi_hits}-{r.hmmer3_compat}"

if key not in self._processes:
self._processes[key] = ScanProcess(
Expand Down Expand Up @@ -278,13 +283,22 @@ def __init__(
elif work == WorkType.press:
self._consumer = PressConsumer(poster)

mqtt_host = mqtt_host
mqtt_port = mqtt_port
def log_callback(client: Client, userdata, level: int, buf: str):
from paho.mqtt.enums import LogLevel

self._client = paho.Client(paho.CallbackAPIVersion.VERSION2) # type: ignore
self._client.enable_logger()
self._client.user_data_set(self._consumer)
self._client.connect(mqtt_host, mqtt_port)
del client
del userdata

if level == LogLevel.MQTT_LOG_INFO:
logger.info(buf)
if level == LogLevel.MQTT_LOG_NOTICE:
logger.info(buf)
if level == LogLevel.MQTT_LOG_WARNING:
logger.warning(buf)
if level == LogLevel.MQTT_LOG_ERR:
logger.error(buf)
if level == LogLevel.MQTT_LOG_DEBUG:
logger.debug(buf)

def on_connect(
client: Client,
Expand All @@ -304,8 +318,12 @@ def on_message(_, consumer: ScanConsumer | PressConsumer, msg: MQTTMessage):
info(f"Received: {msg.payload}")
consumer.add(msg.payload)

self._client = paho.Client(paho.CallbackAPIVersion.VERSION2) # type: ignore
self._client.on_log = log_callback
self._client.on_connect = on_connect
self._client.on_message = on_message
self._client.user_data_set(self._consumer)
self._client.connect(mqtt_host, mqtt_port)

def loop_forever(self):
self._client.loop_forever()
Expand Down

0 comments on commit 5eb7670

Please sign in to comment.