From cf1bd5e699b1d907cb8663e6d4a42ad4d57fdb4d Mon Sep 17 00:00:00 2001 From: Marcus Stoiber Date: Thu, 26 Sep 2019 16:33:17 -0700 Subject: [PATCH 01/14] Updates to normalize SNP database. Main command runs, but variant output does not appear logically correct. --- megalodon/aggregate.py | 1 - megalodon/megalodon.py | 2 +- megalodon/megalodon_helper.py | 3 + megalodon/mods.py | 59 ++--- megalodon/snps.py | 486 +++++++++++++++++++++++++++------- 5 files changed, 429 insertions(+), 122 deletions(-) diff --git a/megalodon/aggregate.py b/megalodon/aggregate.py index bc873c4..21e2c21 100644 --- a/megalodon/aggregate.py +++ b/megalodon/aggregate.py @@ -267,7 +267,6 @@ def aggregate_stats( 0, 0, queue.Queue(), queue.Queue()) if mh.SNP_NAME in outputs: snps_db_fn = mh.get_megalodon_fn(out_dir, mh.PR_SNP_NAME) - logger.info('Computing number of unique variants.') num_snps = snps.AggSnps(snps_db_fn).num_uniq() logger.info('Spawning variant aggregation processes.') # create process to collect snp stats from workers diff --git a/megalodon/megalodon.py b/megalodon/megalodon.py index 313e046..4892961 100644 --- a/megalodon/megalodon.py +++ b/megalodon/megalodon.py @@ -89,7 +89,7 @@ def process_read( if snps_q is not None: handle_errors( - func=snps.call_read_snps, + func=snps.call_read_vars, args=(snps_data, r_ref_pos, np_ref_seq, mapped_rl_cumsum, r_to_q_poss, r_post, post_mapped_start), r_vals=(read_id, r_ref_pos.chrm, r_ref_pos.strand, diff --git a/megalodon/megalodon_helper.py b/megalodon/megalodon_helper.py index 68925ad..a78b90c 100644 --- a/megalodon/megalodon_helper.py +++ b/megalodon/megalodon_helper.py @@ -37,6 +37,9 @@ _MAX_QUEUE_SIZE = 10000 +# allow 64GB for memory mapped sqlite file access +MEMORY_MAP_LIMIT = 64000000000 + # VCF spec text MIN_GL_VALUE = -999 MAX_PL_VALUE = 999 diff --git a/megalodon/mods.py b/megalodon/mods.py index 959e420..5ffa12f 100644 --- a/megalodon/mods.py +++ b/megalodon/mods.py @@ -37,9 +37,6 @@ OUT_BUFFER_LIMIT = 10000 -# allow 64GB for memory mapped sqlite file access -MEMORY_MAP_LIMIT = 64000000000 - ################### ##### Mods DB ##### @@ -51,29 +48,30 @@ class ModsDb(object): # when inserting into the data table forces a scan of a very # large table or maintainance of a very large pos table index # both of which slow data base speed - # thus foreign constraint must be handled by the class - chrm_tbl = OrderedDict(( - ('chrm_id', 'INTEGER PRIMARY KEY'), - ('chrm', 'TEXT'))) - pos_tbl = OrderedDict(( - ('pos_id', 'INTEGER PRIMARY KEY'), - ('pos_chrm', 'INTEGER'), - ('strand', 'INTEGER'), - ('pos', 'INTEGER'))) - mod_tbl = OrderedDict(( - ('mod_id', 'INTEGER PRIMARY KEY'), - ('mod_base', 'TEXT'), - ('motif', 'TEXT'), - ('motif_pos', 'INTEGER'), - ('raw_motif', 'TEXT'))) - read_tbl = OrderedDict(( - ('read_id', 'INTEGER PRIMARY KEY'), - ('uuid', 'TEXT'))) - data_tbl = OrderedDict(( - ('score', 'FLOAT'), - ('score_pos', 'INTEGER'), - ('score_mod', 'INTEGER'), - ('score_read', 'INTEGER'))) + # thus foreign key constraint must be handled by the class + db_tables = OrderedDict(( + ('chrm', OrderedDict(( + ('chrm_id', 'INTEGER PRIMARY KEY'), + ('chrm', 'TEXT')))), + ('pos', OrderedDict(( + ('pos_id', 'INTEGER PRIMARY KEY'), + ('pos_chrm', 'INTEGER'), + ('strand', 'INTEGER'), + ('pos', 'INTEGER')))), + ('mod', OrderedDict(( + ('mod_id', 'INTEGER PRIMARY KEY'), + ('mod_base', 'TEXT'), + ('motif', 'TEXT'), + ('motif_pos', 'INTEGER'), + ('raw_motif', 'TEXT')))), + ('read', OrderedDict(( + ('read_id', 'INTEGER PRIMARY KEY'), + ('uuid', 'TEXT')))), + ('data', OrderedDict(( + ('score', 'FLOAT'), + ('score_pos', 'INTEGER'), + ('score_mod', 'INTEGER'), + ('score_read', 'INTEGER')))))) # namedtuple for returning mods from a single position mod_data = namedtuple('mod_data', [ @@ -105,7 +103,7 @@ def __init__(self, fn, read_only=True, db_safety=1, self.cur = self.db.cursor() if read_only: # use memory mapped file access - self.db.execute('PRAGMA mmap_size = {}'.format(MEMORY_MAP_LIMIT)) + self.db.execute('PRAGMA mmap_size = {}'.format(mh.MEMORY_MAP_LIMIT)) if self.cm_idx_in_mem: self.load_chrm_read_index() self.load_mod_read_index() @@ -120,10 +118,7 @@ def __init__(self, fn, read_only=True, db_safety=1, self.db.execute('PRAGMA journal_mode = OFF') # create tables - for tbl_name, tbl in ( - ('chrm', self.chrm_tbl), ('pos', self.pos_tbl), - ('mod', self.mod_tbl), ('read', self.read_tbl), - ('data', self.data_tbl)): + for tbl_name, tbl in self.db_tables.items(): try: self.db.execute("CREATE TABLE {} ({})".format( tbl_name, ','.join(( @@ -636,6 +631,8 @@ def store_mod_call( if mods_txt_fp is not None: mods_txt_fp.close() if pr_refs_fn is not None: pr_refs_fp.close() mods_db.create_mod_index() + if not mods_db.pos_idx_in_mem: + mods_db.create_pos_index() mods_db.create_data_covering_index() mods_db.close() diff --git a/megalodon/snps.py b/megalodon/snps.py index 5f7d498..a0bd899 100755 --- a/megalodon/snps.py +++ b/megalodon/snps.py @@ -5,9 +5,8 @@ import datetime from time import sleep from array import array -import multiprocessing as mp from operator import itemgetter -from itertools import product, combinations, groupby +from itertools import product, combinations from collections import defaultdict, namedtuple, OrderedDict import pysam @@ -20,49 +19,17 @@ _DEBUG_PER_READ = False -_RAISE_VARIANT_PROCESSING_ERRORS = False +_RAISE_VARIANT_PROCESSING_ERRORS = True VARIANT_DATA = namedtuple('VARIANT_DATA', ( 'np_ref', 'np_alts', 'id', 'chrom', 'start', 'stop', - 'ref', 'alts', 'ref_start', )) + 'ref', 'alts', 'ref_start')) # set default value of None for ref, alts and ref_start VARIANT_DATA.__new__.__defaults__ = (None, None, None) DIPLOID_MODE = 'diploid' HAPLIOD_MODE = 'haploid' -FIELD_NAMES = ('read_id', 'chrm', 'strand', 'pos', 'score', - 'ref_seq', 'alt_seq', 'snp_id', 'test_start', 'test_end') -SNP_DATA = namedtuple('SNP_DATA', FIELD_NAMES) -CREATE_SNPS_TBLS = """ -CREATE TABLE snps ( - {} TEXT, - {} TEXT, - {} INTEGER, - {} INTEGER, - {} FLOAT, - {} TEXT, - {} TEXT, - {} TEXT, - {} INTEGER, - {} INTEGER -)""".format(*FIELD_NAMES) - -SET_NO_ROLLBACK_MODE='PRAGMA journal_mode = OFF' -SET_ASYNC_MODE='PRAGMA synchronous = OFF' - -ADDMANY_SNPS = "INSERT INTO snps VALUES (?,?,?,?,?,?,?,?,?,?)" -CREATE_SNPS_IDX = ''' -CREATE INDEX snp_pos ON snps (chrm, test_start, test_end)''' - -COUNT_UNIQ_SNPS = """ -SELECT COUNT(*) FROM ( -SELECT DISTINCT chrm, test_start, test_end FROM snps)""" -SEL_UNIQ_SNP_ID = ''' -SELECT DISTINCT chrm, test_start, test_end FROM snps''' -SEL_SNP_STATS = ''' -SELECT * FROM snps WHERE chrm IS ? AND test_start IS ? AND test_end IS ?''' - SAMPLE_NAME = 'SAMPLE' # specified by sam format spec WHATSHAP_MAX_QUAL = 40 @@ -84,6 +51,360 @@ 'alleles (semi-colon separated)">') +####################### +##### Variants DB ##### +####################### + +class VarsDb(object): + # note foreign key constraint is not applied here as this + # drastically reduces efficiency. Namely the search for pos_id + # when inserting into the data table forces a scan of a very + # large table or maintainance of a very large pos table index + # both of which slow data base speed + # thus foreign key constraint must be handled by the class + db_tables = OrderedDict(( + ('chrm', OrderedDict(( + ('chrm_id', 'INTEGER PRIMARY KEY'), + ('chrm', 'TEXT')))), + ('loc', OrderedDict(( + ('loc_id', 'INTEGER PRIMARY KEY'), + ('loc_chrm', 'INTEGER'), + ('test_start', 'INTEGER'), + ('test_end', 'INTEGER'), + ('var_name', 'TEXT'), + ('pos', 'INTEGER'), + ('ref_seq', 'TEXT')))), + ('alt', OrderedDict(( + ('alt_id', 'INTEGER PRIMARY KEY'), + ('alt_seq', 'TEXT')))), + ('read', OrderedDict(( + ('read_id', 'INTEGER PRIMARY KEY'), + ('uuid', 'TEXT'), + ('strand', 'INTEGER')))), + ('data', OrderedDict(( + ('score', 'FLOAT'), + ('score_loc', 'INTEGER'), + ('score_alt', 'INTEGER'), + ('score_read', 'INTEGER')))))) + + # namedtuple for returning var info from a single position + var_data = namedtuple('var_data', [ + 'score', 'read_id', 'chrm', 'alt_seq', 'pos', 'ref_seq', 'var_name']) + + def __init__(self, fn, read_only=True, db_safety=1, + loc_index_in_memory=False, chrm_index_in_memory=True, + alt_index_in_memory=True): + """ Interface to database containing sequence variant statistics. + + Default settings are for optimal read_only performance. + """ + self.fn = mh.resolve_path(fn) + self.read_only = read_only + self.loc_idx_in_mem = loc_index_in_memory + self.chrm_idx_in_mem = chrm_index_in_memory + self.alt_idx_in_mem = alt_index_in_memory + + if read_only: + if not os.path.exists(fn): + logger = logging.get_logger('vars') + logger.error(( + 'Variant per-read database file ({}) does ' + + 'not exist.').format(fn)) + raise mh.MegaError('Invalid variant DB filename.') + self.db = sqlite3.connect('file:' + fn + '?mode=ro', uri=True) + else: + self.db = sqlite3.connect(fn) + + self.cur = self.db.cursor() + if self.read_only: + # use memory mapped file access + self.db.execute('PRAGMA mmap_size = {}'.format(mh.MEMORY_MAP_LIMIT)) + if self.chrm_idx_in_mem: + self.load_chrm_read_index() + if self.loc_idx_in_mem: + self.load_loc_read_index() + if self.alt_idx_in_mem: + self.load_alt_read_index() + else: + if db_safety < 2: + # set asynchronous mode to off for max speed + self.db.execute('PRAGMA synchronous = OFF') + if db_safety < 1: + # set no rollback mode + self.db.execute('PRAGMA journal_mode = OFF') + + # create tables + for tbl_name, tbl in self.db_tables.items(): + try: + self.db.execute("CREATE TABLE {} ({})".format( + tbl_name, ','.join(( + '{} {}'.format(*ft) for ft in tbl.items())))) + except sqlite3.OperationalError: + raise mh.MegaError( + 'Sequence variants database already exists. Either ' + + 'provide location for new database or open in ' + + 'read_only mode.') + + if self.loc_idx_in_mem: + self.loc_idx = {} + else: + self.create_loc_index() + if self.chrm_idx_in_mem: + self.chrm_idx = {} + else: + self.create_chrm_index() + if self.alt_idx_in_mem: + self.alt_idx = {} + else: + self.create_alt_index() + + return + + def insert_chrm(self, chrm): + self.cur.execute('INSERT INTO chrm (chrm) VALUES (?)', (chrm,)) + if self.chrm_idx_in_mem: + self.chrm_idx[chrm] = self.cur.lastrowid + return self.cur.lastrowid + + def get_chrm_id(self, chrm): + try: + if self.chrm_idx_in_mem: + chrm_id = self.chrm_idx[chrm] + else: + chrm_id = self.cur.execute( + 'SELECT chrm_id FROM chrm WHERE chrm=?', + (chrm,)).fetchone()[0] + except (TypeError, KeyError): + raise mh.MegaError('Reference record (chromosome) not found in ' + + 'database.') + return chrm_id + + def get_chrm(self, chrm_id): + try: + if self.chrm_idx_in_mem: + chrm = self.chrm_read_idx[chrm_id] + else: + chrm = self.cur.execute( + 'SELECT chrm FROM chrm WHERE chrm_id=?', + (chrm_id,)).fetchone()[0] + except (TypeError, KeyError): + raise mh.MegaError('Reference record (chromosome) not found in ' + + 'vars database.') + return chrm + + def get_loc_id(self, chrm, test_start, test_end, chrm_id=None): + if chrm_id is None: + chrm_id = self.get_chrm_id(chrm) + + try: + if self.loc_idx_in_mem: + loc_id = self.loc_idx[(chrm_id, test_start, test_end)] + else: + loc_id = self.cur.execute( + 'SELECT loc_id FROM loc WHERE loc_chrm=? AND ' + + 'test_start=? AND test_end=?', ( + chrm_id, test_start, test_end)).fetchone()[0] + except (TypeError, KeyError): + raise mh.MegaError( + 'Reference position not found in database.') + + return loc_id + + def get_loc_id_or_insert(self, chrm, test_start, test_end, var_name, + pos, ref_seq, chrm_id=None): + if chrm_id is None: + chrm_id = self.get_chrm_id(chrm) + try: + loc_id = self.get_loc_id(None, test_start, test_end, chrm_id) + except mh.MegaError: + self.cur.execute( + 'INSERT INTO loc (loc_chrm, test_start, test_end, var_name, ' + + 'pos, ref_seq) VALUES (?,?,?,?,?,?)', + (chrm_id, test_start, test_end,var_name, pos, ref_seq)) + loc_id = self.cur.lastrowid + if self.loc_idx_in_mem: + self.loc_idx[(chrm_id, test_start, test_end)] = loc_id + return loc_id + + def get_loc_data(self, loc_id): + try: + if self.loc_idx_in_mem: + loc_data = self.loc_read_idx[chrm_id] + else: + loc_data = self.cur.execute( + 'SELECT pos, ref_seq, var_name FROM loc ' + + 'WHERE loc_id = ?', (loc_id, )).fetchone() + except (TypeError, KeyError): + raise mh.MegaError('Variant location data not found in ' + + 'vars database.') + return loc_data + + def get_alt_id(self, alt_seq): + try: + if self.alt_idx_in_mem: + alt_id = self.alt_idx[alt_seq] + else: + alt_id = self.cur.execute( + 'SELECT alt_id FROM var WHERE alt_seq=?', + (var_name, )).fetchone()[0] + except (TypeError, KeyError): + raise mh.MegaError('Variant not found in var table.') + return alt_id + + def get_alt_id_or_insert(self, alt_seq): + try: + alt_id = self.get_alt_id(alt_seq) + except mh.MegaError: + self.cur.execute( + 'INSERT INTO alt (alt_seq) VALUES (?)', (alt_seq, )) + alt_id = self.cur.lastrowid + if self.alt_idx_in_mem: + self.alt_idx[alt_seq] = alt_id + return alt_id + + def get_alt_seq(self, alt_id): + try: + if self.alt_idx_in_mem: + alt_seq = self.alt_read_idx[alt_id] + else: + alt_seq = self.cur.execute( + 'SELECT alt_seq FROM alt WHERE alt_id=?', + (alt_id,)).fetchone()[0] + except (TypeError, KeyError): + raise mh.MegaError('Alt sequence not found in vars database.') + return alt_seq + + def insert_read_scores(self, r_var_scores, uuid, chrm, strand): + self.cur.execute('INSERT INTO read (uuid, strand) VALUES (?,?)', + (uuid, strand)) + read_id = self.cur.lastrowid + chrm_id = self.get_chrm_id(chrm) + read_insert_data = [] + for (pos, alt_lps, snp_ref_seq, snp_alt_seqs, var_name, + test_start, test_end) in r_var_scores: + loc_id = self.get_loc_id_or_insert( + None, test_start, test_end, var_name, pos, snp_ref_seq, chrm_id) + for alt_lp, alt_seq in zip(alt_lps, snp_alt_seqs): + alt_id = self.get_alt_id_or_insert(alt_seq) + read_insert_data.append((alt_lp, loc_id, alt_id, read_id)) + + self.cur.executemany( + 'INSERT INTO data VALUES (?,?,?,?)', read_insert_data) + return + + def create_chrm_index(self): + self.cur.execute('CREATE UNIQUE INDEX chrm_idx ON chrm(chrm)') + return + + def load_chrm_read_index(self): + self.cur.execute('SELECT chrm_id, chrm FROM chrm') + self.chrm_read_idx = dict(self.cur.fetchall()) + return + + def create_alt_index(self): + self.cur.execute('CREATE UNIQUE INDEX alt_idx ON alt(alt_seq)') + return + + def load_alt_read_index(self): + self.cur.execute('SELECT alt_id, alt_seq FROM alt') + self.alt_read_idx = dict(self.cur.fetchall()) + return + + def create_loc_index(self): + self.cur.execute('CREATE UNIQUE INDEX loc_idx ON loc' + + '(loc_chrm, test_start, test_end)') + return + + def load_loc_read_index(self): + self.cur.execute('SELECT loc_id, pos, ref_seq, var_name FROM loc') + self.loc_read_idx = dict( + (loc_id, (pos, ref_seq, var_name)) + for loc_id, pos, ref_seq, var_name in self.cur) + return + + def create_data_covering_index(self): + self.cur.execute('CREATE INDEX data_cov_idx ON data(' + + 'score_loc, score_alt, score_read, score)') + return + + def close(self): + self.db.commit() + self.db.close() + return + + def get_num_uniq_var_loc(self): + return self.cur.execute('SELECT MAX(loc_id) FROM loc').fetchone()[0] + + def iter_loc_id(self): + self.cur.execute('SELECT loc_id FROM loc') + for loc_id in self.cur: + yield loc_id[0] + + return + + def iter_locs(self): + self.cur.execute( + 'SELECT loc_id, loc_chrm, test_start, test_end FROM loc') + for loc in self.cur: + yield loc + + return + + def iter_loc_id_ordered(self): + self.cur.execute('SELECT loc_id FROM loc ORDER BY loc_id') + for loc_id in self.cur: + yield loc_id[0] + + return + + def iter_loc_ordered(self): + self.cur.execute( + 'SELECT loc_id, loc_chrm, test_start, test_end FROM loc ' + + 'ORDER BY loc_id') + for loc in self.cur: + yield loc + + return + + def get_loc_stats(self, loc_data, return_uuids=False): + read_id_conv = self.get_uuid if return_uuids else lambda x: x + loc_id, chrm_id, test_start, test_end = loc_data + self.cur.execute( + 'SELECT score, score_read, score_alt, score_loc FROM data ' + + 'WHERE score_loc=?', (loc_id, )) + return [ + self.var_data(score, read_id_conv(read_id), + self.get_chrm(chrm_id), + self.get_alt_seq(alt_id), + *self.get_loc_data(loc_id)) + for score, read_id, alt_id, loc_id in self.cur] + + def get_read_id(self, uuid): + try: + read_id = self.cur.execute( + 'SELECT read_id FROM read WHERE uuid=?', (uuid,)).fetchone()[0] + except TypeError: + raise mh.MegaError('Read ID not found in vars database.') + return read_id + + def get_read_id_or_insert(self, uuid): + try: + read_id = self.get_read_id(uuid) + except mh.MegaError: + self.cur.execute('INSERT INTO read (uuid) VALUES (?)', (uuid,)) + read_id = self.cur.lastrowid + return read_id + + def create_data_read_index(self): + self.cur.execute('CREATE INDEX data_read_idx ON data(score_read)') + return + + def get_read_stats(self, uuid): + # TODO implement this for API + raise NotImplementedError + return + + ############################ ##### Helper Functions ##### ############################ @@ -152,10 +473,10 @@ def score_seq(tpost, seq, tpost_start=0, tpost_end=None, return score -def call_read_snps( - snps_data, read_ref_pos, strand_read_np_ref_seq, rl_cumsum, r_to_q_poss, +def call_read_vars( + vars_data, read_ref_pos, strand_read_np_ref_seq, rl_cumsum, r_to_q_poss, r_post, post_mapped_start): - if read_ref_pos.end - read_ref_pos.start <= 2 * snps_data.edge_buffer: + if read_ref_pos.end - read_ref_pos.start <= 2 * vars_data.edge_buffer: raise mh.MegaError('Mapped region too short for variant calling.') # convert to forward strand sequence in order to annotate with variants @@ -163,15 +484,15 @@ def call_read_snps( mh.revcomp_np(strand_read_np_ref_seq)) # call all snps overlapping this read r_snp_calls = [] - logger = logging.get_logger('per_read_snps') + logger = logging.get_logger('per_read_vars') read_cached_scores = {} - read_variants = snps_data.fetch_read_variants( + read_variants = vars_data.fetch_read_variants( read_ref_pos, read_ref_fwd_seq) filt_read_variants = [] # first pass over variants assuming the reference ground truth # (not including context variants) for (np_s_snp_ref_seq, np_s_snp_alt_seqs, np_s_context_seqs, - s_ref_start, s_ref_end, variant) in snps_data.iter_snps( + s_ref_start, s_ref_end, variant) in vars_data.iter_snps( read_variants, read_ref_pos, read_ref_fwd_seq, context_max_dist=0): blk_start = rl_cumsum[r_to_q_poss[s_ref_start]] blk_end = rl_cumsum[r_to_q_poss[s_ref_end]] @@ -189,7 +510,7 @@ def call_read_snps( np_s_context_seqs[0][0], np_s_snp_ref_seq, np_s_context_seqs[0][1]]) loc_ref_lp = score_seq( r_post, np_ref_seq, post_mapped_start + blk_start, - post_mapped_start + blk_end, snps_data.all_paths) + post_mapped_start + blk_end, vars_data.all_paths) loc_alt_lps = [] loc_alt_llrs = [] @@ -202,12 +523,12 @@ def call_read_snps( np_s_context_seqs[0][1]]) loc_alt_lp = score_seq( r_post, np_alt_seq, post_mapped_start + blk_start, - post_mapped_start + blk_end, snps_data.all_paths) + post_mapped_start + blk_end, vars_data.all_paths) loc_alt_lps.append(loc_alt_lp) if _DEBUG_PER_READ: loc_contexts_alts_lps.append(np.array([loc_alt_lp,])) # calibrate log probs - loc_alt_llrs.append(snps_data.calibrate_llr( + loc_alt_llrs.append(vars_data.calibrate_llr( loc_ref_lp - loc_alt_lp, variant.ref, var_alt_seq)) # due to calibration mutli-allelic log likelihoods could result in @@ -220,7 +541,7 @@ def call_read_snps( np_s_snp_ref_seq, np_s_snp_alt_seqs, np_s_context_seqs, np.array([loc_ref_lp,]), loc_contexts_alts_lps, False, logger) - if sum(np.exp(loc_alt_log_ps)) >= snps_data.context_min_alt_prob: + if sum(np.exp(loc_alt_log_ps)) >= vars_data.context_min_alt_prob: filt_read_variants.append(variant) read_cached_scores[(variant.id, variant.start, variant.stop)] = ( loc_ref_lp, loc_alt_lps) @@ -233,7 +554,7 @@ def call_read_snps( # second round for variants with some evidence for alternative alleles # process with other potential variants as context for (np_s_snp_ref_seq, np_s_snp_alt_seqs, np_s_context_seqs, - s_ref_start, s_ref_end, variant) in snps_data.iter_snps( + s_ref_start, s_ref_end, variant) in vars_data.iter_snps( filt_read_variants, read_ref_pos, read_ref_fwd_seq): ref_cntxt_ref_lp, ref_cntxt_alt_lps = read_cached_scores[( variant.id, variant.start, variant.stop)] @@ -261,7 +582,7 @@ def call_read_snps( for up_context_seq, dn_context_seq in np_s_context_seqs[1:]) loc_contexts_ref_lps = np.array([ref_cntxt_ref_lp] + [score_seq( r_post, ref_seq, post_mapped_start + blk_start, - post_mapped_start + blk_end, snps_data.all_paths) + post_mapped_start + blk_end, vars_data.all_paths) for ref_seq in ref_context_seqs]) loc_ref_lp = logsumexp(loc_contexts_ref_lps) @@ -276,13 +597,13 @@ def call_read_snps( for up_context_seq, dn_context_seq in np_s_context_seqs[1:]) loc_contexts_alt_lps = np.array([ref_cntxt_alt_lp,] + [ score_seq(r_post, alt_seq, post_mapped_start + blk_start, - post_mapped_start + blk_end, snps_data.all_paths) + post_mapped_start + blk_end, vars_data.all_paths) for alt_seq in alt_context_seqs]) loc_alt_lp = logsumexp(loc_contexts_alt_lps) if _DEBUG_PER_READ: loc_contexts_alts_lps.append(loc_contexts_alt_lps) # calibrate log probs - loc_alt_llrs.append(snps_data.calibrate_llr( + loc_alt_llrs.append(vars_data.calibrate_llr( loc_ref_lp - loc_alt_lp, variant.ref, var_alt_seq)) # due to calibration mutli-allelic log likelihoods could result in @@ -443,7 +764,7 @@ def annotate_snps(r_start, ref_seq, r_snp_calls, strand): return snp_seq, snp_quals, snp_cigar def _get_snps_queue( - snps_q, snps_conn, snps_db_fn, snps_txt_fn, db_safety, pr_refs_fn, + snps_q, snps_conn, vars_db_fn, snps_txt_fn, db_safety, pr_refs_fn, pr_ref_filts, whatshap_map_fn, ref_names_and_lens, ref_fn): def write_whatshap_alignment( read_id, snp_seq, snp_quals, chrm, strand, r_st, snp_cigar): @@ -471,13 +792,7 @@ def write_whatshap_alignment( def get_snp_call( r_snp_calls, read_id, chrm, strand, r_start, ref_seq, read_len, q_st, q_en, cigar): - # note strand is +1 for fwd or -1 for rev - snps_db.executemany(ADDMANY_SNPS, [ - (read_id, chrm, strand, pos, alt_lp, - snp_ref_seq, snp_alt_seq, snp_id, test_start, test_end) - for pos, alt_lps, snp_ref_seq, snp_alt_seqs, snp_id, - test_start, test_end in r_snp_calls - for alt_lp, snp_alt_seq in zip(alt_lps, snp_alt_seqs)]) + snps_db.insert_read_scores(r_snp_calls, read_id, chrm, strand) if snps_txt_fp is not None and len(r_snp_calls) > 0: snp_out_text = '' for (pos, alt_lps, snp_ref_seq, snp_alt_seqs, snp_id, @@ -507,13 +822,13 @@ def get_snp_call( return - logger = logging.get_logger('snps_getter') - snps_db = sqlite3.connect(snps_db_fn) - if db_safety < 2: - snps_db.execute(SET_ASYNC_MODE) - if db_safety < 1: - snps_db.execute(SET_NO_ROLLBACK_MODE) - snps_db.execute(CREATE_SNPS_TBLS) + logger = logging.get_logger('vars_getter') + # TODO convert loc index to command line option + snps_db = VarsDb(vars_db_fn, db_safety=db_safety, read_only=False, + loc_index_in_memory=True) + for ref_name in ref_names_and_lens[0]: + snps_db.insert_chrm(ref_name) + snps_db.create_chrm_index() if snps_txt_fn is None: snps_txt_fp = None else: @@ -579,8 +894,10 @@ def get_snp_call( if snps_txt_fp is not None: snps_txt_fp.close() if pr_refs_fn is not None: pr_refs_fp.close() if whatshap_map_fn is not None: whatshap_map_fp.close() - snps_db.execute(CREATE_SNPS_IDX) - snps_db.commit() + snps_db.create_alt_index() + if not snps_db.loc_idx_in_mem: + snps_db.create_loc_index() + snps_db.create_data_covering_index() snps_db.close() return @@ -602,7 +919,7 @@ def check_vars_match_ref( ref_seq = aligner.seq(contig, var_data.start, var_data.stop) if ref_seq != var_data.ref: # variant reference sequence does not match reference - logger = logging.get_logger() + logger = logging.get_logger('vars') logger.debug(( 'Reference sequence does not match variant reference ' + 'sequence at {} expected "{}" got "{}"').format( @@ -618,7 +935,7 @@ def __init__( keep_snp_fp_open=False, do_validate_reference=True, edge_buffer=mh.DEFAULT_EDGE_BUFFER, context_min_alt_prob=mh.DEFAULT_CONTEXT_MIN_ALT_PROB): - logger = logging.get_logger('snps') + logger = logging.get_logger('vars') self.max_indel_size = max_indel_size self.all_paths = all_paths self.write_snps_txt = write_snps_txt @@ -1112,7 +1429,7 @@ def extract_variant_contexts(variant, context_vars): context_seqs) - logger = logging.get_logger('snps') + logger = logging.get_logger('vars') for variant, context_variants in self.iter_context_variants( read_variants, context_max_dist): (context_ref_start, context_read_start, context_read_end, @@ -1225,7 +1542,7 @@ def add_haploid_probs(self, probs, gts): try: qual = int(np.around(np.minimum(raw_pl[0], mh.MAX_PL_VALUE))) except ValueError: - logger = logging.get_logger() + logger = logging.get_logger('vars') logger.debug( 'NAN quality value encountered. gts:{}, probs:{}'.format( str(gts), str(probs))) @@ -1253,7 +1570,7 @@ def add_diploid_probs(self, probs, gts): try: qual = int(np.minimum(np.around(raw_pl[0]), mh.MAX_PL_VALUE)) except ValueError: - logger = logging.get_logger() + logger = logging.get_logger('vars') logger.debug( 'NAN quality value encountered. gts:{}, probs:{}'.format( str(gts), str(probs))) @@ -1337,34 +1654,25 @@ class AggSnps(mh.AbstractAggregationClass): """ Class to assist in database queries for per-site aggregation of SNP calls over reads. """ - def __init__(self, snps_db_fn, write_vcf_log_probs=False): + def __init__(self, vars_db_fn, write_vcf_log_probs=False, + loc_index_in_memory=False): # open as read only database - if not os.path.exists(snps_db_fn): - logger = logging.get_logger('snps') - logger.error(( - 'SNP per-read database file ({}) does ' + - 'not exist.').format(snps_db_fn)) - raise mh.MegaError('Invalid snps DB filename.') - self.snps_db = sqlite3.connect(snps_db_fn, uri=True) + self.snps_db = VarsDb( + vars_db_fn, loc_index_in_memory=loc_index_in_memory) self.n_uniq_snps = None self.write_vcf_log_probs = write_vcf_log_probs return def num_uniq(self): if self.n_uniq_snps is None: - self.n_uniq_snps = self.snps_db.execute( - COUNT_UNIQ_SNPS).fetchone()[0] + self.n_uniq_snps = self.snps_db.get_num_uniq_var_loc() return self.n_uniq_snps def iter_uniq(self): - for q_val in self.snps_db.execute(SEL_UNIQ_SNP_ID): + for q_val in self.snps_db.iter_locs(): yield q_val return - def get_per_read_snp_stats(self, snp_loc): - return [SNP_DATA(*snp_stats) for snp_stats in self.snps_db.execute( - SEL_SNP_STATS, snp_loc)] - def compute_diploid_probs(self, ref_lps, alts_lps, het_factor=1.0): def compute_het_lp(a1, a2): # order by the inverse log likelihood ratio @@ -1414,7 +1722,7 @@ def compute_snp_stats( assert call_mode in (HAPLIOD_MODE, DIPLOID_MODE), ( 'Invalid SNP aggregation ploidy call mode: {}.'.format(call_mode)) - pr_snp_stats = self.get_per_read_snp_stats(snp_loc) + pr_snp_stats = self.snps_db.get_loc_stats(snp_loc) alt_seqs = sorted(set(r_stats.alt_seq for r_stats in pr_snp_stats)) pr_alt_lps = defaultdict(dict) for r_stats in pr_snp_stats: @@ -1440,7 +1748,7 @@ def compute_snp_stats( r0_stats = pr_snp_stats[0] snp_var = Variant( chrom=r0_stats.chrm, pos=r0_stats.pos, ref=r0_stats.ref_seq, - alts=alt_seqs, id=r0_stats.snp_id) + alts=alt_seqs, id=r0_stats.var_name) snp_var.add_tag('DP', '{}'.format(ref_lps.shape[0])) snp_var.add_sample_field('DP', '{}'.format(ref_lps.shape[0])) From 3323963e0b4677caa53bfa5f607a8afea8a50ad5 Mon Sep 17 00:00:00 2001 From: Marcus Stoiber Date: Fri, 27 Sep 2019 08:48:50 -0700 Subject: [PATCH 02/14] Fix variant pipeline and remove cruff from per-read db classes. --- megalodon/aggregate.py | 1 - megalodon/mods.py | 151 +++++++++++------------------------------ megalodon/snps.py | 71 +++---------------- 3 files changed, 48 insertions(+), 175 deletions(-) diff --git a/megalodon/aggregate.py b/megalodon/aggregate.py index 21e2c21..f9616b7 100644 --- a/megalodon/aggregate.py +++ b/megalodon/aggregate.py @@ -202,7 +202,6 @@ def _agg_prog_worker( mod_bar = tqdm(desc='Mods', unit=' sites', total=num_mods, position=0, smoothing=0, dynamic_ncols=True) - logger = logging.get_logger() while True: try: snp_prog_q.get(block=False) diff --git a/megalodon/mods.py b/megalodon/mods.py index 5ffa12f..1e275f4 100644 --- a/megalodon/mods.py +++ b/megalodon/mods.py @@ -108,7 +108,7 @@ def __init__(self, fn, read_only=True, db_safety=1, self.load_chrm_read_index() self.load_mod_read_index() if self.pos_idx_in_mem: - self.load_pos_index() + self.load_pos_read_index() else: if db_safety < 2: # set asynchronous mode to off for max speed @@ -163,9 +163,12 @@ def get_chrm_id(self, chrm): def get_chrm(self, chrm_id): try: - chrm = self.cur.execute( - 'SELECT chrm FROM chrm WHERE chrm_id=?', - (chrm_id,)).fetchone()[0] + if self.cm_idx_in_mem: + chrm = self.chrm_read_idx[chrm_id] + else: + chrm = self.cur.execute( + 'SELECT chrm FROM chrm WHERE chrm_id=?', + (chrm_id,)).fetchone()[0] except TypeError: raise mh.MegaError('Reference record (chromosome) not found in ' + 'mods database.') @@ -215,6 +218,18 @@ def get_mod_base_id(self, mod_base, motif, motif_pos, raw_motif): raise mh.MegaError('Modified base not found in mods database.') return mod_id + def get_mod_base_data(self, mod_id): + try: + if self.cm_idx_in_mem: + mod_base_data = self.mod_read_idx[mod_id] + else: + mod_base_data = self.cur.execute( + 'SELECT mod_base, motif, motif_pos, raw_motif FROM mod ' + + 'WHERE mod_id=?', (mod_id,)).fetchone() + except (TypeError, KeyError): + raise mh.MegaError('Modified base not found in mods database.') + return mod_base_data + def get_mod_base_id_or_insert(self, mod_base, motif, motif_pos, raw_motif): try: mod_base_id = self.get_mod_base_id( @@ -236,6 +251,7 @@ def insert_read_scores(self, r_mod_scores, uuid, chrm, strand): read_insert_data = [] for (pos, mod_lps, mod_bases, ref_motif, rel_pos, raw_motif) in r_mod_scores: + # TODO convert to execute many for pos and mod_base inserts pos_id = self.get_pos_id_or_insert(None, strand, pos, chrm_id) for mod_lp, mod_base in zip(mod_lps, mod_bases): mod_base_id = self.get_mod_base_id_or_insert( @@ -251,10 +267,8 @@ def create_chrm_index(self): return def load_chrm_read_index(self): - self.chrm_read_idx = {} self.cur.execute('SELECT chrm_id, chrm FROM chrm') - for chrm_id, chrm in self.cur: - self.chrm_read_idx[chrm_id] = chrm + self.chrm_read_idx = dict(self.cur.fetchall()) return def create_mod_index(self): @@ -263,11 +277,11 @@ def create_mod_index(self): return def load_mod_read_index(self): - self.mod_read_idx = {} self.cur.execute( 'SELECT mod_id, mod_base, motif, motif_pos, raw_motif FROM mod') - for mod_id, mod_base, motif, motif_pos, raw_motif in self.cur: - self.mod_read_idx[mod_id] = (mod_base, motif, motif_pos, raw_motif) + self.mod_read_idx = dict(( + (mod_id, (mod_base, motif, motif_pos, raw_motif)) + for mod_id, mod_base, motif, motif_pos, raw_motif in self.cur)) return def create_pos_index(self): @@ -275,11 +289,11 @@ def create_pos_index(self): '(pos_chrm, strand, pos)') return - def load_pos_index(self): - self.pos_read_idx = {} + def load_pos_read_index(self): self.cur.execute('SELECT pos_id, pos_chrm, strand, pos FROM pos') - for pos_id, chrm_id, strand, pos in self.cur: - self.pos_read_idx[pos_id] = (chrm_id, strand, pos) + self.pos_read_idx = dict(( + (pos_id, (chrm_id, strand, pos)) + for pos_id, chrm_id, strand, pos in self.cur)) return def create_data_covering_index(self): @@ -295,110 +309,28 @@ def close(self): def get_num_uniq_mod_pos(self): return self.cur.execute('SELECT MAX(pos_id) FROM pos').fetchone()[0] - def iter_pos_id(self): - self.cur.execute('SELECT pos_id FROM pos') - for pos_id in self.cur: - yield pos_id[0] - - return - def iter_pos(self): self.cur.execute('SELECT pos_id, pos_chrm, strand, pos FROM pos') for pos in self.cur: yield pos - - return - - def iter_pos_id_ordered(self): - self.cur.execute('SELECT pos_id FROM pos ORDER BY pos_id') - for pos_id in self.cur: - yield pos_id[0] - - return - - def iter_pos_ordered(self): - self.cur.execute('SELECT pos_id, pos_chrm, strand, pos FROM pos ' + - 'ORDER BY pos_id') - for pos in self.cur: - yield pos - return def get_pos_stats(self, pos_data, return_uuids=False): + read_id_conv = self.get_uuid if return_uuids else lambda x: x + # these attributes are specified in self.iter_pos pos_id, chrm_id, strand, pos = pos_data - if self.cm_idx_in_mem: - if return_uuids: - self.cur.execute( - 'SELECT read.uuid, data.score, data.score_mod ' + - 'FROM data ' + - 'INNER JOIN read ON read.read_id = data.score_read ' + - 'WHERE score_pos=?', (pos_id, )) - else: - # simplest query using covering index - self.cur.execute( - 'SELECT score_read, score, score_mod ' + - 'FROM data ' + - 'WHERE score_pos=?', (pos_id, )) - return [ - self.mod_data(read_id, self.chrm_read_idx[chrm_id], strand, - pos, score, *self.mod_read_idx[mod_id]) - for read_id, score, mod_id in self.cur] - - # perform full query from on-disk database - chrm = self.cur.execute( - 'SELECT chrm FROM chrm WHERE chrm_id=?', (chrm_id,)).fetchone()[0] - if return_uuids: - self.cur.execute( - 'SELECT read.uuid, data.score, mod.mod_base, mod.motif, ' + - 'mod.motif_pos, mod.raw_motif ' + - 'FROM data ' + - 'INNER JOIN read ON read.read_id = data.score_read ' + - 'INNER JOIN mod ON mod.mod_id = data.score_mod ' + - 'WHERE score_pos=?', (pos_id, )) - else: - self.cur.execute( - 'SELECT data.score_read, data.score, mod.mod_base, ' + - 'mod.motif, mod.motif_pos, mod.raw_motif ' + - 'FROM data ' + - 'INNER JOIN mod ON mod.mod_id = data.score_mod ' + - 'WHERE score_pos=?', (pos_id, )) - return [self.mod_data(read_id, chrm, strand, pos, score, mod_base, - motif, motif_pos, raw_motif) - for read_id, score, mod_base, motif, motif_pos, raw_motif in - self.cur] - - def get_read_id(self, uuid): - try: - read_id = self.cur.execute( - 'SELECT read_id FROM read WHERE uuid=?', (uuid,)).fetchone()[0] - except TypeError: - raise mh.MegaError('Read ID not found in mods data base.') - return read_id - - def get_read_id_or_insert(self, uuid): - try: - read_id = self.get_read_id(uuid) - except mh.MegaError: - self.cur.execute('INSERT INTO read (uuid) VALUES (?)', (uuid,)) - read_id = self.cur.lastrowid - return read_id - - def create_data_read_index(self): - self.cur.execute('CREATE INDEX data_read_idx ON data(score_read)') - return + self.cur.execute( + 'SELECT score_read, score, score_mod FROM data ' + + 'WHERE score_pos=?', (pos_id, )) + return [ + self.mod_data(read_id_conv(read_id), self.get_chrm(chrm_id), + strand, pos, score, *self.get_mod_base_data(mod_id)) + for read_id, score, mod_id in self.cur] def get_read_stats(self, uuid): - # TODO optimize this as with position query - self.cur.execute( - 'SELECT uuid, chrm.chrm, pos.strand, pos.pos, data.score, ' + - ' mod.mod_base, mod.motif, mod.motif_pos, mod.raw_motif ' + - 'FROM read ' + - 'INNER JOIN data ON data.score_read = read.uuid ' + - 'INNER JOIN pos ON pos.pos_id = data.score_pos ' + - 'INNER JOIN mod ON mod.mod_id = data.score_mod ' + - 'INNER JOIN chrm ON chrm.chrm_id = pos.pos_chrm ' + - 'WHERE uuid=?', (uuid,)) - return [self.mod_data(*pos_data_i) for pos_data_i in self.cur] + # TODO implement this for API + raise NotImplementedError + return ################################ @@ -1090,9 +1022,6 @@ def num_uniq(self): return self.n_uniq_mods def iter_uniq(self): - # fill queue with only pos_ids for faster queue filling - # and let workers extract pos info from an order query - #for q_val in self.mods_db.iter_pos_id_ordered(): # fill queue with full position information to make # workers avoid the ordered pos data extraction for q_val in self.mods_db.iter_pos(): diff --git a/megalodon/snps.py b/megalodon/snps.py index a0bd899..72e9d0d 100755 --- a/megalodon/snps.py +++ b/megalodon/snps.py @@ -89,7 +89,7 @@ class VarsDb(object): # namedtuple for returning var info from a single position var_data = namedtuple('var_data', [ - 'score', 'read_id', 'chrm', 'alt_seq', 'pos', 'ref_seq', 'var_name']) + 'score', 'pos', 'ref_seq', 'var_name', 'read_id', 'chrm', 'alt_seq']) def __init__(self, fn, read_only=True, db_safety=1, loc_index_in_memory=False, chrm_index_in_memory=True, @@ -226,19 +226,6 @@ def get_loc_id_or_insert(self, chrm, test_start, test_end, var_name, self.loc_idx[(chrm_id, test_start, test_end)] = loc_id return loc_id - def get_loc_data(self, loc_id): - try: - if self.loc_idx_in_mem: - loc_data = self.loc_read_idx[chrm_id] - else: - loc_data = self.cur.execute( - 'SELECT pos, ref_seq, var_name FROM loc ' + - 'WHERE loc_id = ?', (loc_id, )).fetchone() - except (TypeError, KeyError): - raise mh.MegaError('Variant location data not found in ' + - 'vars database.') - return loc_data - def get_alt_id(self, alt_seq): try: if self.alt_idx_in_mem: @@ -282,6 +269,7 @@ def insert_read_scores(self, r_var_scores, uuid, chrm, strand): read_insert_data = [] for (pos, alt_lps, snp_ref_seq, snp_alt_seqs, var_name, test_start, test_end) in r_var_scores: + # TODO convert to execute many for loc and alt inserts loc_id = self.get_loc_id_or_insert( None, test_start, test_end, var_name, pos, snp_ref_seq, chrm_id) for alt_lp, alt_seq in zip(alt_lps, snp_alt_seqs): @@ -335,32 +323,9 @@ def close(self): def get_num_uniq_var_loc(self): return self.cur.execute('SELECT MAX(loc_id) FROM loc').fetchone()[0] - def iter_loc_id(self): - self.cur.execute('SELECT loc_id FROM loc') - for loc_id in self.cur: - yield loc_id[0] - - return - def iter_locs(self): self.cur.execute( - 'SELECT loc_id, loc_chrm, test_start, test_end FROM loc') - for loc in self.cur: - yield loc - - return - - def iter_loc_id_ordered(self): - self.cur.execute('SELECT loc_id FROM loc ORDER BY loc_id') - for loc_id in self.cur: - yield loc_id[0] - - return - - def iter_loc_ordered(self): - self.cur.execute( - 'SELECT loc_id, loc_chrm, test_start, test_end FROM loc ' + - 'ORDER BY loc_id') + 'SELECT loc_id, loc_chrm, pos, ref_seq, var_name FROM loc') for loc in self.cur: yield loc @@ -368,37 +333,17 @@ def iter_loc_ordered(self): def get_loc_stats(self, loc_data, return_uuids=False): read_id_conv = self.get_uuid if return_uuids else lambda x: x - loc_id, chrm_id, test_start, test_end = loc_data + # these attributes are specified in self.iter_locs + loc_id, chrm_id, pos, ref_seq, var_name = loc_data self.cur.execute( 'SELECT score, score_read, score_alt, score_loc FROM data ' + 'WHERE score_loc=?', (loc_id, )) return [ - self.var_data(score, read_id_conv(read_id), - self.get_chrm(chrm_id), - self.get_alt_seq(alt_id), - *self.get_loc_data(loc_id)) + self.var_data( + score, pos, ref_seq, var_name, read_id_conv(read_id), + self.get_chrm(chrm_id), self.get_alt_seq(alt_id)) for score, read_id, alt_id, loc_id in self.cur] - def get_read_id(self, uuid): - try: - read_id = self.cur.execute( - 'SELECT read_id FROM read WHERE uuid=?', (uuid,)).fetchone()[0] - except TypeError: - raise mh.MegaError('Read ID not found in vars database.') - return read_id - - def get_read_id_or_insert(self, uuid): - try: - read_id = self.get_read_id(uuid) - except mh.MegaError: - self.cur.execute('INSERT INTO read (uuid) VALUES (?)', (uuid,)) - read_id = self.cur.lastrowid - return read_id - - def create_data_read_index(self): - self.cur.execute('CREATE INDEX data_read_idx ON data(score_read)') - return - def get_read_stats(self, uuid): # TODO implement this for API raise NotImplementedError From bfdade6d05c904c14a02b91ef2964462a65dbe48 Mon Sep 17 00:00:00 2001 From: Marcus Stoiber Date: Fri, 27 Sep 2019 10:44:47 -0700 Subject: [PATCH 03/14] Added flags in aggregation classes to avoid loading in memory indices when not necessary. --- megalodon/aggregate.py | 22 ++++++---------------- megalodon/mods.py | 36 ++++++++++++++++++++++-------------- megalodon/snps.py | 12 ++++++++---- 3 files changed, 36 insertions(+), 34 deletions(-) diff --git a/megalodon/aggregate.py b/megalodon/aggregate.py index f9616b7..2eecb29 100644 --- a/megalodon/aggregate.py +++ b/megalodon/aggregate.py @@ -78,32 +78,20 @@ def _get_snp_stats_queue( def _agg_mods_worker( locs_q, mod_stats_q, mod_prog_q, mods_db_fn, mod_agg_info, valid_read_ids, write_mod_lp): - def get_pos_id(): - return locs_q.get(block=False) - def get_loc_data_from_id(q_pos_id): - # function for profiling purposes - if q_pos_id is None: return None - for pos_id, chrm_id, strand, pos in locs_iter: - if q_pos_id == pos_id: - return pos_id, chrm_id, strand, pos + # functions for profiling purposes def get_loc_data(): return locs_q.get(block=False) def put_mod_site(mod_site): - # function for profiling purposes mod_stats_q.put(mod_site) return def do_sleep(): - # function for profiling purposes sleep(0.0001) return - # needed if only pos id is loaded into queue - #locs_iter = mods.ModsDb(mods_db_fn).iter_pos_ordered() agg_mods = mods.AggMods(mods_db_fn, mod_agg_info, write_mod_lp) while True: try: - #loc_data = get_loc_data_from_id(get_pos_id()) loc_data = get_loc_data() except queue.Empty: do_sleep() @@ -236,7 +224,7 @@ def _agg_prog_worker( return def _fill_locs_queue(locs_q, db_fn, agg_class, num_ps, limit=None): - agg_db = agg_class(db_fn) + agg_db = agg_class(db_fn, load_in_mem_indices=False) for i, loc in enumerate(agg_db.iter_uniq()): locs_q.put(loc) if limit is not None and i >= limit: break @@ -266,7 +254,8 @@ def aggregate_stats( 0, 0, queue.Queue(), queue.Queue()) if mh.SNP_NAME in outputs: snps_db_fn = mh.get_megalodon_fn(out_dir, mh.PR_SNP_NAME) - num_snps = snps.AggSnps(snps_db_fn).num_uniq() + num_snps = snps.AggSnps( + snps_db_fn, load_in_mem_indices=False).num_uniq() logger.info('Spawning variant aggregation processes.') # create process to collect snp stats from workers snp_stats_q, snp_stats_p, main_snp_stats_conn = mh.create_getter_q( @@ -292,7 +281,8 @@ def aggregate_stats( if mh.MOD_NAME in outputs: mods_db_fn = mh.get_megalodon_fn(out_dir, mh.PR_MOD_NAME) - num_mods = mods.AggMods(mods_db_fn).num_uniq() + num_mods = mods.AggMods( + mods_db_fn, load_in_mem_indices=False).num_uniq() logger.info('Spawning modified base aggregation processes.') # create process to collect mods stats from workers mod_stats_q, mod_stats_p, main_mod_stats_conn = mh.create_getter_q( diff --git a/megalodon/mods.py b/megalodon/mods.py index 1e275f4..181d36e 100644 --- a/megalodon/mods.py +++ b/megalodon/mods.py @@ -79,7 +79,8 @@ class ModsDb(object): 'motif_pos', 'raw_motif']) def __init__(self, fn, read_only=True, db_safety=1, - pos_index_in_memory=False, mod_chrm_index_in_memory=True): + pos_index_in_memory=False, chrm_index_in_memory=True, + mod_index_in_memory=True): """ Interface to database containing modified base statistics. Default settings are for optimal read_only performance. @@ -87,7 +88,8 @@ def __init__(self, fn, read_only=True, db_safety=1, self.fn = mh.resolve_path(fn) self.read_only = read_only self.pos_idx_in_mem = pos_index_in_memory - self.cm_idx_in_mem = mod_chrm_index_in_memory + self.chrm_idx_in_mem = chrm_index_in_memory + self.mod_idx_in_mem = mod_index_in_memory if read_only: if not os.path.exists(fn): @@ -104,8 +106,9 @@ def __init__(self, fn, read_only=True, db_safety=1, if read_only: # use memory mapped file access self.db.execute('PRAGMA mmap_size = {}'.format(mh.MEMORY_MAP_LIMIT)) - if self.cm_idx_in_mem: + if self.chrm_idx_in_mem: self.load_chrm_read_index() + if self.mod_idx_in_mem: self.load_mod_read_index() if self.pos_idx_in_mem: self.load_pos_read_index() @@ -133,24 +136,26 @@ def __init__(self, fn, read_only=True, db_safety=1, self.pos_idx = {} else: self.create_pos_index() - if self.cm_idx_in_mem: + if self.chrm_idx_in_mem: self.chrm_idx = {} + else: + self.create_chrm_index() + if self.mod_idx_in_mem: self.mod_idx = {} else: self.create_mod_index() - self.create_chrm_index() return def insert_chrm(self, chrm): self.cur.execute('INSERT INTO chrm (chrm) VALUES (?)', (chrm,)) - if self.cm_idx_in_mem: + if self.chrm_idx_in_mem: self.chrm_idx[chrm] = self.cur.lastrowid return self.cur.lastrowid def get_chrm_id(self, chrm): try: - if self.cm_idx_in_mem: + if self.chrm_idx_in_mem: chrm_id = self.chrm_idx[chrm] else: chrm_id = self.cur.execute( @@ -163,7 +168,7 @@ def get_chrm_id(self, chrm): def get_chrm(self, chrm_id): try: - if self.cm_idx_in_mem: + if self.chrm_idx_in_mem: chrm = self.chrm_read_idx[chrm_id] else: chrm = self.cur.execute( @@ -207,7 +212,7 @@ def get_pos_id_or_insert(self, chrm, strand, pos, chrm_id=None): def get_mod_base_id(self, mod_base, motif, motif_pos, raw_motif): try: - if self.cm_idx_in_mem: + if self.mod_idx_in_mem: mod_id = self.mod_idx[(mod_base, motif, motif_pos, raw_motif)] else: mod_id = self.cur.execute( @@ -220,7 +225,7 @@ def get_mod_base_id(self, mod_base, motif, motif_pos, raw_motif): def get_mod_base_data(self, mod_id): try: - if self.cm_idx_in_mem: + if self.mod_idx_in_mem: mod_base_data = self.mod_read_idx[mod_id] else: mod_base_data = self.cur.execute( @@ -239,7 +244,7 @@ def get_mod_base_id_or_insert(self, mod_base, motif, motif_pos, raw_motif): 'INSERT INTO mod (mod_base, motif, motif_pos, raw_motif) ' + 'VALUES (?,?,?,?)', (mod_base, motif, motif_pos, raw_motif)) mod_base_id = self.cur.lastrowid - if self.cm_idx_in_mem: + if self.mod_idx_in_mem: self.mod_idx[(mod_base, motif, motif_pos, raw_motif)] = mod_base_id return mod_base_id @@ -1005,10 +1010,13 @@ class AggMods(mh.AbstractAggregationClass): CpG sites). """ def __init__(self, mods_db_fn, agg_info=DEFAULT_AGG_INFO, - write_mod_lp=False, pos_index_in_memory=False): + write_mod_lp=False, load_in_mem_indices=True): # open as read only database (default) - self.mods_db = ModsDb( - mods_db_fn, pos_index_in_memory=pos_index_in_memory) + if load_in_mem_indices: + self.mods_db = ModsDb(mods_db_fn) + else: + self.mods_db = ModsDb(mods_db_fn, chrm_index_in_memory=False, + mod_index_in_memory=False) self.n_uniq_mods = None assert agg_info.method in AGG_METHOD_NAMES self.agg_method = agg_info.method diff --git a/megalodon/snps.py b/megalodon/snps.py index 72e9d0d..33d4ddd 100755 --- a/megalodon/snps.py +++ b/megalodon/snps.py @@ -1599,11 +1599,15 @@ class AggSnps(mh.AbstractAggregationClass): """ Class to assist in database queries for per-site aggregation of SNP calls over reads. """ - def __init__(self, vars_db_fn, write_vcf_log_probs=False, - loc_index_in_memory=False): + def __init__( + self, vars_db_fn, write_vcf_log_probs=False, + load_in_mem_indices=True): # open as read only database - self.snps_db = VarsDb( - vars_db_fn, loc_index_in_memory=loc_index_in_memory) + if load_in_mem_indices: + self.snps_db = VarsDb(vars_db_fn) + else: + self.snps_db = VarsDb(vars_db_fn, chrm_index_in_memory=False, + alt_index_in_memory=False) self.n_uniq_snps = None self.write_vcf_log_probs = write_vcf_log_probs return From 80ae0db1f71e7effa7a1a601b4c17dac0a149a00 Mon Sep 17 00:00:00 2001 From: Marcus Stoiber Date: Fri, 27 Sep 2019 11:10:22 -0700 Subject: [PATCH 04/14] Added command line argument to toggle mod position and variant locations tables stored in memory to allow potentially more environment and test set use cases. --- docs/advanced_arguments.rst | 6 ++++++ megalodon/megalodon.py | 25 +++++++++++++++++++++---- megalodon/mods.py | 9 +++++---- megalodon/snps.py | 12 +++++++----- 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/docs/advanced_arguments.rst b/docs/advanced_arguments.rst index 2747e91..e87c835 100644 --- a/docs/advanced_arguments.rst +++ b/docs/advanced_arguments.rst @@ -45,6 +45,9 @@ SNP Arguments - ``--variant-context-bases`` - Context bases for single base SNP and indel calling. Default: [10, 30] +- ``--variant-locations-on-disk`` + + - Force sequence variant locations to be stored only within on disk database table. This option will reduce the RAM memory requirement, but may drastically slow processing. Default: Store locations in memory and on disk. - ``--write-vcf-log-probs`` - Write per-read alt log probabilities out in non-standard VCF field. @@ -99,6 +102,9 @@ Modified Base Arguments - A genotype field ``VALID_DP`` indicates the number of reads included in the proportion modified calculation. - Modified base proportion estimates are stored in genotype fields specified by the single letter modified base encodings (definied in the model file). +- ``--mod-positions-on-disk`` + + - Force modified base positions to be stored only within on disk database table. This option will reduce the RAM memory requirement, but may drastically slow processing. Default: Store positions in memory and on disk. - ``--write-mod-log-probs`` - Write per-read modified base log probabilities out in non-standard VCF field. diff --git a/megalodon/megalodon.py b/megalodon/megalodon.py index d001a39..3b1322a 100644 --- a/megalodon/megalodon.py +++ b/megalodon/megalodon.py @@ -480,7 +480,8 @@ def process_all_reads( snps._get_snps_queue, ( mh.get_megalodon_fn(out_dir, mh.PR_SNP_NAME), snps_txt_fn, db_safety, pr_refs_fn, pr_ref_filts, - whatshap_map_fn, aligner.ref_names_and_lens, aligner.ref_fn)) + whatshap_map_fn, aligner.ref_names_and_lens, aligner.ref_fn, + snps_data.loc_index_in_memory)) if mh.PR_MOD_NAME in outputs: pr_refs_fn = mh.get_megalodon_fn(out_dir, mh.PR_REF_NAME) if ( mh.PR_REF_NAME in outputs and mods_info.do_pr_ref_mods) else None @@ -490,7 +491,7 @@ def process_all_reads( mods._get_mods_queue, ( mh.get_megalodon_fn(out_dir, mh.PR_MOD_NAME), db_safety, aligner.ref_names_and_lens, mods_txt_fn, - pr_refs_fn, pr_ref_filts)) + pr_refs_fn, pr_ref_filts, mods_info.pos_index_in_memory)) proc_reads_ps, map_conns = [], [] for device in model_info.process_devices: @@ -617,7 +618,8 @@ def snps_validation(args, is_cat_mod, output_size, aligner): args.variant_context_bases, snp_calib_fn, snps.HAPLIOD_MODE if args.haploid else snps.DIPLOID_MODE, args.refs_include_snps, aligner, edge_buffer=args.edge_buffer, - context_min_alt_prob=args.context_min_alt_prob) + context_min_alt_prob=args.context_min_alt_prob, + loc_index_in_memory=not args.variant_locations_on_disk) except mh.MegaError as e: logger.error(str(e)) sys.exit(1) @@ -663,7 +665,8 @@ def mods_validation(args, model_info): model_info, args.mod_motif, args.mod_all_paths, args.write_mods_text, args.mod_context_bases, mh.BC_MODS_NAME in args.outputs, args.refs_include_mods, mod_calib_fn, - args.mod_output_formats, args.edge_buffer) + args.mod_output_formats, args.edge_buffer, + not args.mod_positions_on_disk) return args, mods_info def parse_pr_ref_output(args): @@ -832,6 +835,13 @@ def hidden_help(help_msg): 'SNP scores. As created by ' + 'megalodon/scripts/calibrate_snp_llr_scores.py. ' + 'Default: Load default calibration file.')) + snp_grp.add_argument( + '--variant-locations-on-disk', action='store_true', + help=hidden_help('Force sequence variant locations to be stored ' + + 'only within on disk database table. This option ' + + 'will reduce the RAM memory requirement, but may ' + + 'drastically slow processing. Default: Store ' + + 'locations in memory and on disk.')) snp_grp.add_argument( '--variant-context-bases', type=int, nargs=2, default=[mh.DEFAULT_SNV_CONTEXT, mh.DEFAULT_INDEL_CONTEXT], @@ -886,6 +896,13 @@ def hidden_help(help_msg): choices=tuple(mh.MOD_OUTPUT_FMTS.keys()), help=hidden_help('Modified base aggregated output format(s). ' + 'Default: %(default)s')) + mod_grp.add_argument( + '--mod-positions-on-disk', action='store_true', + help=hidden_help('Force modified base positions to be stored ' + + 'only within on disk database table. This option ' + + 'will reduce the RAM memory requirement, but may ' + + 'drastically slow processing. Default: Store ' + + 'positions in memory and on disk.')) mod_grp.add_argument( '--write-mod-log-probs', action='store_true', help=hidden_help('Write per-read modified base log probabilities ' + diff --git a/megalodon/mods.py b/megalodon/mods.py index 181d36e..02ef880 100644 --- a/megalodon/mods.py +++ b/megalodon/mods.py @@ -472,7 +472,7 @@ def annotate_mods(r_start, ref_seq, r_mod_scores, strand): def _get_mods_queue( mods_q, mods_conn, mods_db_fn, db_safety, ref_names_and_lens, - mods_txt_fn, pr_refs_fn, pr_ref_filts): + mods_txt_fn, pr_refs_fn, pr_ref_filts, pos_index_in_memory): def store_mod_call( r_mod_scores, read_id, chrm, strand, r_start, ref_seq, read_len, q_st, q_en, cigar, been_warned): @@ -517,7 +517,7 @@ def store_mod_call( been_warned = False mods_db = ModsDb(mods_db_fn, db_safety=db_safety, read_only=False, - pos_index_in_memory=True) + pos_index_in_memory=pos_index_in_memory) for ref_name in ref_names_and_lens[0]: mods_db.insert_chrm(ref_name) mods_db.create_chrm_index() @@ -568,7 +568,7 @@ def store_mod_call( if mods_txt_fp is not None: mods_txt_fp.close() if pr_refs_fn is not None: pr_refs_fp.close() mods_db.create_mod_index() - if not mods_db.pos_idx_in_mem: + if mods_db.pos_idx_in_mem: mods_db.create_pos_index() mods_db.create_data_covering_index() mods_db.close() @@ -648,7 +648,7 @@ def __init__( write_mods_txt=None, mod_context_bases=None, do_output_mods=False, do_pr_ref_mods=False, mods_calib_fn=None, mod_output_fmts=[mh.MOD_BEDMETHYL_NAME], - edge_buffer=mh.DEFAULT_EDGE_BUFFER): + edge_buffer=mh.DEFAULT_EDGE_BUFFER, pos_index_in_memory=True): logger = logging.get_logger() # this is pretty hacky, but these attributes are stored here as # they are generally needed alongside other alphabet info @@ -663,6 +663,7 @@ def __init__( self.calib_table = calibration.ModCalibrator(mods_calib_fn) self.mod_output_fmts = mod_output_fmts self.edge_buffer = edge_buffer + self.pos_index_in_memory = pos_index_in_memory self.alphabet = model_info.can_alphabet self.ncan_base = len(self.alphabet) diff --git a/megalodon/snps.py b/megalodon/snps.py index 33d4ddd..2a9bfd4 100755 --- a/megalodon/snps.py +++ b/megalodon/snps.py @@ -710,7 +710,8 @@ def annotate_snps(r_start, ref_seq, r_snp_calls, strand): def _get_snps_queue( snps_q, snps_conn, vars_db_fn, snps_txt_fn, db_safety, pr_refs_fn, - pr_ref_filts, whatshap_map_fn, ref_names_and_lens, ref_fn): + pr_ref_filts, whatshap_map_fn, ref_names_and_lens, ref_fn, + loc_index_in_memory): def write_whatshap_alignment( read_id, snp_seq, snp_quals, chrm, strand, r_st, snp_cigar): a = pysam.AlignedSegment() @@ -768,9 +769,8 @@ def get_snp_call( logger = logging.get_logger('vars_getter') - # TODO convert loc index to command line option snps_db = VarsDb(vars_db_fn, db_safety=db_safety, read_only=False, - loc_index_in_memory=True) + loc_index_in_memory=loc_index_in_memory) for ref_name in ref_names_and_lens[0]: snps_db.insert_chrm(ref_name) snps_db.create_chrm_index() @@ -840,7 +840,7 @@ def get_snp_call( if pr_refs_fn is not None: pr_refs_fp.close() if whatshap_map_fn is not None: whatshap_map_fp.close() snps_db.create_alt_index() - if not snps_db.loc_idx_in_mem: + if snps_db.loc_idx_in_mem: snps_db.create_loc_index() snps_db.create_data_covering_index() snps_db.close() @@ -879,7 +879,8 @@ def __init__( call_mode=DIPLOID_MODE, do_pr_ref_snps=False, aligner=None, keep_snp_fp_open=False, do_validate_reference=True, edge_buffer=mh.DEFAULT_EDGE_BUFFER, - context_min_alt_prob=mh.DEFAULT_CONTEXT_MIN_ALT_PROB): + context_min_alt_prob=mh.DEFAULT_CONTEXT_MIN_ALT_PROB, + loc_index_in_memory=True): logger = logging.get_logger('vars') self.max_indel_size = max_indel_size self.all_paths = all_paths @@ -895,6 +896,7 @@ def __init__( self.do_pr_ref_snps = do_pr_ref_snps self.edge_buffer = edge_buffer self.context_min_alt_prob = context_min_alt_prob + self.loc_index_in_memory = loc_index_in_memory self.variant_fn = variant_fn self.variants_idx = None if self.variant_fn is None: From b494ae950ce558e2e5f0d678d7955a86b0f98107 Mon Sep 17 00:00:00 2001 From: Marcus Stoiber Date: Mon, 30 Sep 2019 06:29:26 -0700 Subject: [PATCH 05/14] Converted alt sequence and location DB inserts to executemany calls on full read data. --- megalodon/snps.py | 172 ++++++++++++++++++++++++++++------------------ 1 file changed, 107 insertions(+), 65 deletions(-) diff --git a/megalodon/snps.py b/megalodon/snps.py index 2a9bfd4..3b4e53b 100755 --- a/megalodon/snps.py +++ b/megalodon/snps.py @@ -192,62 +192,98 @@ def get_chrm(self, chrm_id): 'vars database.') return chrm - def get_loc_id(self, chrm, test_start, test_end, chrm_id=None): - if chrm_id is None: - chrm_id = self.get_chrm_id(chrm) - - try: - if self.loc_idx_in_mem: - loc_id = self.loc_idx[(chrm_id, test_start, test_end)] - else: - loc_id = self.cur.execute( - 'SELECT loc_id FROM loc WHERE loc_chrm=? AND ' + - 'test_start=? AND test_end=?', ( - chrm_id, test_start, test_end)).fetchone()[0] - except (TypeError, KeyError): - raise mh.MegaError( - 'Reference position not found in database.') - - return loc_id - - def get_loc_id_or_insert(self, chrm, test_start, test_end, var_name, - pos, ref_seq, chrm_id=None): - if chrm_id is None: - chrm_id = self.get_chrm_id(chrm) - try: - loc_id = self.get_loc_id(None, test_start, test_end, chrm_id) - except mh.MegaError: - self.cur.execute( - 'INSERT INTO loc (loc_chrm, test_start, test_end, var_name, ' + - 'pos, ref_seq) VALUES (?,?,?,?,?,?)', - (chrm_id, test_start, test_end,var_name, pos, ref_seq)) - loc_id = self.cur.lastrowid - if self.loc_idx_in_mem: - self.loc_idx[(chrm_id, test_start, test_end)] = loc_id - return loc_id - - def get_alt_id(self, alt_seq): - try: - if self.alt_idx_in_mem: - alt_id = self.alt_idx[alt_seq] - else: - alt_id = self.cur.execute( - 'SELECT alt_id FROM var WHERE alt_seq=?', - (var_name, )).fetchone()[0] - except (TypeError, KeyError): - raise mh.MegaError('Variant not found in var table.') - return alt_id + def get_loc_ids_or_insert_locations(self, r_var_scores, chrm_id): + """ Extract all location IDs and add those locations not currently + found in the database + """ + r_locs = dict(( + ((chrm_id, test_start, test_end), (pos, ref_seq, var_name)) + for pos, _, ref_seq, _, var_name, + test_start, test_end in r_var_scores)) + if self.loc_idx_in_mem: + locs_to_add = list(set(r_locs).difference(self.loc_idx)) + else: + test_starts, test_ends = map(set, list(zip(*r_uniq_locs))[1:]) + loc_ids = dict(( + ((chrm_id, test_start, test_end), loc_id) + for chrm_id, test_start, test_end, loc_id in self.cur.execute( + ('SELECT loc_chrm, test_start, test_end, loc_id ' + + 'FROM loc WHERE loc_chrm=? AND ' + + 'test_start in ({0}) AND test_end in ({1})').format( + ','.join(['?',] * len(test_starts)), + ','.join(['?',] * len(test_ends))), + (chrm_id, *test_starts, *test_ends)).fetchall())) + locs_to_add = list(set(r_locs).difference(loc_ids)) + + if len(locs_to_add) > 0: + next_loc_id = self.get_num_uniq_var_loc() + 1 + self.cur.executemany( + 'INSERT INTO loc (loc_chrm, test_start, test_end, ' + + 'pos, ref_seq, var_name) VALUES (?,?,?,?,?,?)', + ((*loc_key, *r_locs[loc_key]) for loc_key in locs_to_add)) + + if self.loc_idx_in_mem: + if len(locs_to_add) > 0: + self.loc_idx.update(zip( + locs_to_add, + range(next_loc_id, next_loc_id + len(locs_to_add)))) + r_loc_ids = [ + self.loc_idx[(chrm_id, test_start, test_end)] + for _, _, _, _, _, test_start, test_end in r_var_scores] + else: + if len(locs_to_add) > 0: + loc_ids.update(zip( + locs_to_add, + range(next_loc_id, next_loc_id + len(locs_to_add)))) + r_loc_ids = [ + loc_ids[(chrm_id, test_start, test_end)] + for _, _, _, _, _, test_start, test_end in r_var_scores] + + return r_loc_ids + + def get_alt_ids_or_insert_alt_seqs(self, r_var_scores): + r_seqs_and_lps = [ + tuple(zip(alt_seqs, alt_lps)) + for _, alt_lps, _, alt_seqs, _, _, _ in r_var_scores] + r_uniq_seqs = set((seq_lp_i[0] for loc_seqs_lps in r_seqs_and_lps + for seq_lp_i in loc_seqs_lps)) + if self.alt_idx_in_mem: + alts_to_add = list(r_uniq_seqs.difference(self.alt_idx)) + else: + alt_ids = dict(( + (alt_seq, alt_id) + for alt_seq, alt_id in self.cur.execute( + ('SELECT alt_seq, alt_id ' + + 'FROM alt WHERE alt_seq in ({})').format( + ','.join(['?',] * len(r_uniq_seqs))), + r_uniq_seqs).fetchall())) + alts_to_add = list(r_uniq_seqs.difference(alt_ids)) + + if len(alts_to_add) > 0: + next_alt_id = self.get_num_uniq_alt_seqs() + 1 + self.cur.executemany( + 'INSERT INTO alt (alt_seq) VALUES (?)', alts_to_add) + + if self.alt_idx_in_mem: + if len(alts_to_add) > 0: + self.alt_idx.update(zip( + alts_to_add, + range(next_alt_id, next_alt_id + len(alts_to_add)))) + r_alt_ids = [ + tuple((self.alt_idx[alt_seq], alt_lp) + for alt_seq, alt_lp in loc_seqs_lps) + for loc_seqs_lps in r_seqs_and_lps] + else: + if len(alts_to_add) > 0: + alt_ids.update(zip( + alts_to_add, + range(next_alt_id, next_alt_id + len(alts_to_add)))) + r_alt_ids = [ + tuple((alt_ids[alt_seq], alt_lp) + for alt_seq, alt_lp in loc_seqs_lps) + for loc_seqs_lps in r_seqs_and_lps] - def get_alt_id_or_insert(self, alt_seq): - try: - alt_id = self.get_alt_id(alt_seq) - except mh.MegaError: - self.cur.execute( - 'INSERT INTO alt (alt_seq) VALUES (?)', (alt_seq, )) - alt_id = self.cur.lastrowid - if self.alt_idx_in_mem: - self.alt_idx[alt_seq] = alt_id - return alt_id + return r_alt_ids def get_alt_seq(self, alt_id): try: @@ -266,15 +302,12 @@ def insert_read_scores(self, r_var_scores, uuid, chrm, strand): (uuid, strand)) read_id = self.cur.lastrowid chrm_id = self.get_chrm_id(chrm) - read_insert_data = [] - for (pos, alt_lps, snp_ref_seq, snp_alt_seqs, var_name, - test_start, test_end) in r_var_scores: - # TODO convert to execute many for loc and alt inserts - loc_id = self.get_loc_id_or_insert( - None, test_start, test_end, var_name, pos, snp_ref_seq, chrm_id) - for alt_lp, alt_seq in zip(alt_lps, snp_alt_seqs): - alt_id = self.get_alt_id_or_insert(alt_seq) - read_insert_data.append((alt_lp, loc_id, alt_id, read_id)) + loc_ids = self.get_loc_ids_or_insert_locations(r_var_scores, chrm_id) + alt_ids = self.get_alt_ids_or_insert_alt_seqs(r_var_scores) + + read_insert_data = ((alt_lp, loc_id, alt_id, read_id) + for loc_id, loc_alts in zip(loc_ids, alt_ids) + for alt_id, alt_lp in loc_alts) self.cur.executemany( 'INSERT INTO data VALUES (?,?,?,?)', read_insert_data) @@ -321,7 +354,16 @@ def close(self): return def get_num_uniq_var_loc(self): - return self.cur.execute('SELECT MAX(loc_id) FROM loc').fetchone()[0] + num_locs = self.cur.execute('SELECT MAX(loc_id) FROM loc').fetchone()[0] + if num_locs is None: + num_locs = 0 + return num_locs + + def get_num_uniq_alt_seqs(self): + num_alts = self.cur.execute('SELECT MAX(alt_id) FROM alt').fetchone()[0] + if num_alts is None: + num_alts = 0 + return num_alts def iter_locs(self): self.cur.execute( From 8ea4847502d8021d83ccae7550eb25d6f6ac2605 Mon Sep 17 00:00:00 2001 From: Marcus Stoiber Date: Mon, 30 Sep 2019 13:35:04 -0700 Subject: [PATCH 06/14] Convert modified base position and mod type to batch inserts along with DB class cleanup and testing of several run modes. --- megalodon/aggregate.py | 12 +- megalodon/mods.py | 282 +++++++++++++++++++++++------------------ megalodon/snps.py | 190 +++++++++++++-------------- 3 files changed, 259 insertions(+), 225 deletions(-) diff --git a/megalodon/aggregate.py b/megalodon/aggregate.py index 2eecb29..d6d2010 100644 --- a/megalodon/aggregate.py +++ b/megalodon/aggregate.py @@ -76,11 +76,11 @@ def _get_snp_stats_queue( return def _agg_mods_worker( - locs_q, mod_stats_q, mod_prog_q, mods_db_fn, mod_agg_info, + pos_q, mod_stats_q, mod_prog_q, mods_db_fn, mod_agg_info, valid_read_ids, write_mod_lp): # functions for profiling purposes - def get_loc_data(): - return locs_q.get(block=False) + def get_pos_data(): + return pos_q.get(block=False) def put_mod_site(mod_site): mod_stats_q.put(mod_site) return @@ -92,16 +92,16 @@ def do_sleep(): while True: try: - loc_data = get_loc_data() + pos_data = get_pos_data() except queue.Empty: do_sleep() continue - if loc_data is None: + if pos_data is None: break try: mod_site = agg_mods.compute_mod_stats( - loc_data, valid_read_ids=valid_read_ids) + pos_data, valid_read_ids=valid_read_ids) put_mod_site(mod_site) except mh.MegaError: # no valid reads cover location diff --git a/megalodon/mods.py b/megalodon/mods.py index 02ef880..7eeb32f 100644 --- a/megalodon/mods.py +++ b/megalodon/mods.py @@ -147,126 +147,100 @@ def __init__(self, fn, read_only=True, db_safety=1, return - def insert_chrm(self, chrm): - self.cur.execute('INSERT INTO chrm (chrm) VALUES (?)', (chrm,)) + # insert data functions + def insert_chrms(self, chrms): + next_chrm_id = self.get_num_uniq_chrms() + 1 + self.cur.executemany('INSERT INTO chrm (chrm) VALUES (?)', + [(chrm,) for chrm in chrms]) if self.chrm_idx_in_mem: - self.chrm_idx[chrm] = self.cur.lastrowid - return self.cur.lastrowid - - def get_chrm_id(self, chrm): - try: - if self.chrm_idx_in_mem: - chrm_id = self.chrm_idx[chrm] - else: - chrm_id = self.cur.execute( - 'SELECT chrm_id FROM chrm WHERE chrm=?', - (chrm,)).fetchone()[0] - except (TypeError, KeyError): - raise mh.MegaError('Reference record (chromosome) not found in ' + - 'database.') - return chrm_id - - def get_chrm(self, chrm_id): - try: - if self.chrm_idx_in_mem: - chrm = self.chrm_read_idx[chrm_id] - else: - chrm = self.cur.execute( - 'SELECT chrm FROM chrm WHERE chrm_id=?', - (chrm_id,)).fetchone()[0] - except TypeError: - raise mh.MegaError('Reference record (chromosome) not found in ' + - 'mods database.') - return chrm - - def get_pos_id(self, chrm, strand, pos, chrm_id=None): - if chrm_id is None: - chrm_id = self.get_chrm_id(chrm) - - try: - if self.pos_idx_in_mem: - pos_id = self.pos_idx[(chrm_id, strand, pos)] - else: - pos_id = self.cur.execute( - 'SELECT pos_id FROM pos WHERE pos_chrm=? AND strand=? ' + - 'AND pos=?', (chrm_id, strand, pos)).fetchone()[0] - except (TypeError, KeyError): - raise mh.MegaError( - 'Reference position not found in database.') - - return pos_id + self.chrm_idx.update(zip( + chrms, range(next_chrm_id, next_chrm_id + len(chrms)))) + return - def get_pos_id_or_insert(self, chrm, strand, pos, chrm_id=None): - if chrm_id is None: - chrm_id = self.get_chrm_id(chrm) - try: - pos_id = self.get_pos_id(chrm, strand, pos, chrm_id) - except mh.MegaError: - self.cur.execute( + def get_pos_ids_or_insert(self, r_mod_scores, chrm_id, strand): + r_pos = list(zip(*r_mod_scores))[0] + r_uniq_pos = set(((chrm_id, strand, pos) for pos in r_pos)) + if self.pos_idx_in_mem: + pos_to_add = list(r_uniq_pos.difference(self.pos_idx)) + else: + pos_ids = dict( + ((chrm_id, strand, pos_and_id[0]), pos_and_id[1]) + for pos_key in r_uniq_pos + for pos_and_id in self.cur.execute( + 'SELECT pos, pos_id FROM pos ' + + 'WHERE pos_chrm=? AND strand=? AND pos=?', + pos_key).fetchall()) + pos_to_add = list(r_uniq_pos.difference(pos_ids)) + + if len(pos_to_add) > 0: + next_pos_id = self.get_num_uniq_mod_pos() + 1 + self.cur.executemany( 'INSERT INTO pos (pos_chrm, strand, pos) VALUES (?,?,?)', - (chrm_id, strand, pos)) - pos_id = self.cur.lastrowid - if self.pos_idx_in_mem: - self.pos_idx[(chrm_id, strand, pos)] = pos_id - return pos_id + pos_to_add) + + pos_idx = self.pos_idx if self.pos_idx_in_mem else pos_ids + if len(pos_to_add) > 0: + pos_idx.update(zip( + pos_to_add, + range(next_pos_id, next_pos_id + len(pos_to_add)))) + r_pos_ids = [pos_idx[(chrm_id, strand, pos)] for pos in r_pos] + + return r_pos_ids + + def get_mod_base_ids_or_insert(self, r_mod_scores): + r_mod_bases = [ + [((mod_base, motif, motif_pos, raw_motif), mod_lp) + for mod_lp, mod_base in zip(mod_lps, mod_bases)] + for _, mod_lps, mod_bases, motif, motif_pos, raw_motif in + r_mod_scores] + r_uniq_mod_bases = set(( + mod_key for pos_mods in r_mod_bases for mod_key, _ in pos_mods)) + if self.mod_idx_in_mem: + mod_bases_to_add = list(r_uniq_mod_bases.difference(self.mod_idx)) + else: + mod_base_ids = dict( + (mod_data_w_id[:-1], mod_data_w_id[-1]) + for mod_data in r_uniq_mod_bases for mod_data_w_id in + self.cur.execute( + 'SELECT mod_base, motif, motif_pos, raw_motif, ' + + 'mod_id FROM mod WHERE mod_base=? AND motif=? AND ' + + 'motif_pos=? AND raw_motif=?', mod_data).fetchall()) + mod_bases_to_add = list(r_uniq_mod_bases.difference(mod_base_ids)) + + if len(mod_bases_to_add) > 0: + next_mod_base_id = self.get_num_uniq_mod_bases() + 1 + self.cur.executemany( + 'INSERT INTO mod (mod_base, motif, motif_pos, raw_motif) ' + + 'VALUES (?,?,?,?)', mod_bases_to_add) - def get_mod_base_id(self, mod_base, motif, motif_pos, raw_motif): - try: - if self.mod_idx_in_mem: - mod_id = self.mod_idx[(mod_base, motif, motif_pos, raw_motif)] - else: - mod_id = self.cur.execute( - 'SELECT mod_id FROM mod WHERE mod_base=? AND motif=? AND ' + - 'motif_pos=? AND raw_motif=?', - (mod_base, motif, motif_pos, raw_motif)).fetchone()[0] - except (TypeError, KeyError): - raise mh.MegaError('Modified base not found in mods database.') - return mod_id + mod_idx = self.mod_idx if self.mod_idx_in_mem else mod_base_ids + if len(mod_bases_to_add) > 0: + mod_idx.update(zip( + mod_bases_to_add, + range(next_mod_base_id, + next_mod_base_id + len(mod_bases_to_add)))) + r_mod_base_ids = [ + [(mod_idx[mod_key], mod_lp) for mod_key, mod_lp in pos_mods] + for pos_mods in r_mod_bases] - def get_mod_base_data(self, mod_id): - try: - if self.mod_idx_in_mem: - mod_base_data = self.mod_read_idx[mod_id] - else: - mod_base_data = self.cur.execute( - 'SELECT mod_base, motif, motif_pos, raw_motif FROM mod ' + - 'WHERE mod_id=?', (mod_id,)).fetchone() - except (TypeError, KeyError): - raise mh.MegaError('Modified base not found in mods database.') - return mod_base_data - - def get_mod_base_id_or_insert(self, mod_base, motif, motif_pos, raw_motif): - try: - mod_base_id = self.get_mod_base_id( - mod_base, motif, motif_pos, raw_motif) - except mh.MegaError: - self.cur.execute( - 'INSERT INTO mod (mod_base, motif, motif_pos, raw_motif) ' + - 'VALUES (?,?,?,?)', (mod_base, motif, motif_pos, raw_motif)) - mod_base_id = self.cur.lastrowid - if self.mod_idx_in_mem: - self.mod_idx[(mod_base, motif, motif_pos, - raw_motif)] = mod_base_id - return mod_base_id + return r_mod_base_ids def insert_read_scores(self, r_mod_scores, uuid, chrm, strand): self.cur.execute('INSERT INTO read (uuid) VALUES (?)', (uuid,)) read_id = self.cur.lastrowid chrm_id = self.get_chrm_id(chrm) - read_insert_data = [] - for (pos, mod_lps, mod_bases, ref_motif, rel_pos, - raw_motif) in r_mod_scores: - # TODO convert to execute many for pos and mod_base inserts - pos_id = self.get_pos_id_or_insert(None, strand, pos, chrm_id) - for mod_lp, mod_base in zip(mod_lps, mod_bases): - mod_base_id = self.get_mod_base_id_or_insert( - mod_base, ref_motif, rel_pos, raw_motif) - read_insert_data.append((mod_lp, pos_id, mod_base_id, read_id)) + pos_ids = self.get_pos_ids_or_insert(r_mod_scores, chrm_id, strand) + mod_base_ids = self.get_mod_base_ids_or_insert(r_mod_scores) + + read_insert_data = [(mod_lp, pos_id, mod_base_id, read_id) + for pos_id, pos_mods in zip(pos_ids, mod_base_ids) + for mod_base_id, mod_lp in pos_mods] self.cur.executemany( 'INSERT INTO data VALUES (?,?,?,?)', read_insert_data) return + # create and load index functions def create_chrm_index(self): self.cur.execute('CREATE UNIQUE INDEX chrm_idx ON chrm(chrm)') return @@ -306,13 +280,73 @@ def create_data_covering_index(self): 'score_pos, score_mod, score_read, score)') return - def close(self): - self.db.commit() - self.db.close() - return + # reader functions + def get_chrm_id(self, chrm): + try: + if self.chrm_idx_in_mem: + chrm_id = self.chrm_idx[chrm] + else: + chrm_id = self.cur.execute( + 'SELECT chrm_id FROM chrm WHERE chrm=?', + (chrm,)).fetchone()[0] + except (TypeError, KeyError): + raise mh.MegaError('Reference record (chromosome) not found in ' + + 'database.') + return chrm_id + + def get_chrm(self, chrm_id): + try: + if self.chrm_idx_in_mem: + chrm = self.chrm_read_idx[chrm_id] + else: + chrm = self.cur.execute( + 'SELECT chrm FROM chrm WHERE chrm_id=?', + (chrm_id,)).fetchone()[0] + except (TypeError, KeyError): + raise mh.MegaError('Reference record (chromosome ID) not found ' + + 'in mods database.') + return chrm + + def get_mod_base_data(self, mod_id): + try: + if self.mod_idx_in_mem: + mod_base_data = self.mod_read_idx[mod_id] + else: + mod_base_data = self.cur.execute( + 'SELECT mod_base, motif, motif_pos, raw_motif FROM mod ' + + 'WHERE mod_id=?', (mod_id,)).fetchone() + except (TypeError, KeyError): + raise mh.MegaError('Modified base not found in mods database.') + return mod_base_data + + def get_uuid(self, read_id): + try: + uuid = self.cur.execute( + 'SELECT uuid FROM read WHERE read_id=?', + (read_id,)).fetchone()[0] + except TypeError: + raise mh.MegaError('Read ID not found in vars database.') + return uuid def get_num_uniq_mod_pos(self): - return self.cur.execute('SELECT MAX(pos_id) FROM pos').fetchone()[0] + num_pos = self.cur.execute('SELECT MAX(pos_id) FROM pos').fetchone()[0] + if num_pos is None: + num_pos = 0 + return num_pos + + def get_num_uniq_mod_bases(self): + num_mod_bases = self.cur.execute( + 'SELECT MAX(mod_id) FROM mod').fetchone()[0] + if num_mod_bases is None: + num_mod_bases = 0 + return num_mod_bases + + def get_num_uniq_chrms(self): + num_chrms = self.cur.execute( + 'SELECT MAX(chrm_id) FROM chrm').fetchone()[0] + if num_chrms is None: + num_chrms = 0 + return num_chrms def iter_pos(self): self.cur.execute('SELECT pos_id, pos_chrm, strand, pos FROM pos') @@ -324,17 +358,17 @@ def get_pos_stats(self, pos_data, return_uuids=False): read_id_conv = self.get_uuid if return_uuids else lambda x: x # these attributes are specified in self.iter_pos pos_id, chrm_id, strand, pos = pos_data - self.cur.execute( - 'SELECT score_read, score, score_mod FROM data ' + - 'WHERE score_pos=?', (pos_id, )) + chrm = self.get_chrm(chrm_id) return [ - self.mod_data(read_id_conv(read_id), self.get_chrm(chrm_id), - strand, pos, score, *self.get_mod_base_data(mod_id)) - for read_id, score, mod_id in self.cur] + self.mod_data(read_id_conv(read_id), chrm, strand, pos, score, + *self.get_mod_base_data(mod_id)) + for read_id, score, mod_id in self.cur.execute( + 'SELECT score_read, score, score_mod FROM data ' + + 'WHERE score_pos=?', (pos_id,)).fetchall()] - def get_read_stats(self, uuid): - # TODO implement this for API - raise NotImplementedError + def close(self): + self.db.commit() + self.db.close() return @@ -518,8 +552,7 @@ def store_mod_call( mods_db = ModsDb(mods_db_fn, db_safety=db_safety, read_only=False, pos_index_in_memory=pos_index_in_memory) - for ref_name in ref_names_and_lens[0]: - mods_db.insert_chrm(ref_name) + mods_db.insert_chrms(ref_names_and_lens[0]) mods_db.create_chrm_index() if mods_txt_fn is None: @@ -567,7 +600,8 @@ def store_mod_call( if mods_txt_fp is not None: mods_txt_fp.close() if pr_refs_fn is not None: pr_refs_fp.close() - mods_db.create_mod_index() + if mods_db.mod_idx_in_mem: + mods_db.create_mod_index() if mods_db.pos_idx_in_mem: mods_db.create_pos_index() mods_db.create_data_covering_index() @@ -1103,7 +1137,7 @@ def emp_em(self, pos_scores, max_iters): return curr_mix_prop, pos_scores.shape[0] - def compute_mod_stats(self, mod_loc, agg_method=None, valid_read_ids=None): + def compute_mod_stats(self, mod_pos, agg_method=None, valid_read_ids=None): if agg_method is None: agg_method = self.agg_method if agg_method not in AGG_METHOD_NAMES: @@ -1112,7 +1146,7 @@ def compute_mod_stats(self, mod_loc, agg_method=None, valid_read_ids=None): agg_method)) pr_mod_stats = self.mods_db.get_pos_stats( - mod_loc, return_uuids=valid_read_ids is not None) + mod_pos, return_uuids=valid_read_ids is not None) mod_type_stats = defaultdict(dict) for r_stats in pr_mod_stats: if (valid_read_ids is not None and diff --git a/megalodon/snps.py b/megalodon/snps.py index 3b4e53b..428b146 100755 --- a/megalodon/snps.py +++ b/megalodon/snps.py @@ -160,39 +160,17 @@ def __init__(self, fn, read_only=True, db_safety=1, return - def insert_chrm(self, chrm): - self.cur.execute('INSERT INTO chrm (chrm) VALUES (?)', (chrm,)) + # insert data function + def insert_chrms(self, chrms): + next_chrm_id = self.get_num_uniq_chrms() + 1 + self.cur.executemany('INSERT INTO chrm (chrm) VALUES (?)', + [(chrm,) for chrm in chrms]) if self.chrm_idx_in_mem: - self.chrm_idx[chrm] = self.cur.lastrowid - return self.cur.lastrowid - - def get_chrm_id(self, chrm): - try: - if self.chrm_idx_in_mem: - chrm_id = self.chrm_idx[chrm] - else: - chrm_id = self.cur.execute( - 'SELECT chrm_id FROM chrm WHERE chrm=?', - (chrm,)).fetchone()[0] - except (TypeError, KeyError): - raise mh.MegaError('Reference record (chromosome) not found in ' + - 'database.') - return chrm_id - - def get_chrm(self, chrm_id): - try: - if self.chrm_idx_in_mem: - chrm = self.chrm_read_idx[chrm_id] - else: - chrm = self.cur.execute( - 'SELECT chrm FROM chrm WHERE chrm_id=?', - (chrm_id,)).fetchone()[0] - except (TypeError, KeyError): - raise mh.MegaError('Reference record (chromosome) not found in ' + - 'vars database.') - return chrm + self.chrm_idx.update(zip( + chrms, range(next_chrm_id, next_chrm_id + len(chrms)))) + return - def get_loc_ids_or_insert_locations(self, r_var_scores, chrm_id): + def get_loc_ids_or_insert(self, r_var_scores, chrm_id): """ Extract all location IDs and add those locations not currently found in the database """ @@ -203,7 +181,8 @@ def get_loc_ids_or_insert_locations(self, r_var_scores, chrm_id): if self.loc_idx_in_mem: locs_to_add = list(set(r_locs).difference(self.loc_idx)) else: - test_starts, test_ends = map(set, list(zip(*r_uniq_locs))[1:]) + test_starts, test_ends = map( + set, list(zip(*r_locs.keys()))[1:]) loc_ids = dict(( ((chrm_id, test_start, test_end), loc_id) for chrm_id, test_start, test_end, loc_id in self.cur.execute( @@ -222,26 +201,18 @@ def get_loc_ids_or_insert_locations(self, r_var_scores, chrm_id): 'pos, ref_seq, var_name) VALUES (?,?,?,?,?,?)', ((*loc_key, *r_locs[loc_key]) for loc_key in locs_to_add)) - if self.loc_idx_in_mem: - if len(locs_to_add) > 0: - self.loc_idx.update(zip( - locs_to_add, - range(next_loc_id, next_loc_id + len(locs_to_add)))) - r_loc_ids = [ - self.loc_idx[(chrm_id, test_start, test_end)] - for _, _, _, _, _, test_start, test_end in r_var_scores] - else: - if len(locs_to_add) > 0: - loc_ids.update(zip( - locs_to_add, - range(next_loc_id, next_loc_id + len(locs_to_add)))) - r_loc_ids = [ - loc_ids[(chrm_id, test_start, test_end)] - for _, _, _, _, _, test_start, test_end in r_var_scores] + loc_idx = self.loc_idx if self.loc_idx_in_mem else loc_ids + if len(locs_to_add) > 0: + loc_idx.update(zip( + locs_to_add, + range(next_loc_id, next_loc_id + len(locs_to_add)))) + r_loc_ids = [ + loc_idx[(chrm_id, test_start, test_end)] + for _, _, _, _, _, test_start, test_end in r_var_scores] return r_loc_ids - def get_alt_ids_or_insert_alt_seqs(self, r_var_scores): + def get_alt_ids_or_insert(self, r_var_scores): r_seqs_and_lps = [ tuple(zip(alt_seqs, alt_lps)) for _, alt_lps, _, alt_seqs, _, _, _ in r_var_scores] @@ -264,46 +235,25 @@ def get_alt_ids_or_insert_alt_seqs(self, r_var_scores): self.cur.executemany( 'INSERT INTO alt (alt_seq) VALUES (?)', alts_to_add) - if self.alt_idx_in_mem: - if len(alts_to_add) > 0: - self.alt_idx.update(zip( - alts_to_add, - range(next_alt_id, next_alt_id + len(alts_to_add)))) - r_alt_ids = [ - tuple((self.alt_idx[alt_seq], alt_lp) - for alt_seq, alt_lp in loc_seqs_lps) - for loc_seqs_lps in r_seqs_and_lps] - else: - if len(alts_to_add) > 0: - alt_ids.update(zip( - alts_to_add, - range(next_alt_id, next_alt_id + len(alts_to_add)))) - r_alt_ids = [ - tuple((alt_ids[alt_seq], alt_lp) - for alt_seq, alt_lp in loc_seqs_lps) - for loc_seqs_lps in r_seqs_and_lps] + alt_idx = self.alt_idx if self.alt_idx_in_mem else alt_ids + if len(alts_to_add) > 0: + alt_idx.update(zip( + alts_to_add, + range(next_alt_id, next_alt_id + len(alts_to_add)))) + r_alt_ids = [ + tuple((alt_idx[alt_seq], alt_lp) + for alt_seq, alt_lp in loc_seqs_lps) + for loc_seqs_lps in r_seqs_and_lps] return r_alt_ids - def get_alt_seq(self, alt_id): - try: - if self.alt_idx_in_mem: - alt_seq = self.alt_read_idx[alt_id] - else: - alt_seq = self.cur.execute( - 'SELECT alt_seq FROM alt WHERE alt_id=?', - (alt_id,)).fetchone()[0] - except (TypeError, KeyError): - raise mh.MegaError('Alt sequence not found in vars database.') - return alt_seq - def insert_read_scores(self, r_var_scores, uuid, chrm, strand): self.cur.execute('INSERT INTO read (uuid, strand) VALUES (?,?)', (uuid, strand)) read_id = self.cur.lastrowid chrm_id = self.get_chrm_id(chrm) - loc_ids = self.get_loc_ids_or_insert_locations(r_var_scores, chrm_id) - alt_ids = self.get_alt_ids_or_insert_alt_seqs(r_var_scores) + loc_ids = self.get_loc_ids_or_insert(r_var_scores, chrm_id) + alt_ids = self.get_alt_ids_or_insert(r_var_scores) read_insert_data = ((alt_lp, loc_id, alt_id, read_id) for loc_id, loc_alts in zip(loc_ids, alt_ids) @@ -313,6 +263,7 @@ def insert_read_scores(self, r_var_scores, uuid, chrm, strand): 'INSERT INTO data VALUES (?,?,?,?)', read_insert_data) return + # create and load index functions def create_chrm_index(self): self.cur.execute('CREATE UNIQUE INDEX chrm_idx ON chrm(chrm)') return @@ -348,10 +299,60 @@ def create_data_covering_index(self): 'score_loc, score_alt, score_read, score)') return - def close(self): - self.db.commit() - self.db.close() - return + # reader functions + def get_chrm_id(self, chrm): + try: + if self.chrm_idx_in_mem: + chrm_id = self.chrm_idx[chrm] + else: + chrm_id = self.cur.execute( + 'SELECT chrm_id FROM chrm WHERE chrm=?', + (chrm,)).fetchone()[0] + except (TypeError, KeyError): + raise mh.MegaError('Reference record (chromosome) not found in ' + + 'database.') + return chrm_id + + def get_chrm(self, chrm_id): + try: + if self.chrm_idx_in_mem: + chrm = self.chrm_read_idx[chrm_id] + else: + chrm = self.cur.execute( + 'SELECT chrm FROM chrm WHERE chrm_id=?', + (chrm_id,)).fetchone()[0] + except (TypeError, KeyError): + raise mh.MegaError('Reference record (chromosome) not found in ' + + 'vars database.') + return chrm + + def get_alt_seq(self, alt_id): + try: + if self.alt_idx_in_mem: + alt_seq = self.alt_read_idx[alt_id] + else: + alt_seq = self.cur.execute( + 'SELECT alt_seq FROM alt WHERE alt_id=?', + (alt_id,)).fetchone()[0] + except (TypeError, KeyError): + raise mh.MegaError('Alt sequence not found in vars database.') + return alt_seq + + def get_uuid(self, read_id): + try: + uuid = self.cur.execute( + 'SELECT uuid FROM read WHERE read_id=?', + (read_id,)).fetchone()[0] + except TypeError: + raise mh.MegaError('Read ID not found in vars database.') + return uuid + + def get_num_uniq_chrms(self): + num_chrms = self.cur.execute( + 'SELECT MAX(chrm_id) FROM chrm').fetchone()[0] + if num_chrms is None: + num_chrms = 0 + return num_chrms def get_num_uniq_var_loc(self): num_locs = self.cur.execute('SELECT MAX(loc_id) FROM loc').fetchone()[0] @@ -377,18 +378,18 @@ def get_loc_stats(self, loc_data, return_uuids=False): read_id_conv = self.get_uuid if return_uuids else lambda x: x # these attributes are specified in self.iter_locs loc_id, chrm_id, pos, ref_seq, var_name = loc_data - self.cur.execute( - 'SELECT score, score_read, score_alt, score_loc FROM data ' + - 'WHERE score_loc=?', (loc_id, )) + chrm = self.get_chrm(chrm_id) return [ self.var_data( score, pos, ref_seq, var_name, read_id_conv(read_id), - self.get_chrm(chrm_id), self.get_alt_seq(alt_id)) - for score, read_id, alt_id, loc_id in self.cur] + chrm, self.get_alt_seq(alt_id)) + for score, read_id, alt_id, loc_id in self.cur.execute( + 'SELECT score, score_read, score_alt, score_loc FROM data ' + + 'WHERE score_loc=?', (loc_id,)).fetchall()] - def get_read_stats(self, uuid): - # TODO implement this for API - raise NotImplementedError + def close(self): + self.db.commit() + self.db.close() return @@ -813,8 +814,7 @@ def get_snp_call( logger = logging.get_logger('vars_getter') snps_db = VarsDb(vars_db_fn, db_safety=db_safety, read_only=False, loc_index_in_memory=loc_index_in_memory) - for ref_name in ref_names_and_lens[0]: - snps_db.insert_chrm(ref_name) + snps_db.insert_chrms(ref_names_and_lens[0]) snps_db.create_chrm_index() if snps_txt_fn is None: snps_txt_fp = None From fd776ae20522239bcad52c99f6f95d0dc530c608 Mon Sep 17 00:00:00 2001 From: Marcus Stoiber Date: Mon, 30 Sep 2019 14:25:44 -0700 Subject: [PATCH 07/14] Global conversion from snp to variant. --- README.rst | 24 +- docs/advanced_arguments.rst | 24 +- docs/algorithm_details.rst | 14 +- docs/common_arguments.rst | 8 +- docs/index.rst | 8 +- megalodon/aggregate.py | 145 +++--- megalodon/calibration.py | 14 +- megalodon/megalodon.py | 197 ++++---- megalodon/megalodon_helper.py | 45 +- ....npz => megalodon_variant_calibration.npz} | Bin megalodon/{snps.py => variants.py} | 440 +++++++++--------- scripts/filter_whatshap.py | 2 +- setup.py | 3 +- 13 files changed, 465 insertions(+), 459 deletions(-) rename megalodon/model_data/R941.min.high_acc.5mC_6mA_bio_cntxt/{megalodon_snp_calibration.npz => megalodon_variant_calibration.npz} (100%) rename megalodon/{snps.py => variants.py} (85%) diff --git a/README.rst b/README.rst index 7c4106c..ea6204a 100644 --- a/README.rst +++ b/README.rst @@ -6,10 +6,10 @@ Megalodon """"""""" -Megalodon provides "basecalling augmentation" for raw nanopore sequencing reads, including direct, reference-guided SNP and modified base calling. +Megalodon provides "basecalling augmentation" for raw nanopore sequencing reads, including direct, reference-guided sequence variant and modified base calling. Megalodon anchors the information rich neural network basecalling output to a reference genome. -Variants, modified bases or alternative canonical bases, are then proposed and scored in order to produce highly-accurate reference anchored modified base or SNP calls. +Variants, modified bases or alternative canonical bases, are then proposed and scored in order to produce highly-accurate reference anchored modified base or sequence variant calls. Detailed documentation for all ``megalodon`` arguments and algorithms can be found on the `megalodon documentation page `_. @@ -48,11 +48,11 @@ Megalodon is accessed via the command line interface ``megalodon`` command. # Example command calling variants and CpG methylation # Compute settings: GPU devices 0 and 1 with 8 CPU cores megalodon raw_fast5s/ \ - --outputs basecalls mappings snps mods \ + --outputs basecalls mappings variants mods \ --reference reference.fa --variant-filename variants.vcf.gz \ --mod-motif Z CG 0 --devices 0 1 --processes 8 --verbose-read-progress 3 -This command produces the ``megalodon_results`` output directory containing basecalls, mappings, SNP and modified base results. +This command produces the ``megalodon_results`` output directory containing basecalls, mappings, sequence variants and modified base results. The format for each output is described below. .. note:: @@ -76,7 +76,7 @@ Inputs - Format: VCF or BCF - If not indexed, indexing will be performed - - Megalodon currently requires a set of candidate variants for ``--outputs snps``. + - Megalodon currently requires a set of candidate variants for ``--outputs variants``. - Only small indels (default less than ``50`` bases) are tested by default. - Specify the ``--max-indel-size`` argument to process larger indels @@ -108,14 +108,14 @@ Outputs - Aggregated calls - Aggregated calls are output in either bedMethyl format (default; one file per modified base), a VCF variant format (including all modified bases) or wiggle format (one file per modified base/strand combination). -- SNP Variant Calls +- Sequence Variant Calls - - Per-read SNP Calls + - Per-read Variant Calls - - SQL DB containing scores at each tested reference location + - SQL DB containing scores for each tested variant - - Contains a single ``snps`` table indexed by reference position - - Tab-delimited output can be produced by adding the ``--write-snps-text`` flag + - Contains a single ``variants`` table indexed by reference position + - Tab-delimited output can be produced by adding the ``--write-variants-text`` flag - Aggregated calls - Format: VCF @@ -138,13 +138,13 @@ Note that the model parameters must (currently) be loaded into each GPU process The ``--chunk-size`` and ``--chunk-overlap`` arguments allow users to specify read chunking, but signal normalization is always carried out over the entire read. A number of helper processes will be spawned in order to perform more minor tasks, which should take minimal compute resources. -These include enumerating read ids and files, collecting and reporting progress information and getting data from read processing queues and writing outputs (basecalls, mappings, SNPs and modified bases). +These include enumerating read ids and files, collecting and reporting progress information and getting data from read processing queues and writing outputs (basecalls, mappings, sequence variants and modified bases). Model Compatibility ------------------- The model and calibration files included with megalodon are applicable only to MinION or GridION R9.4.1 flowcells. -New models trained with taiyaki can be used with megalodon, but in order to obtain the highest performance the megalodon (SNP and modified base) calibration files should be reproduced for any new model (TODO provide walkthrough). +New models trained with taiyaki can be used with megalodon, but in order to obtain the highest performance the megalodon (variant and modified base) calibration files should be reproduced for any new model (TODO provide walkthrough). The included model contains 5mC and 6mA capabilities. 5mC was trained only in the human (CpG) and E. coli (CCWGG) contexts while the 6mA was trained only on the E. coli (GATC) context. diff --git a/docs/advanced_arguments.rst b/docs/advanced_arguments.rst index e87c835..21fafce 100644 --- a/docs/advanced_arguments.rst +++ b/docs/advanced_arguments.rst @@ -17,15 +17,15 @@ Output Arguments - A file containing ``read_ids`` to process (one per line). - Used in the variant phasing pipeline. -------------- -SNP Arguments -------------- +-------------------------- +Sequence Variant Arguments +-------------------------- -- ``--disable-snp-calibration`` +- ``--disable-variant-calibration`` - - Use raw neural network SNP scores. + - Use raw neural network sequence variant scores. - This option should be set when calibrating a new model. - - Default: Calibrate scores as described in ``--snp-calibration-filename`` + - Default: Calibrate scores as described in ``--variant-calibration-filename`` - ``--heterozygous-factors`` - Bayes factor used when computing heterozygous probabilities in diploid variant calling mode. @@ -33,14 +33,14 @@ SNP Arguments - ``--max-indel-size`` - Maximum indel size to include in testing. Default: 50 -- ``--snp-all-paths`` +- ``--variant-all-paths`` - Compute the forward algorithm all paths score. - Default: Viterbi best-path score. -- ``--snp-calibration-filename`` +- ``--variant-calibration-filename`` - - File containing emperical calibration for SNP scores. - - As created by megalodon/scripts/calibrate_snp_llr_scores.py. + - File containing emperical calibration for sequence variant scores. + - As created by megalodon/scripts/calibrate_variant_llr_scores.py. - Default: Load default calibration file. - ``--variant-context-bases`` @@ -126,9 +126,9 @@ This output category is intended for use in generating reference sequences for t - ``--refs-include-mods`` - Include modified base calls in per-read reference output. -- ``--refs-include-snps`` +- ``--refs-include-variants`` - - Include SNP calls in per-read reference output. + - Include sequence variant calls in per-read reference output. - ``--refs-percent-identity-threshold`` - Only include reads with higher percent identity in per-read reference output. diff --git a/docs/algorithm_details.rst b/docs/algorithm_details.rst index 25d62aa..b32fe90 100644 --- a/docs/algorithm_details.rst +++ b/docs/algorithm_details.rst @@ -2,7 +2,7 @@ Megalodon Algorithm Details *************************** -This page describes the details of how megalodon processes the raw nanopore signal to produce highly-accurate modified base and SNP calls. +This page describes the details of how megalodon processes the raw nanopore signal to produce highly-accurate modified base and sequence variant calls. ------------ Base Calling @@ -23,11 +23,11 @@ The neural network output is anchored to the reference via standard read mapping If no reference mapping is produced (using ``minimap2`` via the ``mappy`` python interface) that read is not processed further (basecalls will be output if requested). This standard read mapping is processed to produce a matching of each basecall with a reference position. Reference positions within an insertion or deletion are assigned to the previous mapped read position (left justified). -This constitutes the reference anchoring used for modified base and SNP calling steps. +This constitutes the reference anchoring used for modified base and sequence variant calling steps. ------------ -SNP Calling ------------ +------------------------ +Sequence Variant Calling +------------------------ Megalodon currently filters alleles over a certain maximum size (default 50) as performance on larger indels has not currenty been validated. @@ -43,14 +43,14 @@ The difference between these two scores is the assigned score for the proposed v Lower (negative) score are evidence for the alternative sequence and higher (positive) scores are evidence for the reference sequence. These raw scores are softmax values over potential states, to match characteristics of a probability distribution. -In practice, these scores do not match emperical probabilities for a SNP given a truth dataset. +In practice, these scores do not match emperical probabilities for a variant given a truth dataset. Thus a calibration step is applied to convert these scores to estimated emperical probabilities. This enables more accurate aggregation across reads. Finally, calls across reads at each reference location are aggregated in order make a sample-level call. These results will be output into a VCF format file. -Currently ``diploid`` (default) and ``haploid`` SNP aggregation modes are available. +Currently ``diploid`` (default) and ``haploid`` variant aggregation modes are available. In ``haploid`` mode the probability of the reference and alternative alleles are simply the normalized (via Bayes' theorem) product of the individual read probabilities. In ``diploid`` mode the probability of each genotype (homozygous reference, heterozygous and homozygous alternative) are computed. The probabilities for homozygous alleles are as in the ``haploid`` mode, while the heterozygous probability is given by the weighted sum of the maximal probabilities taken over the sampling distribution (binomial with ``p=0.5``) given a true diploid heterozygous allele. diff --git a/docs/common_arguments.rst b/docs/common_arguments.rst index bdf60db..b386c77 100644 --- a/docs/common_arguments.rst +++ b/docs/common_arguments.rst @@ -35,7 +35,7 @@ Output Arguments - ``--outputs`` - Specify desired outputs. - - Options are ``basecalls``, ``mod_basecalls``, ``mappings``, ``whatshap_mappings``, ``per_read_snps``, ``per_read_mods``, ``snp``, and ``mods``. + - Options are ``basecalls``, ``mod_basecalls``, ``mappings``, ``whatshap_mappings``, ``per_read_variants``, ``per_read_mods``, ``variants``, and ``mods``. - ``mod_basecalls`` are currently output in an HDF5 file with a data set corresponding to each read (accessed via the ``read_id``). - ``whatshap_mappings`` are intended only for obtaining highly accurate phased variant genotypes. @@ -77,11 +77,11 @@ Sequence Variant Arguments - Variants file must be sorted. - If variant file is not compressed and indexed this will be performed before further processing. - Variants must be matched to the ``--reference`` provided. -- ``--write-snps-text`` +- ``--write-variants-text`` - - Output per-read SNPs in text format. + - Output per-read variants in text format. - - Output includes columns: ``read_id``, ``chrm``, ``strand``, ``pos``, ``ref_log_prob``, ``alt_log_prob``, ``snp_ref_seq``, ``snp_alt_seq``, ``snp_id`` + - Output includes columns: ``read_id``, ``chrm``, ``strand``, ``pos``, ``ref_log_prob``, ``alt_log_prob``, ``var_ref_seq``, ``var_alt_seq``, ``var_id`` - Log probabilities are calibrated to match observed log-likelihood ratios from ground truth samples. - Reference log probabilities are included to make processing mutliple alternative allele sites easier to process. diff --git a/docs/index.rst b/docs/index.rst index 7a5c580..d642260 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -2,7 +2,7 @@ Welcome to Megalodon's documentation! ************************************* -Megalodon provides "basecalling augmentation" for raw nanopore sequencing reads, including direct, reference-guided SNP and modified base calling. +Megalodon provides "basecalling augmentation" for raw nanopore sequencing reads, including direct, reference-guided sequence variant and modified base calling. Megalodon anchors the information rich neural network basecalling output to a reference genome. Variants, either modified bases or alternative canonical bases, are then proposed and scored in order to produce highly-accurate reference anchored calls. @@ -45,14 +45,14 @@ Megalodon is accessed via the command line interface, ``megalodon`` command. # Example command calling variants and CpG methylation # Compute settings: GPU devices 0 and 1 with 8 CPU cores megalodon raw_fast5s/ \ - --outputs basecalls mappings snps mods \ + --outputs basecalls mappings variants mods \ --reference reference.fa --variant-filename variants.vcf \ --mod-motif Z CG 0 --devices 0 1 --processes 8 --verbose-read-progress 3 -This command produces the ``megalodon_results`` output directory containing basecalls, mappings, SNP and modified base results. +This command produces the ``megalodon_results`` output directory containing basecalls, mappings, sequence variant and modified base results. The majority of megalodon's functionality is accessed via the ``megalodon`` command (exemplified above), though a small number of additional scripts are found in the ``scripts`` directory of the code repository. -These include independent modified base or SNP aggregation (much faster than re-computing per-read calls), modified base result validation, and model statistic calibration. +These include independent modified base or variant aggregation (much faster than re-computing per-read calls), modified base result validation, and model statistic calibration. Helper scripts to perform sequence variant phasing (details here :doc:`variant_phasing`) are also included in the ``scripts`` directory of the repository. -------- diff --git a/megalodon/aggregate.py b/megalodon/aggregate.py index d6d2010..bf42414 100644 --- a/megalodon/aggregate.py +++ b/megalodon/aggregate.py @@ -7,7 +7,7 @@ from tqdm import tqdm -from megalodon import logging, mods, snps, megalodon_helper as mh +from megalodon import logging, mods, variants, megalodon_helper as mh _DO_PROFILE_AGG_MOD = False _DO_PROFILE_GET_MODS = False @@ -16,62 +16,62 @@ _N_MOD_PROF = 200000 -####################################### -##### Aggregate SNP and Mod Stats ##### -####################################### +############################################ +##### Aggregate Variants and Mod Stats ##### +############################################ -def _agg_snps_worker( - locs_q, snp_stats_q, snp_prog_q, snps_db_fn, write_vcf_lp, +def _agg_vars_worker( + locs_q, var_stats_q, var_prog_q, vars_db_fn, write_vcf_lp, het_factors, call_mode, valid_read_ids): - agg_snps = snps.AggSnps(snps_db_fn, write_vcf_lp) + agg_vars = variants.AggVars(vars_db_fn, write_vcf_lp) while True: try: - snp_loc = locs_q.get(block=False) + var_loc = locs_q.get(block=False) except queue.Empty: sleep(0.001) continue - if snp_loc is None: + if var_loc is None: break try: - snp_var = agg_snps.compute_snp_stats( - snp_loc, het_factors, call_mode, valid_read_ids) - snp_stats_q.put(snp_var) + var_var = agg_vars.compute_var_stats( + var_loc, het_factors, call_mode, valid_read_ids) + var_stats_q.put(var_var) except mh.MegaError: # something not right with the stats at this loc pass - snp_prog_q.put(1) + var_prog_q.put(1) return -def _get_snp_stats_queue( - snp_stats_q, snp_conn, out_dir, ref_names_and_lens, out_suffix, +def _get_var_stats_queue( + var_stats_q, var_conn, out_dir, ref_names_and_lens, out_suffix, write_vcf_lp): - agg_snp_fn = mh.get_megalodon_fn(out_dir, mh.SNP_NAME) + agg_var_fn = mh.get_megalodon_fn(out_dir, mh.VAR_NAME) if out_suffix is not None: - base_fn, fn_ext = os.path.splitext(agg_snp_fn) - agg_snp_fn = base_fn + '.' + out_suffix + fn_ext - agg_snp_fp = snps.VcfWriter( - agg_snp_fn, 'w', ref_names_and_lens=ref_names_and_lens, + base_fn, fn_ext = os.path.splitext(agg_var_fn) + agg_var_fn = base_fn + '.' + out_suffix + fn_ext + agg_var_fp = variants.VcfWriter( + agg_var_fn, 'w', ref_names_and_lens=ref_names_and_lens, write_vcf_lp=write_vcf_lp) while True: try: - snp_var = snp_stats_q.get(block=False) - if snp_var is None: continue - agg_snp_fp.write_variant(snp_var) + var_var = var_stats_q.get(block=False) + if var_var is None: continue + agg_var_fp.write_variant(var_var) except queue.Empty: - if snp_conn.poll(): + if var_conn.poll(): break sleep(0.001) continue - while not snp_stats_q.empty(): - snp_var = snp_stats_q.get(block=False) - agg_snp_fp.write_variant(snp_var) + while not var_stats_q.empty(): + var_var = var_stats_q.get(block=False) + agg_var_fp.write_variant(var_var) - agg_snp_fp.close() + agg_var_fp.close() return @@ -174,17 +174,17 @@ def _get_mod_stats_queue(*args): return def _agg_prog_worker( - snp_prog_q, mod_prog_q, num_snps, num_mods, prog_conn, + var_prog_q, mod_prog_q, num_vars, num_mods, prog_conn, suppress_progress): - snp_bar, mod_bar = None, None - if num_snps > 0: + var_bar, mod_bar = None, None + if num_vars > 0: if num_mods > 0 and not suppress_progress: mod_bar = tqdm(desc='Mods', unit=' sites', total=num_mods, position=1, smoothing=0, dynamic_ncols=True) - snp_bar = tqdm(desc='SNPs', unit=' sites', total=num_snps, + var_bar = tqdm(desc='Variants', unit=' sites', total=num_vars, position=0, smoothing=0, dynamic_ncols=True) elif not suppress_progress: - snp_bar = tqdm(desc='SNPs', unit=' sites', total=num_snps, + var_bar = tqdm(desc='Variants', unit=' sites', total=num_vars, position=0, smoothing=0, dynamic_ncols=True) elif num_mods > 0 and not suppress_progress: mod_bar = tqdm(desc='Mods', unit=' sites', total=num_mods, @@ -192,15 +192,15 @@ def _agg_prog_worker( while True: try: - snp_prog_q.get(block=False) + var_prog_q.get(block=False) if not suppress_progress: - if snp_bar is not None: snp_bar.update(1) + if var_bar is not None: var_bar.update(1) if mod_bar is not None: mod_bar.update(0) except queue.Empty: try: mod_prog_q.get(block=False) if not suppress_progress: - if snp_bar is not None: snp_bar.update(0) + if var_bar is not None: var_bar.update(0) if mod_bar is not None: mod_bar.update(1) except queue.Empty: sleep(0.001) @@ -208,17 +208,17 @@ def _agg_prog_worker( break continue - while not snp_prog_q.empty(): - snp_prog_q.get(block=False) - if not suppress_progress: snp_bar.update(1) + while not var_prog_q.empty(): + var_prog_q.get(block=False) + if not suppress_progress: var_bar.update(1) while not mod_prog_q.empty(): mod_prog_q.get(block=False) if not suppress_progress: mod_bar.update(1) - if snp_bar is not None: - snp_bar.close() + if var_bar is not None: + var_bar.close() if mod_bar is not None: mod_bar.close() - if num_mods > 0 and num_snps > 0 and not suppress_progress: + if num_mods > 0 and num_vars > 0 and not suppress_progress: sys.stderr.write('\n\n') return @@ -246,38 +246,39 @@ def aggregate_stats( mod_names, mod_agg_info, write_mod_lp, mod_output_fmts, suppress_progress, ref_names_and_lens, valid_read_ids=None, out_suffix=None): - if mh.SNP_NAME in outputs and mh.MOD_NAME in outputs: + if mh.VAR_NAME in outputs and mh.MOD_NAME in outputs: num_ps = max(num_ps // 2, 1) logger = logging.get_logger('agg') - num_snps, num_mods, snp_prog_q, mod_prog_q = ( + num_vars, num_mods, var_prog_q, mod_prog_q = ( 0, 0, queue.Queue(), queue.Queue()) - if mh.SNP_NAME in outputs: - snps_db_fn = mh.get_megalodon_fn(out_dir, mh.PR_SNP_NAME) - num_snps = snps.AggSnps( - snps_db_fn, load_in_mem_indices=False).num_uniq() + if mh.VAR_NAME in outputs: + vars_db_fn = mh.get_megalodon_fn(out_dir, mh.PR_VAR_NAME) + num_vars = variants.AggVars( + vars_db_fn, load_in_mem_indices=False).num_uniq() logger.info('Spawning variant aggregation processes.') - # create process to collect snp stats from workers - snp_stats_q, snp_stats_p, main_snp_stats_conn = mh.create_getter_q( - _get_snp_stats_queue, ( + # create process to collect var stats from workers + var_stats_q, var_stats_p, main_var_stats_conn = mh.create_getter_q( + _get_var_stats_queue, ( out_dir, ref_names_and_lens, out_suffix, write_vcf_lp)) - # create process to fill snp locs queue - snp_filler_q = mp.Queue(maxsize=mh._MAX_QUEUE_SIZE) - snp_filler_p = mp.Process( + # create process to fill variant locs queue + var_filler_q = mp.Queue(maxsize=mh._MAX_QUEUE_SIZE) + var_filler_p = mp.Process( target=_fill_locs_queue, - args=(snp_filler_q, snps_db_fn, snps.AggSnps, num_ps), daemon=True) - snp_filler_p.start() - # create worker processes to aggregate snps - snp_prog_q = mp.Queue(maxsize=mh._MAX_QUEUE_SIZE) - agg_snps_ps = [] + args=(var_filler_q, vars_db_fn, variants.AggVars, num_ps), + daemon=True) + var_filler_p.start() + # create worker processes to aggregate variants + var_prog_q = mp.Queue(maxsize=mh._MAX_QUEUE_SIZE) + agg_vars_ps = [] for _ in range(num_ps): p = mp.Process( - target=_agg_snps_worker, - args=(snp_filler_q, snp_stats_q, snp_prog_q, snps_db_fn, + target=_agg_vars_worker, + args=(var_filler_q, var_stats_q, var_prog_q, vars_db_fn, write_vcf_lp, het_factors, call_mode, valid_read_ids), daemon=True) p.start() - agg_snps_ps.append(p) + agg_vars_ps.append(p) if mh.MOD_NAME in outputs: mods_db_fn = mh.get_megalodon_fn(out_dir, mh.PR_MOD_NAME) @@ -311,25 +312,25 @@ def aggregate_stats( # create progress process logger.info( - 'Aggregating {} SNPs and {} mod sites over reads.'.format( - num_snps, num_mods)) + 'Aggregating {} variants and {} modified base sites over reads.'.format( + num_vars, num_mods)) main_prog_conn, prog_conn = mp.Pipe() prog_p = mp.Process( target=_agg_prog_worker, - args=(snp_prog_q, mod_prog_q, num_snps, num_mods, prog_conn, + args=(var_prog_q, mod_prog_q, num_vars, num_mods, prog_conn, suppress_progress), daemon=True) prog_p.start() # join filler processes first - if mh.SNP_NAME in outputs: - snp_filler_p.join() - for agg_snps_p in agg_snps_ps: - agg_snps_p.join() + if mh.VAR_NAME in outputs: + var_filler_p.join() + for agg_vars_p in agg_vars_ps: + agg_vars_p.join() # send to conn - if snp_stats_p.is_alive(): - main_snp_stats_conn.send(True) - snp_stats_p.join() + if var_stats_p.is_alive(): + main_var_stats_conn.send(True) + var_stats_p.join() if mh.MOD_NAME in outputs: for agg_mods_p in agg_mods_ps: agg_mods_p.join() diff --git a/megalodon/calibration.py b/megalodon/calibration.py index 6d068ac..25cbf32 100644 --- a/megalodon/calibration.py +++ b/megalodon/calibration.py @@ -11,7 +11,7 @@ DEFAULT_SMOOTH_NVALS = 1001 DEFAULT_MIN_DENSITY = 5e-6 -SNP_CALIB_TYPE = 'snp_type_indel_len' +VAR_CALIB_TYPE = 'snp_type_indel_len' GENERIC_BASE = 'N' SNP_CALIB_TMPLT = 'snp_{}_{}_calibration' SNP_LLR_RNG_TMPLT = 'snp_{}_{}_llr_range' @@ -192,11 +192,11 @@ def compute_log_probs(alt_llrs): ##### Calibration Readers ##### ############################### -class SnpCalibrator(object): +class VarCalibrator(object): def _load_calibration(self): calib_data = np.load(self.fn) self.stratify_type = str(calib_data['stratify_type']) - assert self.stratify_type == SNP_CALIB_TYPE + assert self.stratify_type == VAR_CALIB_TYPE self.num_calib_vals = np.int(calib_data['smooth_nvals']) self.max_indel_len = np.int(calib_data['max_indel_len']) @@ -246,15 +246,15 @@ def _load_calibration(self): return - def __init__(self, snps_calib_fn): - self.fn = snps_calib_fn + def __init__(self, vars_calib_fn): + self.fn = vars_calib_fn if self.fn is not None: self._load_calibration() self.calib_loaded = self.fn is not None return def calibrate_llr(self, llr, read_ref_seq, read_alt_seq): - def simplify_snp_seq(ref_seq, alt_seq): + def simplify_var_seq(ref_seq, alt_seq): while (len(ref_seq) > 0 and len(alt_seq) > 0 and ref_seq[0] == alt_seq[0]): ref_seq = ref_seq[1:] @@ -268,7 +268,7 @@ def simplify_snp_seq(ref_seq, alt_seq): if not self.calib_loaded: return llr if len(read_ref_seq) == len(read_alt_seq): - ref_seq, alt_seq = simplify_snp_seq(read_ref_seq, read_alt_seq) + ref_seq, alt_seq = simplify_var_seq(read_ref_seq, read_alt_seq) # default to a "generic" SNP type that is the total of all SNP types snp_type = ((ref_seq, alt_seq) if (ref_seq, alt_seq) in self.snp_calib_tables else diff --git a/megalodon/megalodon.py b/megalodon/megalodon.py index 3b1322a..bf92469 100644 --- a/megalodon/megalodon.py +++ b/megalodon/megalodon.py @@ -21,14 +21,14 @@ from tqdm._utils import _term_move_up from megalodon import ( - aggregate, backends, decode, fast5_io, logging, mapping, mods, snps, + aggregate, backends, decode, fast5_io, logging, mapping, mods, variants, megalodon_helper as mh) from megalodon._version import MEGALODON_VERSION _DO_PROFILE = False _UNEXPECTED_ERROR_CODE = 'Unexpected error' -_UNEXPECTED_ERROR_FN = 'unexpected_snp_calling_errors.{}.err' +_UNEXPECTED_ERROR_FN = 'unexpected_megalodon_errors.{}.err' _MAX_NUM_UNEXP_ERRORS = 50 @@ -52,7 +52,7 @@ def handle_errors(func, args, r_vals, out_q, fast5_fn, failed_reads_q): return def process_read( - raw_sig, read_id, model_info, bc_q, caller_conn, snps_data, snps_q, + raw_sig, read_id, model_info, bc_q, caller_conn, vars_data, vars_q, mods_q, mods_info, fast5_fn, failed_reads_q): """ Workhorse per-read megalodon function (connects all the parts) """ @@ -87,15 +87,15 @@ def process_read( mapped_rl_cumsum = rl_cumsum[ r_ref_pos.q_trim_start:r_ref_pos.q_trim_end + 1] - post_mapped_start - if snps_q is not None: + if vars_q is not None: handle_errors( - func=snps.call_read_vars, - args=(snps_data, r_ref_pos, np_ref_seq, mapped_rl_cumsum, + func=variants.call_read_vars, + args=(vars_data, r_ref_pos, np_ref_seq, mapped_rl_cumsum, r_to_q_poss, r_post, post_mapped_start), r_vals=(read_id, r_ref_pos.chrm, r_ref_pos.strand, r_ref_pos.start, r_ref_seq, len(r_seq), r_ref_pos.q_trim_start, r_ref_pos.q_trim_end, r_cigar), - out_q=snps_q, fast5_fn=fast5_fn + ':::' + read_id, + out_q=vars_q, fast5_fn=fast5_fn + ':::' + read_id, failed_reads_q=failed_reads_q) if mods_q is not None: handle_errors( @@ -165,10 +165,10 @@ def _get_bc_queue( return def _process_reads_worker( - read_file_q, bc_q, snps_q, failed_reads_q, mods_q, caller_conn, - model_info, snps_data, mods_info, device): + read_file_q, bc_q, vars_q, failed_reads_q, mods_q, caller_conn, + model_info, vars_data, mods_info, device): model_info.prep_model_worker(device) - snps_data.reopen_variant_index() + vars_data.reopen_variant_index() logger = logging.get_logger('main') logger.debug('Starting read worker {}'.format(mp.current_process())) @@ -190,8 +190,8 @@ def _process_reads_worker( raw_sig = fast5_io.get_signal(fast5_fn, read_id, scale=True) process_read( - raw_sig, read_id, model_info, bc_q, caller_conn, snps_data, - snps_q, mods_q, mods_info, fast5_fn, failed_reads_q) + raw_sig, read_id, model_info, bc_q, caller_conn, vars_data, + vars_q, mods_q, mods_info, fast5_fn, failed_reads_q) failed_reads_q.put(( False, True, None, None, None, raw_sig.shape[0])) logger.debug('Successfully processed read {}'.format(read_id)) @@ -253,12 +253,12 @@ def post_process_mapping(out_dir, map_fmt, ref_fn): def post_process_aggregate( mods_info, outputs, mod_bin_thresh, out_dir, num_ps, write_vcf_lp, - het_factors, snps_data, write_mod_lp, supp_prog, ref_names_and_lens): + het_factors, vars_data, write_mod_lp, supp_prog, ref_names_and_lens): mod_names = mods_info.mod_long_names if mh.MOD_NAME in outputs else [] mod_agg_info = mods.AGG_INFO(mods.BIN_THRESH_NAME, mod_bin_thresh) aggregate.aggregate_stats( outputs, out_dir, num_ps, write_vcf_lp, het_factors, - snps_data.call_mode, mod_names, mod_agg_info, + vars_data.call_mode, mod_names, mod_agg_info, write_mod_lp, mods_info.mod_output_fmts, supp_prog, ref_names_and_lens) return @@ -430,7 +430,7 @@ def update_prog(reads_called, sig_called, unexp_err_fp): def process_all_reads( fast5s_dir, recursive, num_reads, read_ids_fn, model_info, outputs, - out_dir, bc_fmt, aligner, snps_data, num_ps, num_update_errors, + out_dir, bc_fmt, aligner, vars_data, num_ps, num_update_errors, suppress_progress, mods_info, db_safety, pr_ref_filts): logger = logging.get_logger() logger.info('Preparing workers to process reads.') @@ -452,8 +452,8 @@ def process_all_reads( suppress_progress), max_size=None) # start output type getters/writers - (bc_q, bc_p, main_bc_conn, mo_q, mo_p, main_mo_conn, snps_q, snps_p, - main_snps_conn, mods_q, mods_p, main_mods_conn) = [None,] * 12 + (bc_q, bc_p, main_bc_conn, mo_q, mo_p, main_mo_conn, vars_q, vars_p, + main_vars_conn, mods_q, mods_p, main_mods_conn) = [None,] * 12 if mh.BC_NAME in outputs or mh.BC_MODS_NAME in outputs: if mh.BC_NAME not in outputs: outputs.append(mh.BC_NAME) @@ -463,25 +463,25 @@ def process_all_reads( if mh.MAP_NAME in outputs: do_output_pr_refs = (mh.PR_REF_NAME in outputs and not mods_info.do_pr_ref_mods and - not snps_data.do_pr_ref_snps) + not vars_data.do_pr_ref_vars) mo_q, mo_p, main_mo_conn = mh.create_getter_q( mapping._get_map_queue, ( out_dir, aligner.ref_names_and_lens, aligner.out_fmt, aligner.ref_fn, do_output_pr_refs, pr_ref_filts)) - if mh.PR_SNP_NAME in outputs: + if mh.PR_VAR_NAME in outputs: pr_refs_fn = mh.get_megalodon_fn(out_dir, mh.PR_REF_NAME) if ( - mh.PR_REF_NAME in outputs and snps_data.do_pr_ref_snps) else None + mh.PR_REF_NAME in outputs and vars_data.do_pr_ref_vars) else None whatshap_map_fn = ( mh.get_megalodon_fn(out_dir, mh.WHATSHAP_MAP_NAME) + '.' + aligner.out_fmt) if mh.WHATSHAP_MAP_NAME in outputs else None - snps_txt_fn = (mh.get_megalodon_fn(out_dir, mh.PR_SNP_TXT_NAME) - if snps_data.write_snps_txt else None) - snps_q, snps_p, main_snps_conn = mh.create_getter_q( - snps._get_snps_queue, ( - mh.get_megalodon_fn(out_dir, mh.PR_SNP_NAME), - snps_txt_fn, db_safety, pr_refs_fn, pr_ref_filts, + vars_txt_fn = (mh.get_megalodon_fn(out_dir, mh.PR_VAR_TXT_NAME) + if vars_data.write_vars_txt else None) + vars_q, vars_p, main_vars_conn = mh.create_getter_q( + variants._get_variants_queue, ( + mh.get_megalodon_fn(out_dir, mh.PR_VAR_NAME), + vars_txt_fn, db_safety, pr_refs_fn, pr_ref_filts, whatshap_map_fn, aligner.ref_names_and_lens, aligner.ref_fn, - snps_data.loc_index_in_memory)) + vars_data.loc_index_in_memory)) if mh.PR_MOD_NAME in outputs: pr_refs_fn = mh.get_megalodon_fn(out_dir, mh.PR_REF_NAME) if ( mh.PR_REF_NAME in outputs and mods_info.do_pr_ref_mods) else None @@ -502,8 +502,8 @@ def process_all_reads( map_conns.append(map_conn) p = mp.Process( target=_process_reads_worker, args=( - read_file_q, bc_q, snps_q, failed_reads_q, mods_q, - caller_conn, model_info, snps_data, mods_info, device)) + read_file_q, bc_q, vars_q, failed_reads_q, mods_q, + caller_conn, model_info, vars_data, mods_info, device)) p.daemon = True p.start() proc_reads_ps.append(p) @@ -539,13 +539,13 @@ def process_all_reads( for on, p, main_conn in ( (mh.BC_NAME, bc_p, main_bc_conn), (mh.MAP_NAME, mo_p, main_mo_conn), - (mh.PR_SNP_NAME, snps_p, main_snps_conn), + (mh.PR_VAR_NAME, vars_p, main_vars_conn), (mh.PR_MOD_NAME, mods_p, main_mods_conn)): if on in outputs and p.is_alive(): main_conn.send(True) - if on == mh.PR_SNP_NAME: + if on == mh.PR_VAR_NAME: logger.info( - 'Waiting for snps database to complete indexing.') + 'Waiting for variants database to complete indexing.') elif on == mh.PR_MOD_NAME: logger.info( 'Waiting for mods database to complete indexing.') @@ -592,42 +592,42 @@ def aligner_validation(args): 'alignment was requested. Argument will be ignored.') return aligner -def snps_validation(args, is_cat_mod, output_size, aligner): +def vars_validation(args, is_cat_mod, output_size, aligner): logger = logging.get_logger() - if mh.WHATSHAP_MAP_NAME in args.outputs and not mh.SNP_NAME in args.outputs: - args.outputs.append(mh.SNP_NAME) - if mh.SNP_NAME in args.outputs and not mh.PR_SNP_NAME in args.outputs: - args.outputs.append(mh.PR_SNP_NAME) - if mh.PR_SNP_NAME in args.outputs and args.variant_filename is None: + if mh.WHATSHAP_MAP_NAME in args.outputs and not mh.VAR_NAME in args.outputs: + args.outputs.append(mh.VAR_NAME) + if mh.VAR_NAME in args.outputs and not mh.PR_VAR_NAME in args.outputs: + args.outputs.append(mh.PR_VAR_NAME) + if mh.PR_VAR_NAME in args.outputs and args.variant_filename is None: logger.error( - '{} output requested, '.format(mh.PR_SNP_NAME) + + '{} output requested, '.format(mh.PR_VAR_NAME) + 'but --variant-filename not provided.') sys.exit(1) - if mh.PR_SNP_NAME in args.outputs and not ( + if mh.PR_VAR_NAME in args.outputs and not ( is_cat_mod or mh.nstate_to_nbase(output_size) == 4): logger.error( - 'SNP calling from naive modified base flip-flop model is ' + + 'Variant calling from naive modified base flip-flop model is ' + 'not supported.') sys.exit(1) - snp_calib_fn = mh.get_snp_calibration_fn( - args.snp_calibration_filename, args.disable_snp_calibration) + var_calib_fn = mh.get_var_calibration_fn( + args.variant_calibration_filename, args.disable_variant_calibration) try: - snps_data = snps.SnpData( + vars_data = variants.VarData( args.variant_filename, args.max_indel_size, - args.snp_all_paths, args.write_snps_text, - args.variant_context_bases, snp_calib_fn, - snps.HAPLIOD_MODE if args.haploid else snps.DIPLOID_MODE, - args.refs_include_snps, aligner, edge_buffer=args.edge_buffer, + args.variant_all_paths, args.write_variants_text, + args.variant_context_bases, var_calib_fn, + variants.HAPLIOD_MODE if args.haploid else variants.DIPLOID_MODE, + args.refs_include_variants, aligner, edge_buffer=args.edge_buffer, context_min_alt_prob=args.context_min_alt_prob, loc_index_in_memory=not args.variant_locations_on_disk) except mh.MegaError as e: logger.error(str(e)) sys.exit(1) - if args.variant_filename is not None and mh.PR_SNP_NAME not in args.outputs: + if args.variant_filename is not None and mh.PR_VAR_NAME not in args.outputs: logger.warning( - '--snps-filename provided, but SNP output not requested ' + + '--variants-filename provided, but variants output not requested ' + '(via --outputs). Argument will be ignored.') - return args, snps_data + return args, vars_data def mods_validation(args, model_info): logger = logging.get_logger() @@ -673,24 +673,25 @@ def parse_pr_ref_output(args): logger = logging.get_logger() if args.output_per_read_references: args.outputs.append(mh.PR_REF_NAME) - if args.refs_include_snps and args.refs_include_mods: - logger.error('Cannot output both modified base and SNPs in ' + + if args.refs_include_vars and args.refs_include_mods: + logger.error('Cannot output both modified base and variants in ' + 'per-read references (remove one of ' + - '--refs-include-snps or --refs-include-mods).') + '--refs-include-variants or --refs-include-mods).') sys.exit(1) - if args.refs_include_snps and mh.PR_SNP_NAME not in args.outputs: - args.outputs.append(mh.PR_SNP_NAME) - logger.warning('--refs-include-snps set, so adding ' + - 'per_read_snps to --outputs.') + if args.refs_include_variants and mh.PR_VAR_NAME not in args.outputs: + args.outputs.append(mh.PR_VAR_NAME) + logger.warning('--refs-include-variants set, so adding ' + + 'per_read_variants to --outputs.') if args.refs_include_mods and mh.PR_MOD_NAME not in args.outputs: args.outputs.append(mh.PR_MOD_NAME) logger.warning('--refs-include-mods set, so adding ' + 'per_read_mods to --outputs.') else: - if args.refs_include_snps: + if args.refs_include_variants: logger.warning( - '--refs-include-snps but not --output-per-read-references ' + - 'set. Ignoring --refs-include-snps.') + '--refs-include-variantss but not ' + + '--output-per-read-references set. Ignoring ' + + '--refs-include-variants.') if args.refs_include_mods: logger.warning( '--refs-include-mods but not --output-per-read-references ' + @@ -791,63 +792,64 @@ def hidden_help(help_msg): help='Reference FASTA or minimap2 index file used for mapping ' + 'called reads.') - snp_grp = parser.add_argument_group('SNP Arguments') - snp_grp.add_argument( + var_grp = parser.add_argument_group('Sequence Variant Arguments') + var_grp.add_argument( '--haploid', action='store_true', - help='Compute SNP aggregation for haploid genotypes. Default: diploid') - snp_grp.add_argument( + help='Compute variant aggregation for haploid genotypes. ' + + 'Default: diploid') + var_grp.add_argument( '--variant-filename', help='Sequence variants to call for each read in VCF/BCF format ' + '(required for variant output).') - snp_grp.add_argument( - '--write-snps-text', action='store_true', - help='Write per-read SNP calls out to a text file. Default: ' + - 'Only ouput to database.') + var_grp.add_argument( + '--write-variants-text', action='store_true', + help='Write per-read sequence variant calls out to a text file. ' + + 'Default: Only ouput to database.') - snp_grp.add_argument( + var_grp.add_argument( '--context-min-alt-prob', type=float, default=mh.DEFAULT_CONTEXT_MIN_ALT_PROB, help=hidden_help('Minimum alternative alleles probability to ' + 'include variant in computation of nearby variants. ' + 'Default: %(default)f')) - snp_grp.add_argument( - '--disable-snp-calibration', action='store_true', - help=hidden_help('Use raw SNP scores from the network. ' + + var_grp.add_argument( + '--disable-variant-calibration', action='store_true', + help=hidden_help('Use raw variant scores from the network. ' + 'Default: Calibrate score with ' + - '--snp-calibration-filename')) - snp_grp.add_argument( + '--variant-calibration-filename')) + var_grp.add_argument( '--heterozygous-factors', type=float, nargs=2, default=[mh.DEFAULT_SNV_HET_FACTOR, mh.DEFAULT_INDEL_HET_FACTOR], help=hidden_help('Bayesian prior factor for snv and indel ' + 'heterozygous calls (compared to 1.0 for hom ' + 'ref/alt). Default: %(default)s')) - snp_grp.add_argument( + var_grp.add_argument( '--max-indel-size', type=int, default=50, help=hidden_help('Maximum difference in number of reference and ' + 'alternate bases. Default: %(default)d')) - snp_grp.add_argument( - '--snp-all-paths', action='store_true', + var_grp.add_argument( + '--variant-all-paths', action='store_true', help=hidden_help('Compute forwards algorithm all paths score. ' + '(Default: Viterbi best-path score)')) - snp_grp.add_argument( - '--snp-calibration-filename', + var_grp.add_argument( + '--variant-calibration-filename', help=hidden_help('File containing emperical calibration for ' + - 'SNP scores. As created by ' + - 'megalodon/scripts/calibrate_snp_llr_scores.py. ' + + 'variant scores. As created by ' + + 'megalodon/scripts/calibrate_variant_llr_scores.py. ' + 'Default: Load default calibration file.')) - snp_grp.add_argument( + var_grp.add_argument( '--variant-locations-on-disk', action='store_true', help=hidden_help('Force sequence variant locations to be stored ' + 'only within on disk database table. This option ' + 'will reduce the RAM memory requirement, but may ' + 'drastically slow processing. Default: Store ' + 'locations in memory and on disk.')) - snp_grp.add_argument( + var_grp.add_argument( '--variant-context-bases', type=int, nargs=2, default=[mh.DEFAULT_SNV_CONTEXT, mh.DEFAULT_INDEL_CONTEXT], - help=hidden_help('Context bases for single base SNP and indel ' + + help=hidden_help('Context bases for single base variant and indel ' + 'calling. Default: %(default)s')) - snp_grp.add_argument( + var_grp.add_argument( '--write-vcf-log-probs', action='store_true', help=hidden_help('Write per-read alt log probabilities out in ' + 'non-standard VCF field.')) @@ -934,9 +936,8 @@ def hidden_help(help_msg): help=hidden_help('Include modified base calls in per-read ' + 'reference output.')) refout_grp.add_argument( - '--refs-include-snps', action='store_true', - help=hidden_help('Include SNP calls in per-read ' + - 'reference output.')) + '--refs-include-variants', action='store_true', + help=hidden_help('Include variant calls in per-read reference output.')) refout_grp.add_argument( '--refs-percent-identity-threshold', type=float, help=hidden_help('Only include reads with higher percent identity ' + @@ -1007,13 +1008,13 @@ def _main(): args.max_concurrent_chunks) args, mods_info = mods_validation(args, model_info) aligner = aligner_validation(args) - args, snps_data = snps_validation( + args, vars_data = vars_validation( args, model_info.is_cat_mod, model_info.output_size, aligner) process_all_reads( args.fast5s_dir, not args.not_recursive, args.num_reads, args.read_ids_filename, model_info, args.outputs, - args.output_directory, args.basecalls_format, aligner, snps_data, + args.output_directory, args.basecalls_format, aligner, vars_data, args.processes, args.verbose_read_progress, args.suppress_progress, mods_info, args.database_safety, pr_ref_filts) @@ -1027,27 +1028,27 @@ def _main(): whatshap_sort_fn, whatshap_p = post_process_whatshap( args.output_directory, aligner.out_fmt, aligner.ref_fn) - if mh.SNP_NAME in args.outputs or mh.MOD_NAME in args.outputs: + if mh.VAR_NAME in args.outputs or mh.MOD_NAME in args.outputs: post_process_aggregate( mods_info, args.outputs, args.mod_binary_threshold, args.output_directory, args.processes, args.write_vcf_log_probs, - args.heterozygous_factors, snps_data, args.write_mod_log_probs, + args.heterozygous_factors, vars_data, args.write_mod_log_probs, args.suppress_progress, aligner.ref_names_and_lens) - if mh.SNP_NAME in args.outputs: + if mh.VAR_NAME in args.outputs: logger.info('Sorting output variant file') - variant_fn = mh.get_megalodon_fn(args.output_directory, mh.SNP_NAME) + variant_fn = mh.get_megalodon_fn(args.output_directory, mh.VAR_NAME) sort_variant_fn = mh.add_fn_suffix(variant_fn, 'sorted') - snps.sort_variants(variant_fn, sort_variant_fn) + variants.sort_variants(variant_fn, sort_variant_fn) logger.info('Indexing output variant file') - index_variant_fn = snps.index_variants(sort_variant_fn) + index_variant_fn = variants.index_variants(sort_variant_fn) if mh.WHATSHAP_MAP_NAME in args.outputs: if whatshap_p.is_alive(): logger.info('Waiting for whatshap mappings sort') while whatshap_p.is_alive(): sleep(0.001) - logger.info(snps.get_whatshap_command( + logger.info(variants.get_whatshap_command( index_variant_fn, whatshap_sort_fn, mh.add_fn_suffix(variant_fn, 'phased'))) diff --git a/megalodon/megalodon_helper.py b/megalodon/megalodon_helper.py index a78b90c..ad11bc1 100644 --- a/megalodon/megalodon_helper.py +++ b/megalodon/megalodon_helper.py @@ -56,10 +56,10 @@ MAP_NAME = 'mappings' MAP_SUMM_NAME = 'mappings_summary' MAP_OUT_FMTS = ('bam', 'cram', 'sam') -PR_SNP_NAME = 'per_read_snps' -PR_SNP_TXT_NAME = 'per_read_snps_text' +PR_VAR_NAME = 'per_read_variants' +PR_VAR_TXT_NAME = 'per_read_variants_text' WHATSHAP_MAP_NAME = 'whatshap_mappings' -SNP_NAME = 'snps' +VAR_NAME = 'variants' PR_MOD_NAME = 'per_read_mods' PR_MOD_TXT_NAME = 'per_read_mods_text' # TOOD add wig/bedgraph modified base output @@ -70,9 +70,9 @@ BC_MODS_NAME:'basecalls.modified_base_scores.hdf5', MAP_NAME:'mappings', MAP_SUMM_NAME:'mappings.summary.txt', - PR_SNP_NAME:'per_read_snp_calls.db', - PR_SNP_TXT_NAME:'per_read_snp_calls.txt', - SNP_NAME:'variants.vcf', + PR_VAR_NAME:'per_read_variant_calls.db', + PR_VAR_TXT_NAME:'per_read_variant_calls.txt', + VAR_NAME:'variants.vcf', WHATSHAP_MAP_NAME:'whatshap_mappings', PR_MOD_NAME:'per_read_modified_base_calls.db', PR_MOD_TXT_NAME:'per_read_modified_base_calls.txt', @@ -85,9 +85,10 @@ BC_NAME:'Called bases (FASTA)', BC_MODS_NAME:'Basecall-anchored modified base scores (HDF5)', MAP_NAME:'Mapped reads (BAM/CRAM/SAM)', - PR_SNP_NAME:'Per-read, per-site SNP scores database', - SNP_NAME:'Sample-level aggregated SNP calls (VCF)', - WHATSHAP_MAP_NAME:'SNP annotated mappings for use with whatshap', + PR_VAR_NAME:'Per-read, per-site sequence variant scores database', + VAR_NAME:'Sample-level aggregated sequence variant calls (VCF)', + WHATSHAP_MAP_NAME:( + 'Sequence variant annotated mappings for use with whatshap'), PR_MOD_NAME:'Per-read, per-site modified base scores database', MOD_NAME:'Sample-level aggregated modified base calls (modVCF)' } @@ -107,7 +108,7 @@ MOD_WIG_NAME:'wig' } -ALIGN_OUTPUTS = set((MAP_NAME, PR_REF_NAME, PR_SNP_NAME, SNP_NAME, +ALIGN_OUTPUTS = set((MAP_NAME, PR_REF_NAME, PR_VAR_NAME, VAR_NAME, WHATSHAP_MAP_NAME, PR_MOD_NAME, MOD_NAME)) PR_REF_FILTERS = namedtuple( @@ -122,7 +123,7 @@ DEFAULT_MODEL_PRESET = MODEL_PRESETS[0] MODEL_DATA_DIR_NAME = 'model_data' MODEL_FN = 'model.checkpoint' -SNP_CALIBRATION_FN = 'megalodon_snp_calibration.npz' +VAR_CALIBRATION_FN = 'megalodon_variant_calibration.npz' MOD_CALIBRATION_FN = 'megalodon_mod_calibration.npz' @@ -196,22 +197,22 @@ def add_fn_suffix(fn, suffix): ##### Calibration File Loading ##### #################################### -def get_snp_calibration_fn( - snp_calib_fn=None, disable_snp_calib=False, preset_str=None): - if disable_snp_calib: +def get_var_calibration_fn( + var_calib_fn=None, disable_var_calib=False, preset_str=None): + if disable_var_calib: return None - elif snp_calib_fn is not None: - return resolve_path(snp_calib_fn) + elif var_calib_fn is not None: + return resolve_path(var_calib_fn) elif preset_str is not None: if preset_str not in MODEL_PRESETS: raise MegaError('Invalid model preset: {}'.format(preset_str)) resolve_path(pkg_resources.resource_filename( 'megalodon', os.path.join( - MODEL_DATA_DIR_NAME, preset_str, SNP_CALIBRATION_FN))) - # else return default snp calibration file + MODEL_DATA_DIR_NAME, preset_str, VAR_CALIBRATION_FN))) + # else return default variant calibration file return resolve_path(pkg_resources.resource_filename( 'megalodon', os.path.join( - MODEL_DATA_DIR_NAME, DEFAULT_MODEL_PRESET, SNP_CALIBRATION_FN))) + MODEL_DATA_DIR_NAME, DEFAULT_MODEL_PRESET, VAR_CALIBRATION_FN))) def get_mod_calibration_fn( mod_calib_fn=None, disable_mod_calib=False, preset_str=None): @@ -224,8 +225,8 @@ def get_mod_calibration_fn( raise MegaError('Invalid model preset: {}'.format(preset_str)) resolve_path(pkg_resources.resource_filename( 'megalodon', os.path.join( - MODEL_DATA_DIR_NAME, preset_str, SNP_CALIBRATION_FN))) - # else return default snp calibration file + MODEL_DATA_DIR_NAME, preset_str, VAR_CALIBRATION_FN))) + # else return default modified base calibration file return resolve_path(pkg_resources.resource_filename( 'megalodon', os.path.join( MODEL_DATA_DIR_NAME, DEFAULT_MODEL_PRESET, MOD_CALIBRATION_FN))) @@ -239,7 +240,7 @@ def get_model_fn(model_fn=None, preset_str=None): resolve_path(pkg_resources.resource_filename( 'megalodon', os.path.join( MODEL_DATA_DIR_NAME, preset_str, MODEL_FN))) - # else return default snp calibration file + # else return default model file return resolve_path(pkg_resources.resource_filename( 'megalodon', os.path.join( MODEL_DATA_DIR_NAME, DEFAULT_MODEL_PRESET, MODEL_FN))) diff --git a/megalodon/model_data/R941.min.high_acc.5mC_6mA_bio_cntxt/megalodon_snp_calibration.npz b/megalodon/model_data/R941.min.high_acc.5mC_6mA_bio_cntxt/megalodon_variant_calibration.npz similarity index 100% rename from megalodon/model_data/R941.min.high_acc.5mC_6mA_bio_cntxt/megalodon_snp_calibration.npz rename to megalodon/model_data/R941.min.high_acc.5mC_6mA_bio_cntxt/megalodon_variant_calibration.npz diff --git a/megalodon/snps.py b/megalodon/variants.py similarity index 85% rename from megalodon/snps.py rename to megalodon/variants.py index 428b146..6b8d7d9 100755 --- a/megalodon/snps.py +++ b/megalodon/variants.py @@ -402,18 +402,18 @@ def logsumexp(x): return np.log(np.sum(np.exp(x - x_max))) + x_max -################################ -##### Per-read SNP Scoring ##### -################################ +#################################### +##### Per-read Variant Scoring ##### +#################################### def write_per_read_debug( - snp_ref_pos, snp_id, read_ref_pos, np_s_snp_ref_seq, np_s_snp_alt_seqs, + var_ref_pos, var_id, read_ref_pos, np_s_var_ref_seq, np_s_var_alt_seqs, np_s_context_seqs, loc_contexts_ref_lps, loc_contexts_alts_lps, w_context, logger): - ref_seq = mh.int_to_seq(np_s_snp_ref_seq) + ref_seq = mh.int_to_seq(np_s_var_ref_seq) if read_ref_pos.strand == -1: ref_seq = mh.revcomp(ref_seq) - alts_seq = [mh.int_to_seq(np_alt) for np_alt in np_s_snp_alt_seqs] + alts_seq = [mh.int_to_seq(np_alt) for np_alt in np_s_var_alt_seqs] if read_ref_pos.strand == -1: alts_seq = [mh.revcomp(alt_seq) for alt_seq in alts_seq] ','.join(alts_seq) @@ -430,8 +430,8 @@ def write_per_read_debug( loc_contexts_ref_lps, zip(*loc_contexts_alts_lps), context_seqs): out_txt += ('VARIANT_FULL_DATA: {}\t{}\t{}\t{}\t{}[{}]{}\t{}\t' + '{:.2f}\t{}\t{}\n').format( - read_ref_pos.chrm, read_ref_pos.strand, snp_ref_pos, - snp_id, up_seq, ref_seq, dn_seq, ','.join(alts_seq), + read_ref_pos.chrm, read_ref_pos.strand, var_ref_pos, + var_id, up_seq, ref_seq, dn_seq, ','.join(alts_seq), ref_lp, ','.join(('{:.2f}'.format(alt_lp) for alt_lp in alt_lps)), 'WITH_CONTEXT' if w_context else 'NO_CONTEXT') @@ -470,8 +470,8 @@ def call_read_vars( # convert to forward strand sequence in order to annotate with variants read_ref_fwd_seq = (strand_read_np_ref_seq if read_ref_pos.strand == 1 else mh.revcomp_np(strand_read_np_ref_seq)) - # call all snps overlapping this read - r_snp_calls = [] + # call all variantss overlapping this read + r_var_calls = [] logger = logging.get_logger('per_read_vars') read_cached_scores = {} read_variants = vars_data.fetch_read_variants( @@ -479,23 +479,23 @@ def call_read_vars( filt_read_variants = [] # first pass over variants assuming the reference ground truth # (not including context variants) - for (np_s_snp_ref_seq, np_s_snp_alt_seqs, np_s_context_seqs, - s_ref_start, s_ref_end, variant) in vars_data.iter_snps( + for (np_s_var_ref_seq, np_s_var_alt_seqs, np_s_context_seqs, + s_ref_start, s_ref_end, variant) in vars_data.iter_vars( read_variants, read_ref_pos, read_ref_fwd_seq, context_max_dist=0): blk_start = rl_cumsum[r_to_q_poss[s_ref_start]] blk_end = rl_cumsum[r_to_q_poss[s_ref_end]] if blk_end - blk_start <= max( len(up_seq) + len(dn_seq) for up_seq, dn_seq in np_s_context_seqs) + max( - np_s_snp_ref_seq.shape[0], max( - snp_alt_seq.shape[0] - for snp_alt_seq in np_s_snp_alt_seqs)): + np_s_var_ref_seq.shape[0], max( + var_alt_seq.shape[0] + for var_alt_seq in np_s_var_alt_seqs)): # no valid mapping over large inserted query bases # i.e. need as many "events/strides" as bases for valid mapping continue np_ref_seq = np.concatenate([ - np_s_context_seqs[0][0], np_s_snp_ref_seq, np_s_context_seqs[0][1]]) + np_s_context_seqs[0][0], np_s_var_ref_seq, np_s_context_seqs[0][1]]) loc_ref_lp = score_seq( r_post, np_ref_seq, post_mapped_start + blk_start, post_mapped_start + blk_end, vars_data.all_paths) @@ -504,10 +504,10 @@ def call_read_vars( loc_alt_llrs = [] if _DEBUG_PER_READ: loc_contexts_alts_lps = [] - for np_s_snp_alt_seq, var_alt_seq in zip( - np_s_snp_alt_seqs, variant.alts): + for np_s_var_alt_seq, var_alt_seq in zip( + np_s_var_alt_seqs, variant.alts): np_alt_seq = np.concatenate([ - np_s_context_seqs[0][0], np_s_snp_alt_seq, + np_s_context_seqs[0][0], np_s_var_alt_seq, np_s_context_seqs[0][1]]) loc_alt_lp = score_seq( r_post, np_alt_seq, post_mapped_start + blk_start, @@ -526,7 +526,7 @@ def call_read_vars( if _DEBUG_PER_READ: write_per_read_debug( variant.start, variant.id, read_ref_pos, - np_s_snp_ref_seq, np_s_snp_alt_seqs, np_s_context_seqs, + np_s_var_ref_seq, np_s_var_alt_seqs, np_s_context_seqs, np.array([loc_ref_lp,]), loc_contexts_alts_lps, False, logger) if sum(np.exp(loc_alt_log_ps)) >= vars_data.context_min_alt_prob: @@ -534,15 +534,15 @@ def call_read_vars( read_cached_scores[(variant.id, variant.start, variant.stop)] = ( loc_ref_lp, loc_alt_lps) else: - r_snp_calls.append(( + r_var_calls.append(( variant.ref_start, loc_alt_log_ps, variant.ref, variant.alts, variant.id, variant.start, variant.start + variant.np_ref.shape[0])) # second round for variants with some evidence for alternative alleles # process with other potential variants as context - for (np_s_snp_ref_seq, np_s_snp_alt_seqs, np_s_context_seqs, - s_ref_start, s_ref_end, variant) in vars_data.iter_snps( + for (np_s_var_ref_seq, np_s_var_alt_seqs, np_s_context_seqs, + s_ref_start, s_ref_end, variant) in vars_data.iter_vars( filt_read_variants, read_ref_pos, read_ref_fwd_seq): ref_cntxt_ref_lp, ref_cntxt_alt_lps = read_cached_scores[( variant.id, variant.start, variant.stop)] @@ -552,13 +552,13 @@ def call_read_vars( if blk_end - blk_start <= max( len(up_seq) + len(dn_seq) for up_seq, dn_seq in np_s_context_seqs) + max( - np_s_snp_ref_seq.shape[0], max( - snp_alt_seq.shape[0] - for snp_alt_seq in np_s_snp_alt_seqs)): + np_s_var_ref_seq.shape[0], max( + var_alt_seq.shape[0] + for var_alt_seq in np_s_var_alt_seqs)): # if some context sequences are too long for signal # just use cached lps # TODO could also filter out invalid context sequences - r_snp_calls.append(( + r_var_calls.append(( variant.start, ref_cntxt_alt_lps, variant.ref, variant.alts, variant.id, variant.start, variant.start + variant.np_ref.shape[0])) @@ -566,7 +566,7 @@ def call_read_vars( # skip first (reference) context seq as this was cached ref_context_seqs = ( - np.concatenate([up_context_seq, np_s_snp_ref_seq, dn_context_seq]) + np.concatenate([up_context_seq, np_s_var_ref_seq, dn_context_seq]) for up_context_seq, dn_context_seq in np_s_context_seqs[1:]) loc_contexts_ref_lps = np.array([ref_cntxt_ref_lp] + [score_seq( r_post, ref_seq, post_mapped_start + blk_start, @@ -577,11 +577,11 @@ def call_read_vars( loc_alt_llrs = [] if _DEBUG_PER_READ: loc_contexts_alts_lps = [] - for np_s_snp_alt_seq, var_alt_seq, ref_cntxt_alt_lp in zip( - np_s_snp_alt_seqs, variant.alts, ref_cntxt_alt_lps): + for np_s_var_alt_seq, var_alt_seq, ref_cntxt_alt_lp in zip( + np_s_var_alt_seqs, variant.alts, ref_cntxt_alt_lps): alt_context_seqs = ( np.concatenate([ - up_context_seq, np_s_snp_alt_seq, dn_context_seq]) + up_context_seq, np_s_var_alt_seq, dn_context_seq]) for up_context_seq, dn_context_seq in np_s_context_seqs[1:]) loc_contexts_alt_lps = np.array([ref_cntxt_alt_lp,] + [ score_seq(r_post, alt_seq, post_mapped_start + blk_start, @@ -601,27 +601,27 @@ def call_read_vars( if _DEBUG_PER_READ: write_per_read_debug( variant.start, variant.id, read_ref_pos, - np_s_snp_ref_seq, np_s_snp_alt_seqs, np_s_context_seqs, + np_s_var_ref_seq, np_s_var_alt_seqs, np_s_context_seqs, loc_contexts_ref_lps, loc_contexts_alts_lps, True, logger) - r_snp_calls.append(( + r_var_calls.append(( variant.ref_start, loc_alt_log_ps, variant.ref, variant.alts, variant.id, variant.start, variant.start + variant.np_ref.shape[0])) # re-sort variants after adding context-included computations - return sorted(r_snp_calls, key=lambda x: x[0]) + return sorted(r_var_calls, key=lambda x: x[0]) -############################### -##### Per-read SNP Output ##### -############################### +################################### +##### Per-read Variant Output ##### +################################### def log_prob_to_phred(log_prob): with np.errstate(divide='ignore'): return -10 * np.log10(1 - np.exp(log_prob)) -def simplify_snp_seq(ref_seq, alt_seq): +def simplify_var_seq(ref_seq, alt_seq): trim_before = trim_after = 0 while (len(ref_seq) > 0 and len(alt_seq) > 0 and ref_seq[0] == alt_seq[0]): @@ -636,194 +636,195 @@ def simplify_snp_seq(ref_seq, alt_seq): return ref_seq, alt_seq, trim_before, trim_after -def iter_non_overlapping_snps(r_snp_calls): - def get_max_prob_allele_snp(snp_grp): - """ For overlapping SNPs return the snp with the highest probability - single allele as this one will be added to the reference sequence. +def iter_non_overlapping_variants(r_var_calls): + def get_max_prob_allele_var(var_grp): + """ For overlapping variantss return the variant with the highest + probability single allele as this one will be added to the reference + sequence. - More complex chained SNPs could be handled, but are not here. - For example, a 5 base deletion covering 2 single base swap SNPs could - validly result in 2 alternative single base swap alleles, but the + More complex chained variantss could be handled, but are not here. + For example, a 5 base deletion covering 2 single base swap variants + could validly result in 2 alternative single base swap alleles, but the logic here would only allow one of those alternatives since they are covered by the same reference deletion. There are certainly many more edge cases than this and each one would require specific logic. This likely covers the majority of valid cases and limiting to 50 base indels by default limits the scope of this issue. """ - most_prob_snp = None - for snp_data in snp_grp: + most_prob_var = None + for var_data in var_grp: with np.errstate(divide='ignore'): - ref_lp = np.log1p(-np.exp(snp_data[1]).sum()) - snp_max_lp = max(ref_lp, max(snp_data[1])) - if most_prob_snp is None or snp_max_lp > most_prob_snp[0]: - most_prob_snp = (snp_max_lp, ref_lp, snp_data) - - _, ref_lp, (snp_pos, alt_lps, snp_ref_seq, - snp_alt_seqs, _, _, _) = most_prob_snp - return snp_pos, alt_lps, snp_ref_seq, snp_alt_seqs, ref_lp - - - if len(r_snp_calls) == 0: return - r_snp_calls_iter = iter(r_snp_calls) - # initialize snp_grp with first snp - snp_data = next(r_snp_calls_iter) - prev_snp_end = snp_data[0] + len(snp_data[2]) - snp_grp = [snp_data] - for snp_data in sorted(r_snp_calls_iter, key=itemgetter(0)): - if snp_data[0] < prev_snp_end: - prev_snp_end = max(snp_data[0] + len(snp_data[2]), prev_snp_end) - snp_grp.append(snp_data) + ref_lp = np.log1p(-np.exp(var_data[1]).sum()) + var_max_lp = max(ref_lp, max(var_data[1])) + if most_prob_var is None or var_max_lp > most_prob_var[0]: + most_prob_var = (var_max_lp, ref_lp, var_data) + + _, ref_lp, (var_pos, alt_lps, var_ref_seq, + var_alt_seqs, _, _, _) = most_prob_var + return var_pos, alt_lps, var_ref_seq, var_alt_seqs, ref_lp + + + if len(r_var_calls) == 0: return + r_var_calls_iter = iter(r_var_calls) + # initialize var_grp with first var + var_data = next(r_var_calls_iter) + prev_var_end = var_data[0] + len(var_data[2]) + var_grp = [var_data] + for var_data in sorted(r_var_calls_iter, key=itemgetter(0)): + if var_data[0] < prev_var_end: + prev_var_end = max(var_data[0] + len(var_data[2]), prev_var_end) + var_grp.append(var_data) else: - yield get_max_prob_allele_snp(snp_grp) - prev_snp_end = snp_data[0] + len(snp_data[2]) - snp_grp = [snp_data] + yield get_max_prob_allele_var(var_grp) + prev_var_end = var_data[0] + len(var_data[2]) + var_grp = [var_data] - # yeild last snp grp data - yield get_max_prob_allele_snp(snp_grp) + # yeild last var grp data + yield get_max_prob_allele_var(var_grp) return -def annotate_snps(r_start, ref_seq, r_snp_calls, strand): - """ Annotate reference sequence with called snps. +def annotate_variants(r_start, ref_seq, r_var_calls, strand): + """ Annotate reference sequence with called variants. - Note: Reference sequence is in read orientation and snp calls are in + Note: Reference sequence is in read orientation and variant calls are in genome coordiates. """ - snp_seqs, snp_quals, snp_cigar = [], [], [] + var_seqs, var_quals, var_cigar = [], [], [] prev_pos, curr_match = 0, 0 - # ref_seq is read-centric so flop order to process snps in genomic order + # ref_seq is read-centric so flop order to process vars in genomic order if strand == -1: ref_seq = ref_seq[::-1] - for (snp_pos, alt_lps, snp_ref_seq, snp_alt_seqs, - ref_lp) in iter_non_overlapping_snps(r_snp_calls): - prev_len = snp_pos - r_start - prev_pos + for (var_pos, alt_lps, var_ref_seq, var_alt_seqs, + ref_lp) in iter_non_overlapping_variants(r_var_calls): + prev_len = var_pos - r_start - prev_pos # called canonical if ref_lp >= max(alt_lps): - snp_seqs.append( - ref_seq[prev_pos:snp_pos - r_start + len(snp_ref_seq)]) - snp_quals.extend( + var_seqs.append( + ref_seq[prev_pos:var_pos - r_start + len(var_ref_seq)]) + var_quals.extend( ([WHATSHAP_MAX_QUAL] * prev_len) + ([min(log_prob_to_phred(ref_lp), WHATSHAP_MAX_QUAL)] * - len(snp_ref_seq))) - curr_match += prev_len + len(snp_ref_seq) + len(var_ref_seq))) + curr_match += prev_len + len(var_ref_seq) else: - alt_seq = snp_alt_seqs[np.argmax(alt_lps)] + alt_seq = var_alt_seqs[np.argmax(alt_lps)] # complement since ref_seq is complement seq # (not reversed; see loop init) read_alt_seq = alt_seq if strand == 1 else mh.comp(alt_seq) - snp_seqs.append(ref_seq[prev_pos:snp_pos - r_start] + read_alt_seq) - snp_quals.extend( + var_seqs.append(ref_seq[prev_pos:var_pos - r_start] + read_alt_seq) + var_quals.extend( ([WHATSHAP_MAX_QUAL] * prev_len) + ([min(log_prob_to_phred(max(alt_lps)), WHATSHAP_MAX_QUAL)] * len(alt_seq))) - # add cigar information for snp or indel - t_ref_seq, t_alt_seq, t_before, t_after = simplify_snp_seq( - snp_ref_seq, alt_seq) + # add cigar information for variant + t_ref_seq, t_alt_seq, t_before, t_after = simplify_var_seq( + var_ref_seq, alt_seq) curr_match += t_before - snp_cigar.append((7, curr_match + prev_len)) + var_cigar.append((7, curr_match + prev_len)) if len(t_alt_seq) == len(t_ref_seq): - snp_cigar.append((8, len(t_alt_seq))) + var_cigar.append((8, len(t_alt_seq))) elif len(t_alt_seq) > len(t_ref_seq): # left justify mismatch bases in complex insertion if len(t_ref_seq) != 0: - snp_cigar.append((8, len(t_ref_seq))) - snp_cigar.append((1, len(t_alt_seq) - len(t_ref_seq))) + var_cigar.append((8, len(t_ref_seq))) + var_cigar.append((1, len(t_alt_seq) - len(t_ref_seq))) else: # left justify mismatch bases in complex deletion if len(t_alt_seq) != 0: - snp_cigar.append((8, len(t_alt_seq))) - snp_cigar.append((2, len(t_ref_seq) - len(t_alt_seq))) + var_cigar.append((8, len(t_alt_seq))) + var_cigar.append((2, len(t_ref_seq) - len(t_alt_seq))) curr_match = t_after - prev_pos = snp_pos - r_start + len(snp_ref_seq) + prev_pos = var_pos - r_start + len(var_ref_seq) - snp_seqs.append(ref_seq[prev_pos:]) - snp_seq = ''.join(snp_seqs) + var_seqs.append(ref_seq[prev_pos:]) + var_seq = ''.join(var_seqs) if strand == -1: - snp_seq = snp_seq[::-1] + var_seq = var_seq[::-1] len_remain = len(ref_seq) - prev_pos - snp_quals.extend([WHATSHAP_MAX_QUAL] * len_remain) + var_quals.extend([WHATSHAP_MAX_QUAL] * len_remain) if strand == -1: - snp_quals = snp_quals[::-1] - snp_quals = list(map(int, snp_quals)) - snp_cigar.append((7, len_remain + curr_match)) + var_quals = var_quals[::-1] + var_quals = list(map(int, var_quals)) + var_cigar.append((7, len_remain + curr_match)) if strand == -1: - snp_cigar = snp_cigar[::-1] + var_cigar = var_cigar[::-1] - return snp_seq, snp_quals, snp_cigar + return var_seq, var_quals, var_cigar -def _get_snps_queue( - snps_q, snps_conn, vars_db_fn, snps_txt_fn, db_safety, pr_refs_fn, +def _get_variants_queue( + vars_q, vars_conn, vars_db_fn, vars_txt_fn, db_safety, pr_refs_fn, pr_ref_filts, whatshap_map_fn, ref_names_and_lens, ref_fn, loc_index_in_memory): def write_whatshap_alignment( - read_id, snp_seq, snp_quals, chrm, strand, r_st, snp_cigar): + read_id, var_seq, var_quals, chrm, strand, r_st, var_cigar): a = pysam.AlignedSegment() a.query_name = read_id a.flag = 0 if strand == 1 else 16 a.reference_id = whatshap_map_fp.get_tid(chrm) a.reference_start = r_st - a.template_length = len(snp_seq) + a.template_length = len(var_seq) a.mapping_quality = WHATSHAP_MAX_QUAL a.set_tags([('RG', WHATSHAP_RG_ID)]) # convert to reference based sequence if strand == -1: - snp_seq = mh.revcomp(snp_seq) - snp_quals = snp_quals[::-1] - snp_cigar = snp_cigar[::-1] - a.query_sequence = snp_seq - a.query_qualities = array('B', snp_quals) - a.cigartuples = snp_cigar + var_seq = mh.revcomp(var_seq) + var_quals = var_quals[::-1] + var_cigar = var_cigar[::-1] + a.query_sequence = var_seq + a.query_qualities = array('B', var_quals) + a.cigartuples = var_cigar whatshap_map_fp.write(a) return - def get_snp_call( - r_snp_calls, read_id, chrm, strand, r_start, ref_seq, read_len, + def get_var_call( + r_var_calls, read_id, chrm, strand, r_start, ref_seq, read_len, q_st, q_en, cigar): - snps_db.insert_read_scores(r_snp_calls, read_id, chrm, strand) - if snps_txt_fp is not None and len(r_snp_calls) > 0: - snp_out_text = '' - for (pos, alt_lps, snp_ref_seq, snp_alt_seqs, snp_id, - test_start, test_end) in r_snp_calls: + vars_db.insert_read_scores(r_var_calls, read_id, chrm, strand) + if vars_txt_fp is not None and len(r_var_calls) > 0: + var_out_text = '' + for (pos, alt_lps, var_ref_seq, var_alt_seqs, var_id, + test_start, test_end) in r_var_calls: with np.errstate(divide='ignore'): ref_lp = np.log1p(-np.exp(alt_lps).sum()) - snp_out_text += '\n'.join(( + var_out_text += '\n'.join(( ('\t'.join('{}' for _ in field_names)).format( read_id, chrm, strand, pos, ref_lp, alt_lp, - snp_ref_seq, snp_alt_seq, snp_id) - for alt_lp, snp_alt_seq in zip( - alt_lps, snp_alt_seqs))) + '\n' - snps_txt_fp.write(snp_out_text) - if do_ann_snps: + var_ref_seq, var_alt_seq, var_id) + for alt_lp, var_alt_seq in zip( + alt_lps, var_alt_seqs))) + '\n' + vars_txt_fp.write(var_out_text) + if do_ann_vars: if not mapping.read_passes_filters( pr_ref_filts, read_len, q_st, q_en, cigar): return - snp_seq, snp_quals, snp_cigar = annotate_snps( - r_start, ref_seq, r_snp_calls, strand) + var_seq, var_quals, var_cigar = annotate_variants( + r_start, ref_seq, r_var_calls, strand) if pr_refs_fn is not None: - pr_refs_fp.write('>{}\n{}\n'.format(read_id, snp_seq)) + pr_refs_fp.write('>{}\n{}\n'.format(read_id, var_seq)) if whatshap_map_fn is not None: write_whatshap_alignment( - read_id, snp_seq, snp_quals, chrm, strand, r_start, - snp_cigar) + read_id, var_seq, var_quals, chrm, strand, r_start, + var_cigar) return logger = logging.get_logger('vars_getter') - snps_db = VarsDb(vars_db_fn, db_safety=db_safety, read_only=False, + vars_db = VarsDb(vars_db_fn, db_safety=db_safety, read_only=False, loc_index_in_memory=loc_index_in_memory) - snps_db.insert_chrms(ref_names_and_lens[0]) - snps_db.create_chrm_index() - if snps_txt_fn is None: - snps_txt_fp = None + vars_db.insert_chrms(ref_names_and_lens[0]) + vars_db.create_chrm_index() + if vars_txt_fn is None: + vars_txt_fp = None else: - snps_txt_fp = open(snps_txt_fn, 'w') + vars_txt_fp = open(vars_txt_fn, 'w') field_names = ( 'read_id', 'chrm', 'strand', 'pos', 'ref_log_prob', 'alt_log_prob', - 'ref_seq', 'alt_seq', 'snp_id') - snps_txt_fp.write('\t'.join(field_names) + '\n') + 'ref_seq', 'alt_seq', 'var_id') + vars_txt_fp.write('\t'.join(field_names) + '\n') if pr_refs_fn is not None: pr_refs_fp = open(pr_refs_fn, 'w') @@ -843,49 +844,49 @@ def get_snp_call( whatshap_map_fp = pysam.AlignmentFile( whatshap_map_fn, w_mode, header=header, reference_filename=ref_fn) - do_ann_snps = whatshap_map_fn is not None or pr_refs_fn is not None + do_ann_vars = whatshap_map_fn is not None or pr_refs_fn is not None while True: try: - r_snp_calls, (read_id, chrm, strand, r_start, ref_seq, read_len, - q_st, q_en, cigar) = snps_q.get(block=False) + r_var_calls, (read_id, chrm, strand, r_start, ref_seq, read_len, + q_st, q_en, cigar) = vars_q.get(block=False) except queue.Empty: - if snps_conn.poll(): + if vars_conn.poll(): break sleep(0.001) continue try: - get_snp_call( - r_snp_calls, read_id, chrm, strand, r_start, ref_seq, read_len, + get_var_call( + r_var_calls, read_id, chrm, strand, r_start, ref_seq, read_len, q_st, q_en, cigar) except Exception as e: logger.debug(( 'Error processing variant output for read: {}\nSet' + - ' _RAISE_VARIANT_PROCESSING_ERRORS in megalodon/snps.py to ' + + ' _RAISE_VARIANT_PROCESSING_ERRORS in megalodon/variants.py to ' + 'see full error.\nError type: {}').format(read_id, str(e))) if _RAISE_VARIANT_PROCESSING_ERRORS: raise - while not snps_q.empty(): - r_snp_calls, (read_id, chrm, strand, r_start, ref_seq, read_len, - q_st, q_en, cigar) = snps_q.get(block=False) + while not vars_q.empty(): + r_var_calls, (read_id, chrm, strand, r_start, ref_seq, read_len, + q_st, q_en, cigar) = vars_q.get(block=False) try: - get_snp_call( - r_snp_calls, read_id, chrm, strand, r_start, ref_seq, read_len, + get_var_call( + r_var_calls, read_id, chrm, strand, r_start, ref_seq, read_len, q_st, q_en, cigar) except Exception as e: logger.debug(( 'Error processing variant output for read: {}\nSet' + - ' _RAISE_VARIANT_PROCESSING_ERRORS in megalodon/snps.py to ' + - 'see full error.\nError type: {}').format(read_id, str(e))) + ' _RAISE_VARIANT_PROCESSING_ERRORS in megalodon/variants.py ' + + 'to see full error.\nError type: {}').format(read_id, str(e))) if _RAISE_VARIANT_PROCESSING_ERRORS: raise - if snps_txt_fp is not None: snps_txt_fp.close() + if vars_txt_fp is not None: vars_txt_fp.close() if pr_refs_fn is not None: pr_refs_fp.close() if whatshap_map_fn is not None: whatshap_map_fp.close() - snps_db.create_alt_index() - if snps_db.loc_idx_in_mem: - snps_db.create_loc_index() - snps_db.create_data_covering_index() - snps_db.close() + vars_db.create_alt_index() + if vars_db.loc_idx_in_mem: + vars_db.create_loc_index() + vars_db.create_data_covering_index() + vars_db.close() return @@ -894,7 +895,7 @@ def get_snp_call( ##### VCF Reader ##### ###################### -class SnpData(object): +class VarData(object): def check_vars_match_ref( self, vars_idx, contigs, aligner, num_contigs=5, num_sites_per_contig=50): @@ -910,32 +911,32 @@ def check_vars_match_ref( logger.debug(( 'Reference sequence does not match variant reference ' + 'sequence at {} expected "{}" got "{}"').format( - snp_ref_pos, var_data.ref, ref_seq)) + ref_pos, var_data.ref, ref_seq)) return False return True def __init__( self, variant_fn, max_indel_size, all_paths, - write_snps_txt, context_bases, snps_calib_fn=None, - call_mode=DIPLOID_MODE, do_pr_ref_snps=False, aligner=None, - keep_snp_fp_open=False, do_validate_reference=True, + write_vars_txt, context_bases, vars_calib_fn=None, + call_mode=DIPLOID_MODE, do_pr_ref_vars=False, aligner=None, + keep_var_fp_open=False, do_validate_reference=True, edge_buffer=mh.DEFAULT_EDGE_BUFFER, context_min_alt_prob=mh.DEFAULT_CONTEXT_MIN_ALT_PROB, loc_index_in_memory=True): logger = logging.get_logger('vars') self.max_indel_size = max_indel_size self.all_paths = all_paths - self.write_snps_txt = write_snps_txt - self.snps_calib_fn = snps_calib_fn - self.calib_table = calibration.SnpCalibrator(self.snps_calib_fn) + self.write_vars_txt = write_vars_txt + self.vars_calib_fn = vars_calib_fn + self.calib_table = calibration.VarCalibrator(self.vars_calib_fn) self.context_bases = context_bases if len(self.context_bases) != 2: raise mh.MegaError( - 'Must provide 2 context bases values (for single base SNPs ' + - 'and indels).') + 'Must provide 2 context bases values (for single base ' + + 'variants and indels).') self.call_mode = call_mode - self.do_pr_ref_snps = do_pr_ref_snps + self.do_pr_ref_vars = do_pr_ref_vars self.edge_buffer = edge_buffer self.context_min_alt_prob = context_min_alt_prob self.loc_index_in_memory = loc_index_in_memory @@ -961,7 +962,7 @@ def __init__( vars_idx = pysam.VariantFile(self.variant_fn) if aligner is None: raise mh.MegaError( - 'Must provide aligner if SNP filename is provided') + 'Must provide aligner if variants filename is provided') if len(set(aligner.ref_names_and_lens[0]).intersection(contigs)) == 0: raise mh.MegaError(( 'Reference and variant files contain no chromosomes/contigs ' + @@ -975,7 +976,7 @@ def __init__( 'Reference sequence file does not match reference sequence ' + 'from variants file.') - if keep_snp_fp_open: + if keep_var_fp_open: self.variants_idx = vars_idx else: vars_idx.close() @@ -1062,7 +1063,7 @@ def iter_alt_variant_seqs(variants): dist_vars = defaultdict(list) for context_var in context_variants: - var_dist = SnpData.compute_variant_distance(variant, context_var) + var_dist = VarData.compute_variant_distance(variant, context_var) if var_dist is not None: dist_vars[var_dist].append(context_var) @@ -1283,7 +1284,7 @@ def iter_atomized_variants( return # trim context bases from seq - np_ref_seq, np_alt_seq, start_trim, _ = simplify_snp_seq( + np_ref_seq, np_alt_seq, start_trim, _ = simplify_var_seq( np_ref_seq, np_alt_seq) var_start = var.start + start_trim try: @@ -1334,10 +1335,10 @@ def fetch_read_variants(self, read_ref_pos, read_ref_fwd_seq): grouped_read_vars, read_ref_fwd_seq, read_ref_pos) return read_variants - def iter_snps( + def iter_vars( self, read_variants, read_ref_pos, read_ref_fwd_seq, max_contexts=16, context_max_dist=mh.CONTEXT_MAX_DIST): - """Iterator over SNPs overlapping the read mapped position. + """Iterator over variants overlapping the read mapped position. Args: read_variants: List of variant objects (from fetch_read_variants) @@ -1349,8 +1350,8 @@ def iter_snps( include around each variant. Yields: - snp_ref_seq: Reference variant sequence on read strand - snp_alt_seqs: Alternative variant sequences on read strand + var_ref_seq: Reference variant sequence on read strand + var_alt_seqs: Alternative variant sequences on read strand context_seqs: Sequences surrounding the variant on read strand context_start: Start of variant context in read coordinates context_end: End of variant context in read coordinates @@ -1359,9 +1360,9 @@ def iter_snps( variant_id: string idnentifier for the variant pos: variant position (0-based coordinate) - SNPs within edge buffer of the end of the mapping will be ignored. + Variantss within edge buffer of the end of the mapping will be ignored. - If more than max_contexts snps exist within context_basss then only + If more than max_contexts variantss exist within context_basss then only the max_contexts most proximal to the variant in question will be returned. """ @@ -1635,34 +1636,34 @@ def write_variant(self, variant): return -################################# -##### SNP Aggregation Class ##### -################################# +##################################### +##### Variant Aggregation Class ##### +##################################### -class AggSnps(mh.AbstractAggregationClass): +class AggVars(mh.AbstractAggregationClass): """ Class to assist in database queries for per-site aggregation of - SNP calls over reads. + variant calls over reads. """ def __init__( self, vars_db_fn, write_vcf_log_probs=False, load_in_mem_indices=True): # open as read only database if load_in_mem_indices: - self.snps_db = VarsDb(vars_db_fn) + self.vars_db = VarsDb(vars_db_fn) else: - self.snps_db = VarsDb(vars_db_fn, chrm_index_in_memory=False, + self.vars_db = VarsDb(vars_db_fn, chrm_index_in_memory=False, alt_index_in_memory=False) - self.n_uniq_snps = None + self.n_uniq_vars = None self.write_vcf_log_probs = write_vcf_log_probs return def num_uniq(self): - if self.n_uniq_snps is None: - self.n_uniq_snps = self.snps_db.get_num_uniq_var_loc() - return self.n_uniq_snps + if self.n_uniq_vars is None: + self.n_uniq_vars = self.vars_db.get_num_uniq_var_loc() + return self.n_uniq_vars def iter_uniq(self): - for q_val in self.snps_db.iter_locs(): + for q_val in self.vars_db.iter_locs(): yield q_val return @@ -1700,31 +1701,32 @@ def compute_het_lp(a1, a2): 0.0 if het_gt else all_lps.shape[1] * np.log(het_factor) for het_gt in het_gts]) log_prior_weights = log_prior_weights - logsumexp(log_prior_weights) - snp_lps = np.array(genotype_lps) + log_prior_weights - post_snp_lps = snp_lps - logsumexp(snp_lps) - return np.exp(post_snp_lps), gts + var_lps = np.array(genotype_lps) + log_prior_weights + post_var_lps = var_lps - logsumexp(var_lps) + return np.exp(post_var_lps), gts def compute_haploid_probs(self, ref_lps, alts_lps): - snp_lps = np.concatenate([[ref_lps.sum()], alts_lps.sum(axis=1)]) - post_snp_lps = snp_lps - logsumexp(snp_lps) - return np.exp(post_snp_lps), list(map(str, range(snp_lps.shape[0]))) + var_lps = np.concatenate([[ref_lps.sum()], alts_lps.sum(axis=1)]) + post_var_lps = var_lps - logsumexp(var_lps) + return np.exp(post_var_lps), list(map(str, range(var_lps.shape[0]))) - def compute_snp_stats( - self, snp_loc, het_factors, call_mode=DIPLOID_MODE, + def compute_var_stats( + self, var_loc, het_factors, call_mode=DIPLOID_MODE, valid_read_ids=None): assert call_mode in (HAPLIOD_MODE, DIPLOID_MODE), ( - 'Invalid SNP aggregation ploidy call mode: {}.'.format(call_mode)) + 'Invalid variant aggregation ploidy call mode: {}.'.format( + call_mode)) - pr_snp_stats = self.snps_db.get_loc_stats(snp_loc) - alt_seqs = sorted(set(r_stats.alt_seq for r_stats in pr_snp_stats)) + pr_var_stats = self.vars_db.get_loc_stats(var_loc) + alt_seqs = sorted(set(r_stats.alt_seq for r_stats in pr_var_stats)) pr_alt_lps = defaultdict(dict) - for r_stats in pr_snp_stats: + for r_stats in pr_var_stats: if (valid_read_ids is not None and r_stats.read_id not in valid_read_ids): continue pr_alt_lps[r_stats.read_id][r_stats.alt_seq] = r_stats.score if len(pr_alt_lps) == 0: - raise mh.MegaError('No valid reads cover SNP') + raise mh.MegaError('No valid reads cover variant') alt_seq_lps = [[] for _ in range(len(alt_seqs))] for read_lps in pr_alt_lps.values(): @@ -1733,20 +1735,20 @@ def compute_snp_stats( alt_seq_lps[i].append(read_lps[alt_seq]) except KeyError: raise mh.MegaError( - 'Alternative SNP seqence must exist for all reads.') + 'Alternative variant seqence must exist for all reads.') alts_lps = np.stack(alt_seq_lps, axis=0) with np.errstate(all='ignore'): ref_lps = np.log1p(-np.exp(alts_lps).sum(axis=0)) - r0_stats = pr_snp_stats[0] - snp_var = Variant( + r0_stats = pr_var_stats[0] + variant = Variant( chrom=r0_stats.chrm, pos=r0_stats.pos, ref=r0_stats.ref_seq, alts=alt_seqs, id=r0_stats.var_name) - snp_var.add_tag('DP', '{}'.format(ref_lps.shape[0])) - snp_var.add_sample_field('DP', '{}'.format(ref_lps.shape[0])) + variant.add_tag('DP', '{}'.format(ref_lps.shape[0])) + variant.add_sample_field('DP', '{}'.format(ref_lps.shape[0])) if self.write_vcf_log_probs: - snp_var.add_sample_field('LOG_PROBS', ','.join( + variant.add_sample_field('LOG_PROBS', ','.join( ';'.join('{:.2f}'.format(lp) for lp in alt_i_lps) for alt_i_lps in alts_lps)) @@ -1757,15 +1759,15 @@ def compute_snp_stats( het_factors[1]) diploid_probs, gts = self.compute_diploid_probs( ref_lps, alts_lps, het_factor) - snp_var.add_diploid_probs(diploid_probs, gts) + variant.add_diploid_probs(diploid_probs, gts) elif call_mode == HAPLIOD_MODE: haploid_probs, gts = self.compute_haploid_probs(ref_lps, alts_lps) - snp_var.add_haploid_probs(haploid_probs, gts) + variant.add_haploid_probs(haploid_probs, gts) - return snp_var + return variant def close(self): - self.snps_db.close() + self.vars_db.close() return diff --git a/scripts/filter_whatshap.py b/scripts/filter_whatshap.py index f1a309c..87f5079 100755 --- a/scripts/filter_whatshap.py +++ b/scripts/filter_whatshap.py @@ -16,7 +16,7 @@ def is_complex_variant(ref, alts): # single base swaps aren't complex if any(len(allele) > 1 for allele in alts + [ref]): for alt in alts: - simp_ref, simp_alt, _, _ = snps.simplify_snp_seq(ref, alt) + simp_ref, simp_alt, _, _ = snps.simplify_var_seq(ref, alt) # if an allele simplifies to a SNV continue if len(simp_ref) == 0 and len(simp_alt) == 0: continue diff --git a/setup.py b/setup.py index b012a5d..fa3499c 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,8 @@ url='http://www.nanoporetech.com', long_description=( 'Megalodon contains base calling augmentation capabilities, mainly ' + - 'including direct, reference-guided SNP and modified base detection.'), + 'including direct, reference-guided sequence variant and modified ' + + 'base detection.'), classifiers=[ 'Development Status :: 3 - Alpha', From b30db6cf922e4b23e889b6b1af5d367c7434b2eb Mon Sep 17 00:00:00 2001 From: Marcus Stoiber Date: Mon, 30 Sep 2019 14:39:17 -0700 Subject: [PATCH 08/14] Some additional snp to variant replacements. --- docs/computing_considerations.rst | 12 +- docs/conf.py | 2 +- docs/variant_phasing.rst | 6 +- scripts/calibrate_mod_llr_scores.py | 2 +- ...res.py => calibrate_variant_llr_scores.py} | 18 +- ...nerate_ground_truth_variant_llr_scores.py} | 199 +++++++++--------- 6 files changed, 120 insertions(+), 119 deletions(-) rename scripts/{calibrate_snp_llr_scores.py => calibrate_variant_llr_scores.py} (93%) rename scripts/{generate_ground_truth_snp_llr_scores.py => generate_ground_truth_variant_llr_scores.py} (68%) diff --git a/docs/computing_considerations.rst b/docs/computing_considerations.rst index c51ce26..27d834c 100644 --- a/docs/computing_considerations.rst +++ b/docs/computing_considerations.rst @@ -36,12 +36,12 @@ A separate thread is linked to each per-read processing worker in order to acces Thus users may notice threads opened for this processing. These threads will generally consume very little compute. ------------------------------ -SNP and Modified Base Calling ------------------------------ +--------------------------------- +Variant and Modified Base Calling +--------------------------------- -SNP and modified base calling is computed within the per-read processing workers using CPU resources. +Sequence variant and modified base calling is computed within the per-read processing workers using CPU resources. Generally, this portion of processing will comsume a minority of the compute resources. -Proposing many SNPs (e.g. all possible 3+ base indels) may show a bottle neck at this portion of processing. +Proposing many variants (e.g. all possible 3+ base indels) may show a bottle neck at this portion of processing. Internal testing shows that proposal of all possible single base substitutions shows minimal processing at this portion of per-read processing. -Note that the data bases storing per-read SNP variants may show slower indexing with very large proposed SNP sets (performed at the end of per-read processing). +Note that the database storing per-read variant score may show slower indexing with a very large number of proposed variant sets (performed at the end of per-read processing). diff --git a/docs/conf.py b/docs/conf.py index 83b3a97..4619671 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -265,7 +265,7 @@ texinfo_documents = [ ('index', project, u'{} Documentation'.format(project), u'Oxford Nanopore Technologies', project, - 'Megalodon provides "basecalling augmentation" for raw nanopore sequencing reads, including direct, reference-guided SNP and modified base calling.', + 'Megalodon provides "basecalling augmentation" for raw nanopore sequencing reads, including direct, reference-guided sequence variant and modified base calling.', 'Miscellaneous'), ] diff --git a/docs/variant_phasing.rst b/docs/variant_phasing.rst index 9fb6639..e981497 100644 --- a/docs/variant_phasing.rst +++ b/docs/variant_phasing.rst @@ -21,7 +21,7 @@ Workflow # run megalodon to produce whatshap_mappings megalodon \ - $reads_dir --outputs mappings snps whatshap_mappings \ + $reads_dir --outputs mappings variants whatshap_mappings \ --reference $ref --variant-filename $variants_vcf \ --output-directory $out_dir \ --processes $nproc --devices $gpu_devices \ @@ -58,12 +58,12 @@ Workflow $out_dir/whatshap_mappings python \ megalodon/scripts/run_aggregation.py \ - --outputs snps --haploid --output-suffix haplotype_1 \ + --outputs variants --haploid --output-suffix haplotype_1 \ --read-ids-filename $out_dir/whatshap_mappings.haplotype_1_read_ids.txt \ --reference $ref --processes $nproc python \ megalodon/scripts/run_aggregation.py \ - --outputs snps --haploid --output-suffix haplotype_2 \ + --outputs variants --haploid --output-suffix haplotype_2 \ --read-ids-filename $out_dir/whatshap_mappings.haplotype_2_read_ids.txt \ --reference $ref --processes $nproc diff --git a/scripts/calibrate_mod_llr_scores.py b/scripts/calibrate_mod_llr_scores.py index 77c1b37..bc24e64 100644 --- a/scripts/calibrate_mod_llr_scores.py +++ b/scripts/calibrate_mod_llr_scores.py @@ -134,7 +134,7 @@ def main(): plot_calib(pdf_fp, mod_base, *plot_data) if pdf_fp is not None: pdf_fp.close() - # save calibration table for reading into SNP table + # save calibration table for reading into mod calibration table sys.stderr.write('Saving calibrations to file.\n') mod_bases = list(mod_base_llrs.keys()) np.savez( diff --git a/scripts/calibrate_snp_llr_scores.py b/scripts/calibrate_variant_llr_scores.py similarity index 93% rename from scripts/calibrate_snp_llr_scores.py rename to scripts/calibrate_variant_llr_scores.py index 2517eaa..70d0943 100644 --- a/scripts/calibrate_snp_llr_scores.py +++ b/scripts/calibrate_variant_llr_scores.py @@ -15,7 +15,7 @@ def plot_calib( - pdf_fp, snp_type, smooth_ls, s_ref, sm_ref, s_alt, sm_alt, + pdf_fp, var_type, smooth_ls, s_ref, sm_ref, s_alt, sm_alt, mono_prob, prob_alt): f, axarr = plt.subplots(3, sharex=True, figsize=(11, 7)) axarr[0].plot(smooth_ls, s_ref, color='orange') @@ -24,7 +24,7 @@ def plot_calib( axarr[0].plot(smooth_ls, sm_alt, color='blue') axarr[0].set_ylabel( 'Probability Density\nred/orange=canonical\nblue/grey=modified') - axarr[0].set_title(snp_type + ' Calibration') + axarr[0].set_title(var_type + ' Calibration') axarr[1].plot(smooth_ls, mono_prob, color='orange') axarr[1].plot( smooth_ls, 1 / (np.exp(smooth_ls) + 1), color='purple') @@ -85,9 +85,9 @@ def prep_out(out_fn, overwrite): def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( - '--ground-truth-llrs', default='snp_calibration_statistics.txt', + '--ground-truth-llrs', default='variant_calibration_statistics.txt', help='Ground truth log-likelihood ratio statistics (produced by ' + - 'generate_ground_truth_snp_llr_scores.py). Default: %(default)s') + 'generate_ground_truth_variant_llr_scores.py). Default: %(default)s') parser.add_argument( '--max-input-llr', type=int, default=calibration.DEFAULT_SMOOTH_MAX, help='Maximum log-likelihood ratio to compute calibration. ' + @@ -106,7 +106,7 @@ def get_parser(): 'dynamically adjusts [--max-input-llr] when it is too large. ' + 'Default: %(default)f') parser.add_argument( - '--out-filename', default='megalodon_snp_calibration.npz', + '--out-filename', default='megalodon_variant_calibration.npz', help='Filename to output calibration values. Default: %(default)s') parser.add_argument( '--out-pdf', @@ -127,9 +127,9 @@ def main(): sys.stderr.write('Parsing log-likelihood ratios\n') snp_ref_llrs, ins_ref_llrs, del_ref_llrs = extract_llrs( args.ground_truth_llrs) - # add calibration for a generic SNP (mostly multiple SNPs + # add calibration for a generic varaint (mostly multiple SNPs # as single variant; but not an indel) - generic_snp_llrs = [llr for snp_type_llrs in snp_ref_llrs.values() + generic_var_llrs = [llr for snp_type_llrs in snp_ref_llrs.values() for llr in snp_type_llrs] # downsample to same level as other snp types snp_ref_llrs[ @@ -186,7 +186,7 @@ def main(): if pdf_fp is not None: pdf_fp.close() - # save calibration table for reading into SNP table + # save calibration table for reading into variant calibration table sys.stderr.write('Saving calibrations to file.\n') snp_llr_range_save_data, snp_calib_save_data = {}, {} for (ref_seq, alt_seq), (snp_calib, snp_llr_range) in snp_calibs.items(): @@ -209,7 +209,7 @@ def main(): calibration.INS_LLR_RNG_TMPLT.format(ins_len)] = ins_llr_range np.savez( args.out_filename, - stratify_type=calibration.SNP_CALIB_TYPE, + stratify_type=calibration.VAR_CALIB_TYPE, smooth_nvals=args.num_calibration_values, max_indel_len=max_indel_len, **snp_calib_save_data, diff --git a/scripts/generate_ground_truth_snp_llr_scores.py b/scripts/generate_ground_truth_variant_llr_scores.py similarity index 68% rename from scripts/generate_ground_truth_snp_llr_scores.py rename to scripts/generate_ground_truth_variant_llr_scores.py index fbf0941..d769382 100644 --- a/scripts/generate_ground_truth_snp_llr_scores.py +++ b/scripts/generate_ground_truth_variant_llr_scores.py @@ -13,7 +13,7 @@ from megalodon import ( decode, fast5_io, megalodon_helper as mh, - megalodon, backends, mapping, snps) + megalodon, backends, mapping, variants) CONTEXT_BASES = [10, 30] @@ -29,44 +29,44 @@ _DO_PROFILE = False -def call_snp( - r_post, post_mapped_start, r_snp_pos, rl_cumsum, r_to_q_poss, - snp_ref_seq, snp_alt_seq, context_bases, all_paths, +def call_variant( + r_post, post_mapped_start, r_var_pos, rl_cumsum, r_to_q_poss, + var_ref_seq, var_alt_seq, context_bases, all_paths, np_ref_seq=None, ref_seq=None): - snp_context_bases = (context_bases[0] - if len(snp_ref_seq) == len(snp_alt_seq) else + var_context_bases = (context_bases[0] + if len(var_ref_seq) == len(var_alt_seq) else context_bases[1]) - pos_bb = min(snp_context_bases, r_snp_pos) + pos_bb = min(var_context_bases, r_var_pos) if ref_seq is None: - pos_ab = min(snp_context_bases, - np_ref_seq.shape[0] - r_snp_pos - len(snp_ref_seq)) - pos_ref_seq = np_ref_seq[r_snp_pos - pos_bb: - r_snp_pos + pos_ab + len(snp_ref_seq)] + pos_ab = min(var_context_bases, + np_ref_seq.shape[0] - r_var_pos - len(var_ref_seq)) + pos_ref_seq = np_ref_seq[r_var_pos - pos_bb: + r_var_pos + pos_ab + len(var_ref_seq)] else: - pos_ab = min(snp_context_bases, - len(ref_seq) - r_snp_pos - len(snp_ref_seq)) + pos_ab = min(var_context_bases, + len(ref_seq) - r_var_pos - len(var_ref_seq)) pos_ref_seq = mh.seq_to_int(ref_seq[ - r_snp_pos - pos_bb:r_snp_pos + pos_ab + len(snp_ref_seq)]]) + r_var_pos - pos_bb:r_var_pos + pos_ab + len(var_ref_seq)]]) pos_alt_seq = np.concatenate([ - pos_ref_seq[:pos_bb], mh.seq_to_int(snp_alt_seq), - pos_ref_seq[pos_bb + len(snp_ref_seq):]]) - blk_start = rl_cumsum[r_to_q_poss[r_snp_pos - pos_bb]] - blk_end = rl_cumsum[r_to_q_poss[r_snp_pos + pos_ab] + 1] + pos_ref_seq[:pos_bb], mh.seq_to_int(var_alt_seq), + pos_ref_seq[pos_bb + len(var_ref_seq):]]) + blk_start = rl_cumsum[r_to_q_poss[r_var_pos - pos_bb]] + blk_end = rl_cumsum[r_to_q_poss[r_var_pos + pos_ab] + 1] if blk_end - blk_start < max(len(pos_ref_seq), len(pos_alt_seq)): return np.NAN - loc_ref_score = snps.score_seq( + loc_ref_score = variants.score_seq( r_post, pos_ref_seq, post_mapped_start + blk_start, post_mapped_start + blk_end, all_paths) - loc_alt_score = snps.score_seq( + loc_alt_score = variants.score_seq( r_post, pos_alt_seq, post_mapped_start + blk_start, post_mapped_start + blk_end, all_paths) return loc_ref_score - loc_alt_score def call_alt_true_indel( - indel_size, r_snp_pos, true_ref_seq, r_seq, map_thr_buf, context_bases, + indel_size, r_var_pos, true_ref_seq, r_seq, map_thr_buf, context_bases, r_post, rl_cumsum, all_paths): def run_aligner(): return next(mappy.Aligner( @@ -76,27 +76,27 @@ def run_aligner(): if indel_size == 0: false_base = choice( - list(set(CAN_BASES).difference(true_ref_seq[r_snp_pos]))) + list(set(CAN_BASES).difference(true_ref_seq[r_var_pos]))) false_ref_seq = ( - true_ref_seq[:r_snp_pos] + false_base + - true_ref_seq[r_snp_pos + 1:]) - snp_ref_seq = false_base - snp_alt_seq = true_ref_seq[r_snp_pos] + true_ref_seq[:r_var_pos] + false_base + + true_ref_seq[r_var_pos + 1:]) + var_ref_seq = false_base + var_alt_seq = true_ref_seq[r_var_pos] elif indel_size > 0: # test alt truth reference insertion false_ref_seq = ( - true_ref_seq[:r_snp_pos + 1] + - true_ref_seq[r_snp_pos + indel_size + 1:]) - snp_ref_seq = true_ref_seq[r_snp_pos] - snp_alt_seq = true_ref_seq[r_snp_pos:r_snp_pos + indel_size + 1] + true_ref_seq[:r_var_pos + 1] + + true_ref_seq[r_var_pos + indel_size + 1:]) + var_ref_seq = true_ref_seq[r_var_pos] + var_alt_seq = true_ref_seq[r_var_pos:r_var_pos + indel_size + 1] else: # test alt truth reference deletion deleted_seq = ''.join(choice(CAN_BASES) for _ in range(-indel_size)) false_ref_seq = ( - true_ref_seq[:r_snp_pos + 1] + deleted_seq + - true_ref_seq[r_snp_pos + 1:]) - snp_ref_seq = true_ref_seq[r_snp_pos] + deleted_seq - snp_alt_seq = true_ref_seq[r_snp_pos] + true_ref_seq[:r_var_pos + 1] + deleted_seq + + true_ref_seq[r_var_pos + 1:]) + var_ref_seq = true_ref_seq[r_var_pos] + deleted_seq + var_alt_seq = true_ref_seq[r_var_pos] try: r_algn = run_aligner() @@ -108,19 +108,19 @@ def run_aligner(): raise mh.MegaError('Indel mapped read mapped to reverse strand.') r_to_q_poss = mapping.parse_cigar(r_algn.cigar, r_algn.strand) - if (r_algn.r_st > r_snp_pos - context_bases[1] or - r_algn.r_en < r_snp_pos + context_bases[1]): - raise mh.MegaError('Indel mapped read clipped snp position.') + if (r_algn.r_st > r_var_pos - context_bases[1] or + r_algn.r_en < r_var_pos + context_bases[1]): + raise mh.MegaError('Indel mapped read clipped variant position.') post_mapped_start = rl_cumsum[r_algn.q_st] mapped_rl_cumsum = rl_cumsum[ r_algn.q_st:r_algn.q_en + 1] - post_mapped_start - score = call_snp( - r_post, post_mapped_start, r_snp_pos, rl_cumsum, r_to_q_poss, - snp_ref_seq, snp_alt_seq, context_bases, all_paths, ref_seq=r_ref_seq) + score = call_variant( + r_post, post_mapped_start, r_var_pos, rl_cumsum, r_to_q_poss, + var_ref_seq, var_alt_seq, context_bases, all_paths, ref_seq=r_ref_seq) - return score, snp_ref_seq, snp_alt_seq + return score, var_ref_seq, var_alt_seq def process_read( raw_sig, read_id, model_info, caller_conn, map_thr_buf, do_false_ref, @@ -147,22 +147,22 @@ def process_read( mapped_rl_cumsum = rl_cumsum[ r_ref_pos.q_trim_start:r_ref_pos.q_trim_end + 1] - post_mapped_start - # candidate SNP locations within a read - snp_poss = list(range( + # candidate variant locations within a read + var_poss = list(range( edge_buffer, np_ref_seq.shape[0] - edge_buffer, every_n))[:max_pos_per_read] - read_snp_calls = [] + read_var_calls = [] if do_false_ref: # first process reference false calls (need to spoof an incorrect # reference for mapping and signal remapping) - for r_snp_pos in snp_poss: + for r_var_pos in var_poss: # first test single base swap SNPs try: - score, snp_ref_seq, snp_alt_seq = call_alt_true_indel( - 0, r_snp_pos, r_ref_seq, r_seq, map_thr_buf, + score, var_ref_seq, var_alt_seq = call_alt_true_indel( + 0, r_var_pos, r_ref_seq, r_seq, map_thr_buf, context_bases, r_post, rl_cumsum, all_paths) - read_snp_calls.append((False, score, snp_ref_seq, snp_alt_seq)) + read_var_calls.append((False, score, var_ref_seq, var_alt_seq)) except mh.MegaError: # introduced error either causes read not to map or # mapping trims the location of interest @@ -170,58 +170,58 @@ def process_read( # then test small indels for indel_size in range(1, max_indel_len + 1): try: - score, snp_ref_seq, snp_alt_seq = call_alt_true_indel( - indel_size, r_snp_pos, r_ref_seq, r_seq, map_thr_buf, + score, var_ref_seq, var_alt_seq = call_alt_true_indel( + indel_size, r_var_pos, r_ref_seq, r_seq, map_thr_buf, context_bases, r_post, rl_cumsum, all_paths) - read_snp_calls.append(( - False, score, snp_ref_seq, snp_alt_seq)) + read_var_calls.append(( + False, score, var_ref_seq, var_alt_seq)) except mh.MegaError: pass try: - score, snp_ref_seq, snp_alt_seq = call_alt_true_indel( - -indel_size, r_snp_pos, r_ref_seq, r_seq, map_thr_buf, + score, var_ref_seq, var_alt_seq = call_alt_true_indel( + -indel_size, r_var_pos, r_ref_seq, r_seq, map_thr_buf, context_bases, r_post, rl_cumsum, all_paths) - read_snp_calls.append(( - False, score, snp_ref_seq, snp_alt_seq)) + read_var_calls.append(( + False, score, var_ref_seq, var_alt_seq)) except mh.MegaError: pass - # now test reference correct SNPs - for r_snp_pos in snp_poss: + # now test reference correct variants + for r_var_pos in var_poss: # test simple SNP first - snp_ref_seq = r_ref_seq[r_snp_pos] - for snp_alt_seq in CAN_BASES_SET.difference(snp_ref_seq): - score = call_snp( - r_post, post_mapped_start, r_snp_pos, mapped_rl_cumsum, - r_to_q_poss, snp_ref_seq, snp_alt_seq, context_bases, all_paths, + var_ref_seq = r_ref_seq[r_var_pos] + for var_alt_seq in CAN_BASES_SET.difference(var_ref_seq): + score = call_variant( + r_post, post_mapped_start, r_var_pos, mapped_rl_cumsum, + r_to_q_poss, var_ref_seq, var_alt_seq, context_bases, all_paths, np_ref_seq=np_ref_seq) - read_snp_calls.append((True, score, snp_ref_seq, snp_alt_seq)) + read_var_calls.append((True, score, var_ref_seq, var_alt_seq)) # then test indels for indel_size in range(1, max_indel_len + 1): # test deletion - snp_ref_seq = r_ref_seq[r_snp_pos:r_snp_pos + indel_size + 1] - snp_alt_seq = r_ref_seq[r_snp_pos] - score = call_snp( - r_post, post_mapped_start, r_snp_pos, mapped_rl_cumsum, - r_to_q_poss, snp_ref_seq, snp_alt_seq, context_bases, + var_ref_seq = r_ref_seq[r_var_pos:r_var_pos + indel_size + 1] + var_alt_seq = r_ref_seq[r_var_pos] + score = call_variant( + r_post, post_mapped_start, r_var_pos, mapped_rl_cumsum, + r_to_q_poss, var_ref_seq, var_alt_seq, context_bases, all_paths, np_ref_seq=np_ref_seq) - read_snp_calls.append((True, score, snp_ref_seq, snp_alt_seq)) + read_var_calls.append((True, score, var_ref_seq, var_alt_seq)) # test random insertion - snp_ref_seq = r_ref_seq[r_snp_pos] - snp_alt_seq = snp_ref_seq + ''.join( + var_ref_seq = r_ref_seq[r_var_pos] + var_alt_seq = var_ref_seq + ''.join( choice(CAN_BASES) for _ in range(indel_size)) - score = call_snp( - r_post, post_mapped_start, r_snp_pos, mapped_rl_cumsum, - r_to_q_poss, snp_ref_seq, snp_alt_seq, context_bases, + score = call_variant( + r_post, post_mapped_start, r_var_pos, mapped_rl_cumsum, + r_to_q_poss, var_ref_seq, var_alt_seq, context_bases, all_paths, np_ref_seq=np_ref_seq) - read_snp_calls.append((True, score, snp_ref_seq, snp_alt_seq)) + read_var_calls.append((True, score, var_ref_seq, var_alt_seq)) - return read_snp_calls + return read_var_calls def _process_reads_worker( - fast5_q, snp_calls_q, caller_conn, model_info, device, do_false_ref): + fast5_q, var_calls_q, caller_conn, model_info, device, do_false_ref): model_info.prep_model_worker(device) map_thr_buf = mappy.ThreadBuffer() @@ -239,12 +239,12 @@ def _process_reads_worker( try: raw_sig = fast5_io.get_signal(fast5_fn, read_id) - read_snp_calls = process_read( + read_var_calls = process_read( raw_sig, read_id, model_info, caller_conn, map_thr_buf, do_false_ref) - snp_calls_q.put((True, read_snp_calls)) + var_calls_q.put((True, read_var_calls)) except Exception as e: - snp_calls_q.put((False, str(e))) + var_calls_q.put((False, str(e))) pass return @@ -254,12 +254,12 @@ def _process_reads_worker( def _process_reads_worker(*args): import cProfile cProfile.runctx('_process_reads_wrapper(*args)', globals(), locals(), - filename='snp_calibration.prof') + filename='variant_calibration.prof') return -def _get_snp_calls( - snp_calls_q, snp_calls_conn, out_fn, getter_num_reads_conn, +def _get_variant_calls( + var_calls_q, var_calls_conn, out_fn, getter_num_reads_conn, suppress_progress): out_fp = open(out_fn, 'w') bar = None @@ -269,13 +269,13 @@ def _get_snp_calls( err_types = defaultdict(int) while True: try: - valid_res, read_snp_calls = snp_calls_q.get(block=False) + valid_res, read_var_calls = var_calls_q.get(block=False) if valid_res: - for snp_call in read_snp_calls: - out_fp.write('{}\t{}\t{}\t{}\n'.format(*snp_call)) + for var_call in read_var_calls: + out_fp.write('{}\t{}\t{}\t{}\n'.format(*var_call)) out_fp.flush() else: - err_types[read_snp_calls] += 1 + err_types[read_var_calls] += 1 if not suppress_progress: bar.update(1) except queue.Empty: @@ -283,16 +283,16 @@ def _get_snp_calls( if getter_num_reads_conn.poll(): bar.total = getter_num_reads_conn.recv() else: - if snp_calls_conn.poll(): + if var_calls_conn.poll(): break sleep(0.01) continue - while not snp_calls_q.empty(): - valid_res, read_snp_calls = snp_calls_q.get(block=False) + while not var_calls_q.empty(): + valid_res, read_var_calls = var_calls_q.get(block=False) if valid_res: - for snp_call in read_snp_calls: - out_fp.write('{}\t{}\t{}\t{}\n'.format(*snp_call)) + for var_call in read_var_calls: + out_fp.write('{}\t{}\t{}\t{}\n'.format(*var_call)) out_fp.flush() else: err_types[str(e)] += 1 @@ -325,8 +325,9 @@ def process_all_reads( daemon=True) files_p.start() - snp_calls_q, snp_calls_p, main_sc_conn = mh.create_getter_q( - _get_snp_calls, (out_fn, getter_num_reads_conn, suppress_progress)) + var_calls_q, var_calls_p, main_sc_conn = mh.create_getter_q( + _get_variant_calls, + (out_fn, getter_num_reads_conn, suppress_progress)) proc_reads_ps, map_conns = [], [] for device in model_info.process_devices: @@ -337,7 +338,7 @@ def process_all_reads( map_conns.append(map_conn) p = mp.Process( target=_process_reads_worker, args=( - fast5_q, snp_calls_q, caller_conn, model_info, device, + fast5_q, var_calls_q, caller_conn, model_info, device, do_false_ref)) p.daemon = True p.start() @@ -358,9 +359,9 @@ def process_all_reads( if map_read_ts is not None: for map_t in map_read_ts: map_t.join() - if snp_calls_p.is_alive(): + if var_calls_p.is_alive(): main_sc_conn.send(True) - snp_calls_p.join() + var_calls_p.join() return @@ -383,7 +384,7 @@ def get_parser(): out_grp = parser.add_argument_group('Output Arguments') out_grp.add_argument( - '--output', default='snp_calibration_statistics.txt', + '--output', default='variant_calibration_statistics.txt', help='Filename to output statistics. Default: %(default)s') out_grp.add_argument( '--num-reads', type=int, From f1d205386ec22bffebf4abdc1a33f30750707189 Mon Sep 17 00:00:00 2001 From: Marcus Stoiber Date: Mon, 30 Sep 2019 14:42:02 -0700 Subject: [PATCH 09/14] Final snp to variant conversion. --- scripts/merge_haploid_variants.py | 4 ++-- scripts/run_aggregation.py | 19 ++++++++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/scripts/merge_haploid_variants.py b/scripts/merge_haploid_variants.py index 4fa682b..87e1e58 100644 --- a/scripts/merge_haploid_variants.py +++ b/scripts/merge_haploid_variants.py @@ -3,7 +3,7 @@ import pysam import numpy as np -from megalodon import snps, megalodon_helper as mh +from megalodon import variats, megalodon_helper as mh HEADER = """##fileformat=VCFv4.1 @@ -219,7 +219,7 @@ def main(): out_vars.close() - index_var_fn = snps.index_variants(args.out_vcf) + index_var_fn = variants.index_variants(args.out_vcf) return diff --git a/scripts/run_aggregation.py b/scripts/run_aggregation.py index e095b96..4fd50a2 100644 --- a/scripts/run_aggregation.py +++ b/scripts/run_aggregation.py @@ -9,7 +9,7 @@ from time import sleep from megalodon import ( - aggregate, backends, logging, mapping, mods, snps, megalodon_helper as mh) + aggregate, backends, logging, mapping, mods, variants, megalodon_helper as mh) def get_parser(): @@ -20,12 +20,13 @@ def get_parser(): 'names). Default: Load default model ({})'.format(mh.MODEL_PRESET_DESC)) parser.add_argument( '--outputs', nargs='+', - default=[mh.SNP_NAME, mh.MOD_NAME], - choices=[mh.SNP_NAME, mh.MOD_NAME], + default=[mh.VAR_NAME, mh.MOD_NAME], + choices=[mh.VAR_NAME, mh.MOD_NAME], help='Output type(s) to produce. Default: %(default)s') parser.add_argument( '--haploid', action='store_true', - help='Compute SNP aggregation for haploid genotypes. Default: diploid') + help='Compute sequence variant aggregation for haploid genotypes. ' + + 'Default: diploid') parser.add_argument( '--heterozygous-factors', type=float, nargs=2, default=[mh.DEFAULT_SNV_HET_FACTOR, mh.DEFAULT_INDEL_HET_FACTOR], @@ -100,21 +101,21 @@ def main(): aggregate.aggregate_stats( args.outputs, args.output_directory, args.processes, args.write_vcf_log_probs, args.heterozygous_factors, - snps.HAPLIOD_MODE if args.haploid else snps.DIPLOID_MODE, + variants.HAPLIOD_MODE if args.haploid else variants.DIPLOID_MODE, mod_names, mod_agg_info, args.write_mod_log_probs, args.mod_output_formats, args.suppress_progress, aligner.ref_names_and_lens, valid_read_ids, args.output_suffix) # note reference is required in order to annotate contigs for VCF writing - if mh.SNP_NAME in args.outputs and args.reference is not None: + if mh.VAR_NAME in args.outputs and args.reference is not None: logger.info('Sorting output variant file') variant_fn = mh.add_fn_suffix( - mh.get_megalodon_fn(args.output_directory, mh.SNP_NAME), + mh.get_megalodon_fn(args.output_directory, mh.VAR_NAME), args.output_suffix) sort_variant_fn = mh.add_fn_suffix(variant_fn, 'sorted') - snps.sort_variants(variant_fn, sort_variant_fn) + variants.sort_variants(variant_fn, sort_variant_fn) logger.info('Indexing output variant file') - index_var_fn = snps.index_variants(sort_variant_fn) + index_var_fn = variants.index_variants(sort_variant_fn) return From 83a804ae1b25362e2d3639273bb2c2b6afd5dd5b Mon Sep 17 00:00:00 2001 From: Marcus Stoiber Date: Mon, 30 Sep 2019 15:37:24 -0700 Subject: [PATCH 10/14] Minor typo and bug fixes throughout scripts. --- scripts/calibrate_variant_llr_scores.py | 2 +- scripts/filter_whatshap.py | 4 ++-- scripts/generate_ground_truth_variant_llr_scores.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/calibrate_variant_llr_scores.py b/scripts/calibrate_variant_llr_scores.py index 70d0943..9e31e1f 100644 --- a/scripts/calibrate_variant_llr_scores.py +++ b/scripts/calibrate_variant_llr_scores.py @@ -135,7 +135,7 @@ def main(): snp_ref_llrs[ (calibration.GENERIC_BASE, calibration.GENERIC_BASE)] = np.random.choice( - generic_snp_llrs, int(len(generic_snp_llrs) / 12), replace=False) + generic_var_llrs, int(len(generic_var_llrs) / 12), replace=False) max_indel_len = max(ins_ref_llrs) assert set(ins_ref_llrs) == set(del_ref_llrs), ( 'Must test same range of lengths for insertions and deletions') diff --git a/scripts/filter_whatshap.py b/scripts/filter_whatshap.py index 87f5079..9c86d81 100755 --- a/scripts/filter_whatshap.py +++ b/scripts/filter_whatshap.py @@ -3,7 +3,7 @@ from tqdm import tqdm -from megalodon import snps +from megalodon import variants parser = argparse.ArgumentParser( description='Remove variants incompatible with whatshap') @@ -16,7 +16,7 @@ def is_complex_variant(ref, alts): # single base swaps aren't complex if any(len(allele) > 1 for allele in alts + [ref]): for alt in alts: - simp_ref, simp_alt, _, _ = snps.simplify_var_seq(ref, alt) + simp_ref, simp_alt, _, _ = variants.simplify_var_seq(ref, alt) # if an allele simplifies to a SNV continue if len(simp_ref) == 0 and len(simp_alt) == 0: continue diff --git a/scripts/generate_ground_truth_variant_llr_scores.py b/scripts/generate_ground_truth_variant_llr_scores.py index d769382..7ac625b 100644 --- a/scripts/generate_ground_truth_variant_llr_scores.py +++ b/scripts/generate_ground_truth_variant_llr_scores.py @@ -46,7 +46,7 @@ def call_variant( pos_ab = min(var_context_bases, len(ref_seq) - r_var_pos - len(var_ref_seq)) pos_ref_seq = mh.seq_to_int(ref_seq[ - r_var_pos - pos_bb:r_var_pos + pos_ab + len(var_ref_seq)]]) + r_var_pos - pos_bb:r_var_pos + pos_ab + len(var_ref_seq)]) pos_alt_seq = np.concatenate([ pos_ref_seq[:pos_bb], mh.seq_to_int(var_alt_seq), @@ -430,7 +430,7 @@ def main(): sys.stderr.write('Loading model.\n') model_info = backends.ModelInfo( - args.taiyaki_model_filename, args.devices, + mh.get_model_fn(args.taiyaki_model_filename), args.devices, args.processes, args.chunk_size, args.chunk_overlap, args.max_concurrent_chunks) sys.stderr.write('Loading reference.\n') From df1adc399357f03d5880f94ac265b2991d919dc6 Mon Sep 17 00:00:00 2001 From: Marcus Stoiber Date: Tue, 1 Oct 2019 00:11:05 +0100 Subject: [PATCH 11/14] Fix bug in alt_seq database entry and control tuples vs list a bit. --- megalodon/mods.py | 12 ++++++------ megalodon/variants.py | 18 ++++++++++-------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/megalodon/mods.py b/megalodon/mods.py index 7eeb32f..784fd63 100644 --- a/megalodon/mods.py +++ b/megalodon/mods.py @@ -151,17 +151,17 @@ def __init__(self, fn, read_only=True, db_safety=1, def insert_chrms(self, chrms): next_chrm_id = self.get_num_uniq_chrms() + 1 self.cur.executemany('INSERT INTO chrm (chrm) VALUES (?)', - [(chrm,) for chrm in chrms]) + ((chrm,) for chrm in chrms)) if self.chrm_idx_in_mem: self.chrm_idx.update(zip( chrms, range(next_chrm_id, next_chrm_id + len(chrms)))) return def get_pos_ids_or_insert(self, r_mod_scores, chrm_id, strand): - r_pos = list(zip(*r_mod_scores))[0] + r_pos = tuple(zip(*r_mod_scores))[0] r_uniq_pos = set(((chrm_id, strand, pos) for pos in r_pos)) if self.pos_idx_in_mem: - pos_to_add = list(r_uniq_pos.difference(self.pos_idx)) + pos_to_add = tuple(r_uniq_pos.difference(self.pos_idx)) else: pos_ids = dict( ((chrm_id, strand, pos_and_id[0]), pos_and_id[1]) @@ -170,7 +170,7 @@ def get_pos_ids_or_insert(self, r_mod_scores, chrm_id, strand): 'SELECT pos, pos_id FROM pos ' + 'WHERE pos_chrm=? AND strand=? AND pos=?', pos_key).fetchall()) - pos_to_add = list(r_uniq_pos.difference(pos_ids)) + pos_to_add = tuple(r_uniq_pos.difference(pos_ids)) if len(pos_to_add) > 0: next_pos_id = self.get_num_uniq_mod_pos() + 1 @@ -196,7 +196,7 @@ def get_mod_base_ids_or_insert(self, r_mod_scores): r_uniq_mod_bases = set(( mod_key for pos_mods in r_mod_bases for mod_key, _ in pos_mods)) if self.mod_idx_in_mem: - mod_bases_to_add = list(r_uniq_mod_bases.difference(self.mod_idx)) + mod_bases_to_add = tuple(r_uniq_mod_bases.difference(self.mod_idx)) else: mod_base_ids = dict( (mod_data_w_id[:-1], mod_data_w_id[-1]) @@ -205,7 +205,7 @@ def get_mod_base_ids_or_insert(self, r_mod_scores): 'SELECT mod_base, motif, motif_pos, raw_motif, ' + 'mod_id FROM mod WHERE mod_base=? AND motif=? AND ' + 'motif_pos=? AND raw_motif=?', mod_data).fetchall()) - mod_bases_to_add = list(r_uniq_mod_bases.difference(mod_base_ids)) + mod_bases_to_add = tuple(r_uniq_mod_bases.difference(mod_base_ids)) if len(mod_bases_to_add) > 0: next_mod_base_id = self.get_num_uniq_mod_bases() + 1 diff --git a/megalodon/variants.py b/megalodon/variants.py index 6b8d7d9..0b74959 100755 --- a/megalodon/variants.py +++ b/megalodon/variants.py @@ -164,7 +164,7 @@ def __init__(self, fn, read_only=True, db_safety=1, def insert_chrms(self, chrms): next_chrm_id = self.get_num_uniq_chrms() + 1 self.cur.executemany('INSERT INTO chrm (chrm) VALUES (?)', - [(chrm,) for chrm in chrms]) + ((chrm,) for chrm in chrms)) if self.chrm_idx_in_mem: self.chrm_idx.update(zip( chrms, range(next_chrm_id, next_chrm_id + len(chrms)))) @@ -179,7 +179,7 @@ def get_loc_ids_or_insert(self, r_var_scores, chrm_id): for pos, _, ref_seq, _, var_name, test_start, test_end in r_var_scores)) if self.loc_idx_in_mem: - locs_to_add = list(set(r_locs).difference(self.loc_idx)) + locs_to_add = tuple(set(r_locs).difference(self.loc_idx)) else: test_starts, test_ends = map( set, list(zip(*r_locs.keys()))[1:]) @@ -192,7 +192,7 @@ def get_loc_ids_or_insert(self, r_var_scores, chrm_id): ','.join(['?',] * len(test_starts)), ','.join(['?',] * len(test_ends))), (chrm_id, *test_starts, *test_ends)).fetchall())) - locs_to_add = list(set(r_locs).difference(loc_ids)) + locs_to_add = tuple(set(r_locs).difference(loc_ids)) if len(locs_to_add) > 0: next_loc_id = self.get_num_uniq_var_loc() + 1 @@ -213,13 +213,14 @@ def get_loc_ids_or_insert(self, r_var_scores, chrm_id): return r_loc_ids def get_alt_ids_or_insert(self, r_var_scores): + logger = logging.get_logger() r_seqs_and_lps = [ tuple(zip(alt_seqs, alt_lps)) for _, alt_lps, _, alt_seqs, _, _, _ in r_var_scores] - r_uniq_seqs = set((seq_lp_i[0] for loc_seqs_lps in r_seqs_and_lps - for seq_lp_i in loc_seqs_lps)) + r_uniq_seqs = set(seq_lp_i[0] for loc_seqs_lps in r_seqs_and_lps + for seq_lp_i in loc_seqs_lps) if self.alt_idx_in_mem: - alts_to_add = list(r_uniq_seqs.difference(self.alt_idx)) + alts_to_add = tuple(r_uniq_seqs.difference(self.alt_idx)) else: alt_ids = dict(( (alt_seq, alt_id) @@ -228,12 +229,13 @@ def get_alt_ids_or_insert(self, r_var_scores): 'FROM alt WHERE alt_seq in ({})').format( ','.join(['?',] * len(r_uniq_seqs))), r_uniq_seqs).fetchall())) - alts_to_add = list(r_uniq_seqs.difference(alt_ids)) + alts_to_add = tuple(r_uniq_seqs.difference(alt_ids)) if len(alts_to_add) > 0: next_alt_id = self.get_num_uniq_alt_seqs() + 1 self.cur.executemany( - 'INSERT INTO alt (alt_seq) VALUES (?)', alts_to_add) + 'INSERT INTO alt (alt_seq) VALUES (?)', + ((alt_seq,) for alt_seq in alts_to_add)) alt_idx = self.alt_idx if self.alt_idx_in_mem else alt_ids if len(alts_to_add) > 0: From d2c314a8efac078b1510d75ccee021edaa5c5896 Mon Sep 17 00:00:00 2001 From: Marcus Stoiber Date: Tue, 1 Oct 2019 00:23:48 +0100 Subject: [PATCH 12/14] Set full variant processing errors output flag to false. --- megalodon/variants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megalodon/variants.py b/megalodon/variants.py index 0b74959..ce9e90d 100755 --- a/megalodon/variants.py +++ b/megalodon/variants.py @@ -19,7 +19,7 @@ _DEBUG_PER_READ = False -_RAISE_VARIANT_PROCESSING_ERRORS = True +_RAISE_VARIANT_PROCESSING_ERRORS = False VARIANT_DATA = namedtuple('VARIANT_DATA', ( 'np_ref', 'np_alts', 'id', 'chrom', 'start', 'stop', From 00507d5f0be7fb42396acb0c6668d42f3a59adec Mon Sep 17 00:00:00 2001 From: Marcus Stoiber Date: Thu, 3 Oct 2019 17:58:21 +0100 Subject: [PATCH 13/14] Bug fixes and added chromosome lengths to database so aligner is not needed for re-aggregation. --- megalodon/aggregate.py | 16 +++--- megalodon/mapping.py | 5 -- megalodon/megalodon.py | 8 +-- megalodon/mods.py | 22 +++++--- megalodon/variants.py | 43 ++++++++++----- scripts/merge_haploid_variants.py | 4 +- scripts/run_aggregation.py | 88 ++++++++++++++++--------------- 7 files changed, 108 insertions(+), 78 deletions(-) diff --git a/megalodon/aggregate.py b/megalodon/aggregate.py index bf42414..d83d813 100644 --- a/megalodon/aggregate.py +++ b/megalodon/aggregate.py @@ -244,8 +244,7 @@ def _fill_locs_queue(*args): def aggregate_stats( outputs, out_dir, num_ps, write_vcf_lp, het_factors, call_mode, mod_names, mod_agg_info, write_mod_lp, mod_output_fmts, - suppress_progress, ref_names_and_lens, valid_read_ids=None, - out_suffix=None): + suppress_progress, valid_read_ids=None, out_suffix=None): if mh.VAR_NAME in outputs and mh.MOD_NAME in outputs: num_ps = max(num_ps // 2, 1) @@ -254,8 +253,11 @@ def aggregate_stats( 0, 0, queue.Queue(), queue.Queue()) if mh.VAR_NAME in outputs: vars_db_fn = mh.get_megalodon_fn(out_dir, mh.PR_VAR_NAME) - num_vars = variants.AggVars( - vars_db_fn, load_in_mem_indices=False).num_uniq() + agg_vars = variants.AggVars( + vars_db_fn, load_in_mem_indices=False) + num_vars = agg_vars.num_uniq() + ref_names_and_lens = agg_vars.get_all_chrm_and_lens() + agg_vars.close() logger.info('Spawning variant aggregation processes.') # create process to collect var stats from workers var_stats_q, var_stats_p, main_var_stats_conn = mh.create_getter_q( @@ -282,8 +284,10 @@ def aggregate_stats( if mh.MOD_NAME in outputs: mods_db_fn = mh.get_megalodon_fn(out_dir, mh.PR_MOD_NAME) - num_mods = mods.AggMods( - mods_db_fn, load_in_mem_indices=False).num_uniq() + agg_mods = mods.AggMods(mods_db_fn, load_in_mem_indices=False) + num_mods = agg_mods.num_uniq() + ref_names_and_lens = agg_mods.get_all_chrm_and_lens() + agg_mods.close() logger.info('Spawning modified base aggregation processes.') # create process to collect mods stats from workers mod_stats_q, mod_stats_p, main_mod_stats_conn = mh.create_getter_q( diff --git a/megalodon/mapping.py b/megalodon/mapping.py index 769e787..820a16b 100644 --- a/megalodon/mapping.py +++ b/megalodon/mapping.py @@ -28,7 +28,6 @@ def add_ref_lens(self): return - def align_read(q_seq, aligner, map_thr_buf, read_id=None): try: # enumerate all alignments to avoid memory leak from mappy @@ -47,7 +46,6 @@ def align_read(q_seq, aligner, map_thr_buf, read_id=None): read_id, q_seq, r_algn.ctg, r_algn.strand, r_algn.r_st, r_algn.q_st, r_algn.q_en, r_algn.cigar) - def _map_read_worker(aligner, map_conn, mo_q): # get mappy aligner thread buffer map_thr_buf = mappy.ThreadBuffer() @@ -69,7 +67,6 @@ def _map_read_worker(aligner, map_conn, mo_q): return - def parse_cigar(r_cigar, strand): # get each base calls genomic position r_to_q_poss = {} @@ -98,7 +95,6 @@ def parse_cigar(r_cigar, strand): return r_to_q_poss - def map_read(q_seq, read_id, caller_conn): """Map read (query) sequence and return: 1) reference sequence (endcoded as int labels) @@ -162,7 +158,6 @@ def test_open_alignment_out_file(out_dir, map_fmt, ref_names_and_lens, ref_fn): map_fp.close() return - def _get_map_queue( mo_q, map_conn, out_dir, ref_names_and_lens, map_fmt, ref_fn, do_output_pr_refs, pr_ref_filts): diff --git a/megalodon/megalodon.py b/megalodon/megalodon.py index bf92469..f991f4d 100644 --- a/megalodon/megalodon.py +++ b/megalodon/megalodon.py @@ -253,14 +253,13 @@ def post_process_mapping(out_dir, map_fmt, ref_fn): def post_process_aggregate( mods_info, outputs, mod_bin_thresh, out_dir, num_ps, write_vcf_lp, - het_factors, vars_data, write_mod_lp, supp_prog, ref_names_and_lens): + het_factors, vars_data, write_mod_lp, supp_prog): mod_names = mods_info.mod_long_names if mh.MOD_NAME in outputs else [] mod_agg_info = mods.AGG_INFO(mods.BIN_THRESH_NAME, mod_bin_thresh) aggregate.aggregate_stats( outputs, out_dir, num_ps, write_vcf_lp, het_factors, vars_data.call_mode, mod_names, mod_agg_info, - write_mod_lp, mods_info.mod_output_fmts, supp_prog, - ref_names_and_lens) + write_mod_lp, mods_info.mod_output_fmts, supp_prog) return @@ -1018,6 +1017,7 @@ def _main(): args.processes, args.verbose_read_progress, args.suppress_progress, mods_info, args.database_safety, pr_ref_filts) + if aligner is not None: aligner.close() if mh.MAP_NAME in args.outputs: logger.info('Spawning process to sort mappings') map_p = post_process_mapping( @@ -1033,7 +1033,7 @@ def _main(): mods_info, args.outputs, args.mod_binary_threshold, args.output_directory, args.processes, args.write_vcf_log_probs, args.heterozygous_factors, vars_data, args.write_mod_log_probs, - args.suppress_progress, aligner.ref_names_and_lens) + args.suppress_progress) if mh.VAR_NAME in args.outputs: logger.info('Sorting output variant file') diff --git a/megalodon/mods.py b/megalodon/mods.py index 784fd63..b6e7190 100644 --- a/megalodon/mods.py +++ b/megalodon/mods.py @@ -52,7 +52,8 @@ class ModsDb(object): db_tables = OrderedDict(( ('chrm', OrderedDict(( ('chrm_id', 'INTEGER PRIMARY KEY'), - ('chrm', 'TEXT')))), + ('chrm', 'TEXT'), + ('chrm_len', 'INTEGER')))), ('pos', OrderedDict(( ('pos_id', 'INTEGER PRIMARY KEY'), ('pos_chrm', 'INTEGER'), @@ -80,7 +81,7 @@ class ModsDb(object): def __init__(self, fn, read_only=True, db_safety=1, pos_index_in_memory=False, chrm_index_in_memory=True, - mod_index_in_memory=True): + mod_index_in_memory=True, uuid_index_in_memory=True): """ Interface to database containing modified base statistics. Default settings are for optimal read_only performance. @@ -90,6 +91,7 @@ def __init__(self, fn, read_only=True, db_safety=1, self.pos_idx_in_mem = pos_index_in_memory self.chrm_idx_in_mem = chrm_index_in_memory self.mod_idx_in_mem = mod_index_in_memory + self.uuid_idx_in_mem = uuid_index_in_memory if read_only: if not os.path.exists(fn): @@ -148,13 +150,15 @@ def __init__(self, fn, read_only=True, db_safety=1, return # insert data functions - def insert_chrms(self, chrms): + def insert_chrms(self, chrm_names_and_lens): next_chrm_id = self.get_num_uniq_chrms() + 1 - self.cur.executemany('INSERT INTO chrm (chrm) VALUES (?)', - ((chrm,) for chrm in chrms)) + self.cur.executemany('INSERT INTO chrm (chrm, chrm_len) VALUES (?,?)', + zip(*chrm_names_and_lens)) if self.chrm_idx_in_mem: self.chrm_idx.update(zip( - chrms, range(next_chrm_id, next_chrm_id + len(chrms)))) + chrm_names_and_lens[0], + range(next_chrm_id, + next_chrm_id + len(chrm_names_and_lens[0])))) return def get_pos_ids_or_insert(self, r_mod_scores, chrm_id, strand): @@ -307,6 +311,10 @@ def get_chrm(self, chrm_id): 'in mods database.') return chrm + def get_all_chrm_and_lens(self): + return tuple(map(tuple, zip(*self.cur.execute( + 'SELECT chrm, chrm_len FROM chrm').fetchall()))) + def get_mod_base_data(self, mod_id): try: if self.mod_idx_in_mem: @@ -552,7 +560,7 @@ def store_mod_call( mods_db = ModsDb(mods_db_fn, db_safety=db_safety, read_only=False, pos_index_in_memory=pos_index_in_memory) - mods_db.insert_chrms(ref_names_and_lens[0]) + mods_db.insert_chrms(ref_names_and_lens) mods_db.create_chrm_index() if mods_txt_fn is None: diff --git a/megalodon/variants.py b/megalodon/variants.py index ce9e90d..b564d4e 100755 --- a/megalodon/variants.py +++ b/megalodon/variants.py @@ -65,7 +65,8 @@ class VarsDb(object): db_tables = OrderedDict(( ('chrm', OrderedDict(( ('chrm_id', 'INTEGER PRIMARY KEY'), - ('chrm', 'TEXT')))), + ('chrm', 'TEXT'), + ('chrm_len', 'INTEGER')))), ('loc', OrderedDict(( ('loc_id', 'INTEGER PRIMARY KEY'), ('loc_chrm', 'INTEGER'), @@ -93,7 +94,7 @@ class VarsDb(object): def __init__(self, fn, read_only=True, db_safety=1, loc_index_in_memory=False, chrm_index_in_memory=True, - alt_index_in_memory=True): + alt_index_in_memory=True, uuid_index_in_memory=True): """ Interface to database containing sequence variant statistics. Default settings are for optimal read_only performance. @@ -103,6 +104,7 @@ def __init__(self, fn, read_only=True, db_safety=1, self.loc_idx_in_mem = loc_index_in_memory self.chrm_idx_in_mem = chrm_index_in_memory self.alt_idx_in_mem = alt_index_in_memory + self.uuid_idx_in_mem = uuid_index_in_memory if read_only: if not os.path.exists(fn): @@ -125,6 +127,8 @@ def __init__(self, fn, read_only=True, db_safety=1, self.load_loc_read_index() if self.alt_idx_in_mem: self.load_alt_read_index() + if self.uuid_idx_in_mem: + self.load_uuid_read_index() else: if db_safety < 2: # set asynchronous mode to off for max speed @@ -160,14 +164,16 @@ def __init__(self, fn, read_only=True, db_safety=1, return - # insert data function - def insert_chrms(self, chrms): + # insert data functions + def insert_chrms(self, chrm_names_and_lens): next_chrm_id = self.get_num_uniq_chrms() + 1 - self.cur.executemany('INSERT INTO chrm (chrm) VALUES (?)', - ((chrm,) for chrm in chrms)) + self.cur.executemany('INSERT INTO chrm (chrm, chrm_len) VALUES (?,?)', + zip(*chrm_names_and_lens)) if self.chrm_idx_in_mem: self.chrm_idx.update(zip( - chrms, range(next_chrm_id, next_chrm_id + len(chrms)))) + chrm_names_and_lens[0], + range(next_chrm_id, + next_chrm_id + len(chrm_names_and_lens[0])))) return def get_loc_ids_or_insert(self, r_var_scores, chrm_id): @@ -275,6 +281,11 @@ def load_chrm_read_index(self): self.chrm_read_idx = dict(self.cur.fetchall()) return + def load_uuid_read_index(self): + self.cur.execute('SELECT read_id, uuid FROM read') + self.uuid_read_idx = dict(self.cur.fetchall()) + return + def create_alt_index(self): self.cur.execute('CREATE UNIQUE INDEX alt_idx ON alt(alt_seq)') return @@ -328,6 +339,10 @@ def get_chrm(self, chrm_id): 'vars database.') return chrm + def get_all_chrm_and_lens(self): + return tuple(map(tuple, zip(*self.cur.execute( + 'SELECT chrm, chrm_len FROM chrm').fetchall()))) + def get_alt_seq(self, alt_id): try: if self.alt_idx_in_mem: @@ -342,10 +357,13 @@ def get_alt_seq(self, alt_id): def get_uuid(self, read_id): try: - uuid = self.cur.execute( - 'SELECT uuid FROM read WHERE read_id=?', + if self.uuid_idx_in_mem: + uuid = self.uuid_read_idx[read_id] + else: + uuid = self.cur.execute( + 'SELECT uuid FROM read WHERE read_id=?', (read_id,)).fetchone()[0] - except TypeError: + except (TypeError, KeyError): raise mh.MegaError('Read ID not found in vars database.') return uuid @@ -817,7 +835,7 @@ def get_var_call( logger = logging.get_logger('vars_getter') vars_db = VarsDb(vars_db_fn, db_safety=db_safety, read_only=False, loc_index_in_memory=loc_index_in_memory) - vars_db.insert_chrms(ref_names_and_lens[0]) + vars_db.insert_chrms(ref_names_and_lens) vars_db.create_chrm_index() if vars_txt_fn is None: vars_txt_fp = None @@ -1719,7 +1737,8 @@ def compute_var_stats( 'Invalid variant aggregation ploidy call mode: {}.'.format( call_mode)) - pr_var_stats = self.vars_db.get_loc_stats(var_loc) + pr_var_stats = self.vars_db.get_loc_stats( + var_loc, valid_read_ids is not None) alt_seqs = sorted(set(r_stats.alt_seq for r_stats in pr_var_stats)) pr_alt_lps = defaultdict(dict) for r_stats in pr_var_stats: diff --git a/scripts/merge_haploid_variants.py b/scripts/merge_haploid_variants.py index 87e1e58..a9bd664 100644 --- a/scripts/merge_haploid_variants.py +++ b/scripts/merge_haploid_variants.py @@ -3,7 +3,7 @@ import pysam import numpy as np -from megalodon import variats, megalodon_helper as mh +from megalodon import variants, megalodon_helper as mh HEADER = """##fileformat=VCFv4.1 @@ -35,7 +35,7 @@ def get_parser(): help='Variant file for haplotype 1.') parser.add_argument( 'haplotype2_variants', - help='Variant file for haplotype 1.') + help='Variant file for haplotype 2.') parser.add_argument( '--out-vcf', default='merged_haploid_variants.vcf', help='Output name for VCF. Default: %(default)s') diff --git a/scripts/run_aggregation.py b/scripts/run_aggregation.py index 4fd50a2..f55d597 100644 --- a/scripts/run_aggregation.py +++ b/scripts/run_aggregation.py @@ -9,69 +9,77 @@ from time import sleep from megalodon import ( - aggregate, backends, logging, mapping, mods, variants, megalodon_helper as mh) + aggregate, backends, logging, mapping, mods, variants, + megalodon_helper as mh) def get_parser(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--taiyaki-model-filename', - help='Taiyaki model checkpoint file (for loading modified base ' + - 'names). Default: Load default model ({})'.format(mh.MODEL_PRESET_DESC)) - parser.add_argument( + parser = argparse.ArgumentParser( + 'Aggregate per-read, per-site statistics from previous megalodon call.') + + out_grp = parser.add_argument_group('Output Arguments') + out_grp.add_argument( '--outputs', nargs='+', default=[mh.VAR_NAME, mh.MOD_NAME], choices=[mh.VAR_NAME, mh.MOD_NAME], help='Output type(s) to produce. Default: %(default)s') - parser.add_argument( + out_grp.add_argument( + '--megalodon-directory', + default='megalodon_results', + help='Megalodon output directory containing per-read database(s) ' + + 'where aggregated results will be added. Default: %(default)s') + out_grp.add_argument( + '--output-suffix', default='re_aggregated', + help='Suffix to apply to aggregated results, to avoid ' + + 'overwriting results. Default: %(default)s') + out_grp.add_argument( + '--read-ids-filename', + help='File containing read ids to process (one per ' + + 'line). Default: All reads') + + var_grp = parser.add_argument_group('Sequence Variant Arguments') + var_grp.add_argument( '--haploid', action='store_true', help='Compute sequence variant aggregation for haploid genotypes. ' + 'Default: diploid') - parser.add_argument( + var_grp.add_argument( '--heterozygous-factors', type=float, nargs=2, default=[mh.DEFAULT_SNV_HET_FACTOR, mh.DEFAULT_INDEL_HET_FACTOR], help='Bayesian prior factor for snv and indel heterozygous calls ' + '(compared to 1.0 for hom ref/alt). Default: %(default)s') - parser.add_argument( + var_grp.add_argument( + '--write-vcf-log-probs', action='store_true', + help='Write alt log prbabilities out in non-standard VCF field.') + + mod_grp = parser.add_argument_group('Modified Base Arguments') + mod_grp.add_argument( '--mod-binary-threshold', type=float, nargs=1, default=mods.DEFAULT_AGG_INFO.binary_threshold, help='Threshold for modified base aggregation (probability of ' + 'modified/canonical base). Default: %(default)s') - parser.add_argument( + mod_grp.add_argument( '--mod-output-formats', nargs='+', default=[mh.MOD_BEDMETHYL_NAME,], choices=tuple(mh.MOD_OUTPUT_FMTS.keys()), help='Modified base aggregated output format(s). Default: %(default)s') - parser.add_argument( - '--output-directory', - default='megalodon_results', - help='Directory to store output results. Default: %(default)s') - parser.add_argument( - '--output-suffix', default='re_aggregated', - help='Suffix to apply to aggregated results, to avoid ' + - 'overwriting results. Default: %(default)s') - parser.add_argument( + mod_grp.add_argument( + '--write-mod-log-probs', action='store_true', + help='Write per-read modified base log probabilities ' + + 'out in non-standard modVCF field.') + + mdl_grp = parser.add_argument_group('Model Arguments') + mdl_grp.add_argument( + '--taiyaki-model-filename', + help='Taiyaki model checkpoint file (for loading modified base ' + + 'names). Default: Load default model ({})'.format(mh.MODEL_PRESET_DESC)) + + misc_grp = parser.add_argument_group('Miscellaneous Arguments') + misc_grp.add_argument( '--processes', type=int, default=1, help='Number of parallel processes. Default: %(default)d') - parser.add_argument( - '--read-ids-filename', - help='File containing read ids to process (one per ' + - 'line). Default: All reads') - parser.add_argument( - '--reference', - help='Reference FASTA or minimap2 index file used for mapping ' + - 'called reads. Used to annotate VCF file with contig names ' + - '(required for VCF sorting and indexing).') - parser.add_argument( + misc_grp.add_argument( '--suppress-progress', action='store_true', help='Suppress progress bar output.') - parser.add_argument( - '--write-mod-log-probs', action='store_true', - help='Write per-read modified base log probabilities ' + - 'out in non-standard modVCF field.') - parser.add_argument( - '--write-vcf-log-probs', action='store_true', - help='Write alt log prbabilities out in non-standard VCF field.') return parser @@ -90,10 +98,6 @@ def main(): mod_names = backends.ModelInfo(mh.get_model_fn( args.taiyaki_model_filename)).mod_long_names if args.reference is not None: logger.info('Loading reference.') - aligner = mapping.alignerPlus( - str(args.reference), preset=str('map-ont'), best_n=1) - if args.reference is not None: - aligner.add_ref_lens() valid_read_ids = None if args.read_ids_filename is not None: with open(args.read_ids_filename) as read_ids_fp: @@ -104,7 +108,7 @@ def main(): variants.HAPLIOD_MODE if args.haploid else variants.DIPLOID_MODE, mod_names, mod_agg_info, args.write_mod_log_probs, args.mod_output_formats, args.suppress_progress, - aligner.ref_names_and_lens, valid_read_ids, args.output_suffix) + valid_read_ids, args.output_suffix) # note reference is required in order to annotate contigs for VCF writing if mh.VAR_NAME in args.outputs and args.reference is not None: From 118d832eb02e6d98607b7c9475c0f1e5caeff9a8 Mon Sep 17 00:00:00 2001 From: Marcus Stoiber Date: Thu, 3 Oct 2019 10:09:38 -0700 Subject: [PATCH 14/14] Bug fixes and cleanup from last commit. --- megalodon/aggregate.py | 4 ++-- megalodon/megalodon.py | 10 +++++++--- megalodon/mods.py | 9 +++++++-- megalodon/variants.py | 9 +++++++-- scripts/run_aggregation.py | 10 ++++------ 5 files changed, 27 insertions(+), 15 deletions(-) diff --git a/megalodon/aggregate.py b/megalodon/aggregate.py index d83d813..5bf06df 100644 --- a/megalodon/aggregate.py +++ b/megalodon/aggregate.py @@ -256,7 +256,7 @@ def aggregate_stats( agg_vars = variants.AggVars( vars_db_fn, load_in_mem_indices=False) num_vars = agg_vars.num_uniq() - ref_names_and_lens = agg_vars.get_all_chrm_and_lens() + ref_names_and_lens = agg_vars.vars_db.get_all_chrm_and_lens() agg_vars.close() logger.info('Spawning variant aggregation processes.') # create process to collect var stats from workers @@ -286,7 +286,7 @@ def aggregate_stats( mods_db_fn = mh.get_megalodon_fn(out_dir, mh.PR_MOD_NAME) agg_mods = mods.AggMods(mods_db_fn, load_in_mem_indices=False) num_mods = agg_mods.num_uniq() - ref_names_and_lens = agg_mods.get_all_chrm_and_lens() + ref_names_and_lens = agg_mods.mods_db.get_all_chrm_and_lens() agg_mods.close() logger.info('Spawning modified base aggregation processes.') # create process to collect mods stats from workers diff --git a/megalodon/megalodon.py b/megalodon/megalodon.py index f991f4d..41795d9 100644 --- a/megalodon/megalodon.py +++ b/megalodon/megalodon.py @@ -1017,16 +1017,20 @@ def _main(): args.processes, args.verbose_read_progress, args.suppress_progress, mods_info, args.database_safety, pr_ref_filts) - if aligner is not None: aligner.close() + if aligner is not None: + ref_fn = aligner.ref_fn + map_out_fmt = aligner.out_fmt + del aligner + if mh.MAP_NAME in args.outputs: logger.info('Spawning process to sort mappings') map_p = post_process_mapping( - args.output_directory, aligner.out_fmt, aligner.ref_fn) + args.output_directory, map_out_fmt, ref_fn) if mh.WHATSHAP_MAP_NAME in args.outputs: logger.info('Spawning process to sort whatshap mappings') whatshap_sort_fn, whatshap_p = post_process_whatshap( - args.output_directory, aligner.out_fmt, aligner.ref_fn) + args.output_directory, map_out_fmt, ref_fn) if mh.VAR_NAME in args.outputs or mh.MOD_NAME in args.outputs: post_process_aggregate( diff --git a/megalodon/mods.py b/megalodon/mods.py index b6e7190..2ff19d5 100644 --- a/megalodon/mods.py +++ b/megalodon/mods.py @@ -312,8 +312,13 @@ def get_chrm(self, chrm_id): return chrm def get_all_chrm_and_lens(self): - return tuple(map(tuple, zip(*self.cur.execute( - 'SELECT chrm, chrm_len FROM chrm').fetchall()))) + try: + return tuple(map(tuple, zip(*self.cur.execute( + 'SELECT chrm, chrm_len FROM chrm').fetchall()))) + except sqlite3.OperationalError: + raise mh.MegaError( + 'Old megalodon database scheme detected. Please re-run ' + + 'megalodon processing or downgrade megalodon installation.') def get_mod_base_data(self, mod_id): try: diff --git a/megalodon/variants.py b/megalodon/variants.py index b564d4e..05b4d1d 100755 --- a/megalodon/variants.py +++ b/megalodon/variants.py @@ -340,8 +340,13 @@ def get_chrm(self, chrm_id): return chrm def get_all_chrm_and_lens(self): - return tuple(map(tuple, zip(*self.cur.execute( - 'SELECT chrm, chrm_len FROM chrm').fetchall()))) + try: + return tuple(map(tuple, zip(*self.cur.execute( + 'SELECT chrm, chrm_len FROM chrm').fetchall()))) + except sqlite3.OperationalError: + raise mh.MegaError( + 'Old megalodon database scheme detected. Please re-run ' + + 'megalodon processing or downgrade megalodon installation.') def get_alt_seq(self, alt_id): try: diff --git a/scripts/run_aggregation.py b/scripts/run_aggregation.py index f55d597..2e0cf0c 100644 --- a/scripts/run_aggregation.py +++ b/scripts/run_aggregation.py @@ -87,7 +87,7 @@ def main(): args = get_parser().parse_args() log_suffix = ('aggregation' if args.output_suffix is None else 'aggregation.' + args.output_suffix) - logging.init_logger(args.output_directory, out_suffix=log_suffix) + logging.init_logger(args.megalodon_directory, out_suffix=log_suffix) logger = logging.get_logger() mod_agg_info = mods.AGG_INFO( @@ -97,24 +97,22 @@ def main(): logger.info('Loading model.') mod_names = backends.ModelInfo(mh.get_model_fn( args.taiyaki_model_filename)).mod_long_names - if args.reference is not None: logger.info('Loading reference.') valid_read_ids = None if args.read_ids_filename is not None: with open(args.read_ids_filename) as read_ids_fp: valid_read_ids = set(line.strip() for line in read_ids_fp) aggregate.aggregate_stats( - args.outputs, args.output_directory, args.processes, + args.outputs, args.megalodon_directory, args.processes, args.write_vcf_log_probs, args.heterozygous_factors, variants.HAPLIOD_MODE if args.haploid else variants.DIPLOID_MODE, mod_names, mod_agg_info, args.write_mod_log_probs, args.mod_output_formats, args.suppress_progress, valid_read_ids, args.output_suffix) - # note reference is required in order to annotate contigs for VCF writing - if mh.VAR_NAME in args.outputs and args.reference is not None: + if mh.VAR_NAME in args.outputs: logger.info('Sorting output variant file') variant_fn = mh.add_fn_suffix( - mh.get_megalodon_fn(args.output_directory, mh.VAR_NAME), + mh.get_megalodon_fn(args.megalodon_directory, mh.VAR_NAME), args.output_suffix) sort_variant_fn = mh.add_fn_suffix(variant_fn, 'sorted') variants.sort_variants(variant_fn, sort_variant_fn)