Skip to content

Commit

Permalink
add fine tune stage
Browse files Browse the repository at this point in the history
  • Loading branch information
ingra14m committed Aug 2, 2023
1 parent 3ce0b41 commit d44473b
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 57 deletions.
3 changes: 2 additions & 1 deletion arguments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, parser):

class OptimizationParams(ParamGroup):
def __init__(self, parser):
self.iterations = 30_000
self.iterations = 10_000
self.position_lr_init = 0.00016
self.position_lr_final = 0.0000016
self.position_lr_delay_mult = 0.01
Expand All @@ -85,6 +85,7 @@ def __init__(self, parser):
self.densify_from_iter = 500
self.densify_until_iter = 15_000
self.densify_grad_threshold = 0.0002
self.fine_tune = False
super().__init__(parser, "Optimization Parameters")

def get_combined_args(parser : ArgumentParser):
Expand Down
153 changes: 153 additions & 0 deletions fine_tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#

import torch
from scene import Scene
import os
from tqdm import tqdm
from os import makedirs
from gaussian_renderer import render
import torchvision
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams
from gaussian_renderer import GaussianModel
from random import randint
from utils.loss_utils import l1_loss, ssim
from utils.image_utils import psnr


def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene: Scene, renderFunc,
renderArgs):
if tb_writer:
tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
tb_writer.add_scalar('iter_time', elapsed, iteration)

# Report test and samples of training set
if iteration in testing_iterations:
torch.cuda.empty_cache()
validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras()},
{'name': 'train',
'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in
range(5, 30, 5)]})

for config in validation_configs:
if config['cameras'] and len(config['cameras']) > 0:
images = torch.tensor([], device="cuda")
gts = torch.tensor([], device="cuda")
for idx, viewpoint in enumerate(config['cameras']):
image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
images = torch.cat((images, image.unsqueeze(0)), dim=0)
gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0)
if tb_writer and (idx < 5):
tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name),
image[None], global_step=iteration)
if iteration == testing_iterations[0]:
tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name),
gt_image[None], global_step=iteration)

l1_test = l1_loss(images, gts)
psnr_test = psnr(images, gts).mean()
print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
if tb_writer:
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)

if tb_writer:
tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
torch.cuda.empty_cache()


def fine_tune_sets(dataset: ModelParams, opt: OptimizationParams, pipe: PipelineParams, iteration: int,
testing_iterations: int, saving_iterations: int):
gaussians = GaussianModel(dataset.sh_degree)

scene = Scene(dataset, gaussians, load_iteration=iteration)
gaussians.training_setup(opt)

bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

iter_start = torch.cuda.Event(enable_timing=True)
iter_end = torch.cuda.Event(enable_timing=True)

viewpoint_stack = None
ema_loss_for_log = 0.0
progress_bar = tqdm(range(opt.iterations), desc="Fine Tune progress")

loaded_iter = scene.loaded_iter + 1
final_iter = opt.iterations + loaded_iter
for iteration in range(loaded_iter, final_iter):
iter_start.record()

# Pick a random Camera
if not viewpoint_stack:
viewpoint_stack = scene.getTrainCameras().copy()

viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))
# Render
render_pkg = render(viewpoint_cam, gaussians, pipe, background)
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], \
render_pkg["visibility_filter"], render_pkg["radii"]

# Loss
gt_image = viewpoint_cam.original_image.cuda()
Ll1 = l1_loss(image, gt_image)
loss = 1.0 - ssim(image, gt_image)
loss.backward()

iter_end.record()

with torch.no_grad():
# Progress bar
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
if iteration % 10 == 0:
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
progress_bar.update(10)
if iteration == final_iter:
progress_bar.close()

# Log and save
training_report(None, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end),
testing_iterations, scene, render, (pipe, background))

if (iteration in saving_iterations):
print("\n[ITER {}] Saving Gaussians".format(iteration))
scene.save(iteration)

# Optimizer step
if iteration < final_iter:
gaussians.optimizer.step()
gaussians.optimizer.zero_grad(set_to_none=True)
gaussians.update_learning_rate(iteration)


if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Testing script parameters") # add argument into parser
model = ModelParams(parser, sentinel=True)
op = OptimizationParams(parser)
pipeline = PipelineParams(parser)
parser.add_argument("--iteration", default=-1, type=int)
parser.add_argument("--test_iterations", nargs="+", type=int, default=[35_000, 40_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[35_000, 40_000])
parser.add_argument("--quiet", action="store_true")
args = get_combined_args(parser)
print("Rendering " + args.model_path)

# Initialize system state (RNG)
safe_state(args.quiet)

fine_tune_sets(model.extract(args), op.extract(args), pipeline.extract(args), args.iteration, args.test_iterations,
args.save_iterations)
Loading

0 comments on commit d44473b

Please sign in to comment.