diff --git a/scorpio/__init__.py b/scorpio/__init__.py index f53dd44..8bd5127 100644 --- a/scorpio/__init__.py +++ b/scorpio/__init__.py @@ -1,2 +1,2 @@ _program = "scorpio" -__version__ = "0.3.12" +__version__ = "0.3.13" diff --git a/scorpio/__main__.py b/scorpio/__main__.py index df5f9d1..cfe9804 100644 --- a/scorpio/__main__.py +++ b/scorpio/__main__.py @@ -109,6 +109,10 @@ def main(sysargs = sys.argv[1:]): "--append-genotypes", dest="append_genotypes", action="store_true", help="Output a column per variant with the call" ) + subparser_haplotype.add_argument( + "--combination", dest="combination", action="store_true", + help="Combines the mutations for the specified constellations, and outputs a string across them all, with counts per found constellation" + ) subparser_haplotype.set_defaults(func=scorpio.subcommands.haplotype.run) # _______________________________ report __________________________________# @@ -154,6 +158,10 @@ def main(sysargs = sys.argv[1:]): '--outgroups', dest='outgroups', required=False, help='Two column CSV with group, and pipe separated list of outgroup sequence_names for that list. ' 'Assumes outgroups will be in main input CSV') + subparser_define.add_argument( + "--protein", dest="protein", action="store_true", + help="Translates definition coordinates to proteins where possible" + ) subparser_define.set_defaults(func=scorpio.subcommands.define.run) @@ -243,7 +251,7 @@ def main(sysargs = sys.argv[1:]): if not args.reference_json: args.reference_json = reference_json logging.info("Found reference %s" %args.reference_json) - if not args.constellations: + if not args.constellations and args.command in ['haplotype', 'classify']: args.constellations = list_constellation_files logging.info("Found constellations:") for c in args.constellations: diff --git a/scorpio/scripts/extract_definitions.py b/scorpio/scripts/extract_definitions.py index d8e3d65..0977f09 100755 --- a/scorpio/scripts/extract_definitions.py +++ b/scorpio/scripts/extract_definitions.py @@ -9,7 +9,7 @@ from Bio.Seq import Seq from operator import itemgetter -from .type_constellations import load_feature_coordinates +from .type_constellations import load_feature_coordinates, resolve_ambiguous_cds def parse_args(): parser = argparse.ArgumentParser(description="""Pick a representative sample for each unique sequence""", @@ -90,6 +90,15 @@ def update_var_dict(var_dict, group, variants): return +def update_feature_dict(feature_dict): + for feature in feature_dict: + if len(feature_dict[feature]) > 2: + cds, aa_pos = resolve_ambiguous_cds(feature_dict[feature][2], feature_dict[feature][0], feature_dict) + if aa_pos: + feature_dict[feature] = (aa_pos, feature_dict[feature][1] + feature_dict[feature][0] - aa_pos, cds) + return feature_dict + + def get_common_mutations(var_dict, min_occurance=3, threshold_common=0.98, threshold_intermediate=0.25): sorted_tuples = sorted(var_dict.items(), key=operator.itemgetter(1)) var_dict = {k: v for k, v in sorted_tuples} @@ -110,7 +119,7 @@ def get_common_mutations(var_dict, min_occurance=3, threshold_common=0.98, thres return common, intermediate -def translate_if_possible(nuc_start, nuc_ref, nuc_alt, feature_dict, reference_seq): +def translate_if_possible(nuc_start, nuc_ref, nuc_alt, feature_dict, reference_seq, include_protein=False): nuc_end = nuc_start + len(nuc_ref) nuc_start = int(nuc_start) nuc_end = int(nuc_end) @@ -138,12 +147,26 @@ def translate_if_possible(nuc_start, nuc_ref, nuc_alt, feature_dict, reference_s if ref_allele == query_allele: return "nuc:%s%i%s" % (nuc_ref, nuc_start, nuc_alt) aa_pos = int((start - feature_dict[feature][0]) / 3) + 1 + if include_protein: + feature, aa_pos = translate_to_protein_if_possible(feature, aa_pos, feature_dict) #print(start, end, ref_allele, query_allele, aa_pos, feature) return "%s:%s%i%s" % (feature, ref_allele, aa_pos, query_allele) return "nuc:%s%i%s" % (nuc_ref, nuc_start, nuc_alt) -def define_mutations(list_variants, feature_dict, reference_seq): +def translate_to_protein_if_possible(cds, aa_start, feature_dict): + if not cds.startswith("orf"): + return cds, aa_start + + for feature in feature_dict: + if len(feature_dict[feature]) < 3: + continue # only want nsp definitions + if feature_dict[feature][2] == cds: + if feature_dict[feature][0] <= aa_start <= feature_dict[feature][1]: + return feature, aa_start-feature_dict[feature][0]+1 + return cds, aa_start + +def define_mutations(list_variants, feature_dict, reference_seq, include_protein=False): merged_list = [] if not list_variants: return merged_list @@ -184,7 +207,7 @@ def define_mutations(list_variants, feature_dict, reference_seq): elif new[3]: current[3] = new[3] elif current[0] != "": - var = translate_if_possible(current[1], current[0], current[2], feature_dict, reference_seq) + var = translate_if_possible(current[1], current[0], current[2], feature_dict, reference_seq, include_protein) if freq: merged_list.append("%s:%s" % (var, freq)) else: @@ -193,7 +216,7 @@ def define_mutations(list_variants, feature_dict, reference_seq): else: current = new if current[0] != "": - var = translate_if_possible(current[1], current[0], current[2], feature_dict, reference_seq) + var = translate_if_possible(current[1], current[0], current[2], feature_dict, reference_seq, include_protein) if freq: merged_list.append("%s:%s" % (var, freq)) else: @@ -214,17 +237,15 @@ def subtract_outgroup(common, outgroup_common): def write_constellation(prefix, group, list_variants, list_intermediates, list_ancestral): group_dict = {"name": group, "sites": list_variants, "intermediate": list_intermediates, - "rules": {"min_alt": int((len(list_variants) + 1) / 4), "max_ref": int((len(list_variants) - 1) / 4)}} + "rules": {"min_alt": max(len(list_variants) - 3, min(len(list_variants), 3)), "max_ref": 3}} if list_ancestral: group_dict["ancestral"] = list_ancestral - group_dict["rules"]["min_alt"] += int((len(list_ancestral)+1)/4) - group_dict["rules"]["max_ref"] += int((len(list_ancestral)-1)/4) with open('%s/%s.json' % (prefix, group), 'w') as outfile: json.dump(group_dict, outfile, indent=4) def extract_definitions(in_variants, in_groups, group_column, index_column, reference_json, prefix, subset, - threshold_common, threshold_intermediate, outgroup_file): + threshold_common, threshold_intermediate, outgroup_file, include_protein): if not in_groups: in_groups = in_variants @@ -239,6 +260,7 @@ def extract_definitions(in_variants, in_groups, group_column, index_column, refe group_dict = get_group_dict(in_groups, group_column, index_column, groups_to_get) reference_seq, feature_dict = load_feature_coordinates(reference_json) + feature_dict = update_feature_dict(feature_dict) var_dict = {} outgroup_var_dict = {} @@ -283,9 +305,9 @@ def extract_definitions(in_variants, in_groups, group_column, index_column, refe if group in outgroup_var_dict: outgroup_common, outgroup_intermediate = get_common_mutations(outgroup_var_dict[group], min_occurance=1, threshold_common=threshold_common, threshold_intermediate=threshold_intermediate) common, ancestral = subtract_outgroup(common, outgroup_common) - nice_common = define_mutations(common, feature_dict, reference_seq) - nice_intermediate = define_mutations(intermediate, feature_dict, reference_seq) - nice_ancestral = define_mutations(ancestral, feature_dict, reference_seq) + nice_common = define_mutations(common, feature_dict, reference_seq, include_protein) + nice_intermediate = define_mutations(intermediate, feature_dict, reference_seq, include_protein) + nice_ancestral = define_mutations(ancestral, feature_dict, reference_seq, include_protein) write_constellation(prefix, group, nice_common, nice_intermediate, nice_ancestral) diff --git a/scorpio/scripts/type_constellations.py b/scorpio/scripts/type_constellations.py index 01a4da0..6bce390 100755 --- a/scorpio/scripts/type_constellations.py +++ b/scorpio/scripts/type_constellations.py @@ -8,6 +8,7 @@ import json import re import logging +import math if sys.version_info[0] < 3: raise Exception("Python 3 or a more recent version is required.") @@ -117,7 +118,7 @@ def get_nuc_position_from_aa_description(cds, aa_pos, features_dict): return int(nuc_pos) -def variant_to_variant_record(l, refseq, features_dict): +def variant_to_variant_record(l, refseq, features_dict, ignore_fails=False): """ convert a variant in one of the following formats @@ -137,8 +138,9 @@ def variant_to_variant_record(l, refseq, features_dict): if "+" in l: m = re.match(r'[aa:]*(?P\w+):(?P\d+)\+(?P[a-zA-Z]+)', l) if not m: - sys.stderr.write("Warning: couldn't parse the following string: %s - ignoring\n" % l) - sys.exit(1) + sys.stderr.write("Warning: couldn't parse the following string: %s\n" % l) + if not ignore_fails: + sys.exit(1) info = m.groupdict() info["type"] = "ins" info["ref_allele"] = "" @@ -154,8 +156,9 @@ def variant_to_variant_record(l, refseq, features_dict): info = {"name": l, "type": "snp"} m = re.match(r'(?P[ACGTUN]+)(?P\d+)(?P[AGCTUN]*)', l[4:]) if not m: - sys.stderr.write("Warning: couldn't parse the following string: %s - ignoring\n" % l) - sys.exit(1) + sys.stderr.write("Warning: couldn't parse the following string: %s\n" % l) + if not ignore_fails: + sys.exit(1) info.update(m.groupdict()) info["ref_start"] = int(info["ref_start"]) ref_allele_check = refseq[info["ref_start"] - 1] @@ -164,7 +167,8 @@ def variant_to_variant_record(l, refseq, features_dict): sys.stderr.write( "variants file says reference nucleotide at position %d is %s, but reference sequence has %s, " "context %s\n" % (info["ref_start"], info["ref_allele"], ref_allele_check, refseq[info["ref_start"] - 4:info["ref_start"] + 3])) - sys.exit(1) + if not ignore_fails: + sys.exit(1) elif lsplit[0] == "del": length = int(lsplit[2]) @@ -174,8 +178,9 @@ def variant_to_variant_record(l, refseq, features_dict): else: m = re.match(r'[aa:]*(?P\w+):(?P[a-zA-Z-*]+)(?P\d+)(?P[a-zA-Z-*]*)', l) if not m: - sys.stderr.write("Warning: couldn't parse the following string: %s - ignoring\n" % l) - sys.exit(1) + sys.stderr.write("Warning: couldn't parse the following string: %s\n" % l) + if not ignore_fails: + sys.exit(1) return info info = m.groupdict() @@ -191,7 +196,8 @@ def variant_to_variant_record(l, refseq, features_dict): if info["ref_allele"] != '?' and info["ref_allele"] != ref_allele_check: sys.stderr.write("variants file says reference amino acid in CDS %s at position %d is %s, but " "reference sequence has %s\n" % (cds, aa_pos, info["ref_allele"], ref_allele_check)) - sys.exit(1) + if not ignore_fails: + sys.exit(1) info["cds"] = cds info["aa_pos"] = aa_pos @@ -217,15 +223,18 @@ def parse_name_from_file(constellation_file): return name -def parse_json_in(refseq, features_dict, variants_file, constellation_names=None, include_ancestral=False, label=None): +def parse_json_in(refseq, features_dict, variants_file, constellation_names=None, include_ancestral=False, label=None, ignore_fails=False): """ returns variant_list name and rules """ variant_list = [] name = None + output_name = None rules = None mrca_lineage = "" incompatible_lineage_calls = "" + parent_lineage = None + lineage_name = None in_json = open(variants_file, 'r') json_dict = json.load(in_json, strict=False) @@ -239,34 +248,43 @@ def parse_json_in(refseq, features_dict, variants_file, constellation_names=None mrca_lineage = json_dict[json_dict["type"]]["mrca_lineage"] if "incompatible_lineage_calls" in json_dict[json_dict["type"]]: incompatible_lineage_calls = "|".join(json_dict[json_dict["type"]]["incompatible_lineage_calls"]) + if "parent_lineage" in json_dict[json_dict["type"]]: + parent_lineage = json_dict[json_dict["type"]]["parent_lineage"] + if "lineage_name" in json_dict[json_dict["type"]]: + lineage_name = json_dict[json_dict["type"]]["lineage_name"] - - if label: - if "type" in json_dict and json_dict["type"] in json_dict and label in json_dict[json_dict["type"]]: - name = json_dict[json_dict["type"]][label] + if "name" in json_dict: + name = json_dict["name"] elif "label" in json_dict: name = json_dict["label"] - elif "name" in json_dict: - name = json_dict["name"] else: name = parse_name_from_file(variants_file) + if label: + if "type" in json_dict and json_dict["type"] in json_dict and label in json_dict[json_dict["type"]]: + output_name = json_dict[json_dict["type"]][label] + elif "label" in json_dict: + output_name = json_dict["label"] + elif name: + output_name = name + + if not name: - return variant_list, name, rules, mrca_lineage, incompatible_lineage_calls - if constellation_names and name not in constellation_names: - return variant_list, name, rules, mrca_lineage, incompatible_lineage_calls + return variant_list, name, output_name, rules, mrca_lineage, incompatible_lineage_calls, parent_lineage, lineage_name + if constellation_names and name not in constellation_names and output_name not in constellation_names: + return variant_list, name, output_name, rules, mrca_lineage, incompatible_lineage_calls, parent_lineage, lineage_name logging.info("\n") logging.info("Parsing constellation JSON file %s" % variants_file) if "sites" in json_dict: for site in json_dict["sites"]: - record = variant_to_variant_record(site, refseq, features_dict) + record = variant_to_variant_record(site, refseq, features_dict, ignore_fails=ignore_fails) if record != {}: variant_list.append(record) if include_ancestral and "ancestral" in json_dict: for site in json_dict["ancestral"]: - record = variant_to_variant_record(site, refseq, features_dict) + record = variant_to_variant_record(site, refseq, features_dict, ignore_fails=ignore_fails) if record != {}: variant_list.append(record) @@ -274,11 +292,14 @@ def parse_json_in(refseq, features_dict, variants_file, constellation_names=None rules = json_dict["rules"] in_json.close() + sorted_variants = sorted(variant_list, key=lambda x: int(x["ref_start"])) + for var in sorted_variants: + print(var) - return variant_list, name, rules, mrca_lineage, incompatible_lineage_calls + return sorted_variants, name, output_name, rules, mrca_lineage, incompatible_lineage_calls, parent_lineage, lineage_name -def parse_csv_in(refseq, features_dict, variants_file, constellation_names=None): +def parse_csv_in(refseq, features_dict, variants_file, constellation_names=None, ignore_fails=False): """ returns variant_list and name """ @@ -310,7 +331,7 @@ def parse_csv_in(refseq, features_dict, variants_file, constellation_names=None) var = "%s:%s" % (row["gene"], row["id"]) else: var = row["id"] - record = variant_to_variant_record(var, refseq, features_dict) + record = variant_to_variant_record(var, refseq, features_dict, ignore_fails=ignore_fails) if record != {}: variant_list.append(record) if "compulsory" in reader.fieldnames and row["compulsory"] in ["True", True, "Y", "y", "Yes", "yes", "YES"]: @@ -322,10 +343,11 @@ def parse_csv_in(refseq, features_dict, variants_file, constellation_names=None) rules = {} for var in compulsory: rules[var] = "alt" - return variant_list, name, rules + sorted_variants = sorted(variant_list, key=lambda x: int(x["ref_start"])) + return sorted_variants, name, rules -def parse_textfile_in(refseq, features_dict, variants_file, constellation_names=None): +def parse_textfile_in(refseq, features_dict, variants_file, constellation_names=None, ignore_fails=False): """ returns variant_list and name """ @@ -334,6 +356,7 @@ def parse_textfile_in(refseq, features_dict, variants_file, constellation_names= name = parse_name_from_file(variants_file) if constellation_names and name not in constellation_names: return variant_list, name + logging.info("\n") logging.info("Parsing constellation text file %s" % variants_file) @@ -341,14 +364,15 @@ def parse_textfile_in(refseq, features_dict, variants_file, constellation_names= for line in f: l = line.split("#")[0].strip() # remove comments from the line if len(l) > 0: # skip blank lines (or comment only lines) - record = variant_to_variant_record(l, refseq, features_dict) + record = variant_to_variant_record(l, refseq, features_dict, ignore_fails=ignore_fails) if record != {}: variant_list.append(record) + sorted_variants = sorted(variant_list, key=lambda x: int(x["ref_start"])) - return variant_list, name + return sorted_variants, name -def parse_variants_in(refseq, features_dict, variants_file, constellation_names=None, include_ancestral=False, label=None): +def parse_variants_in(refseq, features_dict, variants_file, constellation_names=None, include_ancestral=False, label=None, ignore_fails=False): """ read in a variants file and parse its contents and return something sensible. @@ -368,18 +392,23 @@ def parse_variants_in(refseq, features_dict, variants_file, constellation_names= """ variant_list = [] rule_dict = None + output_name = None mrca_lineage = "" incompatible_lineage_calls = "" + parent_lineage = None + lineage_name = None if variants_file.endswith(".json"): - variant_list, name, rule_dict, mrca_lineage, incompatible_lineage_calls = parse_json_in(refseq, features_dict, variants_file, constellation_names, include_ancestral=include_ancestral,label=label) + variant_list, name, output_name, rule_dict, mrca_lineage, incompatible_lineage_calls, parent_lineage, lineage_name = parse_json_in(refseq, features_dict, variants_file, constellation_names, include_ancestral=include_ancestral,label=label,ignore_fails=ignore_fails) elif variants_file.endswith(".csv"): - variant_list, name, rule_dict = parse_csv_in(refseq, features_dict, variants_file, constellation_names) + variant_list, name, rule_dict = parse_csv_in(refseq, features_dict, variants_file, constellation_names, ignore_fails=ignore_fails) + output_name = name if len(variant_list) == 0 and not variants_file.endswith(".json"): - variant_list, name = parse_textfile_in(refseq, features_dict, variants_file, constellation_names) + variant_list, name = parse_textfile_in(refseq, features_dict, variants_file, constellation_names, ignore_fails=ignore_fails) + output_name = name - return name, variant_list, rule_dict, mrca_lineage, incompatible_lineage_calls + return name, output_name, variant_list, rule_dict, mrca_lineage, incompatible_lineage_calls, parent_lineage, lineage_name def parse_mutations_in(mutations_file): @@ -398,7 +427,7 @@ def parse_mutations_in(mutations_file): return mutations_list -def parse_mutations(refseq, features_dict, mutations_list): +def parse_mutations(refseq, features_dict, mutations_list, ignore_fails=False): """ Parse the mutations specified on command line and make a mutations constellation for them @@ -409,7 +438,7 @@ def parse_mutations(refseq, features_dict, mutations_list): problematic = [] for mutation in mutations_list: - record = variant_to_variant_record(mutation, refseq, features_dict) + record = variant_to_variant_record(mutation, refseq, features_dict, ignore_fails=ignore_fails) if record != {}: variant_list.append(record) else: @@ -536,8 +565,9 @@ def var_follows_rules(call, rule): return call == rule_call def counts_follow_rules(counts, rules): - # rules allowed include "max_ref", "min_alt" + # rules allowed include "max_ref", "min_alt", "min_snp_alt" is_rule_follower = True + notes = [] for rule in rules: if ":" in rule: continue @@ -545,25 +575,48 @@ def counts_follow_rules(counts, rules): rule_parts = rule.split("_") if len(rule_parts) <= 1: continue - if rule_parts[0] == "min" and counts[rule_parts[1]] < rules[rule]: - is_rule_follower = False - elif rule_parts[0] == "max" and counts[rule_parts[1]] > rules[rule]: - is_rule_follower = False - else: - counts["rules"] += 1 + elif len(rule_parts) == 2: + if rule_parts[0] == "min" and counts[rule_parts[1]] < rules[rule]: + is_rule_follower = False + elif rule_parts[0] == "max" and counts[rule_parts[1]] > rules[rule]: + is_rule_follower = False + else: + counts["rules"] += 1 + elif len(rule_parts) == 3: + part = None + if rule_parts[1] in ["substitution", "snp"]: + part = "substitution" + elif rule_parts[1] in ["indel"]: + part = "indel" + if not part: + is_rule_follower = False + elif rule_parts[0] == "min" and counts[part][rule_parts[2]] < rules[rule]: + is_rule_follower = False + notes.append("%s_%s_count=%i is less than %i" % (part, rule_parts[2], counts[part][rule_parts[2]], rules[rule])) + elif rule_parts[0] == "max" and counts[part][rule_parts[2]] > rules[rule]: + is_rule_follower = False + notes.append("%s_%s_count=%i is more than %i" % (part, rule_parts[2], counts[part][rule_parts[2]], rules[rule])) + else: + counts["rules"] += 1 else: logging.warning("Warning: Ignoring rule %s:%s" % (rule, str(rules[rule]))) - return is_rule_follower + return is_rule_follower, ";".join(notes) def count_and_classify(record_seq, variant_list, rules): assert rules is not None - counts = {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0, 'rules':0} + counts = {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0, 'rules': 0, + 'substitution': {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0}, + 'indel': {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0}} is_rule_follower = True for var in variant_list: call, query_allele = call_variant_from_fasta(record_seq, var) #print(var, call, query_allele) counts[call] += 1 + if var['type'] in ["aa", "snp"]: + counts["substitution"][call] += 1 + elif var['type'] in ["ins", "del"]: + counts["indel"][call] += 1 if var["name"] in rules: if var_follows_rules(call, rules[var["name"]]): counts['rules'] += 1 @@ -574,19 +627,27 @@ def count_and_classify(record_seq, variant_list, rules): counts['conflict'] = round(counts['ref'] /float(counts['alt'] + counts['ref'] + counts['ambig'] + counts['oth']),4) if not is_rule_follower: - return counts, False + return counts, False, "" else: - return counts, counts_follow_rules(counts, rules) + call, note = counts_follow_rules(counts, rules) + return counts, call, note -def generate_barcode(record_seq, variant_list, ref_char=None, ins_char="?", oth_char="X"): +def generate_barcode(record_seq, variant_list, ref_char=None, ins_char="?", oth_char="X",constellation_count_dict=None): barcode_list = [] counts = {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0} + sorted_alt_sites = [] for var in variant_list: call, query_allele = call_variant_from_fasta(record_seq, var, ins_char, oth_char) # print(var, call, query_allele) counts[call] += 1 + if constellation_count_dict and "constellations" in var: + for constellation in var["constellations"]: + constellation_count_dict[constellation][call] += 1 + if call == "alt": + sorted_alt_sites.append(var["constellations"]) + if ref_char is not None and call == 'ref': barcode_list.append(str(ref_char)) else: @@ -595,48 +656,187 @@ def generate_barcode(record_seq, variant_list, ref_char=None, ins_char="?", oth_ counts['support'] = round(counts['alt'] / float(counts['alt'] + counts['ref'] + counts['ambig'] + counts['oth']), 4) counts['conflict'] = round(counts['ref'] / float(counts['alt'] + counts['ref'] + counts['ambig'] + counts['oth']), 4) - return barcode_list, counts + return barcode_list, counts, constellation_count_dict, sorted_alt_sites -def type_constellations(in_fasta, list_constellation_files, constellation_names, out_csv, reference_json, ref_char=None, - output_counts=False, label=None, append_genotypes=False, mutations_list=None, dry_run=False): - reference_seq, features_dict = load_feature_coordinates(reference_json) - +def load_constellations(list_constellation_files, constellation_names, reference_seq, features_dict, label, + include_ancestral=True, rules_required=False, ignore_fails=False): constellation_dict = {} + name_dict = {} + rule_dict = {} + mrca_lineage_dict = {} + incompatible_dict = {} + parent_lineage_dict = {} + lineage_name_dict = {} for constellation_file in list_constellation_files: - constellation, variants, ignore, mrca_lineage, incompatible_lineage_calls = parse_variants_in(reference_seq, features_dict, constellation_file, constellation_names, label=label) + constellation, output_name, variants, rules, mrca_lineage, \ + incompatible_lineage_calls, parent_lineage, lineage_name = \ + parse_variants_in(reference_seq, + features_dict, + constellation_file, + constellation_names, + include_ancestral=include_ancestral, + label=label, + ignore_fails=ignore_fails) if not constellation: continue - if constellation_names and constellation not in constellation_names: + if constellation_names and constellation not in constellation_names and output_name not in constellation_names: + continue + else: + name_dict[constellation] = output_name + if rules_required and not rules: + logging.warning("Warning: No rules provided to classify %s - ignoring" % constellation) continue + else: + rule_dict[constellation] = rules if len(variants) > 0: constellation_dict[constellation] = variants - logging.info("Found file %s for constellation %s containing %i variants" % ( - constellation_file, constellation, len([v["name"] for v in variants]))) + logging.info("Found file %s for constellation %s containing %i defining mutations" % ( + constellation_file, constellation, len([v["name"] for v in variants]))) + if rules_required: + logging.info("Rules %s" % rule_dict[constellation]) + mrca_lineage_dict[constellation] = mrca_lineage + incompatible_dict[constellation] = incompatible_lineage_calls + if parent_lineage: + parent_lineage_dict[constellation] = parent_lineage + if lineage_name: + lineage_name_dict[lineage_name] = constellation else: logging.warning("Warning: %s is not a valid constellation file - ignoring" % constellation_file) - if mutations_list: - new_mutations_list = [] - for entry in mutations_list: - if '.' in entry: - new_mutations_list.extend(parse_mutations_in(entry)) # this is a file + + return constellation_dict, name_dict, rule_dict, mrca_lineage_dict, incompatible_dict, parent_lineage_dict, lineage_name_dict + + +def parse_mutations_list(mutations_list, reference_seq, features_dict): + new_mutations_list = [] + for entry in mutations_list: + if '.' in entry: + new_mutations_list.extend(parse_mutations_in(entry)) # this is a file + else: + new_mutations_list.append(entry) + mutations_list = new_mutations_list + mutation_variants = parse_mutations(reference_seq, features_dict, mutations_list) + return mutations_list, mutation_variants + + +def combine_constellations(constellation_dict): + variant_dict = {} + constellation_count_dict = {} + for constellation in constellation_dict: + constellation_count_dict[constellation] = {"total": len(constellation_dict[constellation]), 'ref': 0, 'alt': 0, + 'ambig': 0, 'oth': 0} + for variant in constellation_dict[constellation]: + if variant["name"] in variant_dict: + variant_dict[variant["name"]]["constellations"].append(constellation) else: - new_mutations_list.append(entry) - mutations_list = new_mutations_list - logging.info("Typing provided mutations %s" % ",".join(mutations_list)) - mutation_variants = parse_mutations(reference_seq, features_dict, mutations_list) + variant_dict[variant["name"]] = variant + variant_dict[variant["name"]]["constellations"] = [constellation] + sorted_variants = sorted(variant_dict.values(), key=lambda x: int(x["ref_start"])) + return {"union": sorted_variants}, constellation_count_dict + + +def combine_constellations_by_name(constellation_dict, name_dict): + new_constellation_dict = {} + for constellation in constellation_dict: + constellation_name = name_dict[constellation] + if not constellation_name: + continue + if constellation_name not in new_constellation_dict: + new_constellation_dict[constellation_name] = constellation_dict[constellation] + else: + for variant in constellation_dict[constellation]: + if variant not in new_constellation_dict[constellation_name]: + new_constellation_dict[constellation_name].append(variant) + return new_constellation_dict + +def get_number_switches(barcode_list, ref_char="-", ambig_char = "X"): + previous = None + current = None + number_switches = 0 + len_non_ambig_barcode = 0 + for letter in barcode_list: + if letter == ref_char: + current = 1 + elif letter == ambig_char: + continue + else: + current = 0 + + len_non_ambig_barcode += 1 + + if previous == None: + previous = current + continue + + if current != previous: + number_switches += 1 + previous = current + + if number_switches > 1: + number_switches -= 1 + + return number_switches,len_non_ambig_barcode + +#def prob_number_errors_or_fewer(counts): +# # H0: missing is due to errors +# # H1: missing is due to recombination +# prob = counts["substitution"]["alt"]^prob_snp_error + prob_del_error * counts["indel"]["alt"] + +#def prob_number_switches_or_fewer(number_switches, len_non_ambig_barcode): +# # H0: sample is random draws from mixture +# # H1: sample is a due to recombination +# return + + +#def get_interspersion(top_scoring, counts, barcode_list, ref_char="-", ambig_char = "X"): +# number_switches, len_non_ambig_barcode = get_number_switches(barcode_list, ref_char="-", ambig_char = "X" +# # H0: missing is due to sequencing errors +# # H1: sample is a random mixture of +# return + + +def type_constellations(in_fasta, list_constellation_files, constellation_names, out_csv, reference_json, ref_char=None, + output_counts=False, label=None, append_genotypes=False, mutations_list=None, dry_run=False, + combination=False, interspersion=False): + reference_seq, features_dict = load_feature_coordinates(reference_json) + + constellation_dict, name_dict, rule_dict, mrca_lineage_dict, \ + incompatible_dict, parent_lineage_dict, lineage_name_dict = load_constellations(list_constellation_files, + constellation_names, + reference_seq, + features_dict, + label, + include_ancestral=False, + rules_required=False) + if mutations_list: + mutations_list, mutation_variants = parse_mutations_list(mutations_list, reference_seq, features_dict) if len(constellation_dict) == 1 and "mutations" not in constellation_dict: constellation = list(constellation_dict)[0] new_constellation = "%s+%s" %(constellation, '|'.join(mutations_list)) constellation_dict[new_constellation] = constellation_dict[constellation] + mutation_variants del constellation_dict[constellation] - else: constellation_dict["mutations"] = mutation_variants + name_dict["mutations"] = "mutations" if dry_run: return + if combination: + constellation_dict, constellation_count_dict = combine_constellations(constellation_dict) + name_dict["union"] = "union" + else: + constellation_count_dict = None + + logging.info("\n") + logging.info("Update constellation dict") + constellation_dict = combine_constellations_by_name(constellation_dict, name_dict) + logging.debug(constellation_dict) + logging.info("Have %i constellations to type: %s" % (len(constellation_dict), list(constellation_dict.keys()))) + if constellation_count_dict: + logging.info("Have %i candidate constellations to collect counts: %s" % (len(constellation_count_dict), + list(constellation_count_dict.keys()))) + variants_out = None if len(constellation_dict) > 1 or not (output_counts or append_genotypes): variants_out = open(out_csv, "w") @@ -655,6 +855,8 @@ def type_constellations(in_fasta, list_constellation_files, constellation_names, columns.append("ref_count,alt_count,ambig_count,other_count,support,conflict") if append_genotypes: columns.extend([var["name"] for var in constellation_dict[constellation]]) + if combination: + columns.append("notes") counts_out[constellation].write("%s\n" % ','.join(columns)) with open(in_fasta, "r") as f: @@ -666,7 +868,7 @@ def type_constellations(in_fasta, list_constellation_files, constellation_names, out_list = [record.id] for constellation in constellation_dict: - barcode_list, counts = generate_barcode(record.seq, constellation_dict[constellation], ref_char) + barcode_list, counts, constellation_count_dict, sorted_alt_sites = generate_barcode(record.seq, constellation_dict[constellation], ref_char, constellation_count_dict=constellation_count_dict) if output_counts or append_genotypes: columns = [record.id] if output_counts: @@ -674,6 +876,20 @@ def type_constellations(in_fasta, list_constellation_files, constellation_names, counts['ambig'], counts['oth'], counts['support'], counts['conflict'])) if append_genotypes: columns.extend(barcode_list) + if combination: + scores = {} + for candidate in constellation_count_dict: + if constellation_count_dict[candidate]["alt"] > 0: + summary = "%s:%i|%i|%i|%i" % ( + candidate, constellation_count_dict[candidate]["ref"], + constellation_count_dict[candidate]["alt"], + constellation_count_dict[candidate]["ambig"], + constellation_count_dict[candidate]["oth"]) + score = float(constellation_count_dict[candidate]["alt"]) / \ + constellation_count_dict[candidate]["total"] + scores[score] = summary + sorted_scores = sorted(scores, key=lambda x: float(x), reverse=True) + columns.append("; ".join([scores[score] for score in sorted_scores])) counts_out[constellation].write("%s\n" % ','.join(columns)) out_list.append(''.join(barcode_list)) if variants_out: @@ -685,66 +901,74 @@ def type_constellations(in_fasta, list_constellation_files, constellation_names, if counts_out[constellation]: counts_out[constellation].close() +def combine_counts_call_notes(counts1, call1, note1, counts2, call2, note2): + counts = {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0, 'rules': 0, + 'substitution': {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0}, + 'indel': {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0}, + 'support': 0, 'conflict': 0} + for key in counts: + if key in ["substitution", "indel"]: + for subkey in counts[key]: + counts[key][subkey] = counts1[key][subkey] + counts2[key][subkey] + else: + counts[key] = counts1[key] + counts2[key] + counts['support'] = round(counts['alt'] / float(counts['alt'] + counts['ref'] + counts['ambig'] + counts['oth']), 4) + counts['conflict'] = round(counts['ref'] / float(counts['alt'] + counts['ref'] + counts['ambig'] + counts['oth']), 4) + call = call1 and call2 + note = note1 + if note != "" and note2 != "": + note += ";" + note2 + else: + note = note2 + return counts, call, note + def classify_constellations(in_fasta, list_constellation_files, constellation_names, out_csv, reference_json, - output_counts=False, call_all=False, long=False, label=None, list_incompatible=False, mutations_list=None, dry_run=False): + output_counts=False, call_all=False, long=False, label=None, list_incompatible=False, + mutations_list=None, dry_run=False, interspersion=False): reference_seq, features_dict = load_feature_coordinates(reference_json) - constellation_dict = {} - rule_dict = {} - mrca_lineage_dict = {} - incompatible_dict = {} - for constellation_file in list_constellation_files: - constellation, variants, rules, mrca_lineage, incompatible_lineage_calls = parse_variants_in(reference_seq, features_dict, constellation_file, - constellation_names, include_ancestral=True, label=label) - if constellation_names and constellation not in constellation_names: - continue - if not rules: - logging.warning("Warning: No rules provided to classify %s - ignoring" % constellation) - continue - else: - rule_dict[constellation] = rules - if len(variants) > 0: - constellation_dict[constellation] = variants - logging.info("Found file %s for constellation %s containing %i variants" % ( - constellation_file, constellation, len([v["name"] for v in variants]))) - logging.info("Rules %s" %rule_dict[constellation]) - mrca_lineage_dict[constellation] = mrca_lineage - incompatible_dict[constellation] = incompatible_lineage_calls - else: - logging.warning("Warning: %s is not a valid constellation file - ignoring" % constellation_file) + constellation_dict, name_dict, rule_dict, mrca_lineage_dict, \ + incompatible_dict, parent_lineage_dict, lineage_name_dict = load_constellations(list_constellation_files, + constellation_names, + reference_seq, + features_dict, + label, + include_ancestral=True, + rules_required=True, + ignore_fails=True) - if mutations_list: - new_mutations_list = [] - for entry in mutations_list: - if '.' in entry: - new_mutations_list.extend(parse_mutations_in(entry)) # this is a file - else: - new_mutations_list.append(entry) - mutations_list = new_mutations_list - mutation_variants = parse_mutations(reference_seq, features_dict, mutations_list) + logging.debug("parent_dict: %s" %parent_lineage_dict) + logging.debug("lineage_name_dict: %s" %lineage_name_dict) + if mutations_list: + mutations_list, mutation_variants = parse_mutations_list(mutations_list, reference_seq, features_dict) if dry_run: return variants_out = open(out_csv, "w") - columns = ["query","constellations","mrca_lineage"] + columns = ["query", "constellations", "mrca_lineage"] if list_incompatible: columns.append("incompatible_lineages") if long and not call_all: columns.extend(["ref_count","alt_count","ambig_count","other_count","rule_count","support","conflict"]) + columns.append("constellation_name") if mutations_list: columns.extend(mutations_list) - variants_out.write("%s\n" %",".join(columns)) + variants_out.write("%s\n" % ",".join(columns)) counts_out = {} if output_counts: for constellation in constellation_dict: - clean_name = re.sub("[^a-zA-Z0-9_\-.]","_",constellation) - counts_out[constellation] = open("%s.%s_counts.csv" % (out_csv.replace(".csv", ""), clean_name), "w") - counts_out[constellation].write("query,ref_count,alt_count,ambig_count,other_count,rule_count,support," - "conflict,call\n") + constellation_name = name_dict[constellation] + if not constellation_name: + continue + if constellation_name not in counts_out: + clean_name = re.sub("[^a-zA-Z0-9_\-.]", "_", constellation_name) + counts_out[constellation_name] = open("%s.%s_counts.csv" % (out_csv.replace(".csv", ""), clean_name), "w") + counts_out[constellation_name].write("query,ref_count,alt_count,ambig_count,other_count,rule_count,support," + "conflict,call,constellation_name,note\n") with open(in_fasta, "r") as f: for record in SeqIO.parse(f, "fasta"): @@ -754,39 +978,75 @@ def classify_constellations(in_fasta, list_constellation_files, constellation_na sys.exit(1) lineages = [] + names = [] best_constellation = None best_support = 0 best_conflict = 1 best_counts = None + scores = {} + children = {} for constellation in constellation_dict: - counts, call = count_and_classify(record.seq, + constellation_name = name_dict[constellation] + parents = [] + if not constellation_name: + continue + counts, call, note = count_and_classify(record.seq, constellation_dict[constellation], rule_dict[constellation]) + current_constellation = constellation + while current_constellation in parent_lineage_dict: + logging.debug("Current constellation %s in parent dict" % current_constellation) + current_constellation = name_dict[lineage_name_dict[parent_lineage_dict[current_constellation]]] + parents.append(current_constellation) + parent_counts, parent_call, parent_note = count_and_classify(record.seq, + constellation_dict[current_constellation], + rule_dict[current_constellation]) + counts, call, note = combine_counts_call_notes(counts, call, note, parent_counts, parent_call, parent_note) + + for parent in parents: + if parent not in children: + children[parent] = [] + children[parent].append(constellation) + if call: if call_all: - lineages.append(constellation) + lineages.append(constellation_name) + names.append(constellation) + elif constellation in children and best_constellation in children[constellation]: + continue elif (not best_constellation) \ or (counts['support'] > best_support) \ or (counts['support'] == best_support and counts['conflict'] < best_conflict)\ - or (counts['support'] == best_support and counts['conflict'] == best_conflict and counts['rules'] > best_counts["rules"]): + or (counts['support'] == best_support and counts['conflict'] == best_conflict and counts['rules'] > best_counts["rules"])\ + or (best_constellation in parents): best_constellation = constellation best_support = counts['support'] best_conflict = counts['conflict'] best_counts = counts - elif len(constellation_dict) == 1: - best_counts = counts + + if interspersion: + if counts["alt"] > 1: + summary = constellation + score = counts["support"] + scores[score] = summary if output_counts: - counts_out[constellation].write( - "%s,%i,%i,%i,%i,%i,%f,%f,%s\n" % (record.id, counts['ref'], counts['alt'], counts['ambig'], + counts_out[constellation_name].write( + "%s,%i,%i,%i,%i,%i,%f,%f,%s,%s,%s\n" % (record.id, counts['ref'], counts['alt'], counts['ambig'], counts['oth'], counts['rules'], counts['support'], - counts['conflict'], call)) + counts['conflict'], call, constellation, note)) if not call_all and best_constellation: - lineages.append(best_constellation) - - out_entries = [record.id, "|".join(lineages),"|".join([mrca_lineage_dict[l] for l in lineages])] + lineages.append(name_dict[best_constellation]) + names.append(best_constellation) + + out_entries = [record.id, "|".join(lineages), "|".join([mrca_lineage_dict[n] for n in names])] + if interspersion: + sorted_scores = sorted(scores, key=lambda x: float(x), reverse=True) + top_scoring = scores[sorted_scores[0]] + barcode_list, counts, constellation_count_dict, sorted_alt_sites = generate_barcode(record.seq, constellation_dict[constellation], ref_char="-") + get_interspersion(top_scoring, barcode_list) if list_incompatible: - out_entries.append("|".join([incompatible_dict[l] for l in lineages])) + out_entries.append("|".join([incompatible_dict[n] for n in names])) if long and best_counts is not None: out_entries.append("%i,%i,%i,%i,%i,%f,%f" % (best_counts['ref'], best_counts['alt'], best_counts['ambig'], @@ -794,9 +1054,10 @@ def classify_constellations(in_fasta, list_constellation_files, constellation_na best_counts['support'], best_counts['conflict'])) elif long and not call_all: out_entries.append(",,,,,,") + out_entries.append("|".join(names)) if mutations_list: - barcode_list, counts = generate_barcode(record.seq, mutation_variants) + barcode_list, counts, ignore, ignore2 = generate_barcode(record.seq, mutation_variants) out_entries.extend(barcode_list) variants_out.write("%s\n" % ",".join(out_entries)) @@ -811,18 +1072,20 @@ def list_constellations(list_constellation_files, constellation_names, reference reference_seq, features_dict = load_feature_coordinates(reference_json) - list_of_constellations = [] + list_of_constellations = set() for constellation_file in list_constellation_files: - constellation, variants, ignore, mrca_lineage, incompatible_lineage_calls = parse_variants_in(reference_seq, features_dict, constellation_file, constellation_names, label=label) + constellation, output_name, variants, ignore, mrca_lineage, \ + incompatible_lineage_calls, parent_lineage, lineage_name = \ + parse_variants_in(reference_seq, features_dict, constellation_file, constellation_names, label=label) if not constellation: continue - if constellation_names and constellation not in constellation_names: + if constellation_names and constellation not in constellation_names and output_name not in constellation_names: continue if len(variants) > 0 and mrca_lineage: - list_of_constellations.append(mrca_lineage) + list_of_constellations.add(output_name) + list_of_constellations.add(mrca_lineage) elif len(variants) > 0: - list_of_constellations.append(constellation) - print("\n".join(list_of_constellations)) + list_of_constellations.add(output_name) def parse_args(): parser = argparse.ArgumentParser(description="""Type an alignment at specific sites and classify with a barcode.""", diff --git a/scorpio/subcommands/define.py b/scorpio/subcommands/define.py index d53109f..9b9e18d 100644 --- a/scorpio/subcommands/define.py +++ b/scorpio/subcommands/define.py @@ -13,4 +13,5 @@ def run(options): options.subset, options.threshold_common, options.threshold_intermediate, - options.outgroups) + options.outgroups, + options.protein) diff --git a/scorpio/subcommands/haplotype.py b/scorpio/subcommands/haplotype.py index 8e06f76..8ce5831 100644 --- a/scorpio/subcommands/haplotype.py +++ b/scorpio/subcommands/haplotype.py @@ -14,4 +14,5 @@ def run(options): options.label, options.append_genotypes, options.mutations, - options.dry_run) + options.dry_run, + options.combination) diff --git a/scorpio/tests/data/type_constellations/expected.classified.csv b/scorpio/tests/data/type_constellations/expected.classified.csv index 536caaf..574cdb1 100644 --- a/scorpio/tests/data/type_constellations/expected.classified.csv +++ b/scorpio/tests/data/type_constellations/expected.classified.csv @@ -1,8 +1,8 @@ -query,constellations,mrca_lineage -Flatland01,Lineage_X,B.1.1.7 -Flatland02,Lineage_X,B.1.1.7 -Flatland03,Lineage_X,B.1.1.7 -Flatland04,, -Flatland05,, -Flatland06,, -Reference,, +query,constellations,mrca_lineage,constellation_name +Flatland01,Lineage_X,B.1.1.7,Lineage_X +Flatland02,Lineage_X,B.1.1.7,Lineage_X +Flatland03,Lineage_X,B.1.1.7,Lineage_X +Flatland04,,, +Flatland05,,, +Flatland06,,, +Reference,,, diff --git a/scorpio/tests/data/type_constellations/expected.classified.incompatible.csv b/scorpio/tests/data/type_constellations/expected.classified.incompatible.csv index 7da69cd..ae6efd9 100644 --- a/scorpio/tests/data/type_constellations/expected.classified.incompatible.csv +++ b/scorpio/tests/data/type_constellations/expected.classified.incompatible.csv @@ -1,8 +1,8 @@ -query,constellations,mrca_lineage,incompatible_lineages -Flatland01,Lineage_X,B.1.1.7,A|B.1.351 -Flatland02,Lineage_X,B.1.1.7,A|B.1.351 -Flatland03,Lineage_X,B.1.1.7,A|B.1.351 -Flatland04,,, -Flatland05,,, -Flatland06,,, -Reference,,, +query,constellations,mrca_lineage,incompatible_lineages,constellation_name +Flatland01,Lineage_X,B.1.1.7,A|B.1.351,Lineage_X +Flatland02,Lineage_X,B.1.1.7,A|B.1.351,Lineage_X +Flatland03,Lineage_X,B.1.1.7,A|B.1.351,Lineage_X +Flatland04,,,, +Flatland05,,,, +Flatland06,,,, +Reference,,,, diff --git a/scorpio/tests/data/type_constellations/expected.typed.csv b/scorpio/tests/data/type_constellations/expected.typed.csv index 2d5c0b3..4bb2cef 100644 --- a/scorpio/tests/data/type_constellations/expected.typed.csv +++ b/scorpio/tests/data/type_constellations/expected.typed.csv @@ -1,8 +1,8 @@ query,Lineage_X -Flatland01,-IDTT3TT-21YDHIAH-*ICLFK -Flatland02,-IDTT3TT-21YDHIAH-*ICLFK -Flatland03,-IDTT3TT-21YDHIAH-*ICLFK -Flatland04,-IDTT3TT-2XYDHIAH-*ICLF- -Flatland05,-----3-----Y------XX---K -Flatland06,-----3-----X-----------X +Flatland01,--IDTT3TT-21KYDHIAH*ICLF +Flatland02,--IDTT3TT-21KYDHIAH*ICLF +Flatland03,--IDTT3TT-21KYDHIAH*ICLF +Flatland04,--IDTT3TT-2X-YDHIAH*ICLF +Flatland05,------3-----KY-----XX--- +Flatland06,------3-----XX---------- Reference,------------------------ diff --git a/scorpio/tests/data/type_constellations/lineage_X.json b/scorpio/tests/data/type_constellations/lineage_X.json index 72ffabf..358ddc7 100644 --- a/scorpio/tests/data/type_constellations/lineage_X.json +++ b/scorpio/tests/data/type_constellations/lineage_X.json @@ -7,6 +7,7 @@ }, "sites": [ "nuc:C912T", + "nuc:T2680C", "1ab:T1001I", "1ab:A1708D", "nuc:C5986T", @@ -17,19 +18,18 @@ "nuc:C16175T", "s:HV69-", "s:Y144-", + "s:E484K", "s:N501Y", "s:A570D", "s:P681H", "s:T716I", "s:S982A", "s:D1118H", - "nuc:T2680C", "8:Q27*", "8:R52I", "8:Y73C", "N:D3L", - "N:S235F", - "s:E484K" + "N:S235F" ], "rules": { "min_alt": 4, diff --git a/scorpio/tests/type_constellations_test.py b/scorpio/tests/type_constellations_test.py index 2a82cfb..b1f4e54 100644 --- a/scorpio/tests/type_constellations_test.py +++ b/scorpio/tests/type_constellations_test.py @@ -141,7 +141,7 @@ def test_variant_to_variant_record(): def test_parse_json_in(): variants_file = "%s/lineage_X.json" % data_dir - variant_list, name, rules, mrca_lineage, incompatible_lineages = parse_json_in(refseq, features_dict, variants_file) + variant_list, name, output_name, rules, mrca_lineage, incompatible_lineages, parent_lineage, lineage_name = parse_json_in(refseq, features_dict, variants_file) assert len(variant_list) == 24 assert len([v for v in variant_list if v["type"] == "snp"]) == 6 assert len([v for v in variant_list if v["type"] == "del"]) == 3 @@ -185,7 +185,7 @@ def test_parse_variants_in(): results = [] for i in range(len(in_files)): - name, variant_list, rule_dict, mrca_lineage, incompatible_lineages = parse_variants_in(refseq, features_dict, in_files[i]) + name, output_name, variant_list, rule_dict, mrca_lineage, incompatible_lineages, parent_lineage, lineage_name = parse_variants_in(refseq, features_dict, in_files[i]) assert expect_names[i] == name assert expect_rules[i] == rule_dict results.append(variant_list) @@ -245,12 +245,13 @@ def test_count_and_classify(): rules = {"min_alt": 1, "max_ref": 1, "snp2": "alt"} expect_classify = [False, False, True, False] - expect_counts = [{"ref": 5, "alt": 0, "ambig": 0, "oth": 1, "rules": 0, "support": 0.0, "conflict": 0.8333}, - {"ref": 1, "alt": 4, "ambig": 0, "oth": 1, "rules": 0, "support": 0.6667, "conflict": 0.1667}, - {"ref": 0, "alt": 5, "ambig": 0, "oth": 1, "rules": 3, "support": 0.8333, "conflict": 0.0}, - {"ref": 0, "alt": 1, "ambig": 0, "oth": 5, "rules": 0, "support": 0.1667, "conflict": 0.0}] + expect_counts = [{"ref": 5, "alt": 0, "ambig": 0, "oth": 1, "rules": 0, 'substitution': {'ref': 4, 'alt': 0, 'ambig': 0, 'oth': 0}, 'indel': {'ref': 1, 'alt': 0, 'ambig': 0, 'oth': 1}, "support": 0.0, "conflict": 0.8333}, + {"ref": 1, "alt": 4, "ambig": 0, "oth": 1, "rules": 0, 'substitution': {'ref': 1, 'alt': 3, 'ambig': 0, 'oth': 0}, 'indel': {'ref': 0, 'alt': 1, 'ambig': 0, 'oth': 1}, "support": 0.6667, "conflict": 0.1667}, + {"ref": 0, "alt": 5, "ambig": 0, "oth": 1, "rules": 3, 'substitution': {'ref': 0, 'alt': 4, 'ambig': 0, 'oth': 0}, 'indel': {'ref': 0, 'alt': 1, 'ambig': 0, 'oth': 1}, "support": 0.8333, "conflict": 0.0}, + {"ref": 0, "alt": 1, "ambig": 0, "oth": 5, "rules": 0, 'substitution': {'ref': 0, 'alt': 1, 'ambig': 0, 'oth': 3}, 'indel': {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 2}, "support": 0.1667, "conflict": 0.0}] + for i in range(len(seqs)): - counts, classify = count_and_classify(seqs[i], variants, rules) + counts, classify, note = count_and_classify(seqs[i], variants, rules) print(i, counts, classify) assert classify == expect_classify[i] assert counts == expect_counts[i] @@ -276,23 +277,23 @@ def test_generate_barcode(): {"ref": 0, "alt": 1, "ambig": 0, "oth": 5, "support": 0.1667, "conflict": 0.0}] for i in range(len(seqs)): - barcode_list, counts = generate_barcode(seqs[i], variants, ref_char="-", ins_char="?", oth_char="X") + barcode_list, counts, constellation_count_dict, sorted_alt_constellations = generate_barcode(seqs[i], variants, ref_char="-", ins_char="?", oth_char="X") barcode = ''.join(barcode_list) print(i, barcode, counts) assert barcode == expect_barcode_dash[i] assert counts == expect_counts[i] - barcode_list, counts = generate_barcode(seqs[i], variants, ref_char=None, ins_char="?", oth_char="X") + barcode_list, counts, constellation_count_dict, sorted_alt_constellations = generate_barcode(seqs[i], variants, ref_char=None, ins_char="?", oth_char="X") barcode = ''.join(barcode_list) print(i, barcode, counts) assert barcode == expect_barcode_ref[i] - barcode_list, counts = generate_barcode(seqs[i], variants, ref_char=None, ins_char="?", oth_char=None) + barcode_list, counts, constellation_count_dict, sorted_alt_constellations = generate_barcode(seqs[i], variants, ref_char=None, ins_char="?", oth_char=None) barcode = ''.join(barcode_list) print(i, barcode, counts) assert barcode == expect_barcode_ref_oth[i] - barcode_list, counts = generate_barcode(seqs[i], variants, ref_char="-", ins_char="$", oth_char="X") + barcode_list, counts, constellation_count_dict, sorted_alt_constellations = generate_barcode(seqs[i], variants, ref_char="-", ins_char="$", oth_char="X") barcode = ''.join(barcode_list) print(i, barcode, counts) assert barcode == expect_barcode_ins[i]