Skip to content

Commit

Permalink
speed optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaconti committed Dec 29, 2022
1 parent f165013 commit fa6f54a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 47 deletions.
1 change: 1 addition & 0 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def main():
else str(Path(__file__).parent),
args.dataset.replace("-", "_") + "_precomputed",
args.density,
in_memory=True,
source="github" if not _DEBUG_LOCAL else "local",
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
Expand Down
66 changes: 29 additions & 37 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,39 @@
from pathlib import Path as _Path
import torch as _torch
from torch.utils.data import Dataset as _Dataset
import tarfile as _tarfile
from typing import Literal, Dict
import h5py
import numpy as np


def _download(name: str, ext: str = "") -> _Path:
def _download(name: str) -> _Path:
"""
Downloads from github the required resource, and if it is a tar dir it extract it in a
folder with the same name without extension
downloads a file from github and puts it under downloads
"""
if not (out_path := (_Path(__file__).parent / f"downloads/{name}{ext}")).exists():
to = _Path(__file__).parent / "downloads" / name
if not to.exists():
_torch.hub.download_url_to_file(
f"https://github.com/andreaconti/sparsity-agnostic-depth-completion/releases/download/v0.1.0/{name}{ext}",
str(_Path(__file__).parent / f"downloads/{name}{ext}"),
f"https://github.com/andreaconti/sparsity-agnostic-depth-completion/releases/download/v0.1.0/{name}",
to
)
if ext == ".tar":
if not (out_dir := out_path.parent / name).exists():
with _tarfile.open(out_path) as tar:
out_dir = out_path.parent / name
out_dir.mkdir(exist_ok=True, parents=True)
tar.extractall(out_dir)
return out_dir
else:
return out_path

return to

# precomputed results


class _PrecomputedDataset(_Dataset):
def __init__(self, img_gt_root: _Path, pred_hints_root: _Path):
def __init__(self, img_gt_root: _Path, pred_hints_root: _Path, in_memory: bool = False):
img_gt = h5py.File(img_gt_root)
self._img = np.array(img_gt["img"])
self._gt = np.array(img_gt["gt"])
self._img = img_gt["img"]
self._gt = img_gt["gt"]
pred_hints = h5py.File(pred_hints_root)
self._preds = np.array(pred_hints["preds"])
self._hints = np.array(pred_hints["hints"])
self._preds = pred_hints["preds"]
self._hints = pred_hints["hints"]
if in_memory:
self._img = np.array(self._img)
self._gt = np.array(self._gt)
self._preds = np.array(self._preds)
self._hints = np.array(self._hints)

def __len__(self):
return self._img.shape[0]
Expand All @@ -53,23 +48,20 @@ def __getitem__(self, index) -> Dict[str, np.ndarray]:


def kitti_official_precomputed(
hints_density: Literal["lines4", "lines8", "lines16", "lines32", "lines64"]
hints_density: Literal["lines4", "lines8", "lines16", "lines32", "lines64"],
in_memory: bool = False,
) -> _Dataset:
root = _download("kitti-official", ".tar")
return _PrecomputedDataset(
root / "img_gt.h5", root / f"pred_with_{hints_density}.h5"
)
assert hints_density in ["lines4", "lines8", "lines16", "lines32", "lines64"], f"{hints_density} not available"
img_gt = _download("kitti_img_gt.h5")
preds = _download(f"kitti_pred_with_{hints_density}.h5")
return _PrecomputedDataset(img_gt, preds, in_memory)


def nyu_depth_v2_ma_downsampled_precomputed(
hints_density: Literal[5, 50, 100, 200, 500, "livox", "grid-shift"]
hints_density: Literal[5, 50, 100, 200, 500, "livox", "grid-shift"],
in_memory: bool = False
) -> _Dataset:
root = _download("nyu-depth-v2-ma-downsampled", ".tar")
return _PrecomputedDataset(
root / "img_gt.h5", root / f"pred_with_{hints_density}.h5"
)


# models

# TODO: wip
assert hints_density in [5, 50, 100, 200, 500, "livox", "grid-shift"], f"{hints_density} not available"
img_gt = _download("nyu_img_gt.h5")
preds = _download(f"nyu_pred_with_{hints_density}.h5")
return _PrecomputedDataset(img_gt, preds, in_memory)
27 changes: 17 additions & 10 deletions visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@

## Utilities

# simple introduction and
# a button for the dataset, a button for the number of hints
# and finally a button for the example to visualize

_DEBUG_LOCAL = False


Expand Down Expand Up @@ -54,12 +50,11 @@ def main():
options = [5, 50, 100, 200, 500, "grid-shift", "livox"]
density = st.selectbox("Hints Density", options=options)

data = load_data(ds_name, density)

with right:
idx = st.number_input("Example Index", min_value=0, max_value=len(data))
idx = st.number_input("Example Index", min_value=0, max_value=len_data(ds_name, density) - 1)

ex = data[idx]
ex = load_data(ds_name, density, idx)
img, hints, pred, gt = ex["img"], ex["hints"], ex["pred"], ex["gt"]
if (d := cfg["dilate"]) > 0:
hints = morphology.dilation(hints[..., 0], np.ones([d, d]))[..., None]
Expand Down Expand Up @@ -102,8 +97,20 @@ def show_img(label: str, dmap: np.ndarray, cmap=None):
st.pyplot(fig)


@st.cache
def load_data(ds_name, density):
@st.cache(show_spinner=False)
def load_data(ds_name, density, idx):
ds = torch.hub.load(
"andreaconti/sparsity_agnostic_depth_completion"
if not _DEBUG_LOCAL
else str(Path(__file__).parent),
ds_name.replace("-", "_") + "_precomputed",
density,
source="github" if not _DEBUG_LOCAL else "local",
)
return ds[idx]

@st.cache(show_spinner=False)
def len_data(ds_name, density):
ds = torch.hub.load(
"andreaconti/sparsity_agnostic_depth_completion"
if not _DEBUG_LOCAL
Expand All @@ -112,7 +119,7 @@ def load_data(ds_name, density):
density,
source="github" if not _DEBUG_LOCAL else "local",
)
return ds
return len(ds)


if __name__ == "__main__":
Expand Down

0 comments on commit fa6f54a

Please sign in to comment.