Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NeRF test #1355

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions forge/test/mlir/nerf/spherical_harmonics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

C0 = 0.28209479177387814
C1 = 0.4886025119029199
C2 = [1.0925484305920792, -1.0925484305920792, 0.31539156525252005, -1.0925484305920792, 0.5462742152960396]
C3 = [
-0.5900435899266435,
2.890611442640554,
-0.4570457994644658,
0.3731763325901154,
-0.4570457994644658,
1.445305721320277,
-0.5900435899266435,
]
C4 = [
2.5033429417967046,
-1.7701307697799304,
0.9461746957575601,
-0.6690465435572892,
0.10578554691520431,
-0.6690465435572892,
0.47308734787878004,
-1.7701307697799304,
0.6258357354491761,
]


def eval_sh(deg, sh, dirs):
"""
Evaluate spherical harmonics at unit directions using hardcoded SH polynomials.

Args:
deg: Degree of spherical harmonics (0-4)
sh: Spherical harmonics coefficients
dirs: Unit direction vectors [..., 3]

Returns:
Evaluated spherical harmonics
"""
assert 0 <= deg <= 4, "Degree must be between 0 and 4"
assert (deg + 1) ** 2 == sh.shape[-1], "Invalid SH coefficients shape"

result = C0 * sh[..., 0]

if deg == 0:
return result

x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]

result += -C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3]

if deg == 1:
return result

xx, yy, zz = x * x, y * y, z * z
xy, yz, xz = x * y, y * z, x * z

result += (
C2[0] * xy * sh[..., 4]
+ C2[1] * yz * sh[..., 5]
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6]
+ C2[3] * xz * sh[..., 7]
+ C2[4] * (xx - yy) * sh[..., 8]
)

if deg == 2:
return result

result += (
C3[0] * y * (3 * xx - yy) * sh[..., 9]
+ C3[1] * xy * z * sh[..., 10]
+ C3[2] * y * (4 * zz - xx - yy) * sh[..., 11]
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12]
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13]
+ C3[5] * z * (xx - yy) * sh[..., 14]
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15]
)

if deg == 3:
return result

result += (
C4[0] * xy * (xx - yy) * sh[..., 16]
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17]
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18]
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19]
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20]
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21]
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22]
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23]
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]
)

return result
137 changes: 137 additions & 0 deletions forge/test/mlir/nerf/test_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import pytest
import time

from test.mlir.nerf.utils import NeRF
import torch
from loguru import logger

import forge
from forge.tensor import to_forge_tensors
from forge.verify.compare import compare_with_golden
from test.mlir.nerf.spherical_harmonics import eval_sh
from test.mlir.utils import *


@pytest.mark.push
def test_nerf_training():
dtype = torch.float32

# Set training hyperparameters
num_epochs = 3
num_batches = 10
batch_size = 4096
learning_rate = 1e-3
deg = 2

# Define models
nerf_coarse = NeRF(D=4, W=128, in_channels_xyz=63, in_channels_dir=32, deg=deg)
nerf_fine = NeRF(D=4, W=192, in_channels_xyz=63, in_channels_dir=32, deg=deg)

golden_nerf_coarse = NeRF(D=4, W=128, in_channels_xyz=63, in_channels_dir=32, deg=deg)
golden_nerf_fine = NeRF(D=4, W=192, in_channels_xyz=63, in_channels_dir=32, deg=deg)

copy_params(nerf_coarse, golden_nerf_coarse)
copy_params(nerf_fine, golden_nerf_fine)

loss_fn = torch.nn.MSELoss()

# Define optimizer
framework_optimizer_coarse = torch.optim.SGD(nerf_coarse.parameters(), lr=learning_rate)
framework_optimizer_fine = torch.optim.SGD(nerf_fine.parameters(), lr=learning_rate)
golden_optimizer_coarse = torch.optim.SGD(golden_nerf_coarse.parameters(), lr=learning_rate)
golden_optimizer_fine = torch.optim.SGD(golden_nerf_fine.parameters(), lr=learning_rate)

tt_nerf_coarse = forge.compile(
nerf_coarse,
sample_inputs=[torch.rand(batch_size, 63, dtype=dtype, requires_grad=True)],
optimizer=framework_optimizer_coarse,
training=True,
)

tt_nerf_fine = forge.compile(
nerf_fine,
sample_inputs=[torch.rand(batch_size, 63, dtype=dtype, requires_grad=True)],
optimizer=framework_optimizer_fine,
training=True,
)

logger.info("Starting NeRF training loop... (logger will be disabled)")
logger.disable("")
for epoch_idx in range(num_epochs):
for batch_idx in range(num_batches):
# zero the parameter gradients
framework_optimizer_coarse.zero_grad()
framework_optimizer_fine.zero_grad()
golden_optimizer_coarse.zero_grad()
golden_optimizer_fine.zero_grad()

# Generate random input data
input_xyz = torch.rand(batch_size, 63, dtype=dtype, requires_grad=True)
input_dirs = torch.rand(batch_size, 3, dtype=dtype, requires_grad=True)
target_data_sigma = torch.rand(batch_size, 1, dtype=dtype, requires_grad=True)
target_data_sh = torch.rand(batch_size, 3, dtype=dtype, requires_grad=True)

# Forward pass on TT
output_coarse_sigma, output_coarse_sh = tt_nerf_coarse(input_xyz)
output_coarse_sh = output_coarse_sh[:, :27].reshape(-1, 3, (deg + 1) ** 2)
output_coarse_rgb = eval_sh(deg, output_coarse_sh, input_dirs)

output_fine_sigma, output_fine_sh = tt_nerf_fine(input_xyz)
output_fine_sh = output_fine_sh[:, :27].reshape(-1, 3, (deg + 1) ** 2)
output_fine_rgb = eval_sh(deg, output_fine_sh, input_dirs)

# Forward pass on PyTorch
golden_output_coarse_sigma, output_coarse_sh_pt = golden_nerf_coarse(input_xyz)
golden_output_coarse_sh_pt = output_coarse_sh_pt[:, :27].reshape(-1, 3, (deg + 1) ** 2)
golden_output_coarse_rgb_pt = eval_sh(deg, golden_output_coarse_sh_pt, input_dirs)

golden_output_fine_sigma, golden_output_fine_sh = golden_nerf_fine(input_xyz)
golden_output_fine_sh = golden_output_fine_sh[:, :27].reshape(-1, 3, (deg + 1) ** 2)
golden_output_fine_rgb = eval_sh(deg, golden_output_fine_sh, input_dirs)

# Compute loss for TT
loss_coarse_sigma = loss_fn(output_coarse_sigma, target_data_sigma)
loss_coarse_sh = loss_fn(output_coarse_rgb, target_data_sh)
loss_coarse = loss_coarse_sigma + loss_coarse_sh

loss_fine_sigma = loss_fn(output_fine_sigma, target_data_sigma)
loss_fine_sh = loss_fn(output_fine_rgb, target_data_sh)
loss_fine = loss_fine_sigma + loss_fine_sh

# Compute loss for PyTorch
golden_loss_coarse_sigma = loss_fn(golden_output_coarse_sigma, target_data_sigma)
golden_loss_coarse_sh = loss_fn(golden_output_coarse_rgb_pt, target_data_sh)
golden_loss_coarse = golden_loss_coarse_sigma + golden_loss_coarse_sh

golden_loss_fine_sigma = loss_fn(golden_output_fine_sigma, target_data_sigma)
golden_loss_fine_sh = loss_fn(golden_output_fine_rgb, target_data_sh)
golden_loss_fine = golden_loss_fine_sigma + golden_loss_fine_sh

# Compare TT and PyTorch losses
assert compare_with_golden(
loss_coarse, golden_loss_coarse, rtol=0.05, atol=0.05
), f"Loss coarse mismatch at epoch {epoch_idx}, batch {batch_idx}"
assert compare_with_golden(
loss_fine, golden_loss_fine, rtol=0.05, atol=0.05
), f"Loss fine mismatch at epoch {epoch_idx}, batch {batch_idx}"

# Backward pass
loss_coarse.backward()
loss_fine.backward()

golden_loss_coarse.backward()
golden_loss_fine.backward()

# Update weights
framework_optimizer_coarse.step()
framework_optimizer_fine.step()

golden_optimizer_coarse.step()
golden_optimizer_fine.step()

logger.enable("")
logger.info("NeRF training loop completed.")
71 changes: 71 additions & 0 deletions forge/test/mlir/nerf/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import torch
import torch.nn as nn
from test.mlir.nerf.spherical_harmonics import eval_sh


class NeRFHead(nn.Module):
def __init__(self, W, out_dim, *args, **kwargs):
super().__init__(*args, **kwargs)
self.layer1 = nn.Linear(W, W)
self.relu1 = nn.ReLU(False)
self.layer2 = nn.Linear(W, out_dim)

def forward(self, x):
x = self.layer1(x)
x = self.relu1(x)
x = self.layer2(x)
return x


class NeRFEncoding(nn.Module):
def __init__(self, in_dim, W, out_dim, *args, **kwargs):
super().__init__(*args, **kwargs)
self.layer1 = nn.Linear(in_dim, W)
self.relu1 = nn.ReLU(False)
self.layer2 = nn.Linear(W, out_dim)

def forward(self, x):
x = self.layer1(x)
x = self.relu1(x)
x = self.layer2(x)
return x


class NeRF(nn.Module):
def __init__(self, D=8, W=256, in_channels_xyz=63, in_channels_dir=27, deg=2):
super(NeRF, self).__init__()
self.D = D
self.W = W
self.in_channels_xyz = in_channels_xyz
self.in_channels_dir = in_channels_dir
self.deg = deg

for i in range(D):
if i == 0:
layer = NeRFEncoding(in_channels_xyz, W, W)
else:
layer = NeRFEncoding(W, W, W)
setattr(self, f"xyz_encoding_{i+1}", layer)
self.sigma = NeRFHead(W, 1)
self.sh = NeRFHead(W, 32)

def forward(self, xyz):
input_xyz = xyz

xyz_ = input_xyz
for i in range(self.D):
xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_)

sigma = self.sigma(xyz_)
sh = self.sh(xyz_)
return sigma, sh

def postprocess(self, sigma, sh, dirs=None):
sh = sh[:, :27]
rgb = eval_sh(deg=self.deg, sh=sh.reshape(-1, 3, (self.deg + 1) ** 2), dirs=dirs)
rgb = torch.sigmoid(rgb)
return sigma, rgb, sh
Loading