Skip to content

Commit

Permalink
Add Reading of Depth Image
Browse files Browse the repository at this point in the history
  • Loading branch information
ingra14m committed Aug 2, 2023
1 parent 179d768 commit 9da0c16
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
10 changes: 5 additions & 5 deletions scene/dataset_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
import numpy as np
import json
import imageio
from pathlib import Path
from plyfile import PlyData, PlyElement
from utils.sh_utils import SH2RGB
Expand Down Expand Up @@ -190,8 +191,7 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
for idx, frame in enumerate(frames):
cam_name = os.path.join(path, frame["file_path"] + extension)

if is_test:
depth_name = os.path.join(path, frame["file_path"] + "_depth_0001" + extension)
depth_name = os.path.join(path, frame["file_path"] + "_depth0000" + '.exr')

matrix = np.linalg.inv(np.array(frame["transform_matrix"]))
R = -np.transpose(matrix[:3,:3])
Expand All @@ -201,7 +201,7 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
image_path = os.path.join(path, cam_name)
image_name = Path(cam_name).stem
image = Image.open(image_path)
depth = Image.open(depth_name).convert('RGBA') if is_test else None
depth = imageio.imread(depth_name)

im_data = np.array(image.convert("RGBA"))

Expand All @@ -222,9 +222,9 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=

def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
print("Reading Training Transforms")
train_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
print("Reading Test Transforms")
test_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)

if not eval:
train_cam_infos.extend(test_cam_infos)
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
viewpoint_stack = scene.getTrainCameras().copy()

viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
gt_depth = viewpoint_cam.depth.unsqueeze(0)
gt_depth = viewpoint_cam.depth
# Render
render_pkg = render(viewpoint_cam, gaussians, pipe, background)
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
Expand All @@ -82,7 +82,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
Ll1 = l1_loss(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
depth_loss = l1_loss(depth, gt_depth) * 0.1
loss = loss + depth_loss
# loss = loss + depth_loss
loss.backward()

iter_end.record()
Expand Down
11 changes: 7 additions & 4 deletions utils/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from scene.cameras import Camera
import numpy as np
from utils.general_utils import PILtoTorch
from utils.general_utils import PILtoTorch, ArrayToTorch
from utils.graphics_utils import fov2focal

WARNED = False
Expand Down Expand Up @@ -39,13 +39,16 @@ def loadCam(args, id, cam_info, resolution_scale):
resolution = (int(orig_w / scale), int(orig_h / scale))

resized_image_rgb = PILtoTorch(cam_info.image, resolution)
resized_depth_rgb = PILtoTorch(cam_info.depth, resolution) if cam_info.depth is not None else None
if cam_info.depth is not None:
resized_depth_rgb = ArrayToTorch(cam_info.depth, resolution)
else:
resized_depth_rgb = None

gt_image = resized_image_rgb[:3, ...]
if resized_depth_rgb is not None:
depth_mask = resized_depth_rgb[3, ...] > 0
depth_mask = resized_depth_rgb[0, ...] > 60000
gt_depth = resized_depth_rgb[0, ...]
gt_depth[depth_mask] = 2. + 6. * (1 - gt_depth[depth_mask])
gt_depth[depth_mask] = 0
else:
gt_depth = None

Expand Down
9 changes: 9 additions & 0 deletions utils/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ def PILtoTorch(pil_image, resolution):
else:
return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)

def ArrayToTorch(array, resolution):
# resized_image = np.resize(array, resolution)
resized_image_torch = torch.from_numpy(array)

if len(resized_image_torch.shape) == 3:
return resized_image_torch.permute(2, 0, 1)
else:
return resized_image_torch.unsqueeze(dim=-1).permute(2, 0, 1)

def get_expon_lr_func(
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
):
Expand Down

0 comments on commit 9da0c16

Please sign in to comment.