diff --git a/docs/source/Af1.rst b/docs/source/Af1.rst index a4e14fd57..ace32c2b1 100644 --- a/docs/source/Af1.rst +++ b/docs/source/Af1.rst @@ -164,6 +164,8 @@ Genome-wide selection scans plot_h12_calibration h12_gwss plot_h12_gwss + plot_h12_gwss_multi_panel + plot_h12_gwss_multi_overlay h1x_gwss plot_h1x_gwss g123_calibration diff --git a/docs/source/Ag3.rst b/docs/source/Ag3.rst index 12869a5c7..17dbc522c 100644 --- a/docs/source/Ag3.rst +++ b/docs/source/Ag3.rst @@ -174,6 +174,8 @@ Genome-wide selection scans plot_h12_calibration h12_gwss plot_h12_gwss + plot_h12_gwss_multi_panel + plot_h12_gwss_multi_overlay h1x_gwss plot_h1x_gwss g123_calibration diff --git a/malariagen_data/anoph/gplt_params.py b/malariagen_data/anoph/gplt_params.py index 6612e981d..7e916ef4f 100644 --- a/malariagen_data/anoph/gplt_params.py +++ b/malariagen_data/anoph/gplt_params.py @@ -1,7 +1,7 @@ """Parameters for genome plotting functions. N.B., genome plots are always plotted with bokeh.""" -from typing import Literal, Mapping, Optional, Union +from typing import Literal, Mapping, Optional, Union, Sequence import bokeh.models from typing_extensions import Annotated, TypeAlias @@ -83,6 +83,14 @@ "A bokeh figure (only returned if show=False).", ] +def_figure: TypeAlias = Annotated[ + # Use quite a broad type here to accommodate both single-panel figures + # created via bokeh.plotting and multi-panel figures created via + # bokeh.layouts. + bokeh.model.Model, + "A bokeh figure.", +] + output_backend: TypeAlias = Annotated[ Literal["canvas", "webgl", "svg"], """ @@ -103,3 +111,5 @@ Mapping, "Passed through to bokeh line() function.", ] + +colors: TypeAlias = Annotated[Sequence[str], "List of colors."] diff --git a/malariagen_data/anoph/h12.py b/malariagen_data/anoph/h12.py index 3f984808c..f2db55c13 100644 --- a/malariagen_data/anoph/h12.py +++ b/malariagen_data/anoph/h12.py @@ -514,6 +514,307 @@ def plot_h12_gwss( else: return fig + @check_types + @doc( + summary="Plot h12 GWSS data track with multiple traces overlaid.", + ) + def plot_h12_gwss_multi_overlay_track( + self, + contig: base_params.contig, + cohorts: base_params.cohorts, + window_size: h12_params.multi_window_size, + cohort_size: Optional[base_params.cohort_size] = h12_params.cohort_size_default, + sample_query: Optional[base_params.sample_query] = None, + analysis: hap_params.analysis = base_params.DEFAULT, + min_cohort_size: Optional[ + base_params.min_cohort_size + ] = h12_params.min_cohort_size_default, + max_cohort_size: Optional[ + base_params.max_cohort_size + ] = h12_params.max_cohort_size_default, + title: Optional[gplt_params.title] = None, + sample_query_options: Optional[base_params.sample_query_options] = None, + sample_sets: Optional[base_params.sample_sets] = None, + colors: gplt_params.colors = bokeh.palettes.d3["Category10"][10], + random_seed: base_params.random_seed = 42, + sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default, + width: gplt_params.width = gplt_params.width_default, + height: gplt_params.height = 200, + show: gplt_params.show = True, + x_range: Optional[gplt_params.x_range] = None, + output_backend: gplt_params.output_backend = gplt_params.output_backend_default, + ) -> gplt_params.figure: + cohort_queries = self._setup_cohort_queries( + cohorts=cohorts, + sample_sets=sample_sets, + sample_query=sample_query, + sample_query_options=sample_query_options, + cohort_size=cohort_size, + min_cohort_size=None, + ) + + if isinstance(window_size, int): + window_size = {k: window_size for k in cohort_queries.keys()} + elif isinstance(window_size, Mapping): + if set(window_size.keys()) != set(cohort_queries.keys()): + raise ValueError("Cohorts and window_sizes should have the same keys.") + + # Compute H12. + res = {} + for cohort_label, cohort_query in cohort_queries.items(): + res[cohort_label] = self.h12_gwss( + contig=contig, + analysis=analysis, + window_size=window_size[cohort_label], + cohort_size=cohort_size, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + sample_query=cohort_query, + sample_sets=sample_sets, + random_seed=random_seed, + ) + + # Determine X axis range. + x, _ = res[list(cohort_queries.keys())[0]] + x_min = x[0] + x_max = x[-1] + if x_range is None: + x_range = bokeh.models.Range1d(x_min, x_max, bounds="auto") + + # Create a figure. + xwheel_zoom = bokeh.models.WheelZoomTool( + dimensions="width", maintain_focus=False + ) + + fig = bokeh.plotting.figure( + title=title, + tools=[ + "xpan", + "xzoom_in", + "xzoom_out", + xwheel_zoom, + "reset", + "save", + "crosshair", + ], + active_inspect=None, + active_scroll=xwheel_zoom, + active_drag="xpan", + sizing_mode=sizing_mode, + width=width, + height=height, + toolbar_location="above", + x_range=x_range, + y_range=(0, 1), + output_backend=output_backend, + ) + + # Plot H12. + for i, (cohort_label, (x, h12)) in enumerate(res.items()): + fig.scatter( + x=x, + y=h12, + marker="circle", + size=3, + line_width=1, + line_color=colors[i % len(colors)], + fill_color=None, + legend_label=cohort_label, + ) + + # Tidy up the plot. + fig.yaxis.axis_label = "H12" + fig.yaxis.ticker = [0, 1] + self._bokeh_style_genome_xaxis(fig, contig) + + if show: # pragma: no cover + bokeh.plotting.show(fig) + return None + else: + return fig + + @check_types + @doc( + summary="Plot h12 GWSS data with multiple traces overlaid.", + ) + def plot_h12_gwss_multi_overlay( + self, + contig: base_params.contig, + cohorts: base_params.cohorts, + window_size: h12_params.multi_window_size, + cohort_size: Optional[base_params.cohort_size] = h12_params.cohort_size_default, + sample_query: Optional[base_params.sample_query] = None, + analysis: hap_params.analysis = base_params.DEFAULT, + min_cohort_size: Optional[ + base_params.min_cohort_size + ] = h12_params.min_cohort_size_default, + max_cohort_size: Optional[ + base_params.max_cohort_size + ] = h12_params.max_cohort_size_default, + sample_query_options: Optional[base_params.sample_query_options] = None, + sample_sets: Optional[base_params.sample_sets] = None, + colors: gplt_params.colors = bokeh.palettes.d3["Category10"][10], + random_seed: base_params.random_seed = 42, + title: Optional[gplt_params.title] = None, + sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default, + width: gplt_params.width = gplt_params.width_default, + track_height: gplt_params.track_height = 170, + genes_height: gplt_params.genes_height = gplt_params.genes_height_default, + show: gplt_params.show = True, + output_backend: gplt_params.output_backend = gplt_params.output_backend_default, + ) -> gplt_params.figure: + # Plot GWSS track. + fig1 = self.plot_h12_gwss_multi_overlay_track( + contig=contig, + sample_query=sample_query, + cohorts=cohorts, + cohort_size=cohort_size, + window_size=window_size, + analysis=analysis, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + sample_query_options=sample_query_options, + sample_sets=sample_sets, + colors=colors, + random_seed=random_seed, + title=title, + sizing_mode=sizing_mode, + width=width, + height=track_height, + show=False, + output_backend=output_backend, + ) + + fig1.xaxis.visible = False + fig1.legend.location = "top_right" + fig1.legend.click_policy = "hide" + + # Plot genes. + fig2 = self.plot_genes( + region=contig, + sizing_mode=sizing_mode, + width=width, + height=genes_height, + x_range=fig1.x_range, + show=False, + output_backend=output_backend, + ) + + # Combine plots into a single figure. + fig = bokeh.layouts.gridplot( + [fig1, fig2], + ncols=1, + toolbar_location="above", + merge_tools=True, + sizing_mode=sizing_mode, + toolbar_options=dict(active_inspect=None), + ) + + if show: # pragma: no cover + bokeh.plotting.show(fig) + return None + else: + return fig + + @check_types + @doc( + summary="Plot h12 GWSS data with multiple tracks.", + ) + def plot_h12_gwss_multi_panel( + self, + contig: base_params.contig, + cohorts: base_params.cohorts, + window_size: h12_params.multi_window_size, + cohort_size: Optional[base_params.cohort_size] = h12_params.cohort_size_default, + sample_query: Optional[base_params.sample_query] = None, + analysis: hap_params.analysis = base_params.DEFAULT, + min_cohort_size: Optional[ + base_params.min_cohort_size + ] = h12_params.min_cohort_size_default, + max_cohort_size: Optional[ + base_params.max_cohort_size + ] = h12_params.max_cohort_size_default, + sample_query_options: Optional[base_params.sample_query_options] = None, + sample_sets: Optional[base_params.sample_sets] = None, + random_seed: base_params.random_seed = 42, + sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default, + width: gplt_params.width = gplt_params.width_default, + track_height: gplt_params.track_height = 170, + genes_height: gplt_params.genes_height = gplt_params.genes_height_default, + show: gplt_params.show = True, + output_backend: gplt_params.output_backend = gplt_params.output_backend_default, + ) -> gplt_params.figure: + cohort_queries = self._setup_cohort_queries( + cohorts=cohorts, + sample_sets=sample_sets, + sample_query=sample_query, + sample_query_options=sample_query_options, + cohort_size=cohort_size, + min_cohort_size=None, + ) + + if isinstance(window_size, int): + window_size = {k: window_size for k in cohort_queries.keys()} + elif isinstance(window_size, Mapping): + if set(window_size.keys()) != set(cohort_queries.keys()): + raise ValueError("Cohorts and window_sizes should have the same keys.") + + # Plot GWSS track. + figs: list[gplt_params.def_figure] = [] + for i, (cohort_label, cohort_query) in enumerate(cohort_queries.items()): + params = dict( + contig=contig, + analysis=analysis, + window_size=window_size[cohort_label], + sample_sets=sample_sets, + sample_query=cohort_query, + cohort_size=cohort_size, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + random_seed=random_seed, + title=cohort_label, # Deal with a choice of titles later + sizing_mode=sizing_mode, + width=width, + height=track_height, + show=False, + output_backend=output_backend, + ) + if i > 0: + track = self.plot_h12_gwss_track(x_range=figs[0].x_range, **params) + else: + track = self.plot_h12_gwss_track(**params) + track.xaxis.visible = False + figs.append(track) + + # Plot genes. + fig2 = self.plot_genes( + region=contig, + sizing_mode=sizing_mode, + width=width, + height=genes_height, + x_range=figs[0].x_range, + show=False, + output_backend=output_backend, + ) + + figs.append(fig2) + + # Combine plots into a single figure. + fig = bokeh.layouts.gridplot( + figs, + ncols=1, + toolbar_location="above", + merge_tools=True, + sizing_mode=sizing_mode, + toolbar_options=dict(active_inspect=None), + ) + + if show: # pragma: no cover + bokeh.plotting.show(fig) + return None + else: + return fig + def haplotype_frequencies(h): """Compute haplotype frequencies, returning a dictionary that maps diff --git a/malariagen_data/anoph/h12_params.py b/malariagen_data/anoph/h12_params.py index c15d42648..0116bad23 100644 --- a/malariagen_data/anoph/h12_params.py +++ b/malariagen_data/anoph/h12_params.py @@ -1,6 +1,6 @@ """Parameter definitions for H12 analysis functions.""" -from typing import Optional, Sequence +from typing import Optional, Sequence, Union from typing_extensions import Annotated, TypeAlias @@ -22,6 +22,16 @@ """, ] +multi_window_size: TypeAlias = Annotated[ + Union[window_size, dict[str, int]], + """ + The size of windows (number of SNPs) used to calculate statistics within. Can + be a single value, in which case the same window size will be used for all + cohorts. Can also be a mapping from cohort identifiers to values, in case + you need to provide different window sizes for different cohorts. + """, +] + cohort_size_default: Optional[base_params.cohort_size] = None min_cohort_size_default: base_params.min_cohort_size = 15 diff --git a/notebooks/plot_h12_h1x.ipynb b/notebooks/plot_h12_h1x.ipynb index 500ac63cc..3586907de 100644 --- a/notebooks/plot_h12_h1x.ipynb +++ b/notebooks/plot_h12_h1x.ipynb @@ -4,7 +4,9 @@ "cell_type": "code", "execution_count": null, "id": "ced974cd", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "import numpy as np" @@ -14,7 +16,9 @@ "cell_type": "code", "execution_count": null, "id": "9d17216f", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "import malariagen_data" @@ -32,7 +36,9 @@ "cell_type": "code", "execution_count": null, "id": "65b7ff2f", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "ag3 = malariagen_data.Ag3(\n", @@ -48,7 +54,9 @@ "cell_type": "code", "execution_count": null, "id": "ac276861-43fb-4f09-b004-57487d75aa41", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "!rm -rf results_cache" @@ -58,7 +66,9 @@ "cell_type": "code", "execution_count": null, "id": "a7c10446", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "coh1 = \"ML-2_Kati_colu_2014\"\n", @@ -72,7 +82,9 @@ "cell_type": "code", "execution_count": null, "id": "b0d308fe-4e8a-412f-9487-7f1c29d3323a", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "ag3.plot_h12_calibration(\n", @@ -87,7 +99,9 @@ "cell_type": "code", "execution_count": null, "id": "63cb17ff", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "ag3.plot_h12_gwss(\n", @@ -104,7 +118,9 @@ "cell_type": "code", "execution_count": null, "id": "cd125828", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "ag3.plot_h12_gwss(\n", @@ -122,7 +138,9 @@ "cell_type": "code", "execution_count": null, "id": "5716e777", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "ag3.plot_h12_gwss(\n", @@ -139,7 +157,9 @@ "cell_type": "code", "execution_count": null, "id": "fe11bb35", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "ag3.plot_h1x_gwss(\n", @@ -153,6 +173,56 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "73e0e746-add1-4774-bced-461c9742bd95", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "ag3.count_samples(sample_sets=\"AG1000G-ML-A\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "67f48075-6b79-4722-82f7-278bd0868410", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "ag3.plot_h12_gwss_multi_overlay(\n", + " contig=contig,\n", + " window_size=2000,\n", + " cohorts=\"admin2_year\",\n", + " sample_sets=\"AG1000G-ML-A\",\n", + " analysis=\"gamb_colu\",\n", + " cohort_size=20,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3e4622b-31e9-444a-8a09-1590862709fe", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "ag3.plot_h12_gwss_multi_panel(\n", + " contig=contig,\n", + " window_size=2000,\n", + " cohorts=\"admin2_year\",\n", + " sample_sets=\"AG1000G-ML-A\",\n", + " analysis=\"gamb_colu\",\n", + " cohort_size=20,\n", + ")" + ] + }, { "cell_type": "markdown", "id": "2801861c", @@ -223,6 +293,61 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "981ba052-1960-4f2e-ba91-72ac442ca178", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "cohorts = {\n", + " \"cohort1\": coh1_query,\n", + " \"cohort2\": coh2_query,\n", + "}\n", + "window_size = {\n", + " \"cohort1\": 2000,\n", + " \"cohort2\": 1500,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3296b4c-cd72-436f-b21a-d210c7520c39", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "af1.plot_h12_gwss_multi_overlay(\n", + " contig=contig,\n", + " window_size=window_size,\n", + " cohorts=cohorts,\n", + " sample_sets=\"1.0\",\n", + " cohort_size=20,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a60cc83-cf30-4152-9eb0-dae7c00ac41b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "af1.plot_h12_gwss_multi_panel(\n", + " contig=contig,\n", + " window_size=window_size,\n", + " cohorts=cohorts,\n", + " sample_sets=\"1.0\",\n", + " cohort_size=20,\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -250,8 +375,14 @@ } ], "metadata": { + "environment": { + "kernel": "python3", + "name": "workbench-notebooks.m125", + "type": "gcloud", + "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m125" + }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "malariagen-data-python", "language": "python", "name": "python3" }, @@ -265,7 +396,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.15" }, "vscode": { "interpreter": { diff --git a/tests/anoph/test_h12.py b/tests/anoph/test_h12.py index d0a2d9980..ae47b74e4 100644 --- a/tests/anoph/test_h12.py +++ b/tests/anoph/test_h12.py @@ -149,6 +149,13 @@ def check_h12_gwss(*, api, h12_params): assert isinstance(fig, bokeh.models.GridPlot) +def check_h12_gwss_multi(*, api, h12_params): + fig = api.plot_h12_gwss_multi_overlay(**h12_params, show=False) + assert isinstance(fig, bokeh.models.GridPlot) + fig = api.plot_h12_gwss_multi_panel(**h12_params, show=False) + assert isinstance(fig, bokeh.models.GridPlot) + + @parametrize_with_cases("fixture,api", cases=".") def test_h12_gwss_with_default_analysis(fixture, api: AnophelesH12Analysis): # Set up test parameters. @@ -211,3 +218,104 @@ def test_h12_gwss_with_analysis(fixture, api: AnophelesH12Analysis): window_size=window_size, min_cohort_size=n_samples + 1, ) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_h12_gwss_multi_with_default_analysis(fixture, api: AnophelesH12Analysis): + # Set up test parameters. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + all_countries = api.sample_metadata()["country"].unique().tolist() + country1, country2 = random.sample(all_countries, 2) + cohort1_query = f"country == '{country1}'" + cohort2_query = f"country == '{country2}'" + h12_params = dict( + contig=random.choice(api.contigs), + sample_sets=all_sample_sets, + window_size=random.randint(100, 500), + min_cohort_size=1, + cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, + ) + + # Run checks. + check_h12_gwss_multi(api=api, h12_params=h12_params) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_h12_gwss_multi_with_window_size_dict(fixture, api: AnophelesH12Analysis): + # Set up test parameters. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + all_countries = api.sample_metadata()["country"].unique().tolist() + country1, country2 = random.sample(all_countries, 2) + cohort1_query = f"country == '{country1}'" + cohort2_query = f"country == '{country2}'" + h12_params = dict( + contig=random.choice(api.contigs), + sample_sets=all_sample_sets, + window_size={ + "cohort1": random.randint(100, 500), + "cohort2": random.randint(100, 500), + }, + min_cohort_size=1, + cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, + ) + + # Run checks. + check_h12_gwss_multi(api=api, h12_params=h12_params) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_h12_gwss_multi_with_analysis(fixture, api: AnophelesH12Analysis): + # Set up test parameters. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + all_countries = api.sample_metadata()["country"].unique().tolist() + country1, country2 = random.sample(all_countries, 2) + cohort1_query = f"country == '{country1}'" + cohort2_query = f"country == '{country2}'" + contig = random.choice(api.contigs) + + for analysis in api.phasing_analysis_ids: + # Check if any samples available for the given phasing analysis. + try: + ds_hap1 = api.haplotypes( + sample_sets=all_sample_sets, + sample_query=cohort1_query, + analysis=analysis, + region=contig, + ) + except ValueError: + n1 = 0 + else: + n1 = ds_hap1.sizes["samples"] + try: + ds_hap2 = api.haplotypes( + sample_sets=all_sample_sets, + sample_query=cohort2_query, + analysis=analysis, + region=contig, + ) + except ValueError: + n2 = 0 + else: + n2 = ds_hap2.sizes["samples"] + + if n1 > 0 and n2 > 0: + # Samples are available, run full checks. + h12_params = dict( + analysis=analysis, + contig=contig, + sample_sets=all_sample_sets, + window_size=random.randint(100, 500), + min_cohort_size=min(n1, n2), + cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, + ) + check_h12_gwss_multi(api=api, h12_params=h12_params) + + # Check min_cohort_size behaviour. + params = h12_params.copy() + params["min_cohort_size"] = n1 + 1 + with pytest.raises(ValueError): + api.plot_h12_gwss_multi_overlay(**params) + params = h12_params.copy() + params["min_cohort_size"] = n2 + 1 + with pytest.raises(ValueError): + api.plot_h12_gwss_multi_panel(**params)