Skip to content

Commit

Permalink
Adding GlueStick inference code
Browse files Browse the repository at this point in the history
  • Loading branch information
iago-suarez committed Apr 3, 2023
1 parent e1ad3b0 commit e054326
Show file tree
Hide file tree
Showing 18 changed files with 2,352 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/
.idea/*
*events.out.tfevents.*
/outputs
46 changes: 45 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,46 @@
# GlueStick
Joint Deep Matcher for Points and Lines 🖼️💥🖼️
Joint deep matcher for points and lines 🖼️💥🖼️

![Visualization of point and line matches](resources/demo_seq1.gif)

This repository contains the official implementation of
[GlueStick: Robust Image Matching by Sticking Points and Lines Together](#).

## Install 🛠️

To install the software in Ubuntu 22.04 follow these instructions:
```bash
sudo apt-get install build-essential cmake libopencv-dev libopencv-contrib-dev
git clone --recursive https://github.com/cvg/GlueStick.git
cd GlueStick
# Create and activate a virtual environment
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
```

## Running GlueStick 🏃
Download the weights of the model:
```
pip install gdown
gdown -O resources/weights/checkpoint_GlueStick_MD.tar https://drive.google.com/uc?id=1Gw26jVaU9fwOemQ3jBVINdJFdMXpV8Qv&export=download
```

You can execute the inference with it with:
```
python -m gluestick.run -img1 resources/img1.jpg -img2 resources/img2.jpg
```

## Training 🏋️
We want to provide you with high-quality and flexible code for training. Stay tuned, we will release it soon!

## Citation 📝
If you use this code in your project, please consider citing the following paper:
```bibtex
@article{pautrat_suarez_2023_gluestick,
title={{GlueStick}: Robust Image Matching by Sticking Points and Lines Together},
author={Pautrat, R{\'e}mi* and Su{\'a}rez, Iago* and Yu, Yifan and Pollefeys, Marc and Larsson, Viktor},
journal={ArXiv},
year={2023}
}
```
53 changes: 53 additions & 0 deletions gluestick/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import collections.abc as collections
from pathlib import Path

import torch

GLUESTICK_ROOT = Path(__file__).parent.parent


def get_class(mod_name, base_path, BaseClass):
"""Get the class object which inherits from BaseClass and is defined in
the module named mod_name, child of base_path.
"""
import inspect
mod_path = '{}.{}'.format(base_path, mod_name)
mod = __import__(mod_path, fromlist=[''])
classes = inspect.getmembers(mod, inspect.isclass)
# Filter classes defined in the module
classes = [c for c in classes if c[1].__module__ == mod_path]
# Filter classes inherited from BaseModel
classes = [c for c in classes if issubclass(c[1], BaseClass)]
assert len(classes) == 1, classes
return classes[0][1]


def get_model(name):
from .models.base_model import BaseModel
return get_class('models.' + name, __name__, BaseModel)


def numpy_image_to_torch(image):
"""Normalize the image tensor and reorder the dimensions."""
if image.ndim == 3:
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
elif image.ndim == 2:
image = image[None] # add channel axis
else:
raise ValueError(f'Not an image: {image.shape}')
return torch.from_numpy(image / 255.).float()


def map_tensor(input_, func):
if isinstance(input_, (str, bytes)):
return input_
elif isinstance(input_, collections.Mapping):
return {k: map_tensor(sample, func) for k, sample in input_.items()}
elif isinstance(input_, collections.Sequence):
return [map_tensor(sample, func) for sample in input_]
else:
return func(input_)


def batch_to_np(batch):
return map_tensor(batch, lambda t: t.detach().cpu().numpy()[0])
166 changes: 166 additions & 0 deletions gluestick/drawing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5,
adaptive=True):
"""Plot a set of images horizontally.
Args:
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
titles: a list of strings, as titles for each image.
cmaps: colormaps for monochrome images.
adaptive: whether the figure size should fit the image aspect ratios.
"""
n = len(imgs)
if not isinstance(cmaps, (list, tuple)):
cmaps = [cmaps] * n

if adaptive:
ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
else:
ratios = [4 / 3] * n
figsize = [sum(ratios) * 4.5, 4.5]
fig, ax = plt.subplots(
1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios})
if n == 1:
ax = [ax]
for i in range(n):
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
ax[i].get_yaxis().set_ticks([])
ax[i].get_xaxis().set_ticks([])
ax[i].set_axis_off()
for spine in ax[i].spines.values(): # remove frame
spine.set_visible(False)
if titles:
ax[i].set_title(titles[i])
fig.tight_layout(pad=pad)
return ax


def plot_keypoints(kpts, colors='lime', ps=4, alpha=1):
"""Plot keypoints for existing images.
Args:
kpts: list of ndarrays of size (N, 2).
colors: string, or list of list of tuples (one for each keypoints).
ps: size of the keypoints as float.
"""
if not isinstance(colors, list):
colors = [colors] * len(kpts)
axes = plt.gcf().axes
for a, k, c in zip(axes, kpts, colors):
a.scatter(k[:, 0], k[:, 1], c=c, s=ps, alpha=alpha, linewidths=0)


def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.):
"""Plot matches for a pair of existing images.
Args:
kpts0, kpts1: corresponding keypoints of size (N, 2).
color: color of each match, string or RGB tuple. Random if not given.
lw: width of the lines.
ps: size of the end points (no endpoint if ps=0)
indices: indices of the images to draw the matches on.
a: alpha opacity of the match lines.
"""
fig = plt.gcf()
ax = fig.axes
assert len(ax) > max(indices)
ax0, ax1 = ax[indices[0]], ax[indices[1]]
fig.canvas.draw()

assert len(kpts0) == len(kpts1)
if color is None:
color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
color = [color] * len(kpts0)

if lw > 0:
# transform the points into the figure coordinate system
transFigure = fig.transFigure.inverted()
fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
fig.lines += [matplotlib.lines.Line2D(
(fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]),
zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw,
alpha=a)
for i in range(len(kpts0))]

# freeze the axes to prevent the transform to change
ax0.autoscale(enable=False)
ax1.autoscale(enable=False)

if ps > 0:
ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)


def plot_lines(lines, line_colors='orange', point_colors='cyan',
ps=4, lw=2, alpha=1., indices=(0, 1)):
""" Plot lines and endpoints for existing images.
Args:
lines: list of ndarrays of size (N, 2, 2).
colors: string, or list of list of tuples (one for each keypoints).
ps: size of the keypoints as float pixels.
lw: line width as float pixels.
alpha: transparency of the points and lines.
indices: indices of the images to draw the matches on.
"""
if not isinstance(line_colors, list):
line_colors = [line_colors] * len(lines)
if not isinstance(point_colors, list):
point_colors = [point_colors] * len(lines)

fig = plt.gcf()
ax = fig.axes
assert len(ax) > max(indices)
axes = [ax[i] for i in indices]
fig.canvas.draw()

# Plot the lines and junctions
for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
for i in range(len(l)):
line = matplotlib.lines.Line2D((l[i, 0, 0], l[i, 1, 0]),
(l[i, 0, 1], l[i, 1, 1]),
zorder=1, c=lc, linewidth=lw,
alpha=alpha)
a.add_line(line)
pts = l.reshape(-1, 2)
a.scatter(pts[:, 0], pts[:, 1],
c=pc, s=ps, linewidths=0, zorder=2, alpha=alpha)


def plot_color_line_matches(lines, correct_matches=None,
lw=2, indices=(0, 1)):
"""Plot line matches for existing images with multiple colors.
Args:
lines: list of ndarrays of size (N, 2, 2).
correct_matches: bool array of size (N,) indicating correct matches.
lw: line width as float pixels.
indices: indices of the images to draw the matches on.
"""
n_lines = len(lines[0])
colors = sns.color_palette('husl', n_colors=n_lines)
np.random.shuffle(colors)
alphas = np.ones(n_lines)
# If correct_matches is not None, display wrong matches with a low alpha
if correct_matches is not None:
alphas[~np.array(correct_matches)] = 0.2

fig = plt.gcf()
ax = fig.axes
assert len(ax) > max(indices)
axes = [ax[i] for i in indices]
fig.canvas.draw()

# Plot the lines
for a, l in zip(axes, lines):
# Transform the points into the figure coordinate system
transFigure = fig.transFigure.inverted()
endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
fig.lines += [matplotlib.lines.Line2D(
(endpoint0[i, 0], endpoint1[i, 0]),
(endpoint0[i, 1], endpoint1[i, 1]),
zorder=1, transform=fig.transFigure, c=colors[i],
alpha=alphas[i], linewidth=lw) for i in range(n_lines)]
Loading

0 comments on commit e054326

Please sign in to comment.