Skip to content

Commit

Permalink
Merge pull request #42 from cov-lineages/dev
Browse files Browse the repository at this point in the history
Handle multiple levels of definition within a constellation
  • Loading branch information
rmcolq authored Dec 15, 2021
2 parents fad0291 + f7cba5c commit 2ae9e99
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 38 deletions.
2 changes: 1 addition & 1 deletion scorpio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
_program = "scorpio"
__version__ = "0.3.15"
__version__ = "0.3.16"
90 changes: 61 additions & 29 deletions scorpio/scripts/type_constellations.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,10 @@ def parse_json_in(refseq, features_dict, variants_file, constellation_names=None
variant_list.append(record)

if "rules" in json_dict:
rules = json_dict["rules"]
if type(json_dict["rules"]) == dict and "default" in json_dict["rules"]:
rules = json_dict["rules"]
else:
rules = {"default": json_dict["rules"]}

in_json.close()
sorted_variants = sorted(variant_list, key=lambda x: int(x["ref_start"]))
Expand Down Expand Up @@ -346,9 +349,9 @@ def parse_csv_in(refseq, features_dict, variants_file, constellation_names=None,
csv_in.close()
rules = None
if len(compulsory) > 0:
rules = {}
rules = {"default": {}}
for var in compulsory:
rules[var] = "alt"
rules["default"][var] = "alt"
sorted_variants = sorted(variant_list, key=lambda x: int(x["ref_start"]))
return sorted_variants, name, rules

Expand Down Expand Up @@ -570,24 +573,24 @@ def var_follows_rules(call, rule):
else:
return call == rule_call

def counts_follow_rules(counts, rules):
def counts_follow_rules(counts, rules, key):
# rules allowed include "max_ref", "min_alt", "min_snp_alt"
is_rule_follower = True
notes = []
for rule in rules:
for rule in rules[key]:
if ":" in rule:
continue
elif str(rule).startswith("min") or str(rule).startswith("max"):
rule_parts = rule.split("_")
if len(rule_parts) <= 1:
continue
elif len(rule_parts) == 2:
if rule_parts[0] == "min" and counts[rule_parts[1]] < rules[rule]:
if rule_parts[0] == "min" and counts[rule_parts[1]] < rules[key][rule]:
is_rule_follower = False
elif rule_parts[0] == "max" and counts[rule_parts[1]] > rules[rule]:
elif rule_parts[0] == "max" and counts[rule_parts[1]] > rules[key][rule]:
is_rule_follower = False
else:
counts["rules"] += 1
counts["rules"][key] += 1
elif len(rule_parts) == 3:
part = None
if rule_parts[1] in ["substitution", "snp"]:
Expand All @@ -596,24 +599,27 @@ def counts_follow_rules(counts, rules):
part = "indel"
if not part:
is_rule_follower = False
elif rule_parts[0] == "min" and counts[part][rule_parts[2]] < rules[rule]:
elif rule_parts[0] == "min" and counts[part][rule_parts[2]] < rules[key][rule]:
is_rule_follower = False
notes.append("%s_%s_count=%i is less than %i" % (part, rule_parts[2], counts[part][rule_parts[2]], rules[rule]))
elif rule_parts[0] == "max" and counts[part][rule_parts[2]] > rules[rule]:
notes.append("%s_%s_count=%i is less than %i" % (part, rule_parts[2], counts[part][rule_parts[2]], rules[key][rule]))
elif rule_parts[0] == "max" and counts[part][rule_parts[2]] > rules[key][rule]:
is_rule_follower = False
notes.append("%s_%s_count=%i is more than %i" % (part, rule_parts[2], counts[part][rule_parts[2]], rules[rule]))
notes.append("%s_%s_count=%i is more than %i" % (part, rule_parts[2], counts[part][rule_parts[2]], rules[key][rule]))
else:
counts["rules"] += 1
counts["rules"][key] += 1
else:
logging.warning("Warning: Ignoring rule %s:%s" % (rule, str(rules[rule])))
logging.warning("Warning: Ignoring rule %s:%s" % (rule, str(rules[key][rule])))
return is_rule_follower, ";".join(notes)

def count_and_classify(record_seq, variant_list, rules):
assert rules is not None
counts = {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0, 'rules': 0,
counts = {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0, 'rules': {},
'substitution': {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0},
'indel': {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0}}
is_rule_follower = True
is_rule_follower_dict = {}
for key in rules:
is_rule_follower_dict[key] = True
counts["rules"][key] = 0

for var in variant_list:
call, query_allele = call_variant_from_fasta(record_seq, var)
Expand All @@ -623,20 +629,27 @@ def count_and_classify(record_seq, variant_list, rules):
counts["substitution"][call] += 1
elif var['type'] in ["ins", "del"]:
counts["indel"][call] += 1
if var["name"] in rules:
if var_follows_rules(call, rules[var["name"]]):
counts['rules'] += 1
elif is_rule_follower:
is_rule_follower = False
for key in rules:
if var["name"] in rules[key]:
if var_follows_rules(call, rules[key][var["name"]]):
counts['rules'][key] += 1
elif is_rule_follower_dict[key]:
is_rule_follower_dict[key] = False

counts['support'] = round(counts['alt']/float(counts['alt'] + counts['ref'] + counts['ambig'] + counts['oth']),4)
counts['conflict'] = round(counts['ref'] /float(counts['alt'] + counts['ref'] + counts['ambig'] + counts['oth']),4)

if not is_rule_follower:
return counts, False, ""
else:
call, note = counts_follow_rules(counts, rules)
return counts, call, note
for key in rules:
if not is_rule_follower_dict[key]:
continue
else:
call, note = counts_follow_rules(counts, rules, key)
if call:
counts["rules"] = counts["rules"][key]
call = key
return counts, call, note
counts["rules"] = counts["rules"]["default"]
return counts, False, ""


def generate_barcode(record_seq, variant_list, ref_char=None, ins_char="?", oth_char="X",constellation_count_dict=None):
Expand Down Expand Up @@ -920,7 +933,15 @@ def combine_counts_call_notes(counts1, call1, note1, counts2, call2, note2):
counts[key] = counts1[key] + counts2[key]
counts['support'] = round(counts['alt'] / float(counts['alt'] + counts['ref'] + counts['ambig'] + counts['oth']), 4)
counts['conflict'] = round(counts['ref'] / float(counts['alt'] + counts['ref'] + counts['ambig'] + counts['oth']), 4)
call = call1 and call2
if not call1 or not call2:
call = False
elif call1 == call2:
call = call1
elif call1 == "default":
call = call2
else:
call = call1

note = note1
if note != "" and note2 != "":
note += ";" + note2
Expand Down Expand Up @@ -989,10 +1010,12 @@ def classify_constellations(in_fasta, list_constellation_files, constellation_na
best_support = 0
best_conflict = 1
best_counts = None
best_call = False
scores = {}
children = {}
for constellation in constellation_dict:
constellation_name = name_dict[constellation]
logging.debug("Consider constellation %s" %constellation_name)
parents = []
if not constellation_name:
continue
Expand All @@ -1015,20 +1038,25 @@ def classify_constellations(in_fasta, list_constellation_files, constellation_na
children[parent].append(constellation)

if call:
logging.debug("Have call for %s" %constellation_name)
if call_all:
if call != "default":
constellation_name = "%s %s" %(call, constellation_name)
lineages.append(constellation_name)
names.append(constellation)
elif constellation in children and best_constellation in children[constellation]:
continue
logging.debug("Ignore as parent of best constellation")
elif (not best_constellation) \
or (counts['support'] > best_support) \
or (counts['support'] == best_support and counts['conflict'] < best_conflict)\
or (counts['support'] == best_support and counts['conflict'] == best_conflict and counts['rules'] > best_counts["rules"])\
or (best_constellation in parents):
best_constellation = constellation
logging.debug("Set best constellation %s" %best_constellation)
best_support = counts['support']
best_conflict = counts['conflict']
best_counts = counts
best_call = call

if interspersion:
if counts["alt"] > 1:
Expand All @@ -1042,7 +1070,11 @@ def classify_constellations(in_fasta, list_constellation_files, constellation_na
counts['oth'], counts['rules'], counts['support'],
counts['conflict'], call, constellation, note))
if not call_all and best_constellation:
lineages.append(name_dict[best_constellation])
if best_call != "default":
best_constellation_name = "%s %s" % (best_call, name_dict[best_constellation])
else:
best_constellation_name = name_dict[best_constellation]
lineages.append(best_constellation_name)
names.append(best_constellation)

out_entries = [record.id, "|".join(lineages), "|".join([mrca_lineage_dict[n] for n in names])]
Expand Down
16 changes: 8 additions & 8 deletions scorpio/tests/type_constellations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def test_parse_json_in():
assert len([v for v in variant_list if v["type"] == "del"]) == 3
assert len([v for v in variant_list if v["type"] == "aa"]) == 15
assert name == "Lineage_X"
assert rules["min_alt"] == 4
assert rules["max_ref"] == 6
assert rules["s:E484K"] == "alt"
assert rules["default"]["min_alt"] == 4
assert rules["default"]["max_ref"] == 6
assert rules["default"]["s:E484K"] == "alt"
assert mrca_lineage == "B.1.1.7"
assert incompatible_lineages == "A|B.1.351"

Expand All @@ -162,7 +162,7 @@ def test_parse_csv_in():
assert len([v for v in variant_list if v["type"] == "del"]) == 3
assert len([v for v in variant_list if v["type"] == "aa"]) == 15
assert name == "lineage_X"
assert rules["s:E484K"] == "alt"
assert rules["default"]["s:E484K"] == "alt"


def test_parse_textfile_in():
Expand All @@ -178,8 +178,8 @@ def test_parse_textfile_in():
def test_parse_variants_in():
in_files = ["%s/lineage_X.json" % data_dir, "%s/lineage_X.csv" % data_dir, "%s/lineage_X.txt" % data_dir]
expect_names = ["Lineage_X", "lineage_X", "lineage_X"]
rule_dict_json = {"min_alt": 4, "max_ref": 6, "s:E484K": "alt"}
rule_dict_csv = {"s:E484K": "alt"}
rule_dict_json = {"default": {"min_alt": 4, "max_ref": 6, "s:E484K": "alt"}}
rule_dict_csv = {"default": {"s:E484K": "alt"}}
rule_dict_txt = None
expect_rules = [rule_dict_json, rule_dict_csv, rule_dict_txt]

Expand Down Expand Up @@ -248,8 +248,8 @@ def test_count_and_classify():
oth_string = "gaaattcgcccgta-gctcgcaatag"
seqs = [Seq(ref_string), Seq(alt_string), Seq(alt_plus_string), Seq(oth_string)]

rules = {"min_alt": 1, "max_ref": 1, "snp2": "alt"}
expect_classify = [False, False, True, False]
rules = {"default": {"min_alt": 1, "max_ref": 1, "snp2": "alt"}}
expect_classify = [False, False, "default", False]
expect_counts = [{"ref": 5, "alt": 0, "ambig": 0, "oth": 1, "rules": 0, 'substitution': {'ref': 4, 'alt': 0, 'ambig': 0, 'oth': 0}, 'indel': {'ref': 1, 'alt': 0, 'ambig': 0, 'oth': 1}, "support": 0.0, "conflict": 0.8333},
{"ref": 1, "alt": 4, "ambig": 0, "oth": 1, "rules": 0, 'substitution': {'ref': 1, 'alt': 3, 'ambig': 0, 'oth': 0}, 'indel': {'ref': 0, 'alt': 1, 'ambig': 0, 'oth': 1}, "support": 0.6667, "conflict": 0.1667},
{"ref": 0, "alt": 5, "ambig": 0, "oth": 1, "rules": 3, 'substitution': {'ref': 0, 'alt': 4, 'ambig': 0, 'oth': 0}, 'indel': {'ref': 0, 'alt': 1, 'ambig': 0, 'oth': 1}, "support": 0.8333, "conflict": 0.0},
Expand Down

0 comments on commit 2ae9e99

Please sign in to comment.