Skip to content

Commit

Permalink
Add subplot view on comparison page.
Browse files Browse the repository at this point in the history
  • Loading branch information
lnsp committed Nov 26, 2023
1 parent 9747135 commit 7399694
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 66 deletions.
9 changes: 5 additions & 4 deletions trace_explorer/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def by_limiting_columns(
figsize=(10, 10),
cluster_figsize=(10, 30),
cluster_path: str = 'cluster_%d.pdf',
cluster_subplots: bool = True,
separate_overview: bool = False,
highlight_clusters=[],
highlight_path='highlight.pdf',
Expand Down Expand Up @@ -76,7 +77,7 @@ def by_limiting_columns(
legend=(1.04, 0.5), legendloc='center left',
legendtitle=legendtitle)
visualize.visualize(tsne, labels_auto, clusters_auto,
cluster_labels_auto, cluster_path % -1,
cluster_labels_auto, cluster_path % 'all',
figsize=figsize,
legend=(1.04, 0.5), legendloc='center left')
else:
Expand All @@ -89,18 +90,18 @@ def by_limiting_columns(
if cluster_path is None:
return

visualize.inspect_clusters(concatenated, pcad, tsne,
generated_cluster_plots = visualize.inspect_clusters(concatenated, pcad, tsne,
cluster_figsize,
cluster_path, clusters_auto,
cluster_labels_auto, labels_auto)
cluster_labels_auto, labels_auto, cluster_subplots)

if len(highlight_clusters) != 0:
visualize.highlight_clusters(
concatenated, pcad, tsne, figsize,
highlight_path, highlight_clusters,
clusters_auto, highlight_labels, labels_auto)

return len(cluster_labels_auto)
return len(cluster_labels_auto), generated_cluster_plots


def by_imputing_columns(superset: pd.DataFrame, subset: pd.DataFrame,
Expand Down
47 changes: 34 additions & 13 deletions trace_explorer/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,30 +388,51 @@ def inspect_clusters(
original: pd.DataFrame, pcad: pd.DataFrame,
embedding: pd.DataFrame, figsize: tuple[int],
cluster_path: str, clusters: np.ndarray, cluster_names: list[str],
labels: np.ndarray):
labels: np.ndarray, as_subplots=True):
# Generate N smaller subplots for each cluster, could be useful
plot_paths = {}
for i in range(len(clusters)):
fig = plt.figure(figsize=figsize)

# Generate label graph
gs = fig.add_gridspec(3, 1)
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[1, 0])
ax3 = fig.add_subplot(gs[2, 0])
if as_subplots:
fig = plt.figure(figsize=figsize)

# Generate label graph
gs = fig.add_gridspec(3, 1)
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[1, 0])
ax3 = fig.add_subplot(gs[2, 0])
else:
fig1, ax1 = plt.subplots(figsize=figsize)
fig2, ax2 = plt.subplots(figsize=figsize)
fig3, ax3 = plt.subplots(figsize=figsize)

labels_iter = (1 if labels[j] == clusters[i] else 0
for j in range(len(labels)))
labels_local = np.fromiter(labels_iter, dtype=int)
clusters_local = np.array([0, 1])
description_local = np.array(['all', cluster_names[i]])

_plot_clusters(ax1, embedding, labels_local,
clusters_local, description_local, show_legend=None)
lgd1 = _plot_clusters(ax1, embedding, labels_local,
clusters_local, description_local)
lgd2 = _visualize_traits_as_barchart(ax2, original,
labels, clusters[i])
lgd3 = _visualize_traits_grouped_by_pca(ax3, original, pcad,
labels, clusters[i])

plt.savefig(cluster_path % i, bbox_extra_artists=(lgd2, lgd3),
bbox_inches='tight')
plt.close(fig)
if as_subplots:
plot_paths[i] = cluster_path % i
plt.savefig(plot_paths[i], bbox_extra_artists=(lgd1, lgd2, lgd3),
bbox_inches='tight')
plt.close(fig)
else:
plot_paths[i] = {
'cluster': cluster_path % ('%d_scatterplot' % i),
'traits': cluster_path % ('%d_traits' % i),
'traits_pca': cluster_path % ('%d_traits_pca' % i),
}
fig1.savefig(plot_paths[i]['cluster'], bbox_inches='tight', bbox_extra_artists=(lgd1,))
plt.close(fig1)
fig2.savefig(plot_paths[i]['traits'], bbox_inches='tight', bbox_extra_artists=(lgd2,))
plt.close(fig2)
fig3.savefig(plot_paths[i]['traits_pca'], bbox_inches='tight', bbox_extra_artists=(lgd3,))
plt.close(fig3)
return plot_paths
85 changes: 54 additions & 31 deletions trace_explorer/web-templates/compare.html
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,20 @@
});
$data.plots.push({
title: 'All clusters',
content: response.clusters_overview,
content: response.clusters_all,
});
for (var i = 0; i < response.clusters.length; i++) {
$data.plots.push({
title: `Cluster ${i}`,
content: response.clusters[i],
subplots: [
{ title: 'Overview', content: response.clusters[i].cluster },
{ title: 'Traits', content: response.clusters[i].traits },
{ title: 'Traits (PCA)', content: response.clusters[i].traits_pca },
]
});
}
$data.selectedPlot = 0;
$data.selectedSubplot = 0;
} catch (e) {
showErrorModal(e);
} finally {
Expand All @@ -64,9 +69,8 @@
navigator.clipboard.writeText($('#cmd-string').text())
}
</script>
<div
class="flex relative flex-grow overflow-x-hidden"
x-data="
<div class="flex relative flex-grow overflow-x-hidden"
x-data="
{
sources: [],
selectedSources: [],
Expand All @@ -78,9 +82,11 @@
inflight: false,
plots: [],
selectedPlot: undefined,
selectedSubplot: undefined,
clipboard: '',
}
" x-init="listSources($data)">
"
x-init="listSources($data)">
<div class="p-4 border-r border-gray-300 shrink-0 w-96">
<form class="flex flex-col gap-4 h-full"
@submit.prevent="updateComparisonView($data)">
Expand Down Expand Up @@ -199,20 +205,37 @@
<div class="flex flex-grow flex-col min-h-full relative min-w-0">
<div class="text-lg font-medium p-4 border-b border-gray-300 flex justify-between">
<div>Output</div>
<select class="border border-gray-900 h-8 px-2 mx-2"
x-model="selectedPlot">
<template x-for="(plot, index) in plots">
<option x-text="plot.title"
:value="index"></option>
</template>
</select>
<div>
<select class="border border-gray-900 h-8 px-2 mx-2"
x-show="plots[selectedPlot].subplots"
x-model="selectedSubplot">
<template x-for="(plot, index) in plots[selectedPlot].subplots">
<option x-text="plot.title"
:value="index"></option>
</template>
</select>
<select class="border border-gray-900 h-8 px-2 mx-2"
x-model="selectedPlot">
<template x-for="(plot, index) in plots">
<option x-text="plot.title"
:value="index"></option>
</template>
</select>
</div>
</div>
<div class="w-full overflow-y-auto flex-grow relative">
<template x-if="plots[selectedPlot]">
<img id="output"
x-show="plots[selectedPlot].content"
:src="`data:image/png;base64,${plots[selectedPlot].content}`"
class="max-h-full">
</template>
<template x-if="plots[selectedPlot] && plots[selectedPlot].subplots[selectedSubplot]">
<img id="output-subplot"
x-show="plots[selectedPlot].subplots[selectedSubplot].content"
:src="`data:image/png;base64,${plots[selectedPlot].subplots[selectedSubplot].content}`"
class="max-h-full">
</template>
</div>
<div class="border-t border-gray-300 p-4">
<div class="text-lg font-medium mb-4 flex gap-4 items-center">
Expand All @@ -236,25 +259,25 @@
</div>
<div id="cmd-string"
class="font-mono whitespace-nowrap bg-gray-100 text-sm overflow-x-scroll p-2 flex gap-4 text-gray-900 underline-offset-4">
<span>trace_explorer compare</span>
<template x-for="(source, index) in selectedSources">
<span>
<span x-show="index === 0">--superset</span>
<span x-show="index > 0">--subset</span>
<span class="text-gray-600 underline decoration-dashed"
x-text="source"></span>
</span>
</template>
<span x-text="`--threshold ${threshold}`"
class="text-gray-600 underline decoration-dashed"></span>
<span>trace_explorer compare</span>
<template x-for="(source, index) in selectedSources">
<span>
<span x-show="index === 0">--superset</span>
<span x-show="index > 0">--subset</span>
<span class="text-gray-600 underline decoration-dashed"
x-text="source"></span>
</span>
<span x-text="`--tsne_perplexity ${perplexity}`"
class="text-gray-600 underline decoration-dashed"></span>
<span x-text="`--tsne_n_iter ${iterations}`"
class="text-gray-600 underline decoration-dashed"></span>
<template x-for="column in excludedColumns">
<span x-text="`--exclude ${column}`"></span>
</template>
</template>
<span x-text="`--threshold ${threshold}`"
class="text-gray-600 underline decoration-dashed"></span>
</span>
<span x-text="`--tsne_perplexity ${perplexity}`"
class="text-gray-600 underline decoration-dashed"></span>
<span x-text="`--tsne_n_iter ${iterations}`"
class="text-gray-600 underline decoration-dashed"></span>
<template x-for="column in excludedColumns">
<span x-text="`--exclude ${column}`"></span>
</template>
</div>
</div>
</div>
Expand Down
33 changes: 15 additions & 18 deletions trace_explorer/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,30 +167,27 @@ def compare_with_params():
# generate compare files
datasets = [pd.read_parquet(s) for s in sources]
overview_path = os.path.join(tempdir, 'overview.png')
cluster_path = os.path.join(tempdir, 'cluster_%d.png')
n_clusters = compare.by_limiting_columns(
cluster_path = os.path.join(tempdir, 'cluster_%s.png')
n_clusters, cluster_path_set = compare.by_limiting_columns(
datasets, excluded_columns, overview_path,
iterations, perplexity, sources, threshold,
cluster_path=cluster_path, separate_overview=True)
cluster_path=cluster_path, separate_overview=True,
cluster_subplots=False, cluster_figsize=(10, 10))

# read pngs into base64
overview_data = None
cluster_overview_data = None
cluster_data = []

response = {}
with open(overview_path, 'rb') as f:
overview_data = f.read()
with open(cluster_path % -1, 'rb') as f:
cluster_overview_data = f.read()
response['overview'] = b64encode(f.read()).decode('utf-8')
with open(cluster_path % 'all', 'rb') as f:
response['clusters_all'] = b64encode(f.read()).decode('utf-8')
response['clusters'] = []
for i in range(n_clusters):
with open(cluster_path % i, 'rb') as f:
cluster_data.append(f.read())

return {
'overview': b64encode(overview_data).decode('utf-8'),
'clusters_overview': b64encode(cluster_overview_data).decode('utf-8'),
'clusters': [b64encode(c).decode('utf-8') for c in cluster_data],
}
cluster_data = {}
for (plot_type, path) in cluster_path_set[i].items():
with open(path, 'rb') as f:
cluster_data[plot_type] = b64encode(f.read()).decode('utf-8')
response['clusters'].append(cluster_data)
return response


@app.route('/visualize', methods=['POST'])
Expand Down

0 comments on commit 7399694

Please sign in to comment.