diff --git a/arguments/__init__.py b/arguments/__init__.py index eba1dbae9..c91fa9c86 100644 --- a/arguments/__init__.py +++ b/arguments/__init__.py @@ -69,7 +69,7 @@ def __init__(self, parser): class OptimizationParams(ParamGroup): def __init__(self, parser): - self.iterations = 30_000 + self.iterations = 10_000 self.position_lr_init = 0.00016 self.position_lr_final = 0.0000016 self.position_lr_delay_mult = 0.01 @@ -85,6 +85,7 @@ def __init__(self, parser): self.densify_from_iter = 500 self.densify_until_iter = 15_000 self.densify_grad_threshold = 0.0002 + self.fine_tune = False super().__init__(parser, "Optimization Parameters") def get_combined_args(parser : ArgumentParser): diff --git a/fine_tune.py b/fine_tune.py new file mode 100644 index 000000000..8e770d08d --- /dev/null +++ b/fine_tune.py @@ -0,0 +1,153 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +from scene import Scene +import os +from tqdm import tqdm +from os import makedirs +from gaussian_renderer import render +import torchvision +from utils.general_utils import safe_state +from argparse import ArgumentParser +from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams +from gaussian_renderer import GaussianModel +from random import randint +from utils.loss_utils import l1_loss, ssim +from utils.image_utils import psnr + + +def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene: Scene, renderFunc, + renderArgs): + if tb_writer: + tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) + tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) + tb_writer.add_scalar('iter_time', elapsed, iteration) + + # Report test and samples of training set + if iteration in testing_iterations: + torch.cuda.empty_cache() + validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras()}, + {'name': 'train', + 'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in + range(5, 30, 5)]}) + + for config in validation_configs: + if config['cameras'] and len(config['cameras']) > 0: + images = torch.tensor([], device="cuda") + gts = torch.tensor([], device="cuda") + for idx, viewpoint in enumerate(config['cameras']): + image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) + gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) + images = torch.cat((images, image.unsqueeze(0)), dim=0) + gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0) + if tb_writer and (idx < 5): + tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), + image[None], global_step=iteration) + if iteration == testing_iterations[0]: + tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), + gt_image[None], global_step=iteration) + + l1_test = l1_loss(images, gts) + psnr_test = psnr(images, gts).mean() + print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) + if tb_writer: + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) + + if tb_writer: + tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) + tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) + torch.cuda.empty_cache() + + +def fine_tune_sets(dataset: ModelParams, opt: OptimizationParams, pipe: PipelineParams, iteration: int, + testing_iterations: int, saving_iterations: int): + gaussians = GaussianModel(dataset.sh_degree) + + scene = Scene(dataset, gaussians, load_iteration=iteration) + gaussians.training_setup(opt) + + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + iter_start = torch.cuda.Event(enable_timing=True) + iter_end = torch.cuda.Event(enable_timing=True) + + viewpoint_stack = None + ema_loss_for_log = 0.0 + progress_bar = tqdm(range(opt.iterations), desc="Fine Tune progress") + + loaded_iter = scene.loaded_iter + 1 + final_iter = opt.iterations + loaded_iter + for iteration in range(loaded_iter, final_iter): + iter_start.record() + + # Pick a random Camera + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) + # Render + render_pkg = render(viewpoint_cam, gaussians, pipe, background) + image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], \ + render_pkg["visibility_filter"], render_pkg["radii"] + + # Loss + gt_image = viewpoint_cam.original_image.cuda() + Ll1 = l1_loss(image, gt_image) + loss = 1.0 - ssim(image, gt_image) + loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) + progress_bar.update(10) + if iteration == final_iter: + progress_bar.close() + + # Log and save + training_report(None, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), + testing_iterations, scene, render, (pipe, background)) + + if (iteration in saving_iterations): + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + # Optimizer step + if iteration < final_iter: + gaussians.optimizer.step() + gaussians.optimizer.zero_grad(set_to_none=True) + gaussians.update_learning_rate(iteration) + + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Testing script parameters") # add argument into parser + model = ModelParams(parser, sentinel=True) + op = OptimizationParams(parser) + pipeline = PipelineParams(parser) + parser.add_argument("--iteration", default=-1, type=int) + parser.add_argument("--test_iterations", nargs="+", type=int, default=[35_000, 40_000]) + parser.add_argument("--save_iterations", nargs="+", type=int, default=[35_000, 40_000]) + parser.add_argument("--quiet", action="store_true") + args = get_combined_args(parser) + print("Rendering " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + fine_tune_sets(model.extract(args), op.extract(args), pipeline.extract(args), args.iteration, args.test_iterations, + args.save_iterations) diff --git a/scene/gaussian_model.py b/scene/gaussian_model.py index cd8c2ce83..896452a27 100644 --- a/scene/gaussian_model.py +++ b/scene/gaussian_model.py @@ -21,8 +21,9 @@ from utils.graphics_utils import BasicPointCloud from utils.general_utils import strip_symmetric, build_scaling_rotation + class GaussianModel: - def __init__(self, sh_degree : int): + def __init__(self, sh_degree: int): def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): L = build_scaling_rotation(scaling_modifier * scaling, rotation) @@ -31,7 +32,7 @@ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): return symm self.active_sh_degree = 0 - self.max_sh_degree = sh_degree + self.max_sh_degree = sh_degree self._xyz = torch.empty(0) self._features_dc = torch.empty(0) @@ -57,52 +58,52 @@ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): @property def get_scaling(self): return self.scaling_activation(self._scaling) - + @property def get_rotation(self): return self.rotation_activation(self._rotation) - + @property def get_xyz(self): return self._xyz - + @property def get_features(self): features_dc = self._features_dc features_rest = self._features_rest return torch.cat((features_dc, features_rest), dim=1) - + @property def get_opacity(self): return self.opacity_activation(self._opacity) - - def get_covariance(self, scaling_modifier = 1): + + def get_covariance(self, scaling_modifier=1): return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) def oneupSHdegree(self): if self.active_sh_degree < self.max_sh_degree: self.active_sh_degree += 1 - def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): + def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float): self.spatial_lr_scale = 5 fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() - features[:, :3, 0 ] = fused_color + features[:, :3, 0] = fused_color features[:, 3:, 1:] = 0.0 print("Number of points at initialisation : ", fused_point_cloud.shape[0]) dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) - scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) + scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3) rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") rots[:, 0] = 1 opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) - self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) - self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) + self._features_dc = nn.Parameter(features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)) self._scaling = nn.Parameter(scales.requires_grad_(True)) self._rotation = nn.Parameter(rots.requires_grad_(True)) self._opacity = nn.Parameter(opacities.requires_grad_(True)) @@ -112,19 +113,28 @@ def training_setup(self, training_args): self.percent_dense = training_args.percent_dense self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") - - l = [ - {'params': [self._xyz], 'lr': training_args.position_lr_init*self.spatial_lr_scale, "name": "xyz"}, - {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, - {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, - {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, - {'params': [self._scaling], 'lr': training_args.scaling_lr*self.spatial_lr_scale, "name": "scaling"}, - {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} - ] + is_fine_tune = training_args.fine_tune + + if is_fine_tune: + self.spatial_lr_scale = 5 + l = [ + {'params': [self._features_dc], 'lr': training_args.feature_lr / 5, "name": "f_dc"}, + # {'params': [self._features_rest], 'lr': training_args.feature_lr / 5, "name": "f_rest"}, + # {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, + ] + else: + l = [ + {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, + {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, + {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, + {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, + {'params': [self._scaling], 'lr': training_args.scaling_lr * self.spatial_lr_scale, "name": "scaling"}, + {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} + ] self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) - self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale, - lr_final=training_args.position_lr_final*self.spatial_lr_scale, + self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale, + lr_final=training_args.position_lr_final * self.spatial_lr_scale, lr_delay_mult=training_args.position_lr_delay_mult, max_steps=training_args.position_lr_max_steps) @@ -139,9 +149,9 @@ def update_learning_rate(self, iteration): def construct_list_of_attributes(self): l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] # All channels except the 3 DC - for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): + for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]): l.append('f_dc_{}'.format(i)) - for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): + for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]): l.append('f_rest_{}'.format(i)) l.append('opacity') for i in range(self._scaling.shape[1]): @@ -170,7 +180,7 @@ def save_ply(self, path): PlyData([el]).write(path) def reset_opacity(self): - opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) + opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * 0.01)) optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") self._opacity = optimizable_tensors["opacity"] @@ -180,7 +190,7 @@ def load_ply(self, path, og_number_points=-1): xyz = np.stack((np.asarray(plydata.elements[0]["x"]), np.asarray(plydata.elements[0]["y"]), - np.asarray(plydata.elements[0]["z"])), axis=1) + np.asarray(plydata.elements[0]["z"])), axis=1) opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] features_dc = np.zeros((xyz.shape[0], 3, 1)) @@ -189,7 +199,7 @@ def load_ply(self, path, og_number_points=-1): features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] - assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 + assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3 features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) for idx, attr_name in enumerate(extra_f_names): features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) @@ -207,8 +217,12 @@ def load_ply(self, path, og_number_points=-1): rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) - self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) - self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._features_dc = nn.Parameter( + torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_( + True)) + self._features_rest = nn.Parameter( + torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_( + True)) self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) @@ -272,27 +286,32 @@ def cat_tensors_to_optimizer(self, tensors_dict): stored_state = self.optimizer.state.get(group['params'][0], None) if stored_state is not None: - stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) - stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) + stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), + dim=0) + stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), + dim=0) del self.optimizer.state[group['params'][0]] - group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + group["params"][0] = nn.Parameter( + torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) self.optimizer.state[group['params'][0]] = stored_state optimizable_tensors[group["name"]] = group["params"][0] else: - group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + group["params"][0] = nn.Parameter( + torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) optimizable_tensors[group["name"]] = group["params"][0] return optimizable_tensors - def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation): + def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, + new_rotation): d = {"xyz": new_xyz, - "f_dc": new_features_dc, - "f_rest": new_features_rest, - "opacity": new_opacities, - "scaling" : new_scaling, - "rotation" : new_rotation} + "f_dc": new_features_dc, + "f_rest": new_features_rest, + "opacity": new_opacities, + "scaling": new_scaling, + "rotation": new_rotation} optimizable_tensors = self.cat_tensors_to_optimizer(d) self._xyz = optimizable_tensors["xyz"] @@ -313,30 +332,33 @@ def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): padded_grad[:grads.shape[0]] = grads.squeeze() selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) selected_pts_mask = torch.logical_and(selected_pts_mask, - torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) + torch.max(self.get_scaling, + dim=1).values > self.percent_dense * scene_extent) - stds = self.get_scaling[selected_pts_mask].repeat(N,1) - means =torch.zeros((stds.size(0), 3),device="cuda") + stds = self.get_scaling[selected_pts_mask].repeat(N, 1) + means = torch.zeros((stds.size(0), 3), device="cuda") samples = torch.normal(mean=means, std=stds) - rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) + rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1) new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) - new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) - new_rotation = self._rotation[selected_pts_mask].repeat(N,1) - new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) - new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) - new_opacity = self._opacity[selected_pts_mask].repeat(N,1) + new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8 * N)) + new_rotation = self._rotation[selected_pts_mask].repeat(N, 1) + new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1) + new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1) + new_opacity = self._opacity[selected_pts_mask].repeat(N, 1) self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation) - prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) + prune_filter = torch.cat( + (selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) self.prune_points(prune_filter) def densify_and_clone(self, grads, grad_threshold, scene_extent): # Extract points that satisfy the gradient condition selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) selected_pts_mask = torch.logical_and(selected_pts_mask, - torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) - + torch.max(self.get_scaling, + dim=1).values <= self.percent_dense * scene_extent) + new_xyz = self._xyz[selected_pts_mask] new_features_dc = self._features_dc[selected_pts_mask] new_features_rest = self._features_rest[selected_pts_mask] @@ -344,7 +366,8 @@ def densify_and_clone(self, grads, grad_threshold, scene_extent): new_scaling = self._scaling[selected_pts_mask] new_rotation = self._rotation[selected_pts_mask] - self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation) + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, + new_rotation) def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): grads = self.xyz_gradient_accum / self.denom @@ -363,5 +386,6 @@ def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): torch.cuda.empty_cache() def add_densification_stats(self, viewspace_point_tensor, update_filter): - self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True) - self.denom[update_filter] += 1 \ No newline at end of file + self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter, :2], dim=-1, + keepdim=True) + self.denom[update_filter] += 1