Skip to content

Commit

Permalink
downsample option in make_video, add scripts to generate flowrate plo…
Browse files Browse the repository at this point in the history
…ts using orb and xcorr
  • Loading branch information
i-jey committed Mar 17, 2024
1 parent c7379a7 commit 8db75ac
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 6 deletions.
2 changes: 1 addition & 1 deletion lfm_data_utilities/image_processing/make_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@
print("Generating videos...")
with mp.Pool() as pool:
pool.starmap(
utils.make_video, [(x, path_to_save) for x in tqdm(valid_datasets)]
utils.make_video, [(x, path_to_save, 1) for x in tqdm(valid_datasets)]
)
115 changes: 115 additions & 0 deletions lfm_data_utilities/image_processing/orb_flowrate_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import argparse
from pathlib import Path
import pickle

import cv2
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import zarr


def get_ewma(data, alpha=0.1):
prev = data[0]
ewma_vals = [prev]
for v in data[1:]:
new_val = prev * (1 - alpha) + v * alpha
ewma_vals.append(new_val)
prev = new_val
return ewma_vals


def get_diffs_from_matches(matches, kp, t_kp):
x_diffs = []
y_diffs = []
for match in matches:
p1 = kp[match.queryIdx].pt
p2 = t_kp[match.trainIdx].pt
x_diffs.append(p2[0] - p1[0])
y_diffs.append(p2[1] - p1[1])
return np.asarray(x_diffs), np.asarray(y_diffs)


def get_orb_xy_diffs(zf_path: Path, scale_factor: int = 1):
num_features = 500
zf = zarr.open(zf_path, "r")
orb = cv2.ORB_create(num_features)

h, w = zf[:, :, 0].shape

matcher = cv2.BFMatcher()

x_diff_pointer = 0
y_diff_pointer = 0
all_x_diffs = np.zeros(num_features * zf.initialized)
all_y_diffs = np.zeros(num_features * zf.initialized)

pos_y_diff_means = np.zeros(zf.initialized)

for i in tqdm(range(1, zf.initialized)):
i1 = cv2.resize(zf[:, :, i - 1], (w // scale_factor, h // scale_factor))
i2 = cv2.resize(zf[:, :, i], (w // scale_factor, h // scale_factor))

kp, des = orb.detectAndCompute(i1, None)
t_kp, t_des = orb.detectAndCompute(i2, None)

matches = matcher.match(des, t_des)

x_diffs, y_diffs = np.asarray(get_diffs_from_matches(matches, kp, t_kp))

if len(y_diffs[y_diffs > 0]) > 0:
pos_mean = np.mean(y_diffs[y_diffs > 0])
else:
pos_mean = 0

all_x_diffs[x_diff_pointer : x_diff_pointer + len(x_diffs)] = x_diffs
all_y_diffs[y_diff_pointer : y_diff_pointer + len(y_diffs)] = y_diffs
x_diff_pointer += len(x_diffs)
y_diff_pointer += len(y_diffs)

pos_y_diff_means[i] = pos_mean

return all_x_diffs, all_y_diffs, pos_y_diff_means


if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="Simple ORB plot",
description="Given a zarr file and a downsampling factor, save a plot of the ORB y diffs",
)

parser.add_argument("zarr_path", type=Path, help="Path to the zarr file")
parser.add_argument(
"downsample_factor", type=int, help="Downsampling factor", default=1
)
parser.add_argument("save_loc", type=Path, help="Path to save the plot")

args = parser.parse_args()
ds_factor = args.downsample_factor

# Get orb diffs
all_x_diffs, all_y_diffs, pos_y_diff_means = get_orb_xy_diffs(
args.zarr_path, scale_factor=ds_factor
)
m, sd = np.mean(pos_y_diff_means) * ds_factor, np.std(pos_y_diff_means) * ds_factor

ewma_alpha = 0.05

# Plot
fig = plt.figure(figsize=(12, 8))
fig.suptitle(f"{Path(args.zarr_path).stem}")
plt.plot(pos_y_diff_means * ds_factor, "o", markersize=0.5, alpha=0.5, label="Raw")
plt.plot(
get_ewma(pos_y_diff_means * ds_factor, ewma_alpha),
alpha=0.75,
label=f"EWMA, alpha={ewma_alpha}",
)
plt.title(
f"Downsampled {ds_factor}x ORB positive y feature diffs vs. frame\nm={m:.2f}, sd={sd:.2f}"
)
plt.xlabel("Frame idx")
plt.ylabel("Displacement (pixels)")
plt.ylim(0, 772)
plt.legend()

plt.savefig(f"{args.save_loc}/{Path(args.zarr_path).stem}_orb_ds{ds_factor}.png")
115 changes: 115 additions & 0 deletions lfm_data_utilities/image_processing/orb_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import argparse
from pathlib import Path
import pickle

import cv2
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import zarr


def get_ewma(data, alpha=0.1):
prev = data[0]
ewma_vals = [prev]
for v in data[1:]:
new_val = prev * (1 - alpha) + v * alpha
ewma_vals.append(new_val)
prev = new_val
return ewma_vals


def get_diffs_from_matches(matches, kp, t_kp):
x_diffs = []
y_diffs = []
for match in matches:
p1 = kp[match.queryIdx].pt
p2 = t_kp[match.trainIdx].pt
x_diffs.append(p2[0] - p1[0])
y_diffs.append(p2[1] - p1[1])
return np.asarray(x_diffs), np.asarray(y_diffs)


def get_orb_xy_diffs(zf_path: Path, scale_factor: int = 1):
num_features = 500
zf = zarr.open(zf_path, "r")
orb = cv2.ORB_create(num_features)

h, w = zf[:, :, 0].shape

matcher = cv2.BFMatcher()

x_diff_pointer = 0
y_diff_pointer = 0
all_x_diffs = np.zeros(num_features * zf.initialized)
all_y_diffs = np.zeros(num_features * zf.initialized)

pos_y_diff_means = np.zeros(zf.initialized)

for i in tqdm(range(1, zf.initialized)):
i1 = cv2.resize(zf[:, :, i - 1], (w // scale_factor, h // scale_factor))
i2 = cv2.resize(zf[:, :, i], (w // scale_factor, h // scale_factor))

kp, des = orb.detectAndCompute(i1, None)
t_kp, t_des = orb.detectAndCompute(i2, None)

matches = matcher.match(des, t_des)

x_diffs, y_diffs = np.asarray(get_diffs_from_matches(matches, kp, t_kp))

if len(y_diffs[y_diffs > 0]) > 0:
pos_mean = np.mean(y_diffs[y_diffs > 0])
else:
pos_mean = 0

all_x_diffs[x_diff_pointer : x_diff_pointer + len(x_diffs)] = x_diffs
all_y_diffs[y_diff_pointer : y_diff_pointer + len(y_diffs)] = y_diffs
x_diff_pointer += len(x_diffs)
y_diff_pointer += len(y_diffs)

pos_y_diff_means[i] = pos_mean

return all_x_diffs, all_y_diffs, pos_y_diff_means


if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="Simple ORB plot",
description="Given a zarr file and a downsampling factor, save a plot of the ORB y diffs",
)

parser.add_argument("zarr_path", type=Path, help="Path to the zarr file")
parser.add_argument(
"downsample_factor", type=int, help="Downsampling factor", default=1
)
parser.add_argument("save_loc", type=Path, help="Path to save the plot")

args = parser.parse_args()
ds_factor = args.downsample_factor

# Get orb diffs
all_x_diffs, all_y_diffs, pos_y_diff_means = get_orb_xy_diffs(
args.zarr_path, scale_factor=ds_factor
)
m, sd = np.mean(pos_y_diff_means) * ds_factor, np.std(pos_y_diff_means) * ds_factor

ewma_alpha = 0.05

# Plot
fig = plt.figure(figsize=(12, 8))
fig.suptitle(f"{Path(args.zarr_path).stem}")
plt.plot(pos_y_diff_means * ds_factor, "o", markersize=0.5, alpha=0.5, label="Raw")
plt.plot(
get_ewma(pos_y_diff_means * ds_factor, ewma_alpha),
alpha=0.75,
label=f"EWMA, alpha={ewma_alpha}",
)
plt.title(
f"Downsampled {ds_factor}x ORB positive y feature diffs vs. frame\nm={m:.2f}, sd={sd:.2f}"
)
plt.xlabel("Frame idx")
plt.ylabel("Displacement (pixels)")
plt.ylim(0, 772)
plt.legend()

plt.savefig(f"{args.save_loc}/{Path(args.zarr_path).stem}_orb_ds{ds_factor}.png")
18 changes: 13 additions & 5 deletions lfm_data_utilities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def get_rms(data: List[float]):
return np.sqrt(ms / N)


def make_video(dataset: Dataset, save_dir: PathLike):
def make_video(dataset: Dataset, save_dir: PathLike, ds_factor: int = 1):
zf = dataset.zarr_file
per_img_csv = dataset.per_img_metadata

Expand All @@ -164,20 +164,28 @@ def make_video(dataset: Dataset, save_dir: PathLike):
height, width = zf[:, :, 0].shape

save_dir.mkdir(exist_ok=True)
output_path = Path(save_dir) / Path(dataset.dp.zarr_path.stem + ".mp4")
output_path = Path(save_dir) / Path(
dataset.dp.zarr_path.stem + f"ds_{ds_factor}.mp4"
)

writer = cv2.VideoWriter(
f"{output_path}",
fourcc=cv2.VideoWriter_fourcc(*"mp4v"),
fps=framerate,
frameSize=(width, height),
frameSize=(width // ds_factor, height // ds_factor),
isColor=False,
)

for i, _ in enumerate(tqdm(range(num_frames))):
img = zf[:, :, i]
img = cv2.resize(zf[:, :, i], (width // ds_factor, height // ds_factor))
img = cv2.putText(
img, f"img {i}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 1
img,
f"img {i}",
(50 // ds_factor, 50 // ds_factor),
cv2.FONT_HERSHEY_SIMPLEX,
1 // ds_factor,
(0, 0, 0),
1,
)
writer.write(img)
writer.release()
Expand Down

0 comments on commit 8db75ac

Please sign in to comment.