Skip to content

Commit

Permalink
Merge pull request #75 from medema-group/feature/gcf-tree
Browse files Browse the repository at this point in the history
Feature/gcf tree
  • Loading branch information
adraismawur authored Nov 6, 2023
2 parents ecca9f1 + 8dfeb5f commit 69e2784
Show file tree
Hide file tree
Showing 15 changed files with 1,149 additions and 71 deletions.
9 changes: 9 additions & 0 deletions big_scape/comparison/comparable_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# from other modules
from big_scape.genbank import BGCRecord
from big_scape.hmm import HSP
from big_scape.enums import ALIGNMENT_MODE

# from this module
from .legacy_lcs import legacy_find_cds_lcs
Expand All @@ -35,6 +36,7 @@ class ComparableRegion:
domain_lists: Optional[tuple[list[HSP], list[HSP]]]
domain_sets: Optional[tuple[set[HSP], set[HSP]]]
domain_dicts: Optional[tuple[dict[HSP, list[int]], dict[HSP, list[int]]]]
alignment_mode: ALIGNMENT_MODE
"""

def __init__(
Expand All @@ -47,6 +49,7 @@ def __init__(
reverse: bool,
):
self.pair = pair
# store possibly extended comparable region
self.a_start = a_start
self.b_start = b_start

Expand All @@ -60,6 +63,12 @@ def __init__(
self.domain_dicts: Optional[
tuple[dict[HSP, list[int]], dict[HSP, list[int]]]
] = None
# store lcs without any extensions
self.lcs_a_start = a_start
self.lcs_b_start = b_start
self.lcs_a_stop = a_stop
self.lcs_b_stop = b_stop
self.alignment_mode: ALIGNMENT_MODE = ALIGNMENT_MODE.GLOBAL

def get_domain_sets(
self, regenerate=False, cache=True
Expand Down
132 changes: 114 additions & 18 deletions big_scape/comparison/legacy_workflow_alt.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,7 @@ def generate_edges(
if len(results) != len(task_batch):
raise ValueError("Mismatch between task length and result length")

for idx, pair in enumerate(task_batch):
distance, jaccard, adjacency, dss = results[idx]
yield (
pair.region_a._db_id,
pair.region_b._db_id,
distance,
jaccard,
adjacency,
dss,
pair_generator.weights,
)
yield from results

done_pairs += len(results)
if callback:
Expand All @@ -198,6 +188,10 @@ def do_lcs_pair(
)

# set the comparable region
pair.comparable_region.lcs_a_start = a_start
pair.comparable_region.lcs_a_stop = a_stop
pair.comparable_region.lcs_b_start = b_start
pair.comparable_region.lcs_b_stop = b_stop
pair.comparable_region.a_start = a_start
pair.comparable_region.a_stop = a_stop
pair.comparable_region.b_start = b_start
Expand Down Expand Up @@ -227,37 +221,99 @@ def expand_pair(pair: RecordPair) -> float:
jc = calc_jaccard_pair(pair)
return jc

pair.comparable_region.alignment_mode = bs_enums.ALIGNMENT_MODE.GLOCAL
jc = calc_jaccard_pair(pair)

return jc


def calculate_scores_pair(
data: tuple[list[RecordPair], bs_enums.ALIGNMENT_MODE, str]
) -> list[tuple[float, float, float, float]]: # pragma no cover
) -> list[
tuple[
Optional[int],
Optional[int],
float,
float,
float,
float,
str,
int,
int,
int,
int,
int,
int,
int,
int,
bool,
str,
]
]: # pragma no cover
"""Calculate the scores for a list of pairs
Args:
data (tuple[list[RecordPair], str, str]): list of pairs, alignment mode, bin
label
Returns:
list[tuple[float, float, float, float]]: list of scores for each pair in the
order as the input data list
list[tuple[int, int, float, float, float, float, int, int, int, int, int, int,
int, int, bool, str,]]: list of scores for each pair in the
order as the input data list, including lcs and extension coordinates
"""
pairs, alignment_mode, weights_label = data

results = []

for pair in pairs:
if pair.region_a.parent_gbk == pair.region_b.parent_gbk:
results.append((0.0, 1.0, 1.0, 1.0))
results.append(
(
pair.region_a._db_id,
pair.region_b._db_id,
0.0,
1.0,
1.0,
1.0,
weights_label,
pair.comparable_region.lcs_a_start,
pair.comparable_region.lcs_a_stop,
pair.comparable_region.lcs_b_start,
pair.comparable_region.lcs_b_stop,
pair.comparable_region.a_start,
pair.comparable_region.a_stop,
pair.comparable_region.b_start,
pair.comparable_region.b_stop,
pair.comparable_region.reverse,
pair.comparable_region.alignment_mode.value,
)
)
continue

jaccard = calc_jaccard_pair(pair)

if jaccard == 0.0:
results.append((1.0, 0.0, 0.0, 0.0))
results.append(
(
pair.region_a._db_id,
pair.region_b._db_id,
1.0,
0.0,
0.0,
0.0,
weights_label,
pair.comparable_region.lcs_a_start,
pair.comparable_region.lcs_a_stop,
pair.comparable_region.lcs_b_start,
pair.comparable_region.lcs_b_stop,
pair.comparable_region.a_start,
pair.comparable_region.a_stop,
pair.comparable_region.b_start,
pair.comparable_region.b_stop,
pair.comparable_region.reverse,
pair.comparable_region.alignment_mode.value,
)
)
continue

# in the form [bool, Pair]. true bools means they need expansion, false they don't
Expand All @@ -267,7 +323,27 @@ def calculate_scores_pair(
jaccard = expand_pair(pair)

if jaccard == 0.0:
results.append((1.0, 0.0, 0.0, 0.0))
results.append(
(
pair.region_a._db_id,
pair.region_b._db_id,
1.0,
0.0,
0.0,
0.0,
weights_label,
pair.comparable_region.lcs_a_start,
pair.comparable_region.lcs_a_stop,
pair.comparable_region.lcs_b_start,
pair.comparable_region.lcs_b_stop,
pair.comparable_region.a_start,
pair.comparable_region.a_stop,
pair.comparable_region.b_start,
pair.comparable_region.b_stop,
pair.comparable_region.reverse,
pair.comparable_region.alignment_mode.value,
)
)
continue

if weights_label not in LEGACY_WEIGHTS:
Expand All @@ -285,6 +361,26 @@ def calculate_scores_pair(
similarity = jaccard * jc_weight + adjacency * ai_weight + dss * dss_weight
distance = 1 - similarity

results.append((distance, jaccard, adjacency, dss))
results.append(
(
pair.region_a._db_id,
pair.region_b._db_id,
distance,
jaccard,
adjacency,
dss,
weights_label,
pair.comparable_region.lcs_a_start,
pair.comparable_region.lcs_a_stop,
pair.comparable_region.lcs_b_start,
pair.comparable_region.lcs_b_stop,
pair.comparable_region.a_start,
pair.comparable_region.a_stop,
pair.comparable_region.b_start,
pair.comparable_region.b_stop,
pair.comparable_region.reverse,
pair.comparable_region.alignment_mode.value,
)
)

return results
85 changes: 78 additions & 7 deletions big_scape/comparison/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,60 @@
# from other modules
from big_scape.data import DB
from big_scape.comparison.binning import RecordPairGenerator, RecordPair
from big_scape.enums import ALIGNMENT_MODE


def save_edge_to_db(
edge: tuple[int, int, float, float, float, float, str], upsert=False
edge: tuple[
int,
int,
float,
float,
float,
float,
str,
int,
int,
int,
int,
int,
int,
int,
int,
bool,
ALIGNMENT_MODE,
],
upsert=False,
) -> None:
"""Save edge to the database
Args:
edge (tuple[int, int, float, float, float, float]): edge tuple containing
region_a_id, region_b_id, distance, jaccard, adjacency, dss
edge (tuple[int, int, float, float, float, float, str, int, int, int, int, int,
int, int, int, bool, ALIGNMENT_MODE,]): edge tuple containing
region_a_id, region_b_id, distance, jaccard, adjacency, dss, weights,
lcs start/stop, extension start/stop, reverse, alignment_mode
upsert (bool, optional): whether to upsert the edge into the database.
"""

region_a_id, region_b_id, distance, jaccard, adjacency, dss, weights = edge
(
region_a_id,
region_b_id,
distance,
jaccard,
adjacency,
dss,
weights,
lcs_a_start,
lcs_a_stop,
lcs_b_start,
lcs_b_stop,
ext_a_start,
ext_a_stop,
ext_b_start,
ext_b_stop,
reverse,
alignment_mode,
) = edge

# save the comparison data to the database

Expand All @@ -41,6 +81,16 @@ def save_edge_to_db(
adjacency=adjacency,
dss=dss,
weights=weights,
lcs_a_start=lcs_a_start,
lcs_a_stop=lcs_a_stop,
lcs_b_start=lcs_b_start,
lcs_b_stop=lcs_b_stop,
ext_a_start=ext_a_start,
ext_a_stop=ext_a_stop,
ext_b_start=ext_b_start,
ext_b_stop=ext_b_stop,
reverse=reverse,
alignment_mode=alignment_mode.value,
)

if upsert:
Expand All @@ -50,12 +100,33 @@ def save_edge_to_db(


def save_edges_to_db(
edges: list[tuple[int, int, float, float, float, float, str]]
edges: list[
tuple[
int,
int,
float,
float,
float,
float,
str,
int,
int,
int,
int,
int,
int,
int,
int,
bool,
str,
]
]
) -> None:
"""Save many edges to the database
Args:
edges (list[tuple[int, int, float, float, float, float, str]]): list of edges to save
edges (list[tuple[int, int, float, float, float, float, str, int, int, int, int,
int, int, int, int, bool, str]]): list of edges to save
"""
# save the comparison data to the database
# using raw sqlite for this because sqlalchemy is not fast enough
Expand All @@ -75,7 +146,7 @@ def save_edges_to_db(
# create a query
# TODO: this should not need ignore. it's there now because protoclusters somehow
# trigger an integrityerror
query = "INSERT OR IGNORE INTO distance VALUES (?, ?, ?, ?, ?, ?, ?)"
query = "INSERT OR IGNORE INTO distance VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"

cursor.executemany(query, edges)

Expand Down
10 changes: 10 additions & 0 deletions big_scape/data/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ CREATE TABLE IF NOT EXISTS distance (
adjacency REAL NOT NULL,
dss REAL NOT NULL,
weights TEXT NOT NULL,
lcs_a_start INTEGER NOT NULL,
lcs_a_stop INTEGER NOT NULL,
lcs_b_start INTEGER NOT NULL,
lcs_b_stop INTEGER NOT NULL,
ext_a_start INTEGER NOT NULL,
ext_a_stop INTEGER NOT NULL,
ext_b_start INTEGER NOT NULL,
ext_b_stop INTEGER NOT NULL,
reverse BOOLEAN NOT NULL,
alignment_mode TEXT NOT NULL,
UNIQUE(region_a_id, region_b_id, weights)
FOREIGN KEY(region_a_id) REFERENCES bgc_record(id)
FOREIGN KEY(region_b_id) REFERENCES bgc_record(id)
Expand Down
Loading

0 comments on commit 69e2784

Please sign in to comment.