Skip to content

Commit

Permalink
feat: automatically unpack ml model archive
Browse files Browse the repository at this point in the history
  • Loading branch information
igboyes committed Jan 19, 2024
1 parent 061bc84 commit fda4b8a
Show file tree
Hide file tree
Showing 36 changed files with 1,180 additions and 1,041 deletions.
8 changes: 5 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
repos:
- repo: https://github.com/psf/black
rev: 22.3.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.13
hooks:
- id: black
- id: ruff
args: [ --fix ]
- id: ruff-format
Binary file modified example/ml/model.tar.gz
Binary file not shown.
16 changes: 0 additions & 16 deletions main.py

This file was deleted.

1,724 changes: 977 additions & 747 deletions poetry.lock

Large diffs are not rendered by default.

10 changes: 1 addition & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,9 @@ exclude = [
".ruff_cache",
"__pypackages__",
]
indent-width = 4
line-length = 88
target-version = "py310"

[tool.ruff.lint]
fixable = ["ALL", "I001"]

[tool.ruff.format]
indent-style = "space"
line-ending = "auto"
quote-style = "double"
select = ["ALL"]

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
4 changes: 1 addition & 3 deletions tests/api/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@


async def test_retry(caplog):
"""
Test that the retry utility retries failing HTTP requests and logs the attempts.
"""
"""Test that the retry utility retries failing HTTP requests and logs the attempts."""
caplog.set_level(logging.INFO)

class Retry:
Expand Down
14 changes: 6 additions & 8 deletions tests/data/test_analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
import pytest
from pyfixtures import FixtureScope

from virtool_workflow.pytest_plugin.data import Data
from virtool_workflow.data.analyses import WFAnalysis
from virtool_workflow.errors import JobsAPINotFound, JobsAPIConflict
from virtool_workflow.errors import JobsAPIConflict, JobsAPINotFound
from virtool_workflow.pytest_plugin.data import Data


async def test_ok(data: Data, scope: FixtureScope):
"""
Test that the analysis fixture returns an Analysis object with the expected values.
"""Test that the analysis fixture returns an Analysis object with the expected values.
"""
data.job.args["analysis_id"] = data.analysis.id

Expand All @@ -28,10 +27,9 @@ async def test_not_found(data: Data, scope: FixtureScope):


async def test_upload_file(
captured_uploads_path: Path, data: Data, scope: FixtureScope, work_path: Path
captured_uploads_path: Path, data: Data, scope: FixtureScope, work_path: Path,
):
"""
Test that the ``Analysis`` object returned by the fixture can be used to upload an
"""Test that the ``Analysis`` object returned by the fixture can be used to upload an
analysis file.
"""
...
Expand All @@ -42,7 +40,7 @@ async def test_upload_file(

path = work_path / "blank.txt"

with open(path, "wt") as f:
with open(path, "w") as f:
f.write("hello world")

await analysis.upload_file(path, "unknown")
Expand Down
28 changes: 14 additions & 14 deletions tests/data/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import pytest
from pyfixtures import FixtureScope
from syrupy import SnapshotAssertion
from syrupy import SnapshotSession

from virtool_workflow.pytest_plugin.data import Data
from virtool_workflow.data.ml import WFMLModelRelease
from virtool_workflow.pytest_plugin.data import Data

test = SnapshotSession

Expand All @@ -21,11 +20,8 @@ async def test_ok(
data: Data,
example_path: Path,
scope: FixtureScope,
snapshot: SnapshotAssertion,
work_path: Path,
):
"""
Test that the ML fixture instantiates, contains the expected data, and
"""Test that the ML fixture instantiates, contains the expected data, and
downloads the sample files to the work path.
"""
data.job.args["analysis_id"] = data.analysis.id
Expand All @@ -39,16 +35,20 @@ async def test_ok(
assert ml.name == data.ml.name

assert ml.path.is_dir()
assert (ml.path / "model.tar.gz").is_file()
assert ml.file_path == ml.path / "model.tar.gz"

assert (
open(ml.file_path, "rb").read()
== open(example_path / "ml/model.tar.gz", "rb").read()
)
assert sorted([p.name for p in ml.path.iterdir()]) == [
"mappability_profile.rds",
"model.tar.gz",
"nucleotide_info.csv",
"reference.json.gz",
"trained_rf.rds",
"trained_xgb.rds",
"virus_segments.rds",
]

async def test_none(
self, data: Data, scope: FixtureScope, snapshot: SnapshotAssertion
self,
data: Data,
scope: FixtureScope,
):
"""Test that the ML fixture returns None when no ML model is specified."""
data.job.args["analysis_id"] = data.analysis.id
Expand Down
33 changes: 15 additions & 18 deletions virtool_workflow/analysis/fastqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from __future__ import annotations

import asyncio
import statistics
import shutil
import statistics
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Protocol, TextIO, IO
from typing import IO, Protocol, TextIO

from pyfixtures import fixture

Expand Down Expand Up @@ -49,16 +49,16 @@ def composite(self, parser: BaseQualityParser):
mean=statistics.mean([this.mean, other.mean]),
median=statistics.mean([this.median, other.median]),
lower_quartile=statistics.mean(
[this.lower_quartile, other.lower_quartile]
[this.lower_quartile, other.lower_quartile],
),
upper_quartile=statistics.mean(
[this.upper_quartile, other.upper_quartile]
[this.upper_quartile, other.upper_quartile],
),
tenth_percentile=statistics.mean(
[this.tenth_percentile, other.tenth_percentile]
[this.tenth_percentile, other.tenth_percentile],
),
ninetieth_percentile=statistics.mean(
[this.ninetieth_percentile, other.ninetieth_percentile]
[this.ninetieth_percentile, other.ninetieth_percentile],
),
)
for this, other in zip(self.data, parser.data)
Expand Down Expand Up @@ -109,7 +109,7 @@ def handle(self, f: TextIO):
upper_quartile=upper_quartile,
tenth_percentile=tenth_percentile,
ninetieth_percentile=ninetieth_percentile,
)
),
)

if i - max_index != 1:
Expand Down Expand Up @@ -208,7 +208,7 @@ def handle(self, f: TextIO):
split = line.split()

try:
g, a, t, c = [float(value) for value in split[1:]]
g, a, t, c = (float(value) for value in split[1:])
except ValueError as err:
if "NaN" not in str(err):
raise
Expand Down Expand Up @@ -276,8 +276,7 @@ def _calculate_index_range(base: str) -> range:


def _handle_base_quality_nan(split_line: list) -> list:
"""
Parse a per-base quality line from FastQC containing NaN values.
"""Parse a per-base quality line from FastQC containing NaN values.
:param split_line: the quality line split into a :class:`.List`
:return: replacement values
Expand All @@ -301,8 +300,7 @@ def _handle_base_quality_nan(split_line: list) -> list:


def _parse_fastqc(fastqc_path: Path, output_path: Path) -> dict:
"""
Parse the FastQC results at `fastqc_path`.
"""Parse the FastQC results at `fastqc_path`.
All FastQC data except the textual data file are removed.
Expand Down Expand Up @@ -333,7 +331,7 @@ def _parse_fastqc(fastqc_path: Path, output_path: Path) -> dict:
nucleotide_composition = NucleotideCompositionParser()
sequence_quality = SequenceQualityParser()

with open(new_path, "r") as f:
with open(new_path) as f:
while True:
line = f.readline()

Expand All @@ -358,7 +356,7 @@ def _parse_fastqc(fastqc_path: Path, output_path: Path) -> dict:
basic_statistics=basic_statistics,
nucleotide_composition=nucleotide_composition,
sequence_quality=sequence_quality,
)
),
)

if len(sides) == 1:
Expand Down Expand Up @@ -412,7 +410,7 @@ def _parse_fastqc(fastqc_path: Path, output_path: Path) -> dict:
"composition": [
[round(n, 1) for n in [point.g, point.a, point.t, point.c]]
for point in left.nucleotide_composition.composite(
right.nucleotide_composition
right.nucleotide_composition,
).data
],
"count": basic.count,
Expand All @@ -432,14 +430,13 @@ async def __call__(self, paths: ReadPaths, output_path: Path) -> dict:

@fixture
async def fastqc(run_subprocess: RunSubprocess):
"""
Provides an asynchronous function that can run FastQC as a subprocess.
"""Provides an asynchronous function that can run FastQC as a subprocess.
The function takes a one or two paths to FASTQ read files (:class:`.ReadPaths`) in
a tuple.
Example:
-------
.. code-block:: python
@step
Expand Down
35 changes: 14 additions & 21 deletions virtool_workflow/analysis/skewer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Utilities and a fixture for using `Skewer <https://github.com/relipmoc/skewer>`_ to trim reads.
"""Utilities and a fixture for using `Skewer <https://github.com/relipmoc/skewer>`_ to
trim reads.
"""
import asyncio
import os
Expand Down Expand Up @@ -85,18 +85,16 @@ class SkewerResult:

@property
def left(self) -> Path:
"""
The path to one of:
- the FASTQ trimming result for an unpaired Illumina dataset
- the FASTA trimming result for the left reads of a paired Illumina dataset
"""The path to one of:
- the FASTQ trimming result for an unpaired Illumina dataset
- the FASTA trimming result for the left reads of a paired Illumina dataset
"""
return self.read_paths[0]

@property
def right(self) -> Path | None:
"""
The path to the rights reads of a paired Illumina dataset.
"""The path to the rights reads of a paired Illumina dataset.
``None`` if the dataset in unpaired.
Expand All @@ -110,16 +108,14 @@ def right(self) -> Path | None:


def calculate_skewer_trimming_parameters(
sample: WFSample, min_read_length: int
sample: WFSample, min_read_length: int,
) -> SkewerConfiguration:
"""
Calculates trimming parameters based on the library type, and minimum allowed trim length.
"""Calculates trimming parameters based on the library type, and minimum allowed trim length.
:param sample: The sample to calculate trimming parameters for.
:param min_read_length: The minimum length of a read before it is discarded.
:return: the trimming parameters
"""

config = SkewerConfiguration(
min_length=min_read_length,
mode=SkewerMode.PAIRED_END if sample.paired else SkewerMode.SINGLE_END,
Expand All @@ -145,15 +141,14 @@ class SkewerRunner(Protocol):
"""A protocol describing callables that can be used to run Skewer."""

async def __call__(
self, config: SkewerConfiguration, paths: ReadPaths, output_path: Path
self, config: SkewerConfiguration, paths: ReadPaths, output_path: Path,
) -> SkewerResult:
...


@fixture
def skewer(proc: int, run_subprocess: RunSubprocess) -> SkewerRunner:
"""
Provides an asynchronous function that can run skewer.
"""Provides an asynchronous function that can run skewer.
The provided function takes a :class:`.SkewerConfiguration` and a tuple of paths to
the left and right reads to trim. If a single member tuple is provided, the dataset
Expand All @@ -163,7 +158,7 @@ def skewer(proc: int, run_subprocess: RunSubprocess) -> SkewerRunner:
for the workflow run.
Example:
-------
.. code-block:: python
@step
Expand All @@ -183,7 +178,7 @@ async def step_one(skewer: SkewerRunner, work_path: Path):
raise RuntimeError("skewer is not installed.")

async def func(
config: SkewerConfiguration, read_paths: ReadPaths, output_path: Path
config: SkewerConfiguration, read_paths: ReadPaths, output_path: Path,
):
temp_path = Path(await asyncio.to_thread(mkdtemp, suffix="_virtool_skewer"))

Expand Down Expand Up @@ -224,7 +219,7 @@ async def func(
)

read_paths = await asyncio.to_thread(
_rename_trimming_results, temp_path, output_path
_rename_trimming_results, temp_path, output_path,
)

return SkewerResult(command, output_path, process, read_paths)
Expand All @@ -233,12 +228,10 @@ async def func(


def _rename_trimming_results(temp_path: Path, output_path: Path) -> ReadPaths:
"""
Rename Skewer output to a simple name used in Virtool.
"""Rename Skewer output to a simple name used in Virtool.
:param path: The path containing the results from Skewer
"""

shutil.move(
temp_path / "reads-trimmed.log",
output_path / "trim.log",
Expand Down
Loading

0 comments on commit fda4b8a

Please sign in to comment.