diff --git a/forge/test/mlir/nerf/spherical_harmonics.py b/forge/test/mlir/nerf/spherical_harmonics.py new file mode 100644 index 000000000..7b7244ab0 --- /dev/null +++ b/forge/test/mlir/nerf/spherical_harmonics.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [1.0925484305920792, -1.0925484305920792, 0.31539156525252005, -1.0925484305920792, 0.5462742152960396] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435, +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions using hardcoded SH polynomials. + + Args: + deg: Degree of spherical harmonics (0-4) + sh: Spherical harmonics coefficients + dirs: Unit direction vectors [..., 3] + + Returns: + Evaluated spherical harmonics + """ + assert 0 <= deg <= 4, "Degree must be between 0 and 4" + assert (deg + 1) ** 2 == sh.shape[-1], "Invalid SH coefficients shape" + + result = C0 * sh[..., 0] + + if deg == 0: + return result + + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + + result += -C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] + + if deg == 1: + return result + + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + + result += ( + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8] + ) + + if deg == 2: + return result + + result += ( + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15] + ) + + if deg == 3: + return result + + result += ( + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24] + ) + + return result diff --git a/forge/test/mlir/nerf/test_training.py b/forge/test/mlir/nerf/test_training.py new file mode 100644 index 000000000..b04853a11 --- /dev/null +++ b/forge/test/mlir/nerf/test_training.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import time + +from test.mlir.nerf.utils import NeRF +import torch +from loguru import logger + +import forge +from forge.tensor import to_forge_tensors +from forge.verify.compare import compare_with_golden +from test.mlir.nerf.spherical_harmonics import eval_sh +from test.mlir.utils import * + + +@pytest.mark.push +def test_nerf_training(): + dtype = torch.float32 + + # Set training hyperparameters + num_epochs = 3 + num_batches = 10 + batch_size = 4096 + learning_rate = 1e-3 + deg = 2 + + # Define models + nerf_coarse = NeRF(D=4, W=128, in_channels_xyz=63, in_channels_dir=32, deg=deg) + nerf_fine = NeRF(D=4, W=192, in_channels_xyz=63, in_channels_dir=32, deg=deg) + + golden_nerf_coarse = NeRF(D=4, W=128, in_channels_xyz=63, in_channels_dir=32, deg=deg) + golden_nerf_fine = NeRF(D=4, W=192, in_channels_xyz=63, in_channels_dir=32, deg=deg) + + copy_params(nerf_coarse, golden_nerf_coarse) + copy_params(nerf_fine, golden_nerf_fine) + + loss_fn = torch.nn.MSELoss() + + # Define optimizer + framework_optimizer_coarse = torch.optim.SGD(nerf_coarse.parameters(), lr=learning_rate) + framework_optimizer_fine = torch.optim.SGD(nerf_fine.parameters(), lr=learning_rate) + golden_optimizer_coarse = torch.optim.SGD(golden_nerf_coarse.parameters(), lr=learning_rate) + golden_optimizer_fine = torch.optim.SGD(golden_nerf_fine.parameters(), lr=learning_rate) + + tt_nerf_coarse = forge.compile( + nerf_coarse, + sample_inputs=[torch.rand(batch_size, 63, dtype=dtype, requires_grad=True)], + optimizer=framework_optimizer_coarse, + training=True, + ) + + tt_nerf_fine = forge.compile( + nerf_fine, + sample_inputs=[torch.rand(batch_size, 63, dtype=dtype, requires_grad=True)], + optimizer=framework_optimizer_fine, + training=True, + ) + + logger.info("Starting NeRF training loop... (logger will be disabled)") + logger.disable("") + for epoch_idx in range(num_epochs): + for batch_idx in range(num_batches): + # zero the parameter gradients + framework_optimizer_coarse.zero_grad() + framework_optimizer_fine.zero_grad() + golden_optimizer_coarse.zero_grad() + golden_optimizer_fine.zero_grad() + + # Generate random input data + input_xyz = torch.rand(batch_size, 63, dtype=dtype, requires_grad=True) + input_dirs = torch.rand(batch_size, 3, dtype=dtype, requires_grad=True) + target_data_sigma = torch.rand(batch_size, 1, dtype=dtype, requires_grad=True) + target_data_sh = torch.rand(batch_size, 3, dtype=dtype, requires_grad=True) + + # Forward pass on TT + output_coarse_sigma, output_coarse_sh = tt_nerf_coarse(input_xyz) + output_coarse_sh = output_coarse_sh[:, :27].reshape(-1, 3, (deg + 1) ** 2) + output_coarse_rgb = eval_sh(deg, output_coarse_sh, input_dirs) + + output_fine_sigma, output_fine_sh = tt_nerf_fine(input_xyz) + output_fine_sh = output_fine_sh[:, :27].reshape(-1, 3, (deg + 1) ** 2) + output_fine_rgb = eval_sh(deg, output_fine_sh, input_dirs) + + # Forward pass on PyTorch + golden_output_coarse_sigma, output_coarse_sh_pt = golden_nerf_coarse(input_xyz) + golden_output_coarse_sh_pt = output_coarse_sh_pt[:, :27].reshape(-1, 3, (deg + 1) ** 2) + golden_output_coarse_rgb_pt = eval_sh(deg, golden_output_coarse_sh_pt, input_dirs) + + golden_output_fine_sigma, golden_output_fine_sh = golden_nerf_fine(input_xyz) + golden_output_fine_sh = golden_output_fine_sh[:, :27].reshape(-1, 3, (deg + 1) ** 2) + golden_output_fine_rgb = eval_sh(deg, golden_output_fine_sh, input_dirs) + + # Compute loss for TT + loss_coarse_sigma = loss_fn(output_coarse_sigma, target_data_sigma) + loss_coarse_sh = loss_fn(output_coarse_rgb, target_data_sh) + loss_coarse = loss_coarse_sigma + loss_coarse_sh + + loss_fine_sigma = loss_fn(output_fine_sigma, target_data_sigma) + loss_fine_sh = loss_fn(output_fine_rgb, target_data_sh) + loss_fine = loss_fine_sigma + loss_fine_sh + + # Compute loss for PyTorch + golden_loss_coarse_sigma = loss_fn(golden_output_coarse_sigma, target_data_sigma) + golden_loss_coarse_sh = loss_fn(golden_output_coarse_rgb_pt, target_data_sh) + golden_loss_coarse = golden_loss_coarse_sigma + golden_loss_coarse_sh + + golden_loss_fine_sigma = loss_fn(golden_output_fine_sigma, target_data_sigma) + golden_loss_fine_sh = loss_fn(golden_output_fine_rgb, target_data_sh) + golden_loss_fine = golden_loss_fine_sigma + golden_loss_fine_sh + + # Compare TT and PyTorch losses + assert compare_with_golden( + loss_coarse, golden_loss_coarse, rtol=0.05, atol=0.05 + ), f"Loss coarse mismatch at epoch {epoch_idx}, batch {batch_idx}" + assert compare_with_golden( + loss_fine, golden_loss_fine, rtol=0.05, atol=0.05 + ), f"Loss fine mismatch at epoch {epoch_idx}, batch {batch_idx}" + + # Backward pass + loss_coarse.backward() + loss_fine.backward() + + golden_loss_coarse.backward() + golden_loss_fine.backward() + + # Update weights + framework_optimizer_coarse.step() + framework_optimizer_fine.step() + + golden_optimizer_coarse.step() + golden_optimizer_fine.step() + + logger.enable("") + logger.info("NeRF training loop completed.") diff --git a/forge/test/mlir/nerf/utils.py b/forge/test/mlir/nerf/utils.py new file mode 100644 index 000000000..7aacf2475 --- /dev/null +++ b/forge/test/mlir/nerf/utils.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn as nn +from test.mlir.nerf.spherical_harmonics import eval_sh + + +class NeRFHead(nn.Module): + def __init__(self, W, out_dim, *args, **kwargs): + super().__init__(*args, **kwargs) + self.layer1 = nn.Linear(W, W) + self.relu1 = nn.ReLU(False) + self.layer2 = nn.Linear(W, out_dim) + + def forward(self, x): + x = self.layer1(x) + x = self.relu1(x) + x = self.layer2(x) + return x + + +class NeRFEncoding(nn.Module): + def __init__(self, in_dim, W, out_dim, *args, **kwargs): + super().__init__(*args, **kwargs) + self.layer1 = nn.Linear(in_dim, W) + self.relu1 = nn.ReLU(False) + self.layer2 = nn.Linear(W, out_dim) + + def forward(self, x): + x = self.layer1(x) + x = self.relu1(x) + x = self.layer2(x) + return x + + +class NeRF(nn.Module): + def __init__(self, D=8, W=256, in_channels_xyz=63, in_channels_dir=27, deg=2): + super(NeRF, self).__init__() + self.D = D + self.W = W + self.in_channels_xyz = in_channels_xyz + self.in_channels_dir = in_channels_dir + self.deg = deg + + for i in range(D): + if i == 0: + layer = NeRFEncoding(in_channels_xyz, W, W) + else: + layer = NeRFEncoding(W, W, W) + setattr(self, f"xyz_encoding_{i+1}", layer) + self.sigma = NeRFHead(W, 1) + self.sh = NeRFHead(W, 32) + + def forward(self, xyz): + input_xyz = xyz + + xyz_ = input_xyz + for i in range(self.D): + xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) + + sigma = self.sigma(xyz_) + sh = self.sh(xyz_) + return sigma, sh + + def postprocess(self, sigma, sh, dirs=None): + sh = sh[:, :27] + rgb = eval_sh(deg=self.deg, sh=sh.reshape(-1, 3, (self.deg + 1) ** 2), dirs=dirs) + rgb = torch.sigmoid(rgb) + return sigma, rgb, sh