Skip to content

Commit

Permalink
Merge pull request #199 from medema-group/hotfix/greedy-simple-extends
Browse files Browse the repository at this point in the history
Hotfix/greedy simple extends
  • Loading branch information
nlouwen authored Oct 23, 2024
2 parents 1210a0a + 122c9d0 commit a902b61
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 10 deletions.
30 changes: 21 additions & 9 deletions big_scape/comparison/extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,20 +540,21 @@ def extend_greedy(pair: RecordPair) -> None:
logging.debug("before greedy extend:")
logging.debug(pair.comparable_region)

a_cds = list(pair.record_a.get_cds())
b_cds = list(pair.record_b.get_cds())
a_domains = list(pair.record_a.get_hsps())
b_domains = list(pair.record_b.get_hsps())

if pair.comparable_region.reverse:
b_domains = b_domains[::-1]

# index of domain to cds position
a_index = get_target_indexes(a_domains)
b_index = get_target_indexes(b_domains)

common_domains = set(a_index.keys()).intersection(set(b_index.keys()))

a_cds_min = len(a_cds)
a_cds_min = len(list(pair.record_a.get_cds()))
a_cds_max = 0
b_cds_min = len(b_cds)
b_cds_min = len(list(pair.record_b.get_cds()))
b_cds_max = 0

a_domain_min = len(a_domains)
Expand Down Expand Up @@ -613,12 +614,23 @@ def extend_simple_match(pair: RecordPair, match, gap):

# so we'll do a loop through cds and through domains to keep track of everything
for cds_idx, cds in enumerate(pair.record_a.get_cds_with_domains()):
for domain in cds.hsps:
a_domains.append((domain, cds_idx))
if cds.strand == 1:
a_domains.extend([(domain, cds_idx) for domain in cds.hsps])
else:
a_domains.extend([(domain, cds_idx) for domain in cds.hsps[::-1]])

for cds_idx, cds in enumerate(pair.record_b.get_cds_with_domains()):
for domain in cds.hsps:
b_domains.append((domain, cds_idx))
b_cds = list(pair.record_b.get_cds_with_domains())

if pair.comparable_region.reverse:
b_cds = b_cds[::-1]

for cds_idx, cds in enumerate(b_cds):
if (cds.strand == 1 and not pair.comparable_region.reverse) or (
cds.strand == -1 and pair.comparable_region.reverse
):
b_domains.extend([(domain, cds_idx) for domain in cds.hsps])
else:
b_domains.extend([(domain, cds_idx) for domain in cds.hsps[::-1]])

# get the common domains
common_domains = set([a[0] for a in a_domains]).intersection(
Expand Down
71 changes: 70 additions & 1 deletion test/comparison/test_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_check_pass(self):
self.assertEqual(expected_result, actual_result)


class TestScoreExtend(unittest.TestCase):
class TestExtendLegacy(unittest.TestCase):
"""Tests for score extension"""

def test_get_target_indexes(self):
Expand Down Expand Up @@ -846,6 +846,8 @@ def test_extend_glocal_multi_domain(self):

self.assertTrue(all(conditions))


class TestExtendGreedy(unittest.TestCase):
def test_extend_greedy(self):
"""Tests greedy extension
Expand Down Expand Up @@ -883,6 +885,29 @@ def test_extend_greedy(self):

self.assertTrue(all(conditions))

def test_extend_greedy_rev(self):
"""Tests greedy extension on a reverse pair"""

a_cds, b_cds = generate_mock_cds_lists(11, 24, [1, 4, 9], [11, 13, 14], True)
record_a = generate_mock_region(a_cds)
record_b = generate_mock_region(b_cds)
pair = big_scape.comparison.record_pair.RecordPair(record_a, record_b)
pair.comparable_region.reverse = True

bs_comp.extend.extend_greedy(pair)
expected_greedy = bs_comp.ComparableRegion(1, 10, 11, 15, 1, 10, 11, 15, True)
conditions = [
pair.comparable_region == expected_greedy, # tests cds start/stops
pair.comparable_region.domain_a_start == expected_greedy.domain_a_start,
pair.comparable_region.domain_b_start == expected_greedy.domain_b_start,
pair.comparable_region.domain_a_stop == expected_greedy.domain_a_stop,
pair.comparable_region.domain_b_stop == expected_greedy.domain_b_stop,
]

self.assertTrue(all(conditions))


class TestExtendSimple(unittest.TestCase):
def test_match_extend(self):
"""Tests the new match extend implementation
Expand Down Expand Up @@ -974,3 +999,47 @@ def test_extend_simple_match_multi_domain(self):
pair.comparable_region.domain_b_stop,
expected_comparable_region.domain_b_stop,
)

def test_extend_simple_match_multi_domain_rev_stranded(self):
"""Tests simple match on multidomain cdss, reverse pair, strand-aware"""
# brackets indicate a cds with multiple domains
#
# vvvvv complementary strand
# A: [XX][XXXXB]X[A BC]D EX XXXX
# B: XXXXXXXXXXX X A[BC]E[DXXXX]X[XXXX]
#
a_cds, b_cds = generate_mock_cds_lists(
10, 18, [3, 3, 3, 4, 5], [12, 13, 13, 15, 14], True
)
a_cds[0].hsps.append(a_cds[0].hsps[0])
a_cds[1].hsps = [a_cds[3].hsps[1]] + [a_cds[0].hsps[0]] * 4
a_cds[1].strand = -1
b_cds[0].hsps.extend([b_cds[0].hsps[0]] * 3)
b_cds[2].hsps = [b_cds[0].hsps[0]] * 4 + b_cds[2].hsps
record_a = generate_mock_region(a_cds)
record_b = generate_mock_region(b_cds)
pair = big_scape.comparison.record_pair.RecordPair(record_a, record_b)
pair.comparable_region = bs_comp.ComparableRegion(
3, 4, 12, 14, 8, 11, 12, 15, True
)
bs_comp.extend.extend_simple_match(pair, 5, -2)
expected_comparable_region = bs_comp.ComparableRegion(
1, 6, 12, 16, 6, 13, 12, 17, True
)
self.assertEqual(pair.comparable_region, expected_comparable_region)
self.assertEqual(
pair.comparable_region.domain_a_start,
expected_comparable_region.domain_a_start,
)
self.assertEqual(
pair.comparable_region.domain_b_start,
expected_comparable_region.domain_b_start,
)
self.assertEqual(
pair.comparable_region.domain_a_stop,
expected_comparable_region.domain_a_stop,
)
self.assertEqual(
pair.comparable_region.domain_b_stop,
expected_comparable_region.domain_b_stop,
)

0 comments on commit a902b61

Please sign in to comment.