From 7399694529e0bf3484b0a18aca7737d31f460805 Mon Sep 17 00:00:00 2001 From: Lennart Espe <3391295+lnsp@users.noreply.github.com> Date: Mon, 27 Nov 2023 00:32:45 +0100 Subject: [PATCH] Add subplot view on comparison page. --- trace_explorer/compare.py | 9 +-- trace_explorer/visualize.py | 47 +++++++++---- trace_explorer/web-templates/compare.html | 85 ++++++++++++++--------- trace_explorer/web.py | 33 ++++----- 4 files changed, 108 insertions(+), 66 deletions(-) diff --git a/trace_explorer/compare.py b/trace_explorer/compare.py index ee3da87..84406d8 100644 --- a/trace_explorer/compare.py +++ b/trace_explorer/compare.py @@ -21,6 +21,7 @@ def by_limiting_columns( figsize=(10, 10), cluster_figsize=(10, 30), cluster_path: str = 'cluster_%d.pdf', + cluster_subplots: bool = True, separate_overview: bool = False, highlight_clusters=[], highlight_path='highlight.pdf', @@ -76,7 +77,7 @@ def by_limiting_columns( legend=(1.04, 0.5), legendloc='center left', legendtitle=legendtitle) visualize.visualize(tsne, labels_auto, clusters_auto, - cluster_labels_auto, cluster_path % -1, + cluster_labels_auto, cluster_path % 'all', figsize=figsize, legend=(1.04, 0.5), legendloc='center left') else: @@ -89,10 +90,10 @@ def by_limiting_columns( if cluster_path is None: return - visualize.inspect_clusters(concatenated, pcad, tsne, + generated_cluster_plots = visualize.inspect_clusters(concatenated, pcad, tsne, cluster_figsize, cluster_path, clusters_auto, - cluster_labels_auto, labels_auto) + cluster_labels_auto, labels_auto, cluster_subplots) if len(highlight_clusters) != 0: visualize.highlight_clusters( @@ -100,7 +101,7 @@ def by_limiting_columns( highlight_path, highlight_clusters, clusters_auto, highlight_labels, labels_auto) - return len(cluster_labels_auto) + return len(cluster_labels_auto), generated_cluster_plots def by_imputing_columns(superset: pd.DataFrame, subset: pd.DataFrame, diff --git a/trace_explorer/visualize.py b/trace_explorer/visualize.py index 096c929..4319441 100644 --- a/trace_explorer/visualize.py +++ b/trace_explorer/visualize.py @@ -388,16 +388,22 @@ def inspect_clusters( original: pd.DataFrame, pcad: pd.DataFrame, embedding: pd.DataFrame, figsize: tuple[int], cluster_path: str, clusters: np.ndarray, cluster_names: list[str], - labels: np.ndarray): + labels: np.ndarray, as_subplots=True): # Generate N smaller subplots for each cluster, could be useful + plot_paths = {} for i in range(len(clusters)): - fig = plt.figure(figsize=figsize) - - # Generate label graph - gs = fig.add_gridspec(3, 1) - ax1 = fig.add_subplot(gs[0, 0]) - ax2 = fig.add_subplot(gs[1, 0]) - ax3 = fig.add_subplot(gs[2, 0]) + if as_subplots: + fig = plt.figure(figsize=figsize) + + # Generate label graph + gs = fig.add_gridspec(3, 1) + ax1 = fig.add_subplot(gs[0, 0]) + ax2 = fig.add_subplot(gs[1, 0]) + ax3 = fig.add_subplot(gs[2, 0]) + else: + fig1, ax1 = plt.subplots(figsize=figsize) + fig2, ax2 = plt.subplots(figsize=figsize) + fig3, ax3 = plt.subplots(figsize=figsize) labels_iter = (1 if labels[j] == clusters[i] else 0 for j in range(len(labels))) @@ -405,13 +411,28 @@ def inspect_clusters( clusters_local = np.array([0, 1]) description_local = np.array(['all', cluster_names[i]]) - _plot_clusters(ax1, embedding, labels_local, - clusters_local, description_local, show_legend=None) + lgd1 = _plot_clusters(ax1, embedding, labels_local, + clusters_local, description_local) lgd2 = _visualize_traits_as_barchart(ax2, original, labels, clusters[i]) lgd3 = _visualize_traits_grouped_by_pca(ax3, original, pcad, labels, clusters[i]) - plt.savefig(cluster_path % i, bbox_extra_artists=(lgd2, lgd3), - bbox_inches='tight') - plt.close(fig) + if as_subplots: + plot_paths[i] = cluster_path % i + plt.savefig(plot_paths[i], bbox_extra_artists=(lgd1, lgd2, lgd3), + bbox_inches='tight') + plt.close(fig) + else: + plot_paths[i] = { + 'cluster': cluster_path % ('%d_scatterplot' % i), + 'traits': cluster_path % ('%d_traits' % i), + 'traits_pca': cluster_path % ('%d_traits_pca' % i), + } + fig1.savefig(plot_paths[i]['cluster'], bbox_inches='tight', bbox_extra_artists=(lgd1,)) + plt.close(fig1) + fig2.savefig(plot_paths[i]['traits'], bbox_inches='tight', bbox_extra_artists=(lgd2,)) + plt.close(fig2) + fig3.savefig(plot_paths[i]['traits_pca'], bbox_inches='tight', bbox_extra_artists=(lgd3,)) + plt.close(fig3) + return plot_paths diff --git a/trace_explorer/web-templates/compare.html b/trace_explorer/web-templates/compare.html index 7f63090..63140cb 100644 --- a/trace_explorer/web-templates/compare.html +++ b/trace_explorer/web-templates/compare.html @@ -45,15 +45,20 @@ }); $data.plots.push({ title: 'All clusters', - content: response.clusters_overview, + content: response.clusters_all, }); for (var i = 0; i < response.clusters.length; i++) { $data.plots.push({ title: `Cluster ${i}`, - content: response.clusters[i], + subplots: [ + { title: 'Overview', content: response.clusters[i].cluster }, + { title: 'Traits', content: response.clusters[i].traits }, + { title: 'Traits (PCA)', content: response.clusters[i].traits_pca }, + ] }); } $data.selectedPlot = 0; + $data.selectedSubplot = 0; } catch (e) { showErrorModal(e); } finally { @@ -64,9 +69,8 @@ navigator.clipboard.writeText($('#cmd-string').text()) } -