-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
235 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# | ||
# Copyright (C) 2023, Inria | ||
# GRAPHDECO research group, https://team.inria.fr/graphdeco | ||
# All rights reserved. | ||
# | ||
# This software is free for non-commercial, research and evaluation use | ||
# under the terms of the LICENSE.md file. | ||
# | ||
# For inquiries contact george.drettakis@inria.fr | ||
# | ||
|
||
import torch | ||
from scene import Scene | ||
import os | ||
from tqdm import tqdm | ||
from os import makedirs | ||
from gaussian_renderer import render | ||
import torchvision | ||
from utils.general_utils import safe_state | ||
from argparse import ArgumentParser | ||
from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams | ||
from gaussian_renderer import GaussianModel | ||
from random import randint | ||
from utils.loss_utils import l1_loss, ssim | ||
from utils.image_utils import psnr | ||
|
||
|
||
def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene: Scene, renderFunc, | ||
renderArgs): | ||
if tb_writer: | ||
tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) | ||
tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) | ||
tb_writer.add_scalar('iter_time', elapsed, iteration) | ||
|
||
# Report test and samples of training set | ||
if iteration in testing_iterations: | ||
torch.cuda.empty_cache() | ||
validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras()}, | ||
{'name': 'train', | ||
'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in | ||
range(5, 30, 5)]}) | ||
|
||
for config in validation_configs: | ||
if config['cameras'] and len(config['cameras']) > 0: | ||
images = torch.tensor([], device="cuda") | ||
gts = torch.tensor([], device="cuda") | ||
for idx, viewpoint in enumerate(config['cameras']): | ||
image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) | ||
gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) | ||
images = torch.cat((images, image.unsqueeze(0)), dim=0) | ||
gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0) | ||
if tb_writer and (idx < 5): | ||
tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), | ||
image[None], global_step=iteration) | ||
if iteration == testing_iterations[0]: | ||
tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), | ||
gt_image[None], global_step=iteration) | ||
|
||
l1_test = l1_loss(images, gts) | ||
psnr_test = psnr(images, gts).mean() | ||
print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) | ||
if tb_writer: | ||
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) | ||
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) | ||
|
||
if tb_writer: | ||
tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) | ||
tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) | ||
torch.cuda.empty_cache() | ||
|
||
|
||
def fine_tune_sets(dataset: ModelParams, opt: OptimizationParams, pipe: PipelineParams, iteration: int, | ||
testing_iterations: int, saving_iterations: int): | ||
gaussians = GaussianModel(dataset.sh_degree) | ||
|
||
scene = Scene(dataset, gaussians, load_iteration=iteration) | ||
gaussians.training_setup(opt) | ||
|
||
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] | ||
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") | ||
|
||
iter_start = torch.cuda.Event(enable_timing=True) | ||
iter_end = torch.cuda.Event(enable_timing=True) | ||
|
||
viewpoint_stack = None | ||
ema_loss_for_log = 0.0 | ||
progress_bar = tqdm(range(opt.iterations), desc="Fine Tune progress") | ||
|
||
loaded_iter = scene.loaded_iter + 1 | ||
final_iter = opt.iterations + loaded_iter | ||
for iteration in range(loaded_iter, final_iter): | ||
iter_start.record() | ||
|
||
# Pick a random Camera | ||
if not viewpoint_stack: | ||
viewpoint_stack = scene.getTrainCameras().copy() | ||
|
||
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) | ||
# Render | ||
render_pkg = render(viewpoint_cam, gaussians, pipe, background) | ||
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], \ | ||
render_pkg["visibility_filter"], render_pkg["radii"] | ||
|
||
# Loss | ||
gt_image = viewpoint_cam.original_image.cuda() | ||
Ll1 = l1_loss(image, gt_image) | ||
loss = 1.0 - ssim(image, gt_image) | ||
loss.backward() | ||
|
||
iter_end.record() | ||
|
||
with torch.no_grad(): | ||
# Progress bar | ||
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log | ||
if iteration % 10 == 0: | ||
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) | ||
progress_bar.update(10) | ||
if iteration == final_iter: | ||
progress_bar.close() | ||
|
||
# Log and save | ||
training_report(None, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), | ||
testing_iterations, scene, render, (pipe, background)) | ||
|
||
if (iteration in saving_iterations): | ||
print("\n[ITER {}] Saving Gaussians".format(iteration)) | ||
scene.save(iteration) | ||
|
||
# Optimizer step | ||
if iteration < final_iter: | ||
gaussians.optimizer.step() | ||
gaussians.optimizer.zero_grad(set_to_none=True) | ||
gaussians.update_learning_rate(iteration) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Set up command line argument parser | ||
parser = ArgumentParser(description="Testing script parameters") # add argument into parser | ||
model = ModelParams(parser, sentinel=True) | ||
op = OptimizationParams(parser) | ||
pipeline = PipelineParams(parser) | ||
parser.add_argument("--iteration", default=-1, type=int) | ||
parser.add_argument("--test_iterations", nargs="+", type=int, default=[35_000, 40_000]) | ||
parser.add_argument("--save_iterations", nargs="+", type=int, default=[35_000, 40_000]) | ||
parser.add_argument("--quiet", action="store_true") | ||
args = get_combined_args(parser) | ||
print("Rendering " + args.model_path) | ||
|
||
# Initialize system state (RNG) | ||
safe_state(args.quiet) | ||
|
||
fine_tune_sets(model.extract(args), op.extract(args), pipeline.extract(args), args.iteration, args.test_iterations, | ||
args.save_iterations) |
Oops, something went wrong.