Skip to content

Commit

Permalink
Fixups for PL
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jan 16, 2025
1 parent 0f89efb commit abed893
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 78 deletions.
127 changes: 79 additions & 48 deletions bio2zarr/vcf2zarr/vcz.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,21 @@ def convert_local_allele_field_types(fields):
gt = fields_by_name["call_genotype"]
if gt.shape[-1] != 2:
raise ValueError("Local alleles only supported on diploid data")
# TODO check if LAA is already in here

# TODO check if LA is already in here

shape = gt.shape[:-1]
chunks = gt.chunks[:-1]

laa = ZarrArraySpec.new(
la = ZarrArraySpec.new(
vcf_field=None,
name="call_LAA",
name="call_LA",
dtype="i1",
shape=gt.shape,
chunks=gt.chunks,
dimensions=gt.dimensions, # FIXME
description=(
"1-based indices into ALT, indicating which alleles"
"0-based indices into REF+ALT, indicating which alleles"
" are relevant (local) for the current sample"
),
)
Expand All @@ -224,16 +225,16 @@ def convert_local_allele_field_types(fields):
ad.description += " (local-alleles)"
# TODO fix dimensions

# pl = fields_by_name.get("call_PL", None)
# if pl is not None:
# # TODO check if call_LPL is in the list already
# pl.name = "call_LPL"
# pl.vcf_field = None
# pl.shape = (*shape, 3)
# pl.chunks = (*chunks, 3)
# pl.description += " (local-alleles)"
# # TODO fix dimensions
return [*fields, laa]
pl = fields_by_name.get("call_PL", None)
if pl is not None:
# TODO check if call_LPL is in the list already
pl.name = "call_LPL"
pl.vcf_field = None
pl.shape = (*shape, 3)
pl.chunks = (*chunks, 3)
pl.description += " (local-alleles)"
# TODO fix dimensions
return [*fields, la]


@dataclasses.dataclass
Expand Down Expand Up @@ -523,50 +524,66 @@ def fromdict(d):
return ret


def compute_laa_field(genotypes):
def compute_la_field(genotypes):
"""
Computes the value of the LAA field for each sample given the genotypes
for a variant.
The LAA field is a list of one-based indices into the ALT alleles
that indicates which alternate alleles are observed in the sample.
Computes the value of the LA field for each sample given the genotypes
for a variant. The LA field lists the unique alleles observed for
each sample, including the REF.
"""
v = 2**31 - 1
if np.any(genotypes >= v):
raise ValueError("Extreme allele value not supported")
G = genotypes.astype(np.int32)
if len(G) > 0:
# Anything <=0 gets mapped to -2 (pad) in the output, which comes last.
# Anything < 0 gets mapped to -2 (pad) in the output, which comes last.
# So, to get this sorting correctly, we remap to the largest value for
# sorting, then map back. We promote the genotypes up to 32 bit for convenience
# here, assuming that we'll never have a allele of 2**31 - 1.
assert np.all(G != v)
G[G <= 0] = v
G[G < 0] = v
G.sort(axis=1)
# Equal non-zero values result in padding also
G[G[:, 0] == G[:, 1], 1] = -2
# Equal values result in padding also
G[G == v] = -2
return G.astype(genotypes.dtype)


def compute_lad_field(ad, laa):
try:
lad = np.full((ad.shape[0], 2), -2, dtype=ad.dtype)
ref_ref = np.where((laa[:, 0] == -2) & (laa[:, 1] == -2))[0]
lad[ref_ref, 0] = ad[ref_ref, 0]
ref_alt = np.where((laa[:, 0] != -2) & (laa[:, 1] == -2))[0]
lad[ref_alt, 0] = ad[ref_alt, 0]
lad[ref_alt, 1] = ad[ref_alt, laa[ref_alt, 0]]
alt_alt = np.where((laa[:, 0] != -2) & (laa[:, 1] != -2))[0]
lad[alt_alt, 0] = ad[alt_alt, laa[alt_alt, 0]]
lad[alt_alt, 1] = ad[alt_alt, laa[alt_alt, 1]]
except Exception as e:
print("ad = ", ad)
print("laa = ", laa)
raise e
def compute_lad_field(ad, la):
assert ad.shape[0] == la.shape[0]
assert la.shape[1] == 2
lad = np.full((ad.shape[0], 2), -2, dtype=ad.dtype)
homs = np.where((la[:, 0] != -2) & (la[:, 1] == -2))
lad[homs, 0] = ad[homs, la[homs, 0]]
hets = np.where(la[:, 1] != -2)
lad[hets, 0] = ad[hets, la[hets, 0]]
lad[hets, 1] = ad[hets, la[hets, 1]]
return lad


def pl_index(a, b):
"""
Returns the PL index for alleles a and b.
"""
return b * (b + 1) // 2 + a


def compute_lpl_field(pl, la):
lpl = np.full((pl.shape[0], 3), -2, dtype=pl.dtype)

homs = np.where((la[:, 0] != -2) & (la[:, 1] == -2))
a = la[homs, 0]
lpl[homs, 0] = pl[homs, pl_index(a, a)]

hets = np.where(la[:, 1] != -2)[0]
a = la[hets, 0]
b = la[hets, 1]
lpl[hets, 0] = pl[hets, pl_index(a, a)]
lpl[hets, 1] = pl[hets, pl_index(a, b)]
lpl[hets, 2] = pl[hets, pl_index(b, b)]

return lpl


@dataclasses.dataclass
class VcfZarrWriteSummary(core.JsonDataclass):
num_partitions: int
Expand Down Expand Up @@ -601,7 +618,7 @@ def has_genotypes(self):

def has_local_alleles(self):
for field in self.schema.fields:
if field.name == "call_LAA" and field.vcf_field is None:
if field.name == "call_LA" and field.vcf_field is None:
return True
return False

Expand Down Expand Up @@ -872,35 +889,49 @@ def encode_genotypes_partition(self, partition_index):

def encode_local_alleles_partition(self, partition_index):
partition = self.metadata.partitions[partition_index]
call_LAA_array = self.init_partition_array(partition_index, "call_LAA")
call_LAA = core.BufferedArray(call_LAA_array, partition.start)
call_LA_array = self.init_partition_array(partition_index, "call_LA")
call_LA = core.BufferedArray(call_LA_array, partition.start)

call_LAD_array = self.init_partition_array(partition_index, "call_LAD")
call_LAD = core.BufferedArray(call_LAD_array, partition.start)
call_AD_source = self.icf.fields["FORMAT/AD"].iter_values(
partition.start, partition.stop
)
call_LPL_array = self.init_partition_array(partition_index, "call_LPL")
call_LPL = core.BufferedArray(call_LPL_array, partition.start)
call_PL_source = self.icf.fields["FORMAT/PL"].iter_values(
partition.start, partition.stop
)

gt_array = zarr.open_array(
store=self.wip_partition_array_path(partition_index, "call_genotype"),
mode="r",
)
for genotypes in core.first_dim_slice_iter(
gt_array, partition.start, partition.stop
):
laa = compute_laa_field(genotypes)
j = call_LAA.next_buffer_row()
call_LAA.buff[j] = laa
la = compute_la_field(genotypes)
j = call_LA.next_buffer_row()
call_LA.buff[j] = la

ad = next(call_AD_source)
ad = icf.sanitise_int_array(ad, 2, ad.dtype)
k = call_LAD.next_buffer_row()
assert j == k
lad = compute_lad_field(ad, laa)
call_LAD.buff[j] = lad
call_LAD.buff[j] = compute_lad_field(ad, la)

pl = next(call_PL_source)
pl = icf.sanitise_int_array(pl, 2, pl.dtype)
k = call_LPL.next_buffer_row()
assert j == k
call_LPL.buff[j] = compute_lpl_field(pl, la)

call_LAA.flush()
self.finalise_partition_array(partition_index, "call_LAA")
call_LA.flush()
self.finalise_partition_array(partition_index, "call_LA")
call_LAD.flush()
self.finalise_partition_array(partition_index, "call_LAD")
call_LPL.flush()
self.finalise_partition_array(partition_index, "call_LPL")

def encode_alleles_partition(self, partition_index):
array_name = "variant_allele"
Expand Down
92 changes: 68 additions & 24 deletions tests/test_local_alleles.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,107 @@
import numpy.testing as nt
import pytest

from bio2zarr.vcf2zarr.vcz import compute_laa_field, compute_lad_field
from bio2zarr.vcf2zarr import vcz


class TestComputeLAA:
class TestComputeLA:
@pytest.mark.parametrize(
("genotypes", "expected"),
[
([], []),
([[0, 0]], [[-2, -2]]),
([[0, 0], [0, 0]], [[-2, -2], [-2, -2]]),
([[1, 1], [0, 0]], [[1, -2], [-2, -2]]),
([[0, 1], [3, 2], [3, 0]], [[1, -2], [2, 3], [3, -2]]),
([[0, 0], [2, 3]], [[-2, -2], [2, 3]]),
([[2, 3], [0, 0]], [[2, 3], [-2, -2]]),
([[128, 0], [6, 5]], [[128, -2], [5, 6]]),
([[0, -1], [-1, 5]], [[-2, -2], [5, -2]]),
([[0, 0]], [[0, -2]]),
([[0, 0], [0, 0]], [[0, -2], [0, -2]]),
([[1, 1], [0, 0]], [[1, -2], [0, -2]]),
([[0, 1], [3, 2], [3, 0]], [[0, 1], [2, 3], [0, 3]]),
([[0, 0], [2, 3]], [[0, -2], [2, 3]]),
([[2, 3], [0, 0]], [[2, 3], [0, -2]]),
([[128, 0], [6, 5]], [[0, 128], [5, 6]]),
([[0, -1], [-1, 5]], [[0, -2], [5, -2]]),
([[-1, -1], [-1, 5]], [[-2, -2], [5, -2]]),
],
)
def test_simple_examples(self, genotypes, expected):
G = np.array(genotypes)
result = compute_laa_field(G)
result = vcz.compute_la_field(G)
nt.assert_array_equal(result, expected)

def test_extreme_value(self):
G = np.array([[0, 2**32 - 1]])
with pytest.raises(ValueError, match="Extreme"):
compute_laa_field(G)
vcz.compute_la_field(G)


class TestComputeLAD:
@pytest.mark.parametrize(
("ad", "laa", "expected"),
("ad", "la", "expected"),
[
# Missing data
([[0, 0]], [[-2, -2]], [[-2, -2]]),
# 0/0 calls
([[10, 0]], [[-2, -2]], [[10, -2]]),
([[10, 0, 0], [11, 0, 0]], [[-2, -2], [-2, -2]], [[10, -2], [11, -2]]),
([[10, 0]], [[0, -2]], [[10, -2]]),
([[10, 0, 0]], [[0, -2]], [[10, -2]]),
([[10, 0, 0], [11, 0, 0]], [[0, -2], [0, -2]], [[10, -2], [11, -2]]),
# 0/1 calls
([[10, 11]], [[1, -2]], [[10, 11]]),
([[10, 11], [12, 0]], [[1, -2], [-2, -2]], [[10, 11], [12, -2]]),
([[10, 11]], [[0, 1]], [[10, 11]]),
([[10, 11], [12, 0]], [[0, 1], [0, -2]], [[10, 11], [12, -2]]),
# 0/2 calls
([[10, 0, 11]], [[2, -2]], [[10, 11]]),
([[10, 0, 11], [10, 11, 0]], [[2, -2], [1, -2]], [[10, 11], [10, 11]]),
([[10, 0, 11]], [[0, 2]], [[10, 11]]),
([[10, 0, 11], [10, 11, 0]], [[0, 2], [0, 1]], [[10, 11], [10, 11]]),
(
[[10, 0, 11], [10, 11, 0], [12, 0, 0]],
[[2, -2], [1, -2], [-2, -2]],
[[0, 2], [0, 1], [0, -2]],
[[10, 11], [10, 11], [12, -2]],
),
# 1/2 calls
([[0, 10, 11]], [[1, 2]], [[10, 11]]),
([[0, 10, 11], [12, 0, 13]], [[1, 2], [2, -2]], [[10, 11], [12, 13]]),
([[0, 10, 11], [12, 0, 13]], [[1, 2], [0, 2]], [[10, 11], [12, 13]]),
(
[[0, 10, 11], [12, 0, 13], [14, 0, 0]],
[[1, 2], [2, -2], [-2, -2]],
[[1, 2], [0, 2], [0, -2]],
[[10, 11], [12, 13], [14, -2]],
),
],
)
def test_simple_examples(self, ad, laa, expected):
result = compute_lad_field(np.array(ad), np.array(laa))
def test_simple_examples(self, ad, la, expected):
result = vcz.compute_lad_field(np.array(ad), np.array(la))
nt.assert_array_equal(result, expected)


# PL translation indexes:
# a b i
# 0 0 0
# 0 1 1
# 0 2 3
# 0 3 6
# 1 1 2
# 1 2 4
# 1 3 7
# 2 2 5
# 2 3 8
# 3 3 9


class TestComputeLPL:
@pytest.mark.parametrize(
("pl", "la", "expected"),
[
# Missing
([range(3)], [[-2, -2]], [[-2, -2, -2]]),
# 0/0 calls
([range(3)], [[0, -2]], [[0, -2, -2]]),
# 0/0 calls
([[-1, -1, -1]], [[0, -2]], [[-1, -2, -2]]),
# 1/1 calls
([range(3)], [[1, -2]], [[2, -2, -2]]),
([range(3), range(3)], [[0, -2], [1, -2]], [[0, -2, -2], [2, -2, -2]]),
# 2/2 calls
([range(6)], [[2, -2]], [[5, -2, -2]]),
# 0/1 calls
([range(3)], [[0, 1]], [[0, 1, 2]]),
# 0/2 calls
([range(6)], [[0, 2]], [[0, 3, 5]]),
],
)
def test_simple_examples(self, pl, la, expected):
result = vcz.compute_lpl_field(np.array(pl), np.array(la))
nt.assert_array_equal(result, expected)
13 changes: 7 additions & 6 deletions tests/test_vcf_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,14 +725,15 @@ def test_call_LAD(self, ds):
]
nt.assert_array_equal(ds.call_LAD.values, call_LAD)

def test_call_LAA(self, ds):
def test_call_LA(self, ds):
# All the genotypes are 0/0
call_LAA = np.full((23, 3, 2), -2)
nt.assert_array_equal(ds.call_LAA.values, call_LAA)
call_LA = np.full((23, 3, 2), -2)
call_LA[:, :, 0] = 0
nt.assert_array_equal(ds.call_LA.values, call_LA)

# def test_call_LPL(self, ds):
# call_LPL = np.tile([0, -2, -2], (23, 3, 1))
# nt.assert_array_equal(ds.call_LPL.values, call_LPL)
def test_call_LPL(self, ds):
call_LPL = np.tile([0, -2, -2], (23, 3, 1))
nt.assert_array_equal(ds.call_LPL.values, call_LPL)


class Test1000G2020AnnotationsExample:
Expand Down

0 comments on commit abed893

Please sign in to comment.