diff --git a/.gitignore b/.gitignore index f8360ca3..745df54f 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,7 @@ nosetests.xml coverage.xml *.cover .hypothesis/ +test-reports/ # Translations *.mo @@ -102,4 +103,7 @@ ENV/ .vscode -.DS_Store \ No newline at end of file +.DS_Store + +test-reports/ + diff --git a/asap/dataimport/make_montage_scapes_stack.py b/asap/dataimport/make_montage_scapes_stack.py index f859bb79..53e96b2d 100644 --- a/asap/dataimport/make_montage_scapes_stack.py +++ b/asap/dataimport/make_montage_scapes_stack.py @@ -237,7 +237,7 @@ def run(self): # import tilespecs to render self.render.run(renderapi.client.import_jsonfiles_parallel, self.output_stack, - jsonfiles) + jsonfiles, close_stack=False) if self.close_stack: # set stack state to complete diff --git a/asap/em_montage_qc/detect_montage_defects.py b/asap/em_montage_qc/detect_montage_defects.py index 0bab75e6..82a1dc98 100644 --- a/asap/em_montage_qc/detect_montage_defects.py +++ b/asap/em_montage_qc/detect_montage_defects.py @@ -1,13 +1,15 @@ from functools import partial import time +import igraph import networkx as nx import numpy as np import renderapi +import renderapi.utils import requests -from rtree import index as rindex -from six import viewkeys from scipy.spatial import cKDTree +import shapely +import shapely.strtree from asap.residuals import compute_residuals as cr from asap.em_montage_qc.schemas import ( @@ -17,10 +19,10 @@ from asap.em_montage_qc.plots import plot_section_maps from asap.em_montage_qc.distorted_montages import ( - do_get_z_scales_nopm, - get_z_scales_nopm, + get_scales_from_tilespecs, + get_rts_fallthrough, get_scale_statistics_mad - ) +) example = { "render": { @@ -41,157 +43,277 @@ } -def detect_seams( - render, stack, match_collection, match_owner, z, - residual_threshold=8, distance=60, min_cluster_size=15, tspecs=None): - # seams will always be computed for montages using montage point matches - # but the input stack can be either montage, rough, or fine - # Compute residuals and other stats for this z - stats, allmatches = cr.compute_residuals_within_group( - render, stack, match_owner, match_collection, z, tilespecs=tspecs) +# TODO methods for tile borders/boundary points should go in render-python +def determine_numX_numY_rectangle(width, height, meshcellsize=32): + numX = max([2, np.around(width / meshcellsize)]) + numY = max([2, np.around(height / meshcellsize)]) + return int(numX), int(numY) + + +def determine_numX_numY_triangle(width, height, meshcellsize=32): + # TODO do we always want width to define the geometry? + numX = max([2, np.around(width / meshcellsize)]) + numY = max([ + 2, + np.around( + height / + ( + 2 * np.sqrt(3. / 4. * (width / (numX - 1)) ** 2) + ) + 1) + ]) + return int(numX), int(numY) + + +def generate_border_mesh_pts( + width, height, meshcellsize=64, mesh_type="square", **kwargs): + numfunc = { + "square": determine_numX_numY_rectangle, + "triangle": determine_numX_numY_triangle + }[mesh_type] + numX, numY = numfunc(width, height, meshcellsize) + + xs = np.linspace(0, width - 1, numX).reshape(-1, 1) + ys = np.linspace(0, height - 1, numY).reshape(-1, 1) + + perim = np.vstack([ + np.hstack([xs, np.zeros(xs.shape)]), + np.hstack([np.ones(ys.shape) * float(height - 1), ys])[1:-1], + np.hstack([xs, np.ones(xs.shape) * float(width - 1)])[::-1], + np.hstack([np.zeros(ys.shape), ys])[1:-1][::-1], + + ]) + return perim + + +def polygon_from_ts(ts, ref_tforms, **kwargs): + tsarr = generate_border_mesh_pts(ts.width, ts.height, **kwargs) + return shapely.geometry.Polygon( + renderapi.transform.estimate_dstpts( + ts.tforms, src=tsarr, reference_tforms=ref_tforms)) + + +def polygons_from_rts(rts, **kwargs): + return [ + polygon_from_ts(ts, rts.transforms, **kwargs) + for ts in rts.tilespecs + ] + + +def strtree_query_geometries(tree, q): + res = tree.query(q) + return tree.geometries[res] + + +def pair_clusters_networkx(pairs, min_cluster_size=25): + G = nx.Graph() + G.add_edges_from(pairs) + + # get the connected subraphs from G + Gc = nx.connected_components(G) + # get the list of nodes in each component + fnodes = sorted((list(n) for n in Gc if len(n) > min_cluster_size), + key=len, reverse=True) + return fnodes + + +def pair_clusters_igraph(pairs, min_cluster_size=25): + G_ig = igraph.Graph(edges=pairs, directed=False) + + # get the connected subraphs from G + cc_ig = G_ig.connected_components(mode='strong') + # filter nodes list with min_cluster_size + fnodes = sorted((c for c in cc_ig if len(c) > min_cluster_size), + key=len, reverse=True) + return fnodes + + +def detect_seams(tilespecs, matches, residual_threshold=10, + distance=80, min_cluster_size=25, cluster_method="igraph"): + stats, allmatches = cr.compute_residuals(tilespecs, matches) # get mean positions of the point matches as numpy array pt_match_positions = np.concatenate( - list(stats['pt_match_positions'].values()), - 0) + list(stats['pt_match_positions'].values()), 0 + ) # get the tile residuals - tile_residuals = np.concatenate(list(stats['tile_residuals'].values())) + tile_residuals = np.concatenate( + list(stats['tile_residuals'].values()) + ) # threshold the points based on residuals new_pts = pt_match_positions[ - np.where(tile_residuals >= residual_threshold), :][0] - - if len(new_pts) > 0: - # construct a KD Tree using these points - tree = cKDTree(new_pts) - # construct a networkx graph - G = nx.Graph() - # find the pairs of points within a distance to each other - pairs = tree.query_pairs(r=distance) - G.add_edges_from(pairs) - # get the connected subraphs from G - Gc = nx.connected_components(G) - # get the list of nodes in each component - nodes = sorted(Gc, key=len, reverse=True) - # filter nodes list with min_cluster_size - fnodes = [list(nn) for nn in nodes if len(nn) > min_cluster_size] - # get pts list for each filtered node list - points_list = [new_pts[mm, :] for mm in fnodes] - centroids = [[np.sum(pt[:, 0])/len(pt), np.sum(pt[:, 1])/len(pt)] - for pt in points_list] - else: - centroids = [] + np.where(tile_residuals >= residual_threshold), : + ][0] + + # construct a KD Tree using these points + tree = cKDTree(new_pts) + # construct a networkx graph + + # find the pairs of points within a distance to each other + pairs = tree.query_pairs(r=distance) + + fnodes = { + "igraph": pair_clusters_igraph, + "networkx": pair_clusters_networkx + }[cluster_method](pairs, min_cluster_size=min_cluster_size) + + # get pts list for each filtered node list + points_list = [new_pts[mm, :] for mm in fnodes] + centroids = np.array([(np.sum(pt, axis=0) / len(pt)).tolist() + for pt in points_list]) return centroids, allmatches, stats -def detect_disconnected_tiles(render, prestitched_stack, poststitched_stack, - z, pre_tilespecs=None, post_tilespecs=None): +def detect_seams_from_collections( + render, stack, match_collection, match_owner, z, + residual_threshold=8, distance=60, min_cluster_size=15, tspecs=None): + session = requests.session() + + groupId = render.run( + renderapi.stack.get_sectionId_for_z, stack, z, session=session) + allmatches = render.run( + renderapi.pointmatch.get_matches_within_group, + match_collection, + groupId, + owner=match_owner, + session=session) + if tspecs is None: + tspecs = render.run( + renderapi.tilespec.get_tile_specs_from_z, + stack, z, session=session) + session.close() + + return detect_seams( + tspecs, allmatches, + residual_threshold=residual_threshold, + distance=distance, + min_cluster_size=min_cluster_size) + + +def detect_disconnected_tiles(pre_tilespecs, post_tilespecs): + pre_tileIds = {ts.tileId for ts in pre_tilespecs} + post_tileIds = {ts.tileId for ts in post_tilespecs} + missing_tileIds = list(pre_tileIds - post_tileIds) + return missing_tileIds + + +def detect_disconnected_tiles_from_collections( + render, prestitched_stack, poststitched_stack, + z, pre_tilespecs=None, post_tilespecs=None): session = requests.session() # get the tilespecs for both prestitched_stack and poststitched_stack if pre_tilespecs is None: pre_tilespecs = render.run( - renderapi.tilespec.get_tile_specs_from_z, - prestitched_stack, - z, - session=session) + renderapi.tilespec.get_tile_specs_from_z, + prestitched_stack, + z, + session=session) if post_tilespecs is None: post_tilespecs = render.run( - renderapi.tilespec.get_tile_specs_from_z, - poststitched_stack, - z, - session=session) - # pre tile_ids - pre_tileIds = [] - pre_tileIds = [ts.tileId for ts in pre_tilespecs] - # post tile_ids - post_tileIds = [] - post_tileIds = [ts.tileId for ts in post_tilespecs] - missing_tileIds = list(set(pre_tileIds) - set(post_tileIds)) + renderapi.tilespec.get_tile_specs_from_z, + poststitched_stack, + z, + session=session) session.close() - return missing_tileIds + return detect_disconnected_tiles(pre_tilespecs, post_tilespecs) + + +def detect_stitching_gaps(pre_rts, post_rts, polygon_kwargs={}, use_bbox=False): + tId_to_pre_polys = { + ts.tileId: ( + polygon_from_ts( + ts, pre_rts.transforms, **polygon_kwargs) + if not use_bbox else shapely.geometry.box(*ts.bbox) + ) + for ts in pre_rts.tilespecs + } + + tId_to_post_polys = { + ts.tileId: ( + polygon_from_ts( + ts, post_rts.transforms, **polygon_kwargs) + if not use_bbox else shapely.geometry.box(*ts.bbox) + ) + for ts in post_rts.tilespecs + } + + poly_id_to_tId = { + id(p): tId for tId, p in + (i for l in ( + tId_to_pre_polys.items(), tId_to_post_polys.items()) + for i in l) + } + + pre_polys = [*tId_to_pre_polys.values()] + post_polys = [*tId_to_post_polys.values()] + + pre_tree = shapely.strtree.STRtree(pre_polys) + post_tree = shapely.strtree.STRtree(post_polys) + + pre_graph = nx.Graph({ + poly_id_to_tId[id(p)]: [ + poly_id_to_tId[id(r)] + for r in strtree_query_geometries(pre_tree, p) + ] + for p in pre_polys}) + post_graph = nx.Graph({ + poly_id_to_tId[id(p)]: [ + poly_id_to_tId[id(r)] + for r in strtree_query_geometries(post_tree, p) + ] + for p in post_polys}) + + diff_g = nx.Graph(pre_graph.edges - post_graph.edges) + gap_tiles = [n for n in diff_g.nodes() if diff_g.degree(n) > 0] + return gap_tiles -def detect_stitching_gaps(render, prestitched_stack, poststitched_stack, - z, pre_tilespecs=None, tilespecs=None): +def detect_stitching_gaps_legacy(render, prestitched_stack, poststitched_stack, + z, pre_tilespecs=None, tilespecs=None): session = requests.session() - # setup an rtree to find overlapping tiles - pre_ridx = rindex.Index() - # setup a graph to store overlapping tiles - G1 = nx.Graph() - # get the tilespecs for both prestitched_stack and poststitched_stack + if pre_tilespecs is None: pre_tilespecs = render.run( - renderapi.tilespec.get_tile_specs_from_z, - prestitched_stack, - z, - session=session) + renderapi.tilespec.get_tile_specs_from_z, + prestitched_stack, + z, + session=session) if tilespecs is None: tilespecs = render.run( - renderapi.tilespec.get_tile_specs_from_z, - poststitched_stack, - z, - session=session) - # insert the prestitched_tilespecs into rtree - # with their bounding boxes to find overlaps - for i, ts in enumerate(pre_tilespecs): - pre_ridx.insert(i, ts.bbox) - - pre_tileIds = {} - for i, ts in enumerate(pre_tilespecs): - pre_tileIds[ts.tileId] = i - nodes = list(pre_ridx.intersection(ts.bbox)) - nodes.remove(i) - [G1.add_edge(i, node) for node in nodes] - # G1 contains the prestitched_stack tile)s and the degree - # of each node representing the number of tiles that overlap. - # This overlap count has to match in the poststitched_stack - G2 = nx.Graph() - post_ridx = rindex.Index() - tileId_to_ts = {ts.tileId: ts for ts in tilespecs} - shared_tileIds = viewkeys(tileId_to_ts) & viewkeys(pre_tileIds) - [post_ridx.insert(pre_tileIds[tId], tileId_to_ts[tId].bbox) - for tId in shared_tileIds] - for ts in tilespecs: - try: - i = pre_tileIds[ts.tileId] - except KeyError: - continue - nodes = list(post_ridx.intersection(ts.bbox)) - nodes.remove(i) - [G2.add_edge(i, node) for node in nodes] - # Now G1 and G2 have the same index for the same tileId - # comparing the degree of each node pre and post - # stitching should reveal stitching gaps - gap_tiles = [] - for n in G2.nodes(): - if G1.degree(n) > G2.degree(n): - tileId = list(pre_tileIds.keys())[ - list(pre_tileIds.values()).index(n)] - gap_tiles.append(tileId) + renderapi.tilespec.get_tile_specs_from_z, + poststitched_stack, + z, + session=session) + + gap_tiles = detect_stitching_gaps( + renderapi.resolvedtiles.ResolvedTiles(tilespecs=pre_tilespecs), + renderapi.resolvedtiles.ResolvedTiles(tilespecs=tilespecs), + use_bbox=True) session.close() return gap_tiles -def detect_distortion(render, poststitched_stack, zvalue, threshold_cutoff=[0.005, 0.005], pool_size=20): - #z_to_scales = {zvalue: do_get_z_scales_nopm(zvalue, [poststitched_stack], render)} - z_to_scales = {} - # check if any scale is None - #zs = [z for z, scales in z_to_scales.items() if scales is None] - #for z in zs: - # z_to_scales[z] = get_z_scales_nopm(z, [poststitched_stack], render) +def detect_distortion_tilespecs(tilespecs, zvalue, threshold_cutoff=[0.005, 0.005]): + scales = get_scales_from_tilespecs(tilespecs) + mad_stats = get_scale_statistics_mad(scales) + badzs_cutoff = ( + [zvalue] if ( + mad_stats[0] > threshold_cutoff[0] or + mad_stats[1] > threshold_cutoff[1] + ) + else []) + return badzs_cutoff - try: - z_to_scales[zvalue] = get_z_scales_nopm(zvalue, [poststitched_stack], render) - except Exception: - z_to_scales[zvalue] = None - # get the mad statistics - z_to_scalestats = {z: get_scale_statistics_mad(scales) for z, scales in z_to_scales.items() if scales is not None} +def detect_distortion( + render, poststitched_stack, zvalue, + threshold_cutoff=[0.005, 0.005], pool_size=20, tilespecs=None): + if tilespecs is None: + rts = get_rts_fallthrough([poststitched_stack], zvalue, render=render) + tilespecs = rts.tilespecs - # find zs that fall outside cutoff - badzs_cutoff = [z for z, s in z_to_scalestats.items() if s[0] > threshold_cutoff[0] or s[1] > threshold_cutoff[1]] - return badzs_cutoff + return detect_distortion_tilespecs(tilespecs, zvalue, threshold_cutoff) def get_pre_post_tspecs(render, prestitched_stack, poststitched_stack, z): @@ -216,18 +338,19 @@ def run_analysis( min_cluster_size, threshold_cutoff, z): pre_tspecs, post_tspecs = get_pre_post_tspecs( render, prestitched_stack, poststitched_stack, z) - disconnected_tiles = detect_disconnected_tiles( + disconnected_tiles = detect_disconnected_tiles_from_collections( render, prestitched_stack, poststitched_stack, z, pre_tspecs, post_tspecs) - gap_tiles = detect_stitching_gaps( + gap_tiles = detect_stitching_gaps_legacy( render, prestitched_stack, poststitched_stack, z, pre_tspecs, post_tspecs) - seam_centroids, matches, stats = detect_seams( - render, poststitched_stack, match_collection, match_collection_owner, + seam_centroids, matches, stats = detect_seams_from_collections( + render, poststitched_stack, match_collection, match_collection_owner, z, residual_threshold=residual_threshold, distance=neighbor_distance, min_cluster_size=min_cluster_size, tspecs=post_tspecs) distorted_zs = detect_distortion( - render, poststitched_stack, z, threshold_cutoff=threshold_cutoff) + render, poststitched_stack, z, threshold_cutoff=threshold_cutoff, + tilespecs=post_tspecs) return (disconnected_tiles, gap_tiles, seam_centroids, distorted_zs, post_tspecs, matches, stats) @@ -346,8 +469,8 @@ def run(self): 'gap_sections': gaps, 'seam_sections': seams, 'distorted_sections': distorted_zs, - 'seam_centroids': np.array(centroids, dtype=object)}) - print(self.output) + 'seam_centroids': np.array(centroids, dtype=object)}, + cls=renderapi.utils.RenderEncoder) # delete the stacks that were cloned if status1 == 'LOADING': self.render.run(renderapi.stack.delete_stack, new_prestitched) diff --git a/asap/em_montage_qc/distorted_montages.py b/asap/em_montage_qc/distorted_montages.py index 6346f325..a422f142 100644 --- a/asap/em_montage_qc/distorted_montages.py +++ b/asap/em_montage_qc/distorted_montages.py @@ -90,6 +90,11 @@ def groupId_from_tilespec(ts): def sections_from_resolvedtiles(rts): return list({groupId_from_tilespec(ts) for ts in rts.tilespecs}) + +def get_scales_from_tilespecs(tilespecs): + return np.array([ts.tforms[-1].scale for ts in tilespecs]) + + def get_z_scales_nopm(z, input_stacks, render): rts = get_rts_fallthrough(input_stacks, z, render=render) scales = np.array([ts.tforms[-1].scale for ts in rts.tilespecs]) diff --git a/asap/em_montage_qc/plots.py b/asap/em_montage_qc/plots.py index e80dfc2c..3e5792b3 100644 --- a/asap/em_montage_qc/plots.py +++ b/asap/em_montage_qc/plots.py @@ -7,13 +7,13 @@ import renderapi from bokeh.palettes import Plasma256, Viridis256 -from bokeh.plotting import figure, output_file, save +from bokeh.plotting import figure, save from bokeh.layouts import row -from bokeh.models.widgets import Tabs, Panel from bokeh.models import (HoverTool, ColumnDataSource, CustomJS, CategoricalColorMapper, LinearColorMapper, - TapTool, OpenURL, Div, ColorBar) + TapTool, OpenURL, Div, ColorBar, + Tabs, TabPanel) from asap.residuals import compute_residuals as cr @@ -33,13 +33,29 @@ xrange = range -def point_match_plot(tilespecsA, matches, tilespecsB=None): +def bbox_tup_to_xs_ys(bbox_tup): + min_x, min_y, max_x, max_y = bbox_tup + return ( + [min_x, max_x, max_x, min_x, min_x], + [min_y, min_y, max_y, max_y, min_y] + ) + + +def tilespecs_to_xs_ys(tilespecs): + return zip(*(bbox_tup_to_xs_ys(ts.bbox) for ts in tilespecs)) + + +def point_match_plot(tilespecsA, matches, tilespecsB=None, match_max=500): if tilespecsB is None: tilespecsB = tilespecsA if len(matches) > 0: - x1, y1, id1 = cr.get_tile_centers(tilespecsA) - x2, y2, id2 = cr.get_tile_centers(tilespecsB) + tId_to_ctr_a = { + idx: (x, y) + for x, y, idx in zip(*cr.get_tile_centers(tilespecsA))} + tId_to_ctr_b = { + idx: (x, y) + for x, y, idx in zip(*cr.get_tile_centers(tilespecsB))} xs = [] ys = [] @@ -54,20 +70,26 @@ def point_match_plot(tilespecsA, matches, tilespecsB=None): xs.append([0.1, 0]) ys.append([0.1, 0.1]) - if (set(x1) != set(x2)): - clist.append(500) - else: - clist.append(200) - - for k in np.arange(len(matches)): - t1 = np.argwhere(id1 == matches[k]['qId']).flatten() - t2 = np.argwhere(id2 == matches[k]['pId']).flatten() - if (t1.size != 0) & (t2.size != 0): - t1 = t1[0] - t2 = t2[0] - xs.append([x1[t1], x2[t2]]) - ys.append([y1[t1], y2[t2]]) - clist.append(len(matches[k]['matches']['q'][0])) + # if (set(tId_to_ctr_a.keys()) != set(tId_to_ctr_b.keys())): + # clist.append(500) + # else: + # clist.append(200) + clist.append(match_max) + + for m in matches: + match_qId = m['qId'] + match_pId = m['pId'] + num_pts = len(m['matches']['q'][0]) + + try: + a_ctr = tId_to_ctr_a[match_qId] + b_ctr = tId_to_ctr_b[match_pId] + except KeyError: + continue + + xs.append([a_ctr[0], b_ctr[0]]) + ys.append([a_ctr[1], b_ctr[1]]) + clist.append(num_pts) mapper = LinearColorMapper( palette=Plasma256, low=min(clist), high=max(clist)) @@ -77,9 +99,24 @@ def point_match_plot(tilespecsA, matches, tilespecsB=None): TOOLS = "pan,box_zoom,reset,hover,save" + w = np.ptp(xs) + h = np.ptp(ys) + base_dim = 1000 + if w > h: + h = int(np.round(base_dim * (h / w))) + w = base_dim + else: + h = base_dim + w = int(np.round(base_dim * w / h)) + plot = figure( - plot_width=800, plot_height=700, - background_fill_color='gray', tools=TOOLS) + # plot_width=800, plot_height=700, + width=w, height=h, + background_fill_color='gray', tools=TOOLS, + # sizing_mode="stretch_both", + match_aspect=True + ) + plot.tools[1].match_aspect = True plot.multi_line( xs="xs", ys="ys", source=source, color={'field': 'colors', 'transform': mapper}, line_width=2) @@ -88,16 +125,28 @@ def point_match_plot(tilespecsA, matches, tilespecsB=None): plot.ygrid.visible = False else: - plot = figure(plot_width=800, plot_height=700, - background_fill_color='gray') + plot = figure( + # width=800, height=700, + background_fill_color='gray') return plot -def plot_residual(xs, ys, residual): - p = figure(width=1000, height=1000) +def plot_residual(xs, ys, residual, residual_max=None): + w = np.ptp(xs) + h = np.ptp(ys) + base_dim = 1000 + if w > h: + h = int(np.round(base_dim * (h / w))) + w = base_dim + else: + h = base_dim + w = int(np.round(base_dim * w / h)) + p = figure(width=w, height=h) + if residual_max is None: + residual_max = max(residual) color_mapper = LinearColorMapper( - palette=Viridis256, low=min(residual), high=max(residual)) + palette=Viridis256, low=min(residual), high=residual_max) source = ColumnDataSource(data=dict(x=xs, y=ys, residual=residual)) @@ -107,7 +156,7 @@ def plot_residual(xs, ys, residual): fill_alpha=1.0, line_color="black", line_width=0.05) color_bar = ColorBar(color_mapper=color_mapper, label_standoff=12, - border_line_color=None, location=(0,0)) + border_line_color=None, location=(0, 0)) p.add_layout(color_bar, 'right') @@ -117,68 +166,37 @@ def plot_residual(xs, ys, residual): return p -def plot_defects(render, stack, out_html_dir, args): - tspecs = args[0] - matches = args[1] - dis_tiles = args[2] - gap_tiles = args[3] - seam_centroids = np.array(args[4]) - stats = args[5] - z = args[6] - - # Tile residual mean +def plot_residual_tilespecs( + tspecs, tileId_to_point_residuals, default_residual=50, residual_max=None): tile_residual_mean = cr.compute_mean_tile_residuals( - stats['tile_residuals']) - - tile_positions = [] - tile_ids = [] - residual = [] - for ts in tspecs: - tile_ids.append(ts.tileId) - pts = [] - pts.append([ts.minX, ts.minY]) - pts.append([ts.maxX, ts.minY]) - pts.append([ts.maxX, ts.maxY]) - pts.append([ts.minX, ts.maxY]) - pts.append([ts.minX, ts.minY]) - tile_positions.append(pts) - - try: - residual.append(tile_residual_mean[ts.tileId]) - except KeyError: - residual.append(50) # a high value for residual for that tile + tileId_to_point_residuals) + xs, ys = tilespecs_to_xs_ys(tspecs) + residual = [ + tile_residual_mean.get(ts.tileId, default_residual) + for ts in tspecs + ] + + return plot_residual(xs, ys, residual, residual_max=residual_max) - out_html = os.path.join( - out_html_dir, - "%s_%d_%s.html" % ( - stack, - z, - datetime.datetime.now().strftime('%Y%m%d%H%S%M%f'))) - - output_file(out_html) - xs = [] - ys = [] - alphas = [] - for tp in tile_positions: - sp = np.array(tp) - x = list(sp[:, 0]) - y = list(sp[:, 1]) - xs.append(x) - ys.append(y) - alphas.append(0.5) + +def montage_defect_plot( + tspecs, matches, disconnected_tiles, gap_tiles, + seam_centroids, stats, z, tile_url_format=None): + xs, ys = tilespecs_to_xs_ys(tspecs) + alphas = [0.5] * len(xs) fill_color = [] label = [] - for t in tile_ids: - if t in gap_tiles: - label.append("Gap tiles") - fill_color.append("red") - elif t in dis_tiles: - label.append("Disconnected tiles") - fill_color.append("yellow") - else: - label.append("Stitched tiles") - fill_color.append("blue") + gap_tiles = set(gap_tiles) + disconnected_tiles = set(disconnected_tiles) + + tile_ids = [ts.tileId for ts in tspecs] + + label, fill_color = zip(*( + (("Gap tiles", "red") if tileId in gap_tiles + else ("Disconnected tiles", "yellow") if tileId in disconnected_tiles + else ("Stitched tiles", "blue")) + for tileId in tile_ids)) color_mapper = CategoricalColorMapper( factors=['Gap tiles', 'Disconnected tiles', 'Stitched tiles'], @@ -194,13 +212,28 @@ def plot_defects(render, stack, out_html_dir, args): TOOLS = "pan,box_zoom,reset,hover,tap,save" - p = figure(title=str(z), width=1000, height=1000, - tools=TOOLS, match_aspect=True) + w = np.ptp(xs) + h = np.ptp(ys) + base_dim = 1000 + if w > h: + h = int(np.round(base_dim * (h / w))) + w = base_dim + else: + h = base_dim + w = int(np.round(base_dim * w / h)) + + p = figure(title=str(z), + # width=1000, height=1000, + width=w, height=h, + tools=TOOLS, + match_aspect=True) + p.tools[1].match_aspect = True pp = p.patches( 'x', 'y', source=source, alpha='alpha', line_width=2, - color={'field': 'labels', 'transform': color_mapper}, legend='labels') - cp = p.circle('x', 'y', source=seam_source, legend='lbl', size=11) + color={'field': 'labels', 'transform': color_mapper}, + legend_group='labels') + cp = p.scatter('x', 'y', source=seam_source, legend_group='lbl', size=11) jscode = """ var inds = cb_obj.selected['1d'].indices; @@ -211,38 +244,78 @@ def plot_defects(render, stack, out_html_dir, args): if ( lines.length > 35 ) { lines.shift(); } div.text = lines.join("\\n"); """ - div = Div(width=1000) + div = Div(width=w) layout = row(p, div) - urls = "%s:%d/render-ws/v1/owner/%s/project/%s/stack/%s/tile/@names/withNeighbors/jpeg-image?scale=0.1" % (render.DEFAULT_HOST, render.DEFAULT_PORT, render.DEFAULT_OWNER, render.DEFAULT_PROJECT, stack) - - taptool = p.select(type=TapTool) - taptool.renderers = [pp] - taptool.callback = OpenURL(url=urls) + if tile_url_format: + taptool = p.select(type=TapTool) + taptool.renderers = [pp] + taptool.callback = OpenURL(url=tile_url_format) hover = p.select(dict(type=HoverTool)) hover.renderers = [pp] hover.point_policy = "follow_mouse" hover.tooltips = [("tileId", "@names"), ("x", "$x{int}"), ("y", "$y{int}")] - source.callback = CustomJS(args=dict(div=div), code=jscode % ('names')) + source.js_event_callbacks['identify'] = [ + CustomJS(args=dict(div=div), code=jscode % ('names')) + ] + return layout + + +def create_montage_qc_plots( + tspecs, matches, disconnected_tiles, gap_tiles, + seam_centroids, stats, z, tile_url_format=None, + match_max=500, residual_max=None): + # montage qc + layout = montage_defect_plot( + tspecs, matches, disconnected_tiles, gap_tiles, + seam_centroids, stats, z, tile_url_format=tile_url_format) # add point match plot in another tab - plot = point_match_plot(tspecs, matches) + plot = point_match_plot(tspecs, matches, match_max=match_max) # montage statistics plots in other tabs - - stat_layout = plot_residual(xs, ys, residual) + stat_layout = plot_residual_tilespecs( + tspecs, stats["tile_residuals"], + residual_max=residual_max + ) tabs = [] - tabs.append(Panel(child=layout, title="Defects")) - tabs.append(Panel(child=plot, title="Point match plot")) - tabs.append(Panel(child=stat_layout, title="Mean tile residual")) + tabs.append(TabPanel(child=layout, title="Defects")) + tabs.append(TabPanel(child=plot, title="Point match plot")) + tabs.append(TabPanel(child=stat_layout, title="Mean tile residual")) plot_tabs = Tabs(tabs=tabs) - save(plot_tabs) + return plot_tabs + +def write_montage_qc_plots(out_fn, plt): + return save(plt, out_fn) + + +def run_montage_qc_plots_legacy(render, stack, out_html_dir, args): + tspecs = args[0] + matches = args[1] + disconnected_tiles = args[2] + gap_tiles = args[3] + seam_centroids = np.array(args[4]) + stats = args[5] + z = args[6] + + out_html = os.path.join( + out_html_dir, + "%s_%d_%s.html" % ( + stack, + z, + datetime.datetime.now().strftime('%Y%m%d%H%S%M%f'))) + tile_url_format = f"{render.DEFAULT_HOST}:{render.DEFAULT_PORT}/render-ws/v1/owner/{render.DEFAULT_OWNER}/project/{render.DEFAULT_PROJECT}/stack/{stack}/tile/@names/withNeighbors/jpeg-image?scale=0.1" + + qc_plot = create_montage_qc_plots( + tspecs, matches, disconnected_tiles, gap_tiles, + seam_centroids, stats, z, tile_url_format=tile_url_format) + write_montage_qc_plots(out_html, qc_plot) return out_html @@ -253,7 +326,9 @@ def plot_section_maps( if out_html_dir is None: out_html_dir = tempfile.mkdtemp() - mypartial = partial(plot_defects, render, stack, out_html_dir) + mypartial = partial( + run_montage_qc_plots_legacy, render, stack, out_html_dir + ) args = zip(post_tspecs, matches, disconnected_tiles, gap_tiles, seam_centroids, stats, zvalues) diff --git a/asap/materialize/validate_materialized_tilesource.py b/asap/materialize/validate_materialized_tilesource.py index 3d5f77da..9ccae04f 100644 --- a/asap/materialize/validate_materialized_tilesource.py +++ b/asap/materialize/validate_materialized_tilesource.py @@ -7,9 +7,9 @@ from multiprocessing.pool import ThreadPool import argschema -import imageio - -from asap.materialize.schemas import ( +import imageio + +from asap.materialize.schemas import ( ValidateMaterializationParameters, ValidateMaterializationOutput) try: @@ -51,6 +51,8 @@ def try_load_file(fn, allow_ENOENT=True, allow_EACCES=False): return True elif not allow_EACCES and (e.errno == errno.EACCES): raise + else: + return else: return return True diff --git a/asap/mesh_lens_correction/run_mesh_lens_correction.py b/asap/mesh_lens_correction/run_mesh_lens_correction.py new file mode 100755 index 00000000..7a15e81e --- /dev/null +++ b/asap/mesh_lens_correction/run_mesh_lens_correction.py @@ -0,0 +1,281 @@ +""" +run mesh lens correction from an acquisition directory/metadata file. +This avoids calls to java clients or other modules and does not require reading or writing from a render collection + +Does not provide the same mask support as do_mesh_lens_correction. +""" + + +import concurrent.futures +import pathlib +import json + +import renderapi +import uri_handler.uri_functions +import shapely +import numpy +import argschema + +import em_stitch.utils.generate_EM_tilespecs_from_metafile +import asap.em_montage_qc.detect_montage_defects +import em_stitch.montage.meta_to_collection +import asap.em_montage_qc.plots +import em_stitch.lens_correction.mesh_and_solve_transform +import asap.pointmatch.generate_point_matches_opencv + + +# FIXME this should be in em-stitch +class GenerateEMTilespecsModule_URIPrefix( + em_stitch.utils.generate_EM_tilespecs_from_metafile.GenerateEMTileSpecsModule): + @staticmethod + def ts_from_imgdata_tileId(imgdata, img_prefix, x, y, tileId, + minint=0, maxint=255, maskUrl=None, + width=3840, height=3840, z=None, sectionId=None, + scopeId=None, cameraId=None, pixelsize=None): + raw_tforms = [renderapi.transform.AffineModel(B0=x, B1=y)] + imageUrl = uri_handler.uri_functions.uri_join( + img_prefix, imgdata["img_path"]) + + if maskUrl is not None: + maskUrl = pathlib.Path(maskUrl).as_uri() + + ip = renderapi.image_pyramid.ImagePyramid() + ip[0] = renderapi.image_pyramid.MipMap(imageUrl=imageUrl, + maskUrl=maskUrl) + return renderapi.tilespec.TileSpec( + tileId=tileId, z=z, + width=width, height=height, + minint=minint, maxint=maxint, + tforms=raw_tforms, + imagePyramid=ip, + sectionId=sectionId, scopeId=scopeId, cameraId=cameraId, + imageCol=imgdata['img_meta']['raster_pos'][0], + imageRow=imgdata['img_meta']['raster_pos'][1], + stageX=imgdata['img_meta']['stage_pos'][0], + stageY=imgdata['img_meta']['stage_pos'][1], + rotation=imgdata['img_meta']['angle'], pixelsize=pixelsize) + + +def resolvedtiles_from_temca_md( + md, image_prefix, z, + sectionId=None, + minimum_intensity=0, maximum_intensity=255, + initial_transform=None): + tspecs = GenerateEMTilespecsModule_URIPrefix.ts_from_metadata( + md, image_prefix, z, sectionId, minimum_intensity, maximum_intensity) + tformlist = [] + if initial_transform is not None: + initial_transform_ref = renderapi.transform.ReferenceTransform( + refId=initial_transform.transformId) + for ts in tspecs: + ts.tforms.insert(0, initial_transform_ref) + tformlist = [initial_transform] + return renderapi.resolvedtiles.ResolvedTiles( + tilespecs=tspecs, + transformList=tformlist + ) + + +def bbox_from_ts(ts, ref_tforms, **kwargs): + tsarr = asap.em_montage_qc.detect_montage_defects.generate_border_mesh_pts( + ts.width, ts.height, **kwargs) + pts = renderapi.transform.estimate_dstpts( + ts.tforms, src=tsarr, + reference_tforms=ref_tforms) + minX, minY = numpy.min(pts, axis=0) + maxX, maxY = numpy.max(pts, axis=0) + return minX, minY, maxX, maxY + + +def apply_resolvedtiles_bboxes(rts, **kwargs): + for ts in rts.tilespecs: + minX, minY, maxX, maxY = bbox_from_ts(ts, rts.transforms, **kwargs) + ts.minX = minX + ts.minY = minY + ts.maxX = maxX + ts.maxY = maxY + + +# TODO: match tilepair options and make part of tilepair generation +def pair_tiles_rts(rts, query_fraction=0.1): + boxes = [shapely.box(*ts.bbox) for ts in rts.tilespecs] + tree = shapely.strtree.STRtree(boxes) + + qresults = {} + + for qidx, qbox in enumerate(boxes): + qbox = boxes[qidx] + qfraction = query_fraction + qdiag = numpy.sqrt(qbox.length ** 2 - 8 * qbox.area) / 2 + qminX, qminY, qmaxX, qmaxY = qbox.bounds + + qresult_idxs = tree.query_nearest( + qbox, max_distance=qdiag * qfraction, exclusive=True) + for ridx in qresult_idxs: + if frozenset((qidx, ridx)) not in qresults: + rbox = boxes[ridx] + is_not_corner = ((qminX < rbox.centroid.x < qmaxX) & + (qminY < rbox.centroid.y < qmaxY)) + if is_not_corner: + rminX, rminY, rmaxX, rmaxY = rbox.bounds + + # copied from java + deltaX = qminX - rminX + deltaY = qminY - rminY + if abs(deltaX) > abs(deltaY): + orientation = ("RIGHT" if deltaX > 0 else "LEFT") + else: + orientation = ("BOTTOM" if deltaY > 0 else "TOP") + + opposite_orientation = { + "LEFT": "RIGHT", + "TOP": "BOTTOM", + "RIGHT": "LEFT", + "BOTTOM": "TOP" + }[orientation] + + p_ts = rts.tilespecs[qidx] + q_ts = rts.tilespecs[ridx] + qresults[frozenset((qidx, ridx))] = { + "p": { + "groupId": p_ts.layout.sectionId, + "id": p_ts.tileId, + "relativePosition": orientation + }, + "q": { + "groupId": q_ts.layout.sectionId, + "id": q_ts.tileId, + "relativePosition": opposite_orientation + } + } + + neighborpairs_rp = [*qresults.values()] + return neighborpairs_rp + + +# TODO: allow different inputs, or substitute with an equivalent helper function from em-stitch +def solve_lc(rts, matches, transformId=None): + solve_matches = renderapi.pointmatch.copy_matches_explicit(matches) + # define input parameters for solving + nvertex = 1000 + + # regularization parameters for components + regularization_dict = { + "translation_factor": 0.001, + "default_lambda": 1.0, + "lens_lambda": 1.0 + } + + # thresholds defining an acceptable solution. + # solves exceeding these will raise an exception in em_stitch.lens_correction.mesh_and_solve_transform._solve_resolvedtiles + good_solve_dict = { + "error_mean": 0.8, + "error_std": 3.0, + "scale_dev": 0.1 + } + + solve_resolvedtiles_args = ( + rts, solve_matches, + nvertex, regularization_dict["default_lambda"], + regularization_dict["translation_factor"], + regularization_dict["lens_lambda"], + good_solve_dict + ) + solved_rts, lc_tform, jresult = em_stitch.lens_correction.mesh_and_solve_transform._solve_resolvedtiles(*solve_resolvedtiles_args) + if transformId: + lc_tform.transformId = transformId + return lc_tform + + +def match_tiles_rts(rts, tpairs, concurrency=10): + sectionId_tId_to_ts = {(ts.layout.sectionId, ts.tileId): ts for ts in rts.tilespecs} + matches_rp = [] + + with concurrent.futures.ProcessPoolExecutor(max_workers=concurrency) as e: + futs = [ + e.submit( + asap.pointmatch.generate_point_matches_opencv.process_matches, + tpair["p"]["id"], tpair["p"]["groupId"], + ( + sectionId_tId_to_ts[ + (tpair["p"]["groupId"], tpair["p"]["id"]) + ].ip[0].imageUrl, + None + ), + tpair["q"]["id"], tpair["q"]["groupId"], + ( + sectionId_tId_to_ts[ + (tpair["q"]["groupId"], tpair["q"]["id"]) + ].ip[0].imageUrl, + None + ), + downsample_scale=0.3, + ndiv=8, + matchMax=1000, + sift_kwargs={ + "nfeatures": 20000, + "nOctaveLayers": 3, + "sigma": 1.5, + }, + # SIFT_nfeature=20000, + # SIFT_noctave=3, + # SIFT_sigma=1.5, + RANSAC_outlier=5.0, + FLANN_ncheck=50, + FLANN_ntree=5, + ratio_of_dist=0.7, + CLAHE_grid=None, + CLAHE_clip=None) + for tpair in tpairs + ] + for fut in concurrent.futures.as_completed(futs): + match_d, num_matches, num_features_p, num_features_q = fut.result() + matches_rp.append(match_d) + return matches_rp + + +class CalculateLensCorrectionParams(argschema.ArgSchema): + metafile_uri = argschema.fields.Str(required=True) + image_prefix = argschema.fields.Str(required=True) + transformId = argschema.fields.Str(required=True) + concurrency = argschema.fields.Int(required=False, default=10) + + +class CalculateLensCorrectionOutputSchema(argschema.schemas.DefaultSchema): + lc_transform = argschema.fields.Dict(required=True) + + +class CalculateLensCorrectionModule(argschema.ArgSchemaParser): + default_schema = CalculateLensCorrectionParams + default_output_schema = CalculateLensCorrectionOutputSchema + + @staticmethod + def compute_lc_from_metadata_uri( + md_uri, image_prefix, sectionId=None, + transformId=None, match_concurrency=10): + md = json.loads(uri_handler.uri_functions.uri_readbytes(md_uri)) + rts = resolvedtiles_from_temca_md( + md, image_prefix, 0, sectionId=sectionId) + apply_resolvedtiles_bboxes(rts) + tpairs = pair_tiles_rts(rts) + + matches = match_tiles_rts(rts, tpairs) + lc_tform = solve_lc(rts, matches, transformId=transformId) + return lc_tform + + def run(self): + lc_tform = self.compute_lc_from_metadata_uri( + self.args["metafile_uri"], + self.args["image_prefix"], + sectionId=self.args["transformId"], + transformId=self.args["transformId"], + match_concurrency=self.args["concurrency"] + ) + self.output({ + "lc_transform": json.loads(renderapi.utils.renderdumps(lc_tform)) + }) + + +if __name__ == "__main__": + mod = CalculateLensCorrectionModule() + mod.run() diff --git a/asap/module/schemas/renderclient_schemas.py b/asap/module/schemas/renderclient_schemas.py index 4cca7e28..7dba3189 100644 --- a/asap/module/schemas/renderclient_schemas.py +++ b/asap/module/schemas/renderclient_schemas.py @@ -18,7 +18,7 @@ class MaterializedBoxParameters(argschema.schemas.DefaultSchema): "height of flat rectangular tiles to generate")) maxLevel = Int(required=False, default=0, description=( "maximum mipMapLevel to generate.")) - fmt = Str(required=False, validator=validate.OneOf(['PNG', 'TIF', 'JPG']), + fmt = Str(required=False, validate=validate.OneOf(['PNG', 'TIF', 'JPG']), description=("image format to generate mipmaps -- " "PNG if not specified")) maxOverviewWidthAndHeight = Int(required=False, description=( @@ -148,7 +148,7 @@ class MatchDerivationParameters(argschema.schemas.DefaultSchema): matchRod = Float(required=False, description=( "Ratio of first to second nearest neighbors used as a cutoff in " "matching features. 0.92 if excluded or None")) - matchModelType = Str(required=False, validator=validate.OneOf([ + matchModelType = Str(required=False, validate=validate.OneOf([ "AFFINE", "RIGID", "SIMILARITY", "TRANSLATION"]), description=( "Model to match for RANSAC filtering. 'AFFINE' if excluded or None")) matchIterations = Int(required=False, description=( @@ -170,7 +170,7 @@ class MatchDerivationParameters(argschema.schemas.DefaultSchema): "3.0 if excluded or None")) matchFilter = Str( required=False, - validator=validate.OneOf( + validate=validate.OneOf( ['SINGLE_SET', 'CONSENSUS_SETS', 'AGGREGATED_CONSENSUS_SETS']), description=( "whether to match one set of matches, or multiple " diff --git a/asap/pointmatch/generate_point_matches_opencv.py b/asap/pointmatch/generate_point_matches_opencv.py index c1821696..d8b5ec9a 100644 --- a/asap/pointmatch/generate_point_matches_opencv.py +++ b/asap/pointmatch/generate_point_matches_opencv.py @@ -10,6 +10,8 @@ import pathlib2 as pathlib import renderapi +import imageio + from asap.pointmatch.schemas import ( PointMatchOpenCVParameters, PointMatchClientOutputSchema) @@ -46,7 +48,7 @@ def ransac_chunk(fargs): [k1xy, k2xy, des1, des2, k1ind, args] = fargs - FLANN_INDEX_KDTREE = 0 + FLANN_INDEX_KDTREE = 1 index_params = dict( algorithm=FLANN_INDEX_KDTREE, trees=args['FLANN_ntree']) @@ -90,11 +92,27 @@ def ransac_chunk(fargs): return k1, k2 +# FIXME work w/ tile min/max in layout +def to_8bpp(im, min_val=None, max_val=None): + if im.dtype == np.uint16: + if max_val is not None or min_val is not None: + min_val = min_val or 0 + max_val = max_val or 65535 + scale_factor = 255 / (max_val - min_val) + im = ((np.clip(im, min_val, max_val) - min_val) * scale_factor) + return (im).astype(np.uint8) + return im + + def read_downsample_equalize_mask_uri( - impath, scale, CLAHE_grid=None, CLAHE_clip=None): - im = cv2.imdecode( - np.fromstring(uri_utils.uri_readbytes(impath[0]), np.uint8), - 0) + impath, scale, CLAHE_grid=None, CLAHE_clip=None, min_val=None, max_val=None): + # im = cv2.imdecode( + # np.fromstring(uri_utils.uri_readbytes(impath[0]), np.uint8), + # 0) + im = imageio.v3.imread(uri_utils.uri_readbytes(impath[0])) + # FIXME this should be read from tilespec + max_val = max_val or im.max() + im = to_8bpp(im, min_val, max_val) im = cv2.resize(im, (0, 0), fx=scale, @@ -110,9 +128,10 @@ def read_downsample_equalize_mask_uri( im = cv2.equalizeHist(im) if impath[1] is not None: - mask = cv2.imdecode( - np.fromstring(uri_utils.uri_readbytes(impath[1]), np.uint8), - 0) + # mask = cv2.imdecode( + # np.fromstring(uri_utils.uri_readbytes(impath[1]), np.uint8), + # 0) + mask = imageio.v3.imread(uri_utils.uri_readbytes(impath[1])) mask = cv2.resize(mask, (0, 0), fx=scale, @@ -121,6 +140,7 @@ def read_downsample_equalize_mask_uri( im = cv2.bitwise_and(im, im, mask=mask) return im + # return to_8bpp(im, min_val, max_val) def read_downsample_equalize_mask( @@ -129,24 +149,101 @@ def read_downsample_equalize_mask( return read_downsample_equalize_mask_uri(uri_impath, *args, **kwargs) -def find_matches(fargs): - [impaths, ids, gids, args] = fargs +FLANN_INDEX_KDTREE = 1 + + +# TODO take this from existing ransac_chunk +def match_and_ransac( + loc_p, des_p, loc_q, des_q, + FLANN_ntree=5, ratio_of_dist=0.7, + FLANN_ncheck=50, RANSAC_outlier=5.0, + min_match_count=10, FLANN_index=FLANN_INDEX_KDTREE, **kwargs): + + index_params = dict( + algorithm=FLANN_INDEX_KDTREE, + trees=FLANN_ntree) + search_params = dict(checks=FLANN_ncheck) + flann = cv2.FlannBasedMatcher(index_params, search_params) + + matches = flann.knnMatch(des_p, des_q, k=2) + + # store all the good matches as per Lowe's ratio test. + good = [] + k1 = [] + k2 = [] + for m, n in matches: + if m.distance < ratio_of_dist * n.distance: + good.append(m) + if len(good) > min_match_count: + src_pts = np.float32( + [loc_p[m.queryIdx] for m in good]).reshape(-1, 1, 2) + dst_pts = np.float32( + [loc_q[m.trainIdx] for m in good]).reshape(-1, 1, 2) + M, mask = cv2.findHomography( + src_pts, + dst_pts, + cv2.RANSAC, + RANSAC_outlier) + matchesMask = mask.ravel().tolist() + + good = np.array(good)[np.array(matchesMask).astype('bool')] + imgIdx = np.array([g.imgIdx for g in good]) + tIdx = np.array([g.trainIdx for g in good]) + qIdx = np.array([g.queryIdx for g in good]) + for i in range(len(tIdx)): + if imgIdx[i] == 1: + k1.append(loc_p[tIdx[i]]) + k2.append(loc_q[qIdx[i]]) + else: + k1.append(loc_p[qIdx[i]]) + k2.append(loc_q[tIdx[i]]) + + return k1, k2 + + +# TODO change this to pq terminology +def chunk_match_keypoints( + loc1, des1, loc2, des2, ndiv=1, full_shape=None, + ransac_kwargs=None, **kwargs): + ransac_kwargs = ransac_kwargs or {} + if full_shape is None: + full_shape = np.ptp(np.concatenate([loc1, loc2]), axis=0) + + nr, nc = full_shape + + chunk_results = [] + + # FIXME better way than doing min and max of arrays + for i in range(ndiv): + r = np.arange(nr * i / ndiv, nr * (i + 1) / ndiv) + for j in range(ndiv): + c = np.arange(nc * j / ndiv, nc * (j + 1) / ndiv) + k1ind = np.argwhere( + (loc1[:, 0] >= r.min()) & + (loc1[:, 0] <= r.max()) & + (loc1[:, 1] >= c.min()) & + (loc1[:, 1] <= c.max())).flatten() + + chunk_results.append(match_and_ransac( + loc1[k1ind, ...], des1[k1ind, ...], + loc2, des2, + **{**ransac_kwargs, **kwargs})) + + p_results, q_results = zip(*chunk_results) + p_results = np.concatenate([i for i in p_results if len(i)]) + q_results = np.concatenate([i for i in q_results if len(i)]) + return p_results, q_results - pim = read_downsample_equalize_mask_uri( - impaths[0], - args['downsample_scale'], - CLAHE_grid=args['CLAHE_grid'], - CLAHE_clip=args['CLAHE_clip']) - qim = read_downsample_equalize_mask_uri( - impaths[1], - args['downsample_scale'], - CLAHE_grid=args['CLAHE_grid'], - CLAHE_clip=args['CLAHE_clip']) - sift = cv2.xfeatures2d.SIFT_create( - nfeatures=args['SIFT_nfeature'], - nOctaveLayers=args['SIFT_noctave'], - sigma=args['SIFT_sigma']) +def sift_match_images( + pim, qim, sift_kwargs=None, + ransac_kwargs=None, match_kwargs=None, + return_num_features=False, + **kwargs): + sift_kwargs = sift_kwargs or {} + match_kwargs = match_kwargs or {} + + sift = cv2.SIFT_create(**sift_kwargs) # find the keypoints and descriptors kp1, des1 = sift.detectAndCompute(pim, None) @@ -155,47 +252,107 @@ def find_matches(fargs): k1xy = np.array([np.array(k.pt) for k in kp1]) k2xy = np.array([np.array(k.pt) for k in kp2]) - nr, nc = pim.shape - k1 = [] - k2 = [] - ransac_args = [] - results = [] - ndiv = args['ndiv'] - for i in range(ndiv): - r = np.arange(nr*i/ndiv, nr*(i+1)/ndiv) - for j in range(ndiv): - c = np.arange(nc*j/ndiv, nc*(j+1)/ndiv) - k1ind = np.argwhere( - (k1xy[:, 0] >= r.min()) & - (k1xy[:, 0] <= r.max()) & - (k1xy[:, 1] >= c.min()) & - (k1xy[:, 1] <= c.max())).flatten() - ransac_args.append([k1xy, k2xy, des1, des2, k1ind, args]) - results.append(ransac_chunk(ransac_args[-1])) - - for result in results: - k1 += result[0] - k2 += result[1] - - if len(k1) >= 1: - k1 = np.array(k1) / args['downsample_scale'] - k2 = np.array(k2) / args['downsample_scale'] - - if k1.shape[0] > args['matchMax']: - a = np.arange(k1.shape[0]) - np.random.shuffle(a) - k1 = k1[a[0: args['matchMax']], :] - k2 = k2[a[0: args['matchMax']], :] - - render = renderapi.connect(**args['render']) - pm_dict = make_pm(ids, gids, k1, k2) - - renderapi.pointmatch.import_matches( - args['match_collection'], - [pm_dict], - render=render) - - return [impaths, len(kp1), len(kp2), len(k1), len(k2)] + k1, k2 = chunk_match_keypoints( + k1xy, des1, k2xy, des2, + full_shape=pim.shape, + ransac_kwargs=ransac_kwargs, + **{**match_kwargs, **kwargs} + ) + if return_num_features: + return (k1, k2), (len(kp1), len(kp2)) + return k1, k2 + + +def locs_to_dict( + pGroupId, pId, loc_p, + qGroupId, qId, loc_q, + scale_factor=1.0, match_max=1000): + if loc_p.shape[0] < 0: + return + loc_p *= scale_factor + loc_q *= scale_factor + + if loc_p.shape[0] > match_max: + ind = np.arange(loc_p.shape[0]) + np.random.shuffle(ind) + ind = ind[0:match_max] + loc_p = loc_p[ind, ...] + loc_q = loc_q[ind, ...] + + return make_pm( + (pId, qId), + (pGroupId, qGroupId), + loc_p, loc_q) + + +def process_matches( + pId, pGroupId, p_image_uri, + qId, qGroupId, q_image_uri, + downsample_scale=1.0, + CLAHE_grid=None, CLAHE_clip=None, + matchMax=1000, + sift_kwargs=None, + **kwargs): + + pim = read_downsample_equalize_mask_uri( + p_image_uri, + downsample_scale, + CLAHE_grid=CLAHE_grid, + CLAHE_clip=CLAHE_clip) + qim = read_downsample_equalize_mask_uri( + q_image_uri, + downsample_scale, + CLAHE_grid=CLAHE_grid, + CLAHE_clip=CLAHE_clip) + + (loc_p, loc_q), (num_features_p, num_features_q) = sift_match_images( + pim, qim, sift_kwargs=sift_kwargs, + return_num_features=True, + **kwargs) + + pm_dict = locs_to_dict( + pGroupId, pId, loc_p, + qGroupId, qId, loc_q, + scale_factor=(1. / downsample_scale), + match_max=matchMax) + + return pm_dict, len(loc_p), num_features_p, num_features_q + + +def find_matches(fargs): + [impaths, ids, gids, args] = fargs + + pm_dict, num_matches, num_features_p, num_features_q = process_matches( + ids[0], gids[0], impaths[0], + ids[1], gids[1], impaths[1], + downsample_scale=args["downsample_scale"], + CLAHE_grid=args["CLAHE_grid"], + CLAHE_clip=args["CLAHE_clip"], + sift_kwargs={ + "nfeatures": args["SIFT_nfeature"], + "nOctaveLayers": args['SIFT_noctave'], + "sigma": args['SIFT_sigma'] + }, + match_kwargs={ + "ndiv": args["ndiv"], + "FLANN_ntree": args["FLANN_ntree"], + "ratio_of_dist": args["ratio_of_dist"], + "FLANN_ncheck": args["FLANN_ncheck"] + }, + ransac_kwargs={ + "RANSAC_outlier": args["RANSAC_outlier"] + }, + matchMax=args["matchMax"] + ) + + render = renderapi.connect(**args['render']) + + renderapi.pointmatch.import_matches( + args['match_collection'], + [pm_dict], + render=render) + + return [impaths, num_features_p, num_features_q, num_matches, num_matches] def make_pm(ids, gids, k1, k2): diff --git a/asap/residuals/compute_residuals.py b/asap/residuals/compute_residuals.py index 5201eb33..f5ce9d1d 100644 --- a/asap/residuals/compute_residuals.py +++ b/asap/residuals/compute_residuals.py @@ -3,48 +3,33 @@ import renderapi -def compute_residuals_within_group(render, stack, matchCollectionOwner, - matchCollection, z, min_points=1, - tilespecs=None): - session = requests.session() +def compute_residuals(tilespecs, matches, min_points=1, extra_statistics=None): + """from compute_residuals_in_group for in-memory""" + extra_statistics = extra_statistics or {} - # get the sectionID which is the group ID in point match collection - groupId = render.run( - renderapi.stack.get_sectionId_for_z, stack, z, session=session) + tId_to_tforms = {ts.tileId: ts.tforms for ts in tilespecs} - # get matches within the group for this section - allmatches = render.run( - renderapi.pointmatch.get_matches_within_group, - matchCollection, - groupId, - owner=matchCollectionOwner, - session=session) - - # get the tilespecs to extract the transformations - if tilespecs is None: - tilespecs = render.run(renderapi.tilespec.get_tile_specs_from_z, - stack, z, session=session) - tforms = {ts.tileId: ts.tforms for ts in tilespecs} - - tile_residuals = {key: np.empty((0, 1)) for key in tforms.keys()} - tile_rmse = {key: np.empty((0, 1)) for key in tforms.keys()} - pt_match_positions = {key: np.empty((0, 2)) for key in tforms.keys()} + tile_residuals = {key: np.empty((0, 1)) for key in tId_to_tforms.keys()} + tile_rmse = {key: np.empty((0, 1)) for key in tId_to_tforms.keys()} + pt_match_positions = {key: np.empty((0, 2)) for key in tId_to_tforms.keys()} statistics = {} - for i, match in enumerate(allmatches): + + for match in matches: pts_p = np.array(match['matches']['p']) pts_q = np.array(match['matches']['q']) - + if pts_p.shape[1] < min_points: continue + try: - t_p = tforms[match['pId']][-1].tform(pts_p.T) - t_q = tforms[match['qId']][-1].tform(pts_q.T) + t_p = tId_to_tforms[match['pId']][-1].tform(pts_p.T) + t_q = tId_to_tforms[match['qId']][-1].tform(pts_q.T) except KeyError: continue positions = (t_p + t_q) / 2. - + res = np.linalg.norm(t_p - t_q, axis=1) rmse = np.true_divide(res, res.shape[0]) @@ -55,38 +40,66 @@ def compute_residuals_within_group(render, stack, matchCollectionOwner, pt_match_positions[match['pId']] = np.append( pt_match_positions[match['pId']], positions, axis=0) - # remove empty entries from these dicts empty_keys = [k for k in tile_residuals if tile_residuals[k].size == 0] + for k in empty_keys: tile_residuals.pop(k) tile_rmse.pop(k) pt_match_positions.pop(k) statistics['tile_rmse'] = tile_rmse - statistics['z'] = z statistics['tile_residuals'] = tile_residuals statistics['pt_match_positions'] = pt_match_positions - session.close() + statistics = {**extra_statistics, **statistics} + + return statistics, matches - return statistics, allmatches +def compute_residuals_within_group(render, stack, matchCollectionOwner, + matchCollection, z, min_points=1, + tilespecs=None): + session = requests.session() -def compute_mean_tile_residuals(residuals): - tile_mean = {} + # get the sectionID which is the group ID in point match collection + groupId = render.run( + renderapi.stack.get_sectionId_for_z, stack, z, session=session) + + # get matches within the group for this section + allmatches = render.run( + renderapi.pointmatch.get_matches_within_group, + matchCollection, + groupId, + owner=matchCollectionOwner, + session=session) - # loop over each tile and compute the mean residual for each tile - # iteritems is specific to py2.7 - maxes = [np.nanmean(v) for v in residuals.values() if len(v) > 0] - maximum = np.max(maxes) + # get the tilespecs to extract the transformations + if tilespecs is None: + tilespecs = render.run(renderapi.tilespec.get_tile_specs_from_z, + stack, z, session=session) - for key in residuals: - if len(residuals[key]) == 0: - tile_mean[key] = maximum - else: - tile_mean[key] = np.nanmean(residuals[key]) + statistics, allmatches = compute_residuals( + tilespecs, allmatches, min_points=min_points, + extra_statistics={"z": z}) - return tile_mean + session.close() + + return statistics, allmatches + + +def compute_mean_tile_residuals(residuals): + tile_mean = { + tileId: (np.nanmean(tile_residuals) if tile_residuals.size else np.nan) + for tileId, tile_residuals in residuals.items() + } + tile_residual_max = np.nanmax(tile_mean.values()) + + return { + tileId: ( + tile_residual if not np.isnan(tile_residual) + else tile_residual_max) + for tileId, tile_residual in tile_mean.items() + } def get_tile_centers(tilespecs): diff --git a/asap/rough_align/fit_multiple_solves.py b/asap/rough_align/fit_multiple_solves.py new file mode 100644 index 00000000..63045793 --- /dev/null +++ b/asap/rough_align/fit_multiple_solves.py @@ -0,0 +1,484 @@ +#!/usr/bin/env python + +'''============================= import block ==============================''' +import concurrent.futures +import copy +import numpy as np +import logging +import renderapi +import bigfeta.bigfeta +import bigfeta.utils +import bigfeta.solve +import collections +import argschema +from bigfeta.bigfeta import create_CSR_A_fromobjects as CSRA +from asap.module.render_module import RenderModuleException +from asap.rough_align.schemas import (FitMultipleSolvesSchema, + FitMultipleSolvesOutputSchema) +# correct schema to use + +if __name__ == "__main__" and __package__ is None: + __package__ = "asap.rough_align.fit_multiple_solves" + +# in the following example, just the keywords are important. +# everything else is most likely customizable unless stated otherwise. + +example_render = { + "owner": "TEM", # use your custom owner name + "project": "MN12_L2_1A", # use your own project name + "port": 8888, # use your own port number + "host": "http://em-131db2.corp.alleninstitute.org", + # use your own render server + "client_scripts": ("/allen/aibs/pipeline/image_processing/" + "volume_assembly/render-jars/production/scripts" + ), + # use your own client-scripts + "memGB": "2G" # customizable, and optional. + # Can be changed/added later in the code if necessary. +} +example = { + "input_stack": dict(example_render, **{ + "name": "MN12_L2_1A_montscape_reord", # use your own name + "collection_type": "stack", + "db_interface": "render" + }), + # remapped, downsampled montage scapes to rough align + "pointmatch_collection": dict(example_render, **{ + "name": "MN12_L2_1A_rough_matches", # use your own name + "collection_type": "pointmatch", + "db_interface": "render" + }), # pointmatched collection run on the input stack + "rigid_output_stack": dict(example_render, **{ + "name": "MN12_L2_1A_rigid_rot_test", # use your own name + "collection_type": "stack", + "db_interface": "render" + }), # output of rigid rotation + "translation_output_stack": None, # if using this, follow similar schema. + # It is the output of rigid translation + "affine_output_stack": dict(example_render, **{ + "name": "MN12_L2_1A_affine_test", # use your own name + "collection_type": "stack", + "db_interface": "render" + }), # output of affine transform + "thin_plate_output_stack": dict(example_render, **{ + "name": "Mn12_L2_1A_tps_test", # use your own name + "collection_type": "stack", + "db_interface": "render" + }), # output of thin plate spline transform + "minZ": 0, # first remapped Z value as an integer + "maxZ": 801, # last remapped Z value as an integer + "pool_size": 20, # customizable. 20 is probably overkill. + # Will be used in any inherent concurrency in the module. + "close_stack": False, # Must remain False to keep the stack open- + # -so new ROI can be added. + "output_json": "/path/to/output/json/files" # required for proper- + # -functioning of the module. +} + +logger = logging.getLogger() + + +class ApplyMultipleSolvesException(RenderModuleException): + + """Raise exception by using try except blocks....""" + + +class TileException(RenderModuleException): + """Raise this when unmatched tile ids exist or + if multiple sections per z value exist""" + + +def create_resolved_tiles(render, pm_render, input_stack: str, + pm_collection: str, solve_range: tuple, + pool_size: int + ): + # input_stack and pm_collection are names of the + # (a) montage-scaped & Z-mapped stack, and + # (b) pointmatch collection from render, respectively; + # solve_range is a tuple of the minimum and maximum of the new_z values. + # pool_size is the concurrency from module input + stack_zs = [z for z in renderapi.stack.get_z_values_for_stack(input_stack, + render=render) if solve_range[0] <= z <= solve_range[-1]] + num_threads = pool_size # customizable + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as e: + rtfut_to_z = { + e.submit(renderapi.resolvedtiles.get_resolved_tiles_from_z, + input_stack, z, render=render): z for z in stack_zs + } + z_to_rts = { + rtfut_to_z[fut]: + fut.result() for fut in concurrent.futures.as_completed(rtfut_to_z) + } + + matchgroups = { + ts.layout.sectionId for ts in (i for l in ( + rts.tilespecs for rts in z_to_rts.values()) + for i in l) + } + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as e: + futs = [e.submit(renderapi.pointmatch.get_matches_with_group, + pm_collection, g, render=pm_render) for g in matchgroups] + allmatches = [i for l in ( + fut.result() for fut in concurrent.futures.as_completed(futs) + ) for i in l] + + match_pIds, match_qIds = map(set, + zip(*[(m["pId"], m["qId"]) for m in allmatches + ])) + match_tileIds = match_pIds | match_qIds + all_tileIds = {ts.tileId for ts in (i for l in ( + rts.tilespecs for rts in z_to_rts.values() + ) for i in l)} + + tIds_nomatch = all_tileIds - match_tileIds # any tiles without matches + + return {'resolved_tiles': z_to_rts, 'matched_tiles': match_tileIds, + 'All_matches': allmatches, 'All_tiles': all_tileIds, + 'Unmatched_tiles': tIds_nomatch + } + + +def nomatchz_counter(combined_rts, tIds_nomatch): + nomatchz_count = collections.Counter( + [combined_rts[tId].z for tId in tIds_nomatch] + ) # should be nothing in it if everything works well + if len(nomatchz_count) > 0: + word = sorted({tId: combined_rts[tId].z for tId in tIds_nomatch}) + raise TileException( + f'Unmatched TileIds Exist: {list(map(int, word))}') + return None + + +def multiz_counter(combined_rts): + multiz_count = { + k: v for k, v in collections.Counter( + [ts.z for ts in combined_rts.tilespecs] + ).items() if v > 1 + } + # should be nothing in it if everything works well + if len(multiz_count) > 0: + keys = list(multiz_count.keys()) + raise TileException(f'multiple sections in {keys}') + return None + + +def get_ts_timestamp(ts): + imgurl = ts.ip[0].imageUrl + return imgurl.split("_")[-4] + + +def create_input_defaults(render, pm_render, input_stack: str, + pm_collection: str, solve_range: tuple, + pool_size: int, multiz=False): + + combine_resolvedtiles = renderapi.resolvedtiles.combine_resolvedtiles + + res_tiles = create_resolved_tiles(render, pm_render, + input_stack=input_stack, + pm_collection=pm_collection, + solve_range=solve_range, + pool_size=pool_size) + + combined_rts = combine_resolvedtiles(res_tiles['resolved_tiles'].values()) + tId_to_tilespecs = {ts.tileId: ts for ts in combined_rts.tilespecs} + swap_matchpair = renderapi.pointmatch.swap_matchpair + new_matches = [] + for m in res_tiles['All_matches']: + try: + pz, qz = tId_to_tilespecs[m["pId"]].z, tId_to_tilespecs[m["qId"]].z + except KeyError: + continue + if pz < qz: + new_matches.append(m) + else: + new_matches.append(swap_matchpair(m)) + rmt = res_tiles['matched_tiles'] + combined_rts.tilespecs = [ + tId_to_tilespecs[tId] for tId in tId_to_tilespecs.keys() & rmt + ] + nomatchz_counter(combined_rts, res_tiles['Unmatched_tiles']) + if multiz: + multiz_counter(combined_rts) + return {'combined_tilespecs': combined_rts, 'new_matches': new_matches} + + +def rigid_solve(combined_rts, new_matches): + rigid_transform_name = "RotationModel" + rigid_order = 2 + rigid_fullsize = False + rigid_transform_apply = [] + rigid_regularization_dict = { + "translation_factor": 1e-10, + "default_lambda": 1e-12, + "freeze_first_tile": False, + "thinplate_factor": 5e7 + } + rigid_matrix_assembly_dict = { + "choose_random": True, + "depth": [1, 2, 3, 4, 5, 6], + "npts_max": 500, + "inverse_dz": True, + "npts_min": 5, + "cross_pt_weight": 0.5, + "montage_pt_weight": 1.0 + } + rig_matches = renderapi.pointmatch.copy_matches_explicit(new_matches) + rigid_create_CSR_A_input = (combined_rts, rig_matches, + rigid_transform_name, + rigid_transform_apply, + rigid_regularization_dict, + rigid_matrix_assembly_dict, rigid_order, + rigid_fullsize) + return rigid_create_CSR_A_input + + +def translation_solve(rigid_result_rts, new_matches): + translation_transform_name = "TranslationModel" + translation_order = 2 + translation_fullsize = False + translation_transform_apply = [] + translation_regularization_dict = { + "translation_factor": 1e0, + "default_lambda": 1e-5, + "freeze_first_tile": False, + "thinplate_factor": 5e7 + } + translation_matrix_assembly_dict = { + "choose_random": True, + "depth": [1, 2, 3, 4], + "npts_max": 500, + "inverse_dz": True, + "npts_min": 5, + "cross_pt_weight": 0.5, + "montage_pt_weight": 1.0 + } + trans_matches = renderapi.pointmatch.copy_matches_explicit(new_matches) + translation_create_CSR_A_input = ( + rigid_result_rts, trans_matches, translation_transform_name, + translation_transform_apply, translation_regularization_dict, + translation_matrix_assembly_dict, + translation_order, translation_fullsize) + + return translation_create_CSR_A_input + + +def affine_solve(rigid_result_rts, new_matches): + aff_transform_name = "AffineModel" + aff_order = 2 + aff_fullsize = False + aff_transform_apply = [] + aff_regularization_dict = { + "translation_factor": 1e-10, + "default_lambda": 1e7, + "freeze_first_tile": False, + "thinplate_factor": 5e7 + } + aff_matrix_assembly_dict = { + "choose_random": True, + "depth": [1, 2, 3, 4, 5, 6], + "npts_max": 500, + "inverse_dz": True, + "npts_min": 5, + "cross_pt_weight": 0.5, + "montage_pt_weight": 1.0 + } + aff_matches = renderapi.pointmatch.copy_matches_explicit(new_matches) + aff_create_CSR_A_input = ( + rigid_result_rts, aff_matches, aff_transform_name, aff_transform_apply, + aff_regularization_dict, aff_matrix_assembly_dict, aff_order, + aff_fullsize) + return aff_create_CSR_A_input + + +def baseline_vertices(xmin, xmax, ymin, ymax, nx, ny): + xt, yt = np.meshgrid( + np.linspace(xmin, xmax, nx), + np.linspace(ymin, ymax, ny) + ) + return np.vstack((xt.flatten(), yt.flatten())).transpose() + + +def tps_from_vertices(vertices): + tf = renderapi.transform.ThinPlateSplineTransform() + tf.ndims = 2 + tf.nLm = vertices.shape[0] + tf.aMtx = np.array([[0.0, 0.0], [0.0, 0.0]]) + tf.bVec = np.array([0.0, 0.0]) + tf.srcPts = vertices.transpose() + tf.dMtxDat = np.zeros_like(tf.srcPts) + return tf + + +def append_tps_tform(tspec, npts=5, ext=0.05): + bb = tspec.bbox_transformed() + xmin, ymin = bb.min(axis=0) + xmax, ymax = bb.max(axis=0) + w, h = np.ptp(bb, axis=0) + vert = baseline_vertices( + xmin - ext * w, + xmax + ext * w, + ymin - ext * h, + ymax + ext * h, + npts, npts + ) + tform = tps_from_vertices(vert) + tspec.tforms.append(tform) + + +def thin_plate_spline_solve(aff_result_rts, new_matches): + + tpsadded_aff_result_rts = copy.deepcopy(aff_result_rts) + for tspec in tpsadded_aff_result_rts.tilespecs: + append_tps_tform(tspec, npts=5) + + tps_transform_name = "ThinPlateSplineTransform" + tps_order = 2 + tps_fullsize = False + tps_transform_apply = [0] + tps_regularization_dict = { + "translation_factor": 1e-10, + "default_lambda": 1e7, + "freeze_first_tile": False, + "thinplate_factor": 5e7 + } + tps_matrix_assembly_dict = { + "choose_random": True, + "depth": 3, + "npts_max": 500, + "inverse_dz": True, + "npts_min": 5, + "cross_pt_weight": 0.5, + "montage_pt_weight": 1.0 + } + + tps_matches = renderapi.pointmatch.copy_matches_explicit(new_matches) + tps_create_CRS_A_input = ( + tpsadded_aff_result_rts, tps_matches, tps_transform_name, + tps_transform_apply, tps_regularization_dict, tps_matrix_assembly_dict, + tps_order, tps_fullsize) + return tps_create_CRS_A_input + + +class FitMultipleSolves(argschema.ArgSchemaParser): + default_schema = FitMultipleSolvesSchema + default_output_schema = FitMultipleSolvesOutputSchema + + def apply_transform(self, new_matches, rts, app): + app_solver = {'rigid': rigid_solve, + 'translate': translation_solve, + 'affine': affine_solve, + 'tps': thin_plate_spline_solve} + app_input = app_solver[app](rts, new_matches) + app_fr, app_draft_resolvedtiles = CSRA(*app_input, + return_draft_resolvedtiles=True) + app_result_rts = copy.deepcopy(app_draft_resolvedtiles) + app_sol = bigfeta.solve.solve(app_fr["A"], app_fr["weights"], + app_fr["reg"], app_fr["x"], app_fr["rhs"] + ) + bigfeta.utils.update_tilespecs(app_result_rts, app_sol["x"]) + return app_result_rts + + @staticmethod + def save_transform(render, rts, outstack): + renderapi.stack.create_stack(outstack, render=render) + renderapi.resolvedtiles.put_tilespecs(outstack, rts, render=render) + + def run(self): + r_in = renderapi.connect(**self.args['input_stack']) + name_in = self.args['input_stack']['name'][0] + + r_pm = renderapi.connect(**self.args['pointmatch_collection']) + name_pm = self.args['pointmatch_collection']['name'][0] + + r_rot = renderapi.connect(**self.args['rigid_output_stack']) + name_rot = self.args["rigid_output_stack"]["name"][0] + + r_aff = renderapi.connect(**self.args['affine_output_stack']) + name_aff = self.args["affine_output_stack"]["name"][0] + + r_tps = renderapi.connect(**self.args['thin_plate_output_stack']) + name_tps = self.args["thin_plate_output_stack"]["name"][0] + + if self.args['translation_output_stack'] is not None: + r_trans = renderapi.connect(**self.args['translation_output_stack']) + name_trans = self.args["translation_output_stack"]["name"][0] + do_translate = True + else: + do_translate = False + + allZ = [int(z) for z in renderapi.stack.get_z_values_for_stack( + name_in, render=r_in)] + + minZ = (allZ[0] if self.args["minZ"] is None else max(self.args["minZ"], + allZ[0])) + + maxZ = (allZ[-1] if self.args["maxZ"] is None else min(self.args["maxZ"], + allZ[-1])) + + sol_range = (minZ, maxZ) + + in_default = create_input_defaults( + render=r_in, pm_render=r_pm, + input_stack=name_in, + pm_collection=name_pm, + solve_range=sol_range, pool_size=self.args['pool_size'], + multiz=False + ) + + '''---------------------Rigid solve-------------------------''' + rigid_result_rts = self.apply_transform( + new_matches=in_default['new_matches'], + rts=in_default['combined_tilespecs'], app='rigid') + self.save_transform(r_rot, rigid_result_rts, name_rot) + if do_translate: + '''---------------Translation solve -------''' + + trans_result_rts = self.apply_transform( + new_matches=in_default['new_matches'], + rts=rigid_result_rts, app='translate') + + self.save_transform(r_trans, trans_result_rts, + name_trans) + + '''---------------Affine solve ------------''' + + aff_result_rts = self.apply_transform( + new_matches=in_default['new_matches'], + rts=trans_result_rts, app='affine') + else: + '''---------------Affine solve ------------''' + + aff_result_rts = self.apply_transform( + new_matches=in_default['new_matches'], + rts=rigid_result_rts, app='affine') + self.save_transform(r_aff, aff_result_rts, name_aff) + + '''---------------Thin_plate_spline solve--------------------''' + tps_result_rts = self.apply_transform( + new_matches=in_default['new_matches'], + rts=aff_result_rts, app='tps') + self.save_transform(r_tps, tps_result_rts, name_tps) + + allZ_out = [ + int(z) for z in renderapi.stack.get_z_values_for_stack( + name_rot, render=r_rot) + ] + + out_dict = { + 'zs': allZ_out, + 'rigid_output_stack': self.args['rigid_output_stack'], + 'affine_output_stack': self.args['affine_output_stack'], + 'thin_plate_output_stack': self.args['thin_plate_output_stack'] + } + + if do_translate: + out_dict['translation_output_stack'] = self.args[ + "translation_output_stack"] + + self.output(out_dict) + + +if __name__ == "__main__": + mod = FitMultipleSolves() + # Use input_data = example as an argument here for test runs. + mod.run() diff --git a/asap/rough_align/schemas.py b/asap/rough_align/schemas.py index 63d18f4b..4d4dd988 100644 --- a/asap/rough_align/schemas.py +++ b/asap/rough_align/schemas.py @@ -1,15 +1,16 @@ import argschema from argschema import InputDir import marshmallow as mm +from bigfeta.schemas import input_stack, output_stack, pointmatch from marshmallow import post_load, ValidationError from argschema.fields import ( - Bool, Float, Int, - Str, InputFile, List, Dict) -from argschema.schemas import DefaultSchema + Bool, Float, Int, NumpyArray, + Str, InputFile, List, Dict, Nested) +from argschema.schemas import DefaultSchema, ArgSchema from asap.module.schemas import ( - RenderParameters, - StackTransitionParameters) + RenderParameters, + StackTransitionParameters) class MakeAnchorStackSchema(StackTransitionParameters): @@ -29,11 +30,11 @@ class MakeAnchorStackSchema(StackTransitionParameters): "AffineModel transform jsons" "will override xml input.")) zValues = List( - Int, - required=False, - missing=[1000], - default=[1000], - description="not used in this module, keeps parents happy") + Int, + required=False, + missing=[1000], + default=[1000], + description="not used in this module, keeps parents happy") class PairwiseRigidSchema(StackTransitionParameters): @@ -298,9 +299,9 @@ class PointMatchCollectionParameters(DefaultSchema): class ApplyRoughAlignmentOutputParameters(DefaultSchema): zs = argschema.fields.NumpyArray( - description="list of z values that were applied to") + description="list of z values that were applied to") output_stack = argschema.fields.Str( - description="stack where applied transforms were set") + description="stack where applied transforms were set") class DownsampleMaskHandlerSchema(RenderParameters): @@ -341,3 +342,74 @@ class DownsampleMaskHandlerSchema(RenderParameters): required=False, default=['png', 'tif'], description="what kind of mask files to recognize") + + +class FitMultipleSolvesSchema(ArgSchema): + input_stack = Nested( + input_stack, + required=True, + description='downsampled sections for rough alignment') + pointmatch_collection = Nested( + pointmatch, + required=True, + description='pointmatch collection parameters') + rigid_output_stack = Nested( + output_stack, + required=True, + description='output stack name of rigid rotation transformed montages') + translation_output_stack = Nested( + output_stack, + allow_none=True, + required=False, + default=None, + missing=None, + description='output stack name of rigid translated montages') + affine_output_stack = Nested( + output_stack, + required=True, + description='output stack name of affine transformed montages') + thin_plate_output_stack = Nested( + output_stack, + required=True, + description=('output stack name of' + 'thin plate spline transformed montages') + ) + minZ = Int( + required=True, + description='first remapped Z value') + maxZ = Int( + required=True, + description='last remapped Z value') + pool_size = Int( + required=False, + default=10, + missing=10, + description='pool size for concurrency') + close_stack = Bool( + required=False, + default=True, + missing=True, + description='if True, then updates stack status to COMPLETE') + + +class FitMultipleSolvesOutputSchema(ArgSchema): + zs = List( + Int, + required=True, + description="list of z values that were applied to") + rigid_output_stack = Nested( + output_stack, + required=True, + description="stack where rigid transforms were set") + translation_output_stack = Nested( + output_stack, + required=False, + description="stack where rigid translation transforms were set") + affine_output_stack = Nested( + output_stack, + required=True, + description="stack where rigid transforms were set") + thin_plate_output_stack = Nested( + output_stack, + required=True, + description="stack where rigid transforms were set") diff --git a/entrypoint.sh b/entrypoint.sh index dc165354..47ff0f4e 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,3 +1,3 @@ #!/bin/bash -source activate render-modules -exec "$@" \ No newline at end of file +source activate asap +exec "$@" diff --git a/integration_tests/test_em_montage_qc.py b/integration_tests/test_em_montage_qc.py index 369156e9..3faca2a5 100644 --- a/integration_tests/test_em_montage_qc.py +++ b/integration_tests/test_em_montage_qc.py @@ -10,11 +10,16 @@ render_params, montage_qc_project) -from asap.em_montage_qc.detect_montage_defects import ( - DetectMontageDefectsModule) -from asap.module.render_module import RenderModuleException -from asap.em_montage_qc import detect_montage_defects -from asap.em_montage_qc.distorted_montages import DetectDistortedMontagesModule +IMPORTS_ERRORED = False +try: + from asap.em_montage_qc.detect_montage_defects import ( + DetectMontageDefectsModule) + from asap.module.render_module import RenderModuleException + from asap.em_montage_qc import detect_montage_defects + from asap.em_montage_qc.distorted_montages import DetectDistortedMontagesModule +except: + IMPORTS_ERRORED = True + raise @pytest.fixture(scope='module') @@ -148,9 +153,10 @@ def test_detect_montage_defects(render, # read the output json with open(ex['output_json'], 'r') as f: data = json.load(f) - f.close() assert(len(data['output_html']) > 0) + assert all([os.path.isfile(html_file) for html_file in data['output_html']]) + assert(len(data['seam_sections']) > 0) assert(len(data['hole_sections']) == 1) @@ -158,6 +164,7 @@ def test_detect_montage_defects(render, assert(len(data['distorted_sections']) == 0) assert(len(data['qc_passed_sections']) == 0) + for s in data['seam_centroids']: assert(len(s) > 0) diff --git a/integration_tests/test_rough_align.py b/integration_tests/test_rough_align.py index 08807d39..0a910374 100644 --- a/integration_tests/test_rough_align.py +++ b/integration_tests/test_rough_align.py @@ -32,6 +32,7 @@ from asap.solver.solve import Solve_stack from asap.rough_align.apply_rough_alignment_to_montages import ( ApplyRoughAlignmentTransform) +from asap.rough_align.fit_multiple_solves import FitMultipleSolves import shutil import numpy as np @@ -1238,3 +1239,67 @@ def test_multiple_transform_apply_rough( assert np.linalg.norm( np.array(a_lpt['local'][:2]) - np.array(r_lpt['local'][:2])) < 1 + + +@pytest.mark.parametrize("do_translate", [True, False]) +def test_fit_multiple_solves( + render, montage_scape_stack, + rough_point_match_collection, + tmpdir_factory, do_translate): + + output_json = tmpdir_factory.mktemp('output').join( + f'fit_multiple_t{int(do_translate)}_output.json') + + output_stack_base = f'{montage_scape_stack}_fit_multiple_t{int(do_translate)}_DS_Rough' + rotation_stack_name = '{}_Rotation'.format(output_stack_base) + translation_stack_name = '{}_Translation'.format(output_stack_base) + affine_stack_name = '{}_Affine'.format(output_stack_base) + tps_stack_name = '{}_TPS'.format(output_stack_base) + + input_dict = { + "input_stack": dict(render.make_kwargs(), **{ + "name": montage_scape_stack, + "collection_type": "stack" + }), + "pointmatch_collection": dict(render.make_kwargs(), **{ + "name": rough_point_match_collection, + "collection_type": "pointmatch" + }), + "rigid_output_stack": dict(render.make_kwargs(), **{ + "name": rotation_stack_name, + "collection_type": "stack", + }), + "translation_output_stack": ( + None if not do_translate else dict(render.make_kwargs(), **{ + "name": translation_stack_name, + "collection_type": "stack" + })), + "affine_output_stack": dict(render.make_kwargs(), **{ + "name": affine_stack_name, + "collection_type": "stack", + }), + "thin_plate_output_stack": dict(render.make_kwargs(), **{ + "name": tps_stack_name, + "collection_type": "stack", + }), + "minZ": 1020, + "maxZ": 1022, + "pool_size": pool_size, + "close_stack": False, + "output_json": str(output_json), + } + + mod = FitMultipleSolves(input_data=input_dict, args=[]) + mod.run() + + with open(output_json, 'r') as f: + output_data = json.load(f) + + # test output stacks for expected zs + stackresults = [v for k, v in output_data.items() if k.endswith('_stack')] + + expected_zvalues = [1020, 1021, 1022] + for stackresult in stackresults: + zvalues = renderapi.connect(**stackresult).run( + renderapi.stack.get_z_values_for_stack, stackresult['name'][0]) + assert(set(zvalues) == set(expected_zvalues)) diff --git a/integration_tests/test_rough_align_qc.py b/integration_tests/test_rough_align_qc.py index db8c903e..7b48998a 100644 --- a/integration_tests/test_rough_align_qc.py +++ b/integration_tests/test_rough_align_qc.py @@ -9,7 +9,11 @@ render_params ) -from asap.em_montage_qc.rough_align_qc import RoughAlignmentQC +IMPORT_ERROR = False +try: + from asap.em_montage_qc.rough_align_qc import RoughAlignmentQC +except: + IMPORT_ERROR = True @pytest.fixture(scope='module') diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..e4c5eedb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,109 @@ +[project] +name = "asap" +requires-python = ">3.8,<3.12" +dynamic = ["version"] +dependencies = [ + "render-python>=2.3.1", + "marshmallow", + "argschema", + "numpy", + "pillow", + "tifffile", + "pathlib2", + "scipy", + "rtree", + "networkx", + "bokeh", + "bigfeta", + "opencv-contrib-python-headless", + "em_stitch", + "matplotlib", + # "jinja2<3.1", + "six", + "scikit-image", + "shapely", + "triangle", + "uri_handler", + "imageio", + "seaborn", + "mpld3", + "descartes", + "lxml", + "pymongo==3.11.1", + "igraph" +] + +[project.optional-dependencies] +test = [ + "coverage>=4.1", + "mock>=2.0.0", + "pep8>=1.7.0", + "pytest>=3.0.5", + "pytest-cov>=2.2.1", + "pytest-pep8>=1.0.6", + "pytest-xdist>=1.14", + "flake8>=3.0.4", + "pylint>=1.5.4", + "codecov", + # "jinja2", +] + +[build-system] +requires = ["setuptools>=64", "setuptools_scm>=8"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +py-modules = [ + "asap" +] + +[tool.setuptools_scm] +version_file = "asap/_version.py" + +# some downgrades from pixi's expectations +[tool.pixi.system-requirements] +linux = "5.0.0" +libc = "2.27" + +[tool.pixi.project] +channels = ["conda-forge"] +platforms = ["linux-64"] + +[tool.pixi.pypi-dependencies] +asap = { path = ".", editable = true } + +# conda-enabled features +[tool.pixi.feature.pixi-base.dependencies] +numpy = "*" +pandas = "*" +scipy = "*" +imageio = "*" +matplotlib = "*" +petsc4py = "*" +python-igraph = "*" + +# version-specific python features +[tool.pixi.feature.py310.dependencies] +python = "3.10.*" +[tool.pixi.feature.py311.dependencies] +python = "3.11.*" + +[tool.pixi.feature.jupyterlab.dependencies] +jupyterlab = "*" + +[tool.pixi.environments] +py310 = ["py310"] +py310-test = ["py310", "test"] +py311 = ["py311"] +py311-test = ["py311", "test"] +py310-conda = ["py310", "pixi-base"] +py310-conda-test = ["py310", "pixi-base", "test"] +py311-conda = ["py311", "pixi-base"] +py311-conda-test = ["py311", "pixi-base", "test"] +py311-jupyter = ["py311", "pixi-base", "jupyterlab"] + +[tool.coverage.run] +omit = ["integration_tests/*", "tests/*"] + +[tool.pixi.feature.test.tasks] +test = "pytest --cov --cov-report=xml --junitxml=test-reports/test.xml" diff --git a/requirements.txt b/requirements.txt index 26ec0b91..d87c80cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ render-python>=2.3.1 -marshmallow -argschema +marshmallow<3.0 +argschema<2.0 numpy pillow tifffile @@ -8,7 +8,7 @@ pathlib2 scipy rtree networkx -bokeh<=1.4.0 +bokeh bigfeta opencv-contrib-python em_stitch @@ -24,3 +24,4 @@ seaborn mpld3 descartes lxml +igraph diff --git a/scripts/conda_env_setup.py b/scripts/conda_env_setup.py index 631a2b7e..8b8d9408 100644 --- a/scripts/conda_env_setup.py +++ b/scripts/conda_env_setup.py @@ -1,3 +1,3 @@ -conda create --name render-modules --prefix $IMAGE_PROCESSING_DEPLOY_PATH python=2.7 -source activate render-modules +conda create --name asap --prefix $IMAGE_PROCESSING_DEPLOY_PATH python=3.11 +source activate asap pip install -r ../requirements.txt diff --git a/tests/test_materialization.py b/tests/test_materialization.py index 0058de1d..6af56d6c 100644 --- a/tests/test_materialization.py +++ b/tests/test_materialization.py @@ -95,11 +95,11 @@ def test_validate_materialization( # truncate file causing truncated file ValueError on read truncbytes = os.path.getsize(truncfn) truncatefile(truncfn, truncbytes//2) - with pytest.raises(ValueError): + with pytest.raises((ValueError, IOError)): _ = imageio.imread(truncfn) # truncate file so that it is unreadable truncatefile(badfn, 0) - with pytest.raises(ValueError): + with pytest.raises((ValueError,)): _ = imageio.imread(badfn) # TODO there are some png cases which lead to SyntaxErrors? # run validation