Skip to content

Commit

Permalink
Merge pull request #67 from rmcar17/main
Browse files Browse the repository at this point in the history
ENH: use numba for kmer processing
  • Loading branch information
GavinHuttley authored Oct 24, 2024
2 parents d6b893d + ec33063 commit 546bb54
Showing 1 changed file with 83 additions and 17 deletions.
100 changes: 83 additions & 17 deletions src/diverse_seq/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Literal, TypeAlias

import cogent3.app.typing as c3_types
import numba
import numpy as np
from cogent3.app.composable import define_app
from cogent3.evolve.fast_distance import DistanceMatrix
Expand Down Expand Up @@ -287,10 +288,15 @@ def mash_sketch(
BottomSketch
The bottom sketch for the given sequence seq_array.
"""
kmer_hashes = {
hash_kmer(kmer, mash_canonical=mash_canonical)
for kmer in get_kmers(seq_array, k, num_states)
}
kmer_hashes = set(
get_kmer_hashes(
seq_array,
k,
num_states,
mash_canonical=mash_canonical,
),
)

heap = []
for kmer_hash in kmer_hashes:
if len(heap) < sketch_size:
Expand All @@ -300,12 +306,15 @@ def mash_sketch(
return sorted(-kmer_hash for kmer_hash in heap)


def get_kmers(
@numba.njit
def get_kmer_hashes(
seq: np.ndarray,
k: int,
num_states: int,
) -> list[np.ndarray]:
"""Get the kmers comprising a sequence.
*,
mash_canonical: bool,
) -> list[int]:
"""Get the kmer hashes comprising a sequence.
Parameters
----------
Expand All @@ -315,13 +324,19 @@ def get_kmers(
kmer size.
num_states
Number of states allowed for sequence type.
mash_canonical
Whether to use the mash canonical representation of kmers.
Returns
-------
list[numpy.ndarray]
kmers for the sequence.
set[int]
kmer hashes for the sequence.
"""
kmers = []
seq = seq.astype(np.int64)

kmer_hashes = [0]
kmer_hashes.pop() # numba requires list to be pre-populated to infer type

skip_until = 0
for i in range(k):
if seq[i] >= num_states:
Expand All @@ -333,11 +348,52 @@ def get_kmers(

if i < skip_until:
continue
kmers.append(seq[i : i + k])
return kmers
kmer_hashes.append(hash_kmer(seq[i : i + k], mash_canonical))
return kmer_hashes


@numba.njit
def murmurhash3_32(data: np.ndarray, seed: int = 0x9747B28C) -> int:
"""MurmurHash3 32-bit implementation for an array of integers.
Parameters
----------
data : np.ndarray
The input array to hash.
seed : int
A seed for the hash function.
Returns
-------
int
The computed hash value.
"""
length = data.size
h = seed ^ length

def hash_kmer(kmer: np.ndarray, *, mash_canonical: bool) -> int:
for i in range(length):
k = data[i]

# Mix the hash
k *= 0xCC9E2D51
k = (k << 15) | (k >> (32 - 15)) # Rotate left
k *= 0x1B873593

h ^= k
h = (h << 13) | (h >> (32 - 13)) # Rotate left
h = h * 5 + 0xE6546B64

h ^= h >> 16
h *= 0x85EBCA6B
h ^= h >> 13
h *= 0xC2B2AE35
h ^= h >> 16

return h & 0xFFFFFFFF # Return as 32-bit integer


@numba.njit
def hash_kmer(kmer: np.ndarray, mash_canonical: bool) -> int:
"""Hash a kmer, optionally use the mash canonical representaiton.
Parameters
Expand All @@ -351,15 +407,25 @@ def hash_kmer(kmer: np.ndarray, *, mash_canonical: bool) -> int:
-------
int
The has of a kmer.
Notes
-----
Uses MurmurHash3 32-bit implementation.
"""
tuple_kmer = tuple(map(int, kmer))
smallest = kmer
if mash_canonical:
reverse = tuple(map(int, reverse_complement(kmer)))
tuple_kmer = min(reverse, tuple_kmer)
reverse = reverse_complement(kmer)
for i in range(kmer.size):
if kmer[i] < reverse[i]:
break
if kmer[i] > reverse[i]:
smallest = reverse
break

return hash(tuple_kmer)
return murmurhash3_32(smallest)


@numba.njit
def reverse_complement(kmer: np.ndarray) -> np.ndarray:
"""Take the reverse complement of a kmer.
Expand Down

0 comments on commit 546bb54

Please sign in to comment.