Skip to content

Commit

Permalink
Add depth visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
ingra14m committed Jul 13, 2023
1 parent 9f45a4f commit 13c6602
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
5 changes: 3 additions & 2 deletions gaussian_renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
colors_precomp = override_color

# Rasterize visible Gaussians to image, obtain their radii (on screen).
rendered_image, radii = rasterizer(
rendered_image, radii, depth = rasterizer(
means3D = means3D,
means2D = means2D,
shs = shs,
Expand All @@ -96,4 +96,5 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
return {"render": rendered_image,
"viewspace_points": screenspace_points,
"visibility_filter" : radii > 0,
"radii": radii}
"radii": radii,
"depth": depth}
9 changes: 8 additions & 1 deletion render.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,22 @@
def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth")

makedirs(render_path, exist_ok=True)
makedirs(gts_path, exist_ok=True)
makedirs(depth_path, exist_ok=True)

for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
rendering = render(view, gaussians, pipeline, background)["render"]
results = render(view, gaussians, pipeline, background)
rendering = results["render"]
depth = results["depth"]
depth = depth / (depth.max() + 1e-5)

gt = view.original_image[0:3, :, :]
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(idx) + ".png"))

def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
with torch.no_grad():
Expand Down

0 comments on commit 13c6602

Please sign in to comment.