From 7399694529e0bf3484b0a18aca7737d31f460805 Mon Sep 17 00:00:00 2001 From: Lennart Espe <3391295+lnsp@users.noreply.github.com> Date: Mon, 27 Nov 2023 00:32:45 +0100 Subject: [PATCH] Add subplot view on comparison page. --- trace_explorer/compare.py | 9 +-- trace_explorer/visualize.py | 47 +++++++++---- trace_explorer/web-templates/compare.html | 85 ++++++++++++++--------- trace_explorer/web.py | 33 ++++----- 4 files changed, 108 insertions(+), 66 deletions(-) diff --git a/trace_explorer/compare.py b/trace_explorer/compare.py index ee3da87..84406d8 100644 --- a/trace_explorer/compare.py +++ b/trace_explorer/compare.py @@ -21,6 +21,7 @@ def by_limiting_columns( figsize=(10, 10), cluster_figsize=(10, 30), cluster_path: str = 'cluster_%d.pdf', + cluster_subplots: bool = True, separate_overview: bool = False, highlight_clusters=[], highlight_path='highlight.pdf', @@ -76,7 +77,7 @@ def by_limiting_columns( legend=(1.04, 0.5), legendloc='center left', legendtitle=legendtitle) visualize.visualize(tsne, labels_auto, clusters_auto, - cluster_labels_auto, cluster_path % -1, + cluster_labels_auto, cluster_path % 'all', figsize=figsize, legend=(1.04, 0.5), legendloc='center left') else: @@ -89,10 +90,10 @@ def by_limiting_columns( if cluster_path is None: return - visualize.inspect_clusters(concatenated, pcad, tsne, + generated_cluster_plots = visualize.inspect_clusters(concatenated, pcad, tsne, cluster_figsize, cluster_path, clusters_auto, - cluster_labels_auto, labels_auto) + cluster_labels_auto, labels_auto, cluster_subplots) if len(highlight_clusters) != 0: visualize.highlight_clusters( @@ -100,7 +101,7 @@ def by_limiting_columns( highlight_path, highlight_clusters, clusters_auto, highlight_labels, labels_auto) - return len(cluster_labels_auto) + return len(cluster_labels_auto), generated_cluster_plots def by_imputing_columns(superset: pd.DataFrame, subset: pd.DataFrame, diff --git a/trace_explorer/visualize.py b/trace_explorer/visualize.py index 096c929..4319441 100644 --- a/trace_explorer/visualize.py +++ b/trace_explorer/visualize.py @@ -388,16 +388,22 @@ def inspect_clusters( original: pd.DataFrame, pcad: pd.DataFrame, embedding: pd.DataFrame, figsize: tuple[int], cluster_path: str, clusters: np.ndarray, cluster_names: list[str], - labels: np.ndarray): + labels: np.ndarray, as_subplots=True): # Generate N smaller subplots for each cluster, could be useful + plot_paths = {} for i in range(len(clusters)): - fig = plt.figure(figsize=figsize) - - # Generate label graph - gs = fig.add_gridspec(3, 1) - ax1 = fig.add_subplot(gs[0, 0]) - ax2 = fig.add_subplot(gs[1, 0]) - ax3 = fig.add_subplot(gs[2, 0]) + if as_subplots: + fig = plt.figure(figsize=figsize) + + # Generate label graph + gs = fig.add_gridspec(3, 1) + ax1 = fig.add_subplot(gs[0, 0]) + ax2 = fig.add_subplot(gs[1, 0]) + ax3 = fig.add_subplot(gs[2, 0]) + else: + fig1, ax1 = plt.subplots(figsize=figsize) + fig2, ax2 = plt.subplots(figsize=figsize) + fig3, ax3 = plt.subplots(figsize=figsize) labels_iter = (1 if labels[j] == clusters[i] else 0 for j in range(len(labels))) @@ -405,13 +411,28 @@ def inspect_clusters( clusters_local = np.array([0, 1]) description_local = np.array(['all', cluster_names[i]]) - _plot_clusters(ax1, embedding, labels_local, - clusters_local, description_local, show_legend=None) + lgd1 = _plot_clusters(ax1, embedding, labels_local, + clusters_local, description_local) lgd2 = _visualize_traits_as_barchart(ax2, original, labels, clusters[i]) lgd3 = _visualize_traits_grouped_by_pca(ax3, original, pcad, labels, clusters[i]) - plt.savefig(cluster_path % i, bbox_extra_artists=(lgd2, lgd3), - bbox_inches='tight') - plt.close(fig) + if as_subplots: + plot_paths[i] = cluster_path % i + plt.savefig(plot_paths[i], bbox_extra_artists=(lgd1, lgd2, lgd3), + bbox_inches='tight') + plt.close(fig) + else: + plot_paths[i] = { + 'cluster': cluster_path % ('%d_scatterplot' % i), + 'traits': cluster_path % ('%d_traits' % i), + 'traits_pca': cluster_path % ('%d_traits_pca' % i), + } + fig1.savefig(plot_paths[i]['cluster'], bbox_inches='tight', bbox_extra_artists=(lgd1,)) + plt.close(fig1) + fig2.savefig(plot_paths[i]['traits'], bbox_inches='tight', bbox_extra_artists=(lgd2,)) + plt.close(fig2) + fig3.savefig(plot_paths[i]['traits_pca'], bbox_inches='tight', bbox_extra_artists=(lgd3,)) + plt.close(fig3) + return plot_paths diff --git a/trace_explorer/web-templates/compare.html b/trace_explorer/web-templates/compare.html index 7f63090..63140cb 100644 --- a/trace_explorer/web-templates/compare.html +++ b/trace_explorer/web-templates/compare.html @@ -45,15 +45,20 @@ }); $data.plots.push({ title: 'All clusters', - content: response.clusters_overview, + content: response.clusters_all, }); for (var i = 0; i < response.clusters.length; i++) { $data.plots.push({ title: `Cluster ${i}`, - content: response.clusters[i], + subplots: [ + { title: 'Overview', content: response.clusters[i].cluster }, + { title: 'Traits', content: response.clusters[i].traits }, + { title: 'Traits (PCA)', content: response.clusters[i].traits_pca }, + ] }); } $data.selectedPlot = 0; + $data.selectedSubplot = 0; } catch (e) { showErrorModal(e); } finally { @@ -64,9 +69,8 @@ navigator.clipboard.writeText($('#cmd-string').text()) } -
+" + x-init="listSources($data)">
@@ -199,20 +205,37 @@
Output
- +
+ + +
+
@@ -236,25 +259,25 @@
- trace_explorer compare - - + trace_explorer compare + + + + + +
diff --git a/trace_explorer/web.py b/trace_explorer/web.py index 0865b6b..942cf0a 100644 --- a/trace_explorer/web.py +++ b/trace_explorer/web.py @@ -167,30 +167,27 @@ def compare_with_params(): # generate compare files datasets = [pd.read_parquet(s) for s in sources] overview_path = os.path.join(tempdir, 'overview.png') - cluster_path = os.path.join(tempdir, 'cluster_%d.png') - n_clusters = compare.by_limiting_columns( + cluster_path = os.path.join(tempdir, 'cluster_%s.png') + n_clusters, cluster_path_set = compare.by_limiting_columns( datasets, excluded_columns, overview_path, iterations, perplexity, sources, threshold, - cluster_path=cluster_path, separate_overview=True) + cluster_path=cluster_path, separate_overview=True, + cluster_subplots=False, cluster_figsize=(10, 10)) # read pngs into base64 - overview_data = None - cluster_overview_data = None - cluster_data = [] - + response = {} with open(overview_path, 'rb') as f: - overview_data = f.read() - with open(cluster_path % -1, 'rb') as f: - cluster_overview_data = f.read() + response['overview'] = b64encode(f.read()).decode('utf-8') + with open(cluster_path % 'all', 'rb') as f: + response['clusters_all'] = b64encode(f.read()).decode('utf-8') + response['clusters'] = [] for i in range(n_clusters): - with open(cluster_path % i, 'rb') as f: - cluster_data.append(f.read()) - - return { - 'overview': b64encode(overview_data).decode('utf-8'), - 'clusters_overview': b64encode(cluster_overview_data).decode('utf-8'), - 'clusters': [b64encode(c).decode('utf-8') for c in cluster_data], - } + cluster_data = {} + for (plot_type, path) in cluster_path_set[i].items(): + with open(path, 'rb') as f: + cluster_data[plot_type] = b64encode(f.read()).decode('utf-8') + response['clusters'].append(cluster_data) + return response @app.route('/visualize', methods=['POST'])