Skip to content

Commit

Permalink
Merge pull request #36 from mancusolab/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
zeyunlu authored Mar 5, 2024
2 parents 8b15576 + 7d1eca1 commit 23beaf5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 23 deletions.
2 changes: 1 addition & 1 deletion docs/manual.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ We provide example data in ``./data/`` folder to test out SuShiE. All the data a

The genotype is the high-quality `HapMap <https://www.genome.gov/10001688/international-hapmap-project>`_ SNPs in some random gene 1M base-pair window, which contains 123, 129, and 113 SNPs for EUR, AFR, and EAS respectively in `1000G <https://www.internationalgenome.org/>`_ project. We provide genotype data in `plink 1 <https://www.cog-genomics.org/plink/1.9/input#bed>`_, `vcf <https://en.wikipedia.org/wiki/Variant_Call_Format>`_, and `bgen <https://www.well.ox.ac.uk/~gav/bgen_format/>`_ 1.3 format.

Using ``./data/make_example.py``, we simulated phenotype data (2 causal QTLs, cis-SNP heritability:0.5 and effect size correlation 0.8), random covariate data for each ancestry. It also outputs ``all.pheno`` file that row-binds simulated phenotype across ancestries, ``all.ancestry.index`` file that specifies ancestry index if using ``all.pheno``, ``all.covar``, and ``.\plink\all`` triplets, ``keep.subject`` file that specifies subjects to be included in the inference.
Using ``./data/make_example.py``, we simulated phenotype data (2 causal QTLs, cis-SNP heritability: 0.5 and effect size correlation 0.8), random covariate data for each ancestry. The two QTL rsID are rs1886340 and rs10914958. It also outputs ``all.pheno`` file that row-binds simulated phenotype across ancestries, ``all.ancestry.index`` file that specifies ancestry index if using ``all.pheno``, ``all.covar``, and ``.\plink\all`` triplets, ``keep.subject`` file that specifies subjects to be included in the inference.

As for the format requirement, see :ref:`Param` for detailed explanations.

Expand Down
33 changes: 11 additions & 22 deletions sushie/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ class _AbstractOptFunc(eqx.Module, metaclass=ABCMeta):
@abstractmethod
def __call__(
self,
beta_hat: ArrayLike,
shat2: ArrayLike,
rTZDinv: ArrayLike,
inv_shat2: ArrayLike,
priors: Prior,
posteriors: Posterior,
Expand All @@ -122,35 +121,29 @@ class _LResult(NamedTuple):
class _EMOptFunc(_AbstractOptFunc):
def __call__(
self,
beta_hat: ArrayLike,
shat2: ArrayLike,
rTZDinv: ArrayLike,
inv_shat2: ArrayLike,
priors: Prior,
posteriors: Posterior,
prior_adjustor: _PriorAdjustor,
l_iter: int,
) -> Prior:
priors, _ = _compute_posterior(
beta_hat, shat2, inv_shat2, priors, posteriors, l_iter
)
priors, _ = _compute_posterior(rTZDinv, inv_shat2, priors, posteriors, l_iter)

return priors


class _NoopOptFunc(_AbstractOptFunc):
def __call__(
self,
beta_hat: ArrayLike,
shat2: ArrayLike,
rTZDinv: ArrayLike,
inv_shat2: ArrayLike,
priors: Prior,
posteriors: Posterior,
prior_adjustor: _PriorAdjustor,
l_iter: int,
) -> Prior:
priors, _ = _compute_posterior(
beta_hat, shat2, inv_shat2, priors, posteriors, l_iter
)
priors, _ = _compute_posterior(rTZDinv, inv_shat2, priors, posteriors, l_iter)
priors = priors._replace(
effect_covar=priors.effect_covar.at[l_iter].set(
priors.effect_covar[l_iter] * prior_adjustor.times + prior_adjustor.plus
Expand Down Expand Up @@ -623,33 +616,29 @@ def _ssr(
shat2 = jnp.eye(n_pop) * shat2[:, jnp.newaxis]
inv_shat2 = jnp.eye(n_pop) * inv_shat2[:, jnp.newaxis]

priors = opt_v_func(
beta_hat, shat2, inv_shat2, priors, posteriors, prior_adjustor, l_iter
)
rTZDinv = beta_hat / jnp.diagonal(shat2, axis1=1, axis2=2)

_, posteriors = _compute_posterior(
beta_hat, shat2, inv_shat2, priors, posteriors, l_iter
)
priors = opt_v_func(rTZDinv, inv_shat2, priors, posteriors, prior_adjustor, l_iter)

_, posteriors = _compute_posterior(rTZDinv, inv_shat2, priors, posteriors, l_iter)

return priors, posteriors


def _compute_posterior(
beta_hat: ArrayLike,
shat2: ArrayLike,
rTZDinv: ArrayLike,
inv_shat2: ArrayLike,
priors: Prior,
posteriors: Posterior,
l_iter: int,
) -> Tuple[Prior, Posterior]:
n_snps, n_pop = beta_hat.shape
n_snps, n_pop, _ = inv_shat2.shape

# prior_covar is kxk
prior_covar = priors.effect_covar[l_iter]
# post_covar is pxkxk
post_covar = jnp.linalg.inv(inv_shat2 + jnp.linalg.inv(prior_covar))
# pxk
rTZDinv = beta_hat / jnp.diagonal(shat2, axis1=1, axis2=2)

# dim m = dim k for the next two lines
post_mean = jnp.einsum("pkm,pm->pk", post_covar, rTZDinv)
Expand Down

0 comments on commit 23beaf5

Please sign in to comment.