-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e1ad3b0
commit e054326
Showing
18 changed files
with
2,352 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,3 +127,6 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
.idea/* | ||
*events.out.tfevents.* | ||
/outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,46 @@ | ||
# GlueStick | ||
Joint Deep Matcher for Points and Lines 🖼️💥🖼️ | ||
Joint deep matcher for points and lines 🖼️💥🖼️ | ||
|
||
![Visualization of point and line matches](resources/demo_seq1.gif) | ||
|
||
This repository contains the official implementation of | ||
[GlueStick: Robust Image Matching by Sticking Points and Lines Together](#). | ||
|
||
## Install 🛠️ | ||
|
||
To install the software in Ubuntu 22.04 follow these instructions: | ||
```bash | ||
sudo apt-get install build-essential cmake libopencv-dev libopencv-contrib-dev | ||
git clone --recursive https://github.com/cvg/GlueStick.git | ||
cd GlueStick | ||
# Create and activate a virtual environment | ||
python -m venv venv | ||
source venv/bin/activate | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Running GlueStick 🏃 | ||
Download the weights of the model: | ||
``` | ||
pip install gdown | ||
gdown -O resources/weights/checkpoint_GlueStick_MD.tar https://drive.google.com/uc?id=1Gw26jVaU9fwOemQ3jBVINdJFdMXpV8Qv&export=download | ||
``` | ||
|
||
You can execute the inference with it with: | ||
``` | ||
python -m gluestick.run -img1 resources/img1.jpg -img2 resources/img2.jpg | ||
``` | ||
|
||
## Training 🏋️ | ||
We want to provide you with high-quality and flexible code for training. Stay tuned, we will release it soon! | ||
|
||
## Citation 📝 | ||
If you use this code in your project, please consider citing the following paper: | ||
```bibtex | ||
@article{pautrat_suarez_2023_gluestick, | ||
title={{GlueStick}: Robust Image Matching by Sticking Points and Lines Together}, | ||
author={Pautrat, R{\'e}mi* and Su{\'a}rez, Iago* and Yu, Yifan and Pollefeys, Marc and Larsson, Viktor}, | ||
journal={ArXiv}, | ||
year={2023} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import collections.abc as collections | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
GLUESTICK_ROOT = Path(__file__).parent.parent | ||
|
||
|
||
def get_class(mod_name, base_path, BaseClass): | ||
"""Get the class object which inherits from BaseClass and is defined in | ||
the module named mod_name, child of base_path. | ||
""" | ||
import inspect | ||
mod_path = '{}.{}'.format(base_path, mod_name) | ||
mod = __import__(mod_path, fromlist=['']) | ||
classes = inspect.getmembers(mod, inspect.isclass) | ||
# Filter classes defined in the module | ||
classes = [c for c in classes if c[1].__module__ == mod_path] | ||
# Filter classes inherited from BaseModel | ||
classes = [c for c in classes if issubclass(c[1], BaseClass)] | ||
assert len(classes) == 1, classes | ||
return classes[0][1] | ||
|
||
|
||
def get_model(name): | ||
from .models.base_model import BaseModel | ||
return get_class('models.' + name, __name__, BaseModel) | ||
|
||
|
||
def numpy_image_to_torch(image): | ||
"""Normalize the image tensor and reorder the dimensions.""" | ||
if image.ndim == 3: | ||
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW | ||
elif image.ndim == 2: | ||
image = image[None] # add channel axis | ||
else: | ||
raise ValueError(f'Not an image: {image.shape}') | ||
return torch.from_numpy(image / 255.).float() | ||
|
||
|
||
def map_tensor(input_, func): | ||
if isinstance(input_, (str, bytes)): | ||
return input_ | ||
elif isinstance(input_, collections.Mapping): | ||
return {k: map_tensor(sample, func) for k, sample in input_.items()} | ||
elif isinstance(input_, collections.Sequence): | ||
return [map_tensor(sample, func) for sample in input_] | ||
else: | ||
return func(input_) | ||
|
||
|
||
def batch_to_np(batch): | ||
return map_tensor(batch, lambda t: t.detach().cpu().numpy()[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import matplotlib | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import seaborn as sns | ||
|
||
|
||
def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5, | ||
adaptive=True): | ||
"""Plot a set of images horizontally. | ||
Args: | ||
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). | ||
titles: a list of strings, as titles for each image. | ||
cmaps: colormaps for monochrome images. | ||
adaptive: whether the figure size should fit the image aspect ratios. | ||
""" | ||
n = len(imgs) | ||
if not isinstance(cmaps, (list, tuple)): | ||
cmaps = [cmaps] * n | ||
|
||
if adaptive: | ||
ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H | ||
else: | ||
ratios = [4 / 3] * n | ||
figsize = [sum(ratios) * 4.5, 4.5] | ||
fig, ax = plt.subplots( | ||
1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios}) | ||
if n == 1: | ||
ax = [ax] | ||
for i in range(n): | ||
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) | ||
ax[i].get_yaxis().set_ticks([]) | ||
ax[i].get_xaxis().set_ticks([]) | ||
ax[i].set_axis_off() | ||
for spine in ax[i].spines.values(): # remove frame | ||
spine.set_visible(False) | ||
if titles: | ||
ax[i].set_title(titles[i]) | ||
fig.tight_layout(pad=pad) | ||
return ax | ||
|
||
|
||
def plot_keypoints(kpts, colors='lime', ps=4, alpha=1): | ||
"""Plot keypoints for existing images. | ||
Args: | ||
kpts: list of ndarrays of size (N, 2). | ||
colors: string, or list of list of tuples (one for each keypoints). | ||
ps: size of the keypoints as float. | ||
""" | ||
if not isinstance(colors, list): | ||
colors = [colors] * len(kpts) | ||
axes = plt.gcf().axes | ||
for a, k, c in zip(axes, kpts, colors): | ||
a.scatter(k[:, 0], k[:, 1], c=c, s=ps, alpha=alpha, linewidths=0) | ||
|
||
|
||
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.): | ||
"""Plot matches for a pair of existing images. | ||
Args: | ||
kpts0, kpts1: corresponding keypoints of size (N, 2). | ||
color: color of each match, string or RGB tuple. Random if not given. | ||
lw: width of the lines. | ||
ps: size of the end points (no endpoint if ps=0) | ||
indices: indices of the images to draw the matches on. | ||
a: alpha opacity of the match lines. | ||
""" | ||
fig = plt.gcf() | ||
ax = fig.axes | ||
assert len(ax) > max(indices) | ||
ax0, ax1 = ax[indices[0]], ax[indices[1]] | ||
fig.canvas.draw() | ||
|
||
assert len(kpts0) == len(kpts1) | ||
if color is None: | ||
color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() | ||
elif len(color) > 0 and not isinstance(color[0], (tuple, list)): | ||
color = [color] * len(kpts0) | ||
|
||
if lw > 0: | ||
# transform the points into the figure coordinate system | ||
transFigure = fig.transFigure.inverted() | ||
fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) | ||
fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) | ||
fig.lines += [matplotlib.lines.Line2D( | ||
(fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), | ||
zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw, | ||
alpha=a) | ||
for i in range(len(kpts0))] | ||
|
||
# freeze the axes to prevent the transform to change | ||
ax0.autoscale(enable=False) | ||
ax1.autoscale(enable=False) | ||
|
||
if ps > 0: | ||
ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) | ||
ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) | ||
|
||
|
||
def plot_lines(lines, line_colors='orange', point_colors='cyan', | ||
ps=4, lw=2, alpha=1., indices=(0, 1)): | ||
""" Plot lines and endpoints for existing images. | ||
Args: | ||
lines: list of ndarrays of size (N, 2, 2). | ||
colors: string, or list of list of tuples (one for each keypoints). | ||
ps: size of the keypoints as float pixels. | ||
lw: line width as float pixels. | ||
alpha: transparency of the points and lines. | ||
indices: indices of the images to draw the matches on. | ||
""" | ||
if not isinstance(line_colors, list): | ||
line_colors = [line_colors] * len(lines) | ||
if not isinstance(point_colors, list): | ||
point_colors = [point_colors] * len(lines) | ||
|
||
fig = plt.gcf() | ||
ax = fig.axes | ||
assert len(ax) > max(indices) | ||
axes = [ax[i] for i in indices] | ||
fig.canvas.draw() | ||
|
||
# Plot the lines and junctions | ||
for a, l, lc, pc in zip(axes, lines, line_colors, point_colors): | ||
for i in range(len(l)): | ||
line = matplotlib.lines.Line2D((l[i, 0, 0], l[i, 1, 0]), | ||
(l[i, 0, 1], l[i, 1, 1]), | ||
zorder=1, c=lc, linewidth=lw, | ||
alpha=alpha) | ||
a.add_line(line) | ||
pts = l.reshape(-1, 2) | ||
a.scatter(pts[:, 0], pts[:, 1], | ||
c=pc, s=ps, linewidths=0, zorder=2, alpha=alpha) | ||
|
||
|
||
def plot_color_line_matches(lines, correct_matches=None, | ||
lw=2, indices=(0, 1)): | ||
"""Plot line matches for existing images with multiple colors. | ||
Args: | ||
lines: list of ndarrays of size (N, 2, 2). | ||
correct_matches: bool array of size (N,) indicating correct matches. | ||
lw: line width as float pixels. | ||
indices: indices of the images to draw the matches on. | ||
""" | ||
n_lines = len(lines[0]) | ||
colors = sns.color_palette('husl', n_colors=n_lines) | ||
np.random.shuffle(colors) | ||
alphas = np.ones(n_lines) | ||
# If correct_matches is not None, display wrong matches with a low alpha | ||
if correct_matches is not None: | ||
alphas[~np.array(correct_matches)] = 0.2 | ||
|
||
fig = plt.gcf() | ||
ax = fig.axes | ||
assert len(ax) > max(indices) | ||
axes = [ax[i] for i in indices] | ||
fig.canvas.draw() | ||
|
||
# Plot the lines | ||
for a, l in zip(axes, lines): | ||
# Transform the points into the figure coordinate system | ||
transFigure = fig.transFigure.inverted() | ||
endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) | ||
endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) | ||
fig.lines += [matplotlib.lines.Line2D( | ||
(endpoint0[i, 0], endpoint1[i, 0]), | ||
(endpoint0[i, 1], endpoint1[i, 1]), | ||
zorder=1, transform=fig.transFigure, c=colors[i], | ||
alpha=alphas[i], linewidth=lw) for i in range(n_lines)] |
Oops, something went wrong.