Skip to content

Commit

Permalink
Plot H12 GWSS traces for multiple cohorts together in a single figure (
Browse files Browse the repository at this point in the history
…#595)

* First draft for multiple traces of H12. There are still things that I am not too happy about.

* Forgot one of the shared params.

* Solving linting issues

* Solving more linting issues

* Trying to get around assertions.

* Trying to get around linting.

* Trying to get around linting.

* Trying to make the params for the multiple tracks make sense

* Corrected a typo

* More arguments moved to sample queries

* Forgot a few selves

* Started working on the docs

* Followed Alistair's advices (I think). Still need to add tests and a better way to handle titles.

* Messed up the type definition of window_sizes

* Added some tests and made some little changes

* Improved the def of window_size

* add examples of multi plots

* Update malariagen_data/anoph/h12.py

Co-authored-by: Alistair Miles <alimanfoo@googlemail.com>

* Update malariagen_data/anoph/h12.py

Co-authored-by: Alistair Miles <alimanfoo@googlemail.com>

* Update malariagen_data/anoph/h12.py

Co-authored-by: Alistair Miles <alimanfoo@googlemail.com>

* Update malariagen_data/anoph/h12.py

Co-authored-by: Alistair Miles <alimanfoo@googlemail.com>

* Update malariagen_data/anoph/h12_params.py

Co-authored-by: Alistair Miles <alimanfoo@googlemail.com>

* Hoping a tab might solve the issue.

* add example using cohorts and window_size as dict

* Update malariagen_data/anoph/h12.py

Co-authored-by: Alistair Miles <alimanfoo@googlemail.com>

* Update malariagen_data/anoph/h12.py

Co-authored-by: Alistair Miles <alimanfoo@googlemail.com>

* Update malariagen_data/anoph/h12.py

Co-authored-by: Alistair Miles <alimanfoo@googlemail.com>

* Update h12.py

* More cohorts for the H12 tests

* Reorganized the tests

* Added a test for the dict version of window_size

* Update malariagen_data/anoph/h12.py

* add h12 multi functions to API docs

---------

Co-authored-by: Lee <4256466+leehart@users.noreply.github.com>
Co-authored-by: Alistair Miles <am23@sanger.ac.uk>
Co-authored-by: Alistair Miles <alimanfoo@googlemail.com>
  • Loading branch information
4 people authored Oct 23, 2024
1 parent eb6693e commit 20efee0
Show file tree
Hide file tree
Showing 7 changed files with 578 additions and 14 deletions.
2 changes: 2 additions & 0 deletions docs/source/Af1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ Genome-wide selection scans
plot_h12_calibration
h12_gwss
plot_h12_gwss
plot_h12_gwss_multi_panel
plot_h12_gwss_multi_overlay
h1x_gwss
plot_h1x_gwss
g123_calibration
Expand Down
2 changes: 2 additions & 0 deletions docs/source/Ag3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ Genome-wide selection scans
plot_h12_calibration
h12_gwss
plot_h12_gwss
plot_h12_gwss_multi_panel
plot_h12_gwss_multi_overlay
h1x_gwss
plot_h1x_gwss
g123_calibration
Expand Down
12 changes: 11 additions & 1 deletion malariagen_data/anoph/gplt_params.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Parameters for genome plotting functions. N.B., genome plots are always
plotted with bokeh."""

from typing import Literal, Mapping, Optional, Union
from typing import Literal, Mapping, Optional, Union, Sequence

import bokeh.models
from typing_extensions import Annotated, TypeAlias
Expand Down Expand Up @@ -83,6 +83,14 @@
"A bokeh figure (only returned if show=False).",
]

def_figure: TypeAlias = Annotated[
# Use quite a broad type here to accommodate both single-panel figures
# created via bokeh.plotting and multi-panel figures created via
# bokeh.layouts.
bokeh.model.Model,
"A bokeh figure.",
]

output_backend: TypeAlias = Annotated[
Literal["canvas", "webgl", "svg"],
"""
Expand All @@ -103,3 +111,5 @@
Mapping,
"Passed through to bokeh line() function.",
]

colors: TypeAlias = Annotated[Sequence[str], "List of colors."]
301 changes: 301 additions & 0 deletions malariagen_data/anoph/h12.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,307 @@ def plot_h12_gwss(
else:
return fig

@check_types
@doc(
summary="Plot h12 GWSS data track with multiple traces overlaid.",
)
def plot_h12_gwss_multi_overlay_track(
self,
contig: base_params.contig,
cohorts: base_params.cohorts,
window_size: h12_params.multi_window_size,
cohort_size: Optional[base_params.cohort_size] = h12_params.cohort_size_default,
sample_query: Optional[base_params.sample_query] = None,
analysis: hap_params.analysis = base_params.DEFAULT,
min_cohort_size: Optional[
base_params.min_cohort_size
] = h12_params.min_cohort_size_default,
max_cohort_size: Optional[
base_params.max_cohort_size
] = h12_params.max_cohort_size_default,
title: Optional[gplt_params.title] = None,
sample_query_options: Optional[base_params.sample_query_options] = None,
sample_sets: Optional[base_params.sample_sets] = None,
colors: gplt_params.colors = bokeh.palettes.d3["Category10"][10],
random_seed: base_params.random_seed = 42,
sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default,
width: gplt_params.width = gplt_params.width_default,
height: gplt_params.height = 200,
show: gplt_params.show = True,
x_range: Optional[gplt_params.x_range] = None,
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
) -> gplt_params.figure:
cohort_queries = self._setup_cohort_queries(
cohorts=cohorts,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
cohort_size=cohort_size,
min_cohort_size=None,
)

if isinstance(window_size, int):
window_size = {k: window_size for k in cohort_queries.keys()}
elif isinstance(window_size, Mapping):
if set(window_size.keys()) != set(cohort_queries.keys()):
raise ValueError("Cohorts and window_sizes should have the same keys.")

# Compute H12.
res = {}
for cohort_label, cohort_query in cohort_queries.items():
res[cohort_label] = self.h12_gwss(
contig=contig,
analysis=analysis,
window_size=window_size[cohort_label],
cohort_size=cohort_size,
min_cohort_size=min_cohort_size,
max_cohort_size=max_cohort_size,
sample_query=cohort_query,
sample_sets=sample_sets,
random_seed=random_seed,
)

# Determine X axis range.
x, _ = res[list(cohort_queries.keys())[0]]
x_min = x[0]
x_max = x[-1]
if x_range is None:
x_range = bokeh.models.Range1d(x_min, x_max, bounds="auto")

# Create a figure.
xwheel_zoom = bokeh.models.WheelZoomTool(
dimensions="width", maintain_focus=False
)

fig = bokeh.plotting.figure(
title=title,
tools=[
"xpan",
"xzoom_in",
"xzoom_out",
xwheel_zoom,
"reset",
"save",
"crosshair",
],
active_inspect=None,
active_scroll=xwheel_zoom,
active_drag="xpan",
sizing_mode=sizing_mode,
width=width,
height=height,
toolbar_location="above",
x_range=x_range,
y_range=(0, 1),
output_backend=output_backend,
)

# Plot H12.
for i, (cohort_label, (x, h12)) in enumerate(res.items()):
fig.scatter(
x=x,
y=h12,
marker="circle",
size=3,
line_width=1,
line_color=colors[i % len(colors)],
fill_color=None,
legend_label=cohort_label,
)

# Tidy up the plot.
fig.yaxis.axis_label = "H12"
fig.yaxis.ticker = [0, 1]
self._bokeh_style_genome_xaxis(fig, contig)

if show: # pragma: no cover
bokeh.plotting.show(fig)
return None
else:
return fig

@check_types
@doc(
summary="Plot h12 GWSS data with multiple traces overlaid.",
)
def plot_h12_gwss_multi_overlay(
self,
contig: base_params.contig,
cohorts: base_params.cohorts,
window_size: h12_params.multi_window_size,
cohort_size: Optional[base_params.cohort_size] = h12_params.cohort_size_default,
sample_query: Optional[base_params.sample_query] = None,
analysis: hap_params.analysis = base_params.DEFAULT,
min_cohort_size: Optional[
base_params.min_cohort_size
] = h12_params.min_cohort_size_default,
max_cohort_size: Optional[
base_params.max_cohort_size
] = h12_params.max_cohort_size_default,
sample_query_options: Optional[base_params.sample_query_options] = None,
sample_sets: Optional[base_params.sample_sets] = None,
colors: gplt_params.colors = bokeh.palettes.d3["Category10"][10],
random_seed: base_params.random_seed = 42,
title: Optional[gplt_params.title] = None,
sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default,
width: gplt_params.width = gplt_params.width_default,
track_height: gplt_params.track_height = 170,
genes_height: gplt_params.genes_height = gplt_params.genes_height_default,
show: gplt_params.show = True,
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
) -> gplt_params.figure:
# Plot GWSS track.
fig1 = self.plot_h12_gwss_multi_overlay_track(
contig=contig,
sample_query=sample_query,
cohorts=cohorts,
cohort_size=cohort_size,
window_size=window_size,
analysis=analysis,
min_cohort_size=min_cohort_size,
max_cohort_size=max_cohort_size,
sample_query_options=sample_query_options,
sample_sets=sample_sets,
colors=colors,
random_seed=random_seed,
title=title,
sizing_mode=sizing_mode,
width=width,
height=track_height,
show=False,
output_backend=output_backend,
)

fig1.xaxis.visible = False
fig1.legend.location = "top_right"
fig1.legend.click_policy = "hide"

# Plot genes.
fig2 = self.plot_genes(
region=contig,
sizing_mode=sizing_mode,
width=width,
height=genes_height,
x_range=fig1.x_range,
show=False,
output_backend=output_backend,
)

# Combine plots into a single figure.
fig = bokeh.layouts.gridplot(
[fig1, fig2],
ncols=1,
toolbar_location="above",
merge_tools=True,
sizing_mode=sizing_mode,
toolbar_options=dict(active_inspect=None),
)

if show: # pragma: no cover
bokeh.plotting.show(fig)
return None
else:
return fig

@check_types
@doc(
summary="Plot h12 GWSS data with multiple tracks.",
)
def plot_h12_gwss_multi_panel(
self,
contig: base_params.contig,
cohorts: base_params.cohorts,
window_size: h12_params.multi_window_size,
cohort_size: Optional[base_params.cohort_size] = h12_params.cohort_size_default,
sample_query: Optional[base_params.sample_query] = None,
analysis: hap_params.analysis = base_params.DEFAULT,
min_cohort_size: Optional[
base_params.min_cohort_size
] = h12_params.min_cohort_size_default,
max_cohort_size: Optional[
base_params.max_cohort_size
] = h12_params.max_cohort_size_default,
sample_query_options: Optional[base_params.sample_query_options] = None,
sample_sets: Optional[base_params.sample_sets] = None,
random_seed: base_params.random_seed = 42,
sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default,
width: gplt_params.width = gplt_params.width_default,
track_height: gplt_params.track_height = 170,
genes_height: gplt_params.genes_height = gplt_params.genes_height_default,
show: gplt_params.show = True,
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
) -> gplt_params.figure:
cohort_queries = self._setup_cohort_queries(
cohorts=cohorts,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
cohort_size=cohort_size,
min_cohort_size=None,
)

if isinstance(window_size, int):
window_size = {k: window_size for k in cohort_queries.keys()}
elif isinstance(window_size, Mapping):
if set(window_size.keys()) != set(cohort_queries.keys()):
raise ValueError("Cohorts and window_sizes should have the same keys.")

# Plot GWSS track.
figs: list[gplt_params.def_figure] = []
for i, (cohort_label, cohort_query) in enumerate(cohort_queries.items()):
params = dict(
contig=contig,
analysis=analysis,
window_size=window_size[cohort_label],
sample_sets=sample_sets,
sample_query=cohort_query,
cohort_size=cohort_size,
min_cohort_size=min_cohort_size,
max_cohort_size=max_cohort_size,
random_seed=random_seed,
title=cohort_label, # Deal with a choice of titles later
sizing_mode=sizing_mode,
width=width,
height=track_height,
show=False,
output_backend=output_backend,
)
if i > 0:
track = self.plot_h12_gwss_track(x_range=figs[0].x_range, **params)
else:
track = self.plot_h12_gwss_track(**params)
track.xaxis.visible = False
figs.append(track)

# Plot genes.
fig2 = self.plot_genes(
region=contig,
sizing_mode=sizing_mode,
width=width,
height=genes_height,
x_range=figs[0].x_range,
show=False,
output_backend=output_backend,
)

figs.append(fig2)

# Combine plots into a single figure.
fig = bokeh.layouts.gridplot(
figs,
ncols=1,
toolbar_location="above",
merge_tools=True,
sizing_mode=sizing_mode,
toolbar_options=dict(active_inspect=None),
)

if show: # pragma: no cover
bokeh.plotting.show(fig)
return None
else:
return fig


def haplotype_frequencies(h):
"""Compute haplotype frequencies, returning a dictionary that maps
Expand Down
12 changes: 11 additions & 1 deletion malariagen_data/anoph/h12_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Parameter definitions for H12 analysis functions."""

from typing import Optional, Sequence
from typing import Optional, Sequence, Union

from typing_extensions import Annotated, TypeAlias

Expand All @@ -22,6 +22,16 @@
""",
]

multi_window_size: TypeAlias = Annotated[
Union[window_size, dict[str, int]],
"""
The size of windows (number of SNPs) used to calculate statistics within. Can
be a single value, in which case the same window size will be used for all
cohorts. Can also be a mapping from cohort identifiers to values, in case
you need to provide different window sizes for different cohorts.
""",
]

cohort_size_default: Optional[base_params.cohort_size] = None

min_cohort_size_default: base_params.min_cohort_size = 15
Expand Down
Loading

0 comments on commit 20efee0

Please sign in to comment.