diff --git a/.github/workflows/cubed.yml b/.github/workflows/cubed.yml index d432640a..f3e00df8 100644 --- a/.github/workflows/cubed.yml +++ b/.github/workflows/cubed.yml @@ -30,10 +30,11 @@ jobs: - name: Test with pytest run: | - pytest -v sgkit/tests/test_{aggregation,association,hwe,pca,window}.py \ + pytest -v sgkit/tests/test_{aggregation,association,hwe,pca,popgen,window}.py \ -k "test_count_call_alleles or \ test_gwas_linear_regression or \ test_hwep or \ + (test_popgen and diversity) or \ test_sample_stats or \ (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or \ (test_variant_stats and not test_variant_stats__chunks[chunks2-False]) or \ diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 9f386298..081e73d3 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -273,10 +273,8 @@ def count_cohort_alleles( ds, variables.call_allele_count, call_allele_count, count_call_alleles ) variables.validate(ds, {call_allele_count: variables.call_allele_count_spec}) - # ensure cohorts is a numpy array to minimize dask task - # dependencies between chunks in other dimensions - AC, SC = da.asarray(ds[call_allele_count]), ds[sample_cohort].values - n_cohorts = SC.max() + 1 # 0-based indexing + AC, SC = da.asarray(ds[call_allele_count]), da.asarray(ds[sample_cohort]) + n_cohorts = ds[sample_cohort].values.max() + 1 # 0-based indexing AC = cohort_sum(AC, SC, n_cohorts, axis=1) new_ds = create_dataset( {variables.cohort_allele_count: (("variants", "cohorts", "alleles"), AC)} @@ -1060,7 +1058,7 @@ def individual_heterozygosity( variables.validate(ds, {call_allele_count: variables.call_allele_count_spec}) AC = da.asarray(ds.call_allele_count) - K = AC.sum(axis=-1) + K = da.sum(AC, axis=-1) # use nan denominator to avoid divide by zero with K - 1 K2 = da.where(K > 1, K, np.nan) AF = AC / K2[..., None] diff --git a/sgkit/stats/cohort_numba_fns.py b/sgkit/stats/cohort_numba_fns.py index 91aa54dd..0a347eba 100644 --- a/sgkit/stats/cohort_numba_fns.py +++ b/sgkit/stats/cohort_numba_fns.py @@ -1,10 +1,11 @@ from functools import wraps from typing import Callable -import dask.array as da import numpy as np +import sgkit.distarray as da from sgkit.accelerate import numba_guvectorize +from sgkit.utils import swapaxes from ..typing import ArrayLike @@ -43,19 +44,20 @@ def cohort_reduction(gufunc: Callable) -> Callable: @wraps(gufunc) def func(x: ArrayLike, cohort: ArrayLike, n: int, axis: int = -1) -> ArrayLike: - x = da.swapaxes(da.asarray(x), axis, -1) + x = swapaxes(da.asarray(x), axis, -1) replaced = len(x.shape) - 1 chunks = x.chunks[0:-1] + (n,) out = da.map_blocks( gufunc, x, cohort, - np.empty(n, np.int8), + da.empty(n, dtype=np.int8), + dtype=x.dtype, # TODO: is this right? Might need to be passed in. chunks=chunks, drop_axis=replaced, new_axis=replaced, ) - return da.swapaxes(out, axis, -1) + return swapaxes(out, axis, -1) return func diff --git a/sgkit/stats/popgen.py b/sgkit/stats/popgen.py index e201dfc9..834b18e4 100644 --- a/sgkit/stats/popgen.py +++ b/sgkit/stats/popgen.py @@ -2,10 +2,10 @@ import itertools from typing import Hashable, Optional, Sequence, Tuple, Union -import dask.array as da import numpy as np from xarray import Dataset +import sgkit.distarray as da from sgkit.cohorts import _cohorts_to_array from sgkit.stats.utils import assert_array_shape from sgkit.utils import ( @@ -90,14 +90,14 @@ def diversity( ) variables.validate(ds, {cohort_allele_count: variables.cohort_allele_count_spec}) - ac = ds[cohort_allele_count] + ac = ds[cohort_allele_count].astype(np.uint64) an = ac.sum(axis=2) - n_pairs = an * (an - 1) / 2 - n_same = (ac * (ac - 1) / 2).sum(axis=2) + n_pairs = an * (an - 1) // 2 + n_same = (ac * (ac - 1) // 2).sum(axis=2) n_diff = n_pairs - n_same # replace zeros to avoid divide by zero error n_pairs_na = n_pairs.where(n_pairs != 0) - pi = n_diff / n_pairs_na + pi = n_diff.astype(np.float64) / n_pairs_na.astype(np.float64) if has_windows(ds): div = window_statistic( diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index 5ed538d0..c154591b 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -400,11 +400,11 @@ def count_biallelic_genotypes(calls, ploidy): calls = ds.call_genotype.values expect = count_biallelic_genotypes(calls, ploidy) if chunked: - # chunk each dim + # chunk each dim except ploidy chunks = ( (n_variant // 2, n_variant - n_variant // 2), (n_sample // 2, n_sample - n_sample // 2), - (ploidy // 2, ploidy - ploidy // 2), + -1, ) ds["call_genotype"] = ds["call_genotype"].chunk(chunks) actual = count_variant_genotypes(ds)["variant_genotype_count"].data diff --git a/sgkit/tests/test_cohort_numba_fns.py b/sgkit/tests/test_cohort_numba_fns.py index 2239e3e5..3a131258 100644 --- a/sgkit/tests/test_cohort_numba_fns.py +++ b/sgkit/tests/test_cohort_numba_fns.py @@ -1,7 +1,7 @@ -import dask.array as da import numpy as np import pytest +import sgkit.distarray as da from sgkit.stats.cohort_numba_fns import ( cohort_mean, cohort_nanmean, @@ -41,7 +41,7 @@ def _cohort_reduction(func, x, cohort, n, axis=-1): _random_cohort_data((20, 20), n=3, axis=-1, missing=0.3), _random_cohort_data((7, 103, 4), n=5, axis=1, scale=7, missing=0.3), _random_cohort_data( - ((3, 4), (50, 50, 3), 4), n=5, axis=1, scale=7, dtype=np.uint8 + ((4, 3), (50, 50, 3), 4), n=5, axis=1, scale=7, dtype=np.uint8 ), _random_cohort_data( ((6, 6), (50, 50, 7), (3, 1)), n=5, axis=1, scale=7, missing=0.3 diff --git a/sgkit/tests/test_popgen.py b/sgkit/tests/test_popgen.py index 251e0568..3d2be338 100644 --- a/sgkit/tests/test_popgen.py +++ b/sgkit/tests/test_popgen.py @@ -1,7 +1,6 @@ import itertools import allel -import dask.array as da import msprime # type: ignore import numpy as np import pytest @@ -11,6 +10,7 @@ from hypothesis import given, settings from hypothesis import strategies as st +import sgkit.distarray as da from sgkit import ( Fst, Garud_H, diff --git a/sgkit/utils.py b/sgkit/utils.py index f60db1f9..5106350f 100644 --- a/sgkit/utils.py +++ b/sgkit/utils.py @@ -3,11 +3,12 @@ from itertools import product from typing import Any, Callable, Hashable, List, Mapping, Optional, Set, Tuple, Union -import dask.array as da import numpy as np import xarray as xr from xarray import Dataset +import sgkit.distarray as da + from . import variables from .typing import ArrayLike, DType @@ -433,3 +434,15 @@ def has_keyword(func, keyword): return keyword in inspect.signature(func).parameters except Exception: # pragma: no cover return False + + +# an implementation of numpy.swapaxes since it is not defined in the array API +def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> ArrayLike: + """Interchange two axes of an array.""" + if axis1 < 0: + axis1 += a.ndim + if axis2 < 0: + axis2 += a.ndim + if axis1 == axis2: + return a + return da.moveaxis(a, [axis1, axis2], [axis2, axis1])