Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local alleles #266

Merged
merged 13 commits into from
Jul 10, 2024
13 changes: 13 additions & 0 deletions bio2zarr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ def list_commands(self, ctx):
help="An approximate bound on overall memory usage (e.g. 10G),",
)

local_alleles = click.option(
"--local-alleles/--no-local-alleles",
show_default=True,
default=True,
help="Use local allele fields to reduce the storage requirements of the output.",
)


def setup_logging(verbosity):
level = "WARNING"
Expand Down Expand Up @@ -214,6 +221,7 @@ def show_work_summary(work_summary, json):
@compressor
@progress
@worker_processes
@local_alleles
def explode(
vcfs,
icf_path,
Expand All @@ -223,6 +231,7 @@ def explode(
compressor,
progress,
worker_processes,
local_alleles,
):
"""
Convert VCF(s) to intermediate columnar format
Expand All @@ -236,6 +245,7 @@ def explode(
column_chunk_size=column_chunk_size,
compressor=get_compressor(compressor),
show_progress=progress,
local_alleles=local_alleles,
)


Expand All @@ -250,6 +260,7 @@ def explode(
@verbose
@progress
@worker_processes
@local_alleles
def dexplode_init(
vcfs,
icf_path,
Expand All @@ -261,6 +272,7 @@ def dexplode_init(
verbose,
progress,
worker_processes,
local_alleles,
):
"""
Initial step for distributed conversion of VCF(s) to intermediate columnar format
Expand All @@ -277,6 +289,7 @@ def dexplode_init(
worker_processes=worker_processes,
compressor=get_compressor(compressor),
show_progress=progress,
local_alleles=local_alleles,
)
show_work_summary(work_summary, json)

Expand Down
142 changes: 138 additions & 4 deletions bio2zarr/vcf2zarr/icf.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def make_field_def(name, vcf_type, vcf_number):
return fields


def scan_vcf(path, target_num_partitions):
def scan_vcf(path, target_num_partitions, *, local_alleles):
with vcf_utils.IndexedVcf(path) as indexed_vcf:
vcf = indexed_vcf.vcf
filters = []
Expand All @@ -236,6 +236,9 @@ def scan_vcf(path, target_num_partitions):
pass_filter = filters.pop(pass_index)
filters.insert(0, pass_filter)

# Indicates whether vcf2zarr can introduce local alleles
can_localize = False
should_add_laa_field = True
fields = fixed_vcf_field_definitions()
for h in vcf.header_iter():
if h["HeaderType"] in ["INFO", "FORMAT"]:
Expand All @@ -244,6 +247,23 @@ def scan_vcf(path, target_num_partitions):
field.vcf_type = "Integer"
field.vcf_number = "."
fields.append(field)
if field.category == "FORMAT":
if field.name == "PL":
can_localize = True
if field.name == "LAA":
should_add_laa_field = False

if local_alleles and can_localize and should_add_laa_field:
laa_field = VcfField(
category="FORMAT",
name="LAA",
vcf_type="Integer",
vcf_number=".",
description="1-based indices into ALT, indicating which alleles"
" are relevant (local) for the current sample",
summary=VcfFieldSummary(),
)
fields.append(laa_field)

try:
contig_lengths = vcf.seqlens
Expand Down Expand Up @@ -280,7 +300,14 @@ def scan_vcf(path, target_num_partitions):
return metadata, vcf.raw_header


def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
def scan_vcfs(
paths,
show_progress,
target_num_partitions,
worker_processes=1,
*,
local_alleles,
):
logger.info(
f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
f" partitions."
Expand All @@ -300,7 +327,12 @@ def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
)
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
for path in paths:
pwm.submit(scan_vcf, path, max(1, target_num_partitions // len(paths)))
pwm.submit(
scan_vcf,
path,
max(1, target_num_partitions // len(paths)),
local_alleles=local_alleles,
)
results = list(pwm.results_as_completed())

# Sort to make the ordering deterministic
Expand Down Expand Up @@ -458,6 +490,95 @@ def sanitise_value_int_2d(buff, j, value):
buff[j, :, : value.shape[1]] = value


def compute_laa_field(variant) -> np.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add a quick summary of what we're doing here, as the process is somewhat involved (necessarily). So, we're talking the local alleles to be anything that's in the genotypes, or has an allele depth of > 0, or is referenced in the PL field?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. I suppose that ideally, the code would infer which alleles are observed based on all the fields that can be localized. If we wanted to start localizing another genotype field in the future, we would need to update compute_laa_field to make sure that the LAA value contains all the alleles used in the genotype field.

"""
Computes the value of the LAA field for each sample given a variant.

The LAA field is a list of one-based indices into the ALT alleles
that indicates which alternate alleles are observed in the sample.

This method infers which alleles are observed from the GT, AD, and PL fields.
"""
sample_count = variant.num_called + variant.num_unknown
alt_allele_count = len(variant.ALT)
allele_count = alt_allele_count + 1
allele_counts = np.zeros((sample_count, allele_count), dtype=int)

if "GT" in variant.FORMAT:
# The last element of each sample's genotype indicates the phasing
# and is not an allele.
genotypes = variant.genotype.array()[:, :-1]
genotypes.clip(0, None, out=genotypes)
genotype_allele_counts = np.apply_along_axis(
np.bincount, axis=1, arr=genotypes, minlength=allele_count
)
allele_counts += genotype_allele_counts
Comment on lines +508 to +515
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

if "AD" in variant.FORMAT:
depths = variant.format("AD")
depths.clip(0, None, out=depths)

def bincount_nonzero(arr, *, minlength):
# nonzero returns the indices of the nonzero elements for each axis
return np.bincount(arr.nonzero()[0], minlength=minlength)

depths_allele_counts = np.apply_along_axis(
bincount_nonzero, axis=1, arr=depths, minlength=allele_count
)
allele_counts += depths_allele_counts
if "PL" in variant.FORMAT:
likelihoods = variant.format("PL")
likelihoods.clip(0, None, out=likelihoods)
# n is the indices of the nonzero likelihoods
n = np.tile(np.arange(likelihoods.shape[1]), (likelihoods.shape[0], 1))
assert n.shape == likelihoods.shape
n[likelihoods <= 0] = 0
ploidy = variant.ploidy

if ploidy == 1:
a = n
b = np.zeros_like(a)
elif ploidy == 2:
# We have n = b(b+1) / 2 + a
# We need to compute a and b
b = np.ceil(np.sqrt(2 * n + 9 / 4) - 3 / 2).astype(int)
a = (n - b * (b + 1) / 2).astype(int)
else:
# TODO: Handle all possible ploidy
raise ValueError(f"Cannot handle ploidy = {ploidy}")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is causing the coverage to fail. Let me know if you want me to add a unit test for this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to cover this case all right. Can we generate a simple triploid one-line VCF with a PL and index on the fly, like we do in the simulated tests?

We can log an issue as something to do in a follow up, if it's too messy though.


a_counts = np.apply_along_axis(
np.bincount, axis=1, arr=a, minlength=allele_count
)
b_counts = np.apply_along_axis(
np.bincount, axis=1, arr=b, minlength=allele_count
)
assert a_counts.shape == b_counts.shape == allele_counts.shape
allele_counts += a_counts
allele_counts += b_counts

allele_counts[:, 0] = 0 # We don't count the reference allele
max_row_length = 1

def nonzero_pad(arr: np.ndarray, *, length: int):
nonlocal max_row_length
alleles = arr.nonzero()[0]
max_row_length = max(max_row_length, len(alleles))
pad_length = length - len(alleles)
return np.pad(
alleles,
(0, pad_length),
mode="constant",
constant_values=constants.INT_FILL,
)

alleles = np.apply_along_axis(
nonzero_pad, axis=1, arr=allele_counts, length=max(1, alt_allele_count)
)
alleles = alleles[:, :max_row_length]

return alleles


missing_value_map = {
"Integer": constants.INT_MISSING,
"Float": constants.FLOAT32_MISSING,
Expand Down Expand Up @@ -962,6 +1083,7 @@ def init(
target_num_partitions=None,
show_progress=False,
compressor=None,
local_alleles,
):
if self.path.exists():
raise ValueError(f"ICF path already exists: {self.path}")
Expand All @@ -976,6 +1098,7 @@ def init(
worker_processes=worker_processes,
show_progress=show_progress,
target_num_partitions=target_num_partitions,
local_alleles=local_alleles,
)
check_field_clobbering(icf_metadata)
self.metadata = icf_metadata
Expand Down Expand Up @@ -1085,8 +1208,15 @@ def process_partition(self, partition_index):
val = variant.genotype.array()
tcw.append("FORMAT/GT", val)
for field in format_fields:
val = variant.format(field.name)
if (
field.full_name == "FORMAT/LAA"
and "LAA" not in variant.FORMAT
):
val = compute_laa_field(variant)
else:
val = variant.format(field.name)
tcw.append(field.full_name, val)

# Note: an issue with updating the progress per variant here like
# this is that we get a significant pause at the end of the counter
# while all the "small" fields get flushed. Possibly not much to be
Expand Down Expand Up @@ -1180,6 +1310,7 @@ def explode(
worker_processes=1,
show_progress=False,
compressor=None,
local_alleles=True,
):
writer = IntermediateColumnarFormatWriter(icf_path)
writer.init(
Expand All @@ -1190,6 +1321,7 @@ def explode(
show_progress=show_progress,
column_chunk_size=column_chunk_size,
compressor=compressor,
local_alleles=local_alleles,
)
writer.explode(worker_processes=worker_processes, show_progress=show_progress)
writer.finalise()
Expand All @@ -1205,6 +1337,7 @@ def explode_init(
worker_processes=1,
show_progress=False,
compressor=None,
local_alleles=True,
):
writer = IntermediateColumnarFormatWriter(icf_path)
return writer.init(
Expand All @@ -1214,6 +1347,7 @@ def explode_init(
show_progress=show_progress,
column_chunk_size=column_chunk_size,
compressor=compressor,
local_alleles=local_alleles,
)


Expand Down
2 changes: 1 addition & 1 deletion bio2zarr/vcf2zarr/vcz.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def from_field(
if array_name is None:
array_name = prefix + vcf_field.name
# TODO make an option to add in the empty extra dimension
if vcf_field.summary.max_number > 1:
if vcf_field.summary.max_number > 1 or vcf_field.full_name == "FORMAT/LAA":
shape.append(vcf_field.summary.max_number)
chunks.append(vcf_field.summary.max_number)
# TODO we should really be checking this to see if the named dimensions
Expand Down
2 changes: 2 additions & 0 deletions bio2zarr/vcf2zarr/verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def verify(vcf_path, zarr_path, show_progress=False):
for colname in root.keys():
if colname.startswith("call") and not colname.startswith("call_genotype"):
vcf_name = colname.split("_", 1)[1]
if vcf_name == "LAA" and vcf_name not in format_headers:
continue # LAA could have been computed during the explode step.
vcf_type = format_headers[vcf_name]["Type"]
vcf_number = format_headers[vcf_name]["Number"]
format_fields[vcf_name] = vcf_type, vcf_number, iter(root[colname])
Expand Down
Binary file added tests/data/vcf/local_alleles.vcf.gz
Binary file not shown.
Binary file added tests/data/vcf/local_alleles.vcf.gz.csi
Binary file not shown.
Binary file added tests/data/vcf/triploid.vcf.gz
Binary file not shown.
Binary file added tests/data/vcf/triploid.vcf.gz.csi
Binary file not shown.
25 changes: 25 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
compressor=None,
worker_processes=1,
show_progress=True,
local_alleles=True,
)

DEFAULT_DEXPLODE_PARTITION_ARGS = dict()
Expand All @@ -23,6 +24,7 @@
column_chunk_size=64,
compressor=None,
show_progress=True,
local_alleles=True,
)

DEFAULT_ENCODE_ARGS = dict(
Expand Down Expand Up @@ -287,6 +289,29 @@ def test_vcf_explode_missing_and_existing_vcf(self, mocked, tmp_path):
assert "'no_such_file' does not exist" in result.stderr
mocked.assert_not_called()

@pytest.mark.parametrize("local_alleles", [False, True])
@mock.patch("bio2zarr.vcf2zarr.explode")
def test_vcf_explode_local_alleles(self, mocked, tmp_path, local_alleles):
icf_path = tmp_path / "icf"
runner = ct.CliRunner(mix_stderr=False)

if local_alleles:
local_alleles_flag = "--local-alleles"
else:
local_alleles_flag = "--no-local-alleles"

result = runner.invoke(
cli.vcf2zarr_main,
f"explode {self.vcf_path} {icf_path} {local_alleles_flag}",
catch_exceptions=False,
)
assert result.exit_code == 0
assert len(result.stdout) == 0
assert len(result.stderr) == 0
args = dict(DEFAULT_EXPLODE_ARGS)
args["local_alleles"] = local_alleles
mocked.assert_called_once_with(str(icf_path), (self.vcf_path,), **args)

@pytest.mark.parametrize(("progress", "flag"), [(True, "-P"), (False, "-Q")])
@mock.patch("bio2zarr.vcf2zarr.explode_init", return_value=FakeWorkSummary(5))
def test_vcf_dexplode_init(self, mocked, tmp_path, progress, flag):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def test_5_chunk_1(self, n, expected):
# It works in CI on Linux, but it'll probably break at some point.
# It's also necessary to update these numbers each time a new data
# file gets added
("tests/data", 4974951),
("tests/data/vcf", 4962814),
("tests/data", 4976351),
("tests/data/vcf", 4964214),
("tests/data/vcf/sample.vcf.gz", 1089),
],
)
Expand Down
Loading
Loading