From 397d5e851ed23c7a00fe8f24f190ef8f8957f707 Mon Sep 17 00:00:00 2001 From: Pavle Glusac Date: Mon, 3 Mar 2025 11:25:35 +0000 Subject: [PATCH 1/4] Add NeRF test --- forge/test/mlir/nerf/spherical_harmonics.py | 79 ++++++++++++ forge/test/mlir/nerf/test_training.py | 135 ++++++++++++++++++++ forge/test/mlir/nerf/utils.py | 70 ++++++++++ 3 files changed, 284 insertions(+) create mode 100644 forge/test/mlir/nerf/spherical_harmonics.py create mode 100644 forge/test/mlir/nerf/test_training.py create mode 100644 forge/test/mlir/nerf/utils.py diff --git a/forge/test/mlir/nerf/spherical_harmonics.py b/forge/test/mlir/nerf/spherical_harmonics.py new file mode 100644 index 000000000..ac1026b0c --- /dev/null +++ b/forge/test/mlir/nerf/spherical_harmonics.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + """ + assert deg <= 4 and deg >= 0 + assert (deg + 1) ** 2 == sh.shape[-1] + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = (result - + C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] - + C1 * x * sh[..., 3]) + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = (result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8]) + + if deg > 2: + result = (result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + if deg > 3: + result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + return result \ No newline at end of file diff --git a/forge/test/mlir/nerf/test_training.py b/forge/test/mlir/nerf/test_training.py new file mode 100644 index 000000000..7ad739af0 --- /dev/null +++ b/forge/test/mlir/nerf/test_training.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import time + +from test.mlir.nerf.utils import NeRF +import torch +from loguru import logger + +import forge +from forge.tensor import to_forge_tensors +from forge.verify.compare import compare_with_golden +from test.mlir.nerf.spherical_harmonics import eval_sh +from test.mlir.utils import * + + +@pytest.mark.push +def test_nerf_training(): + dtype = torch.float32 + + # Set training hyperparameters + num_epochs = 3 + num_batches = 10 + batch_size = 4096 + learning_rate = 1e-3 + deg = 2 + + # Define models + nerf_coarse = NeRF(D=4, W=128, in_channels_xyz=63, in_channels_dir=32, deg=deg) + nerf_fine = NeRF(D=4, W=192, in_channels_xyz=63, in_channels_dir=32, deg=deg) + + golden_nerf_coarse = NeRF(D=4, W=128, in_channels_xyz=63, in_channels_dir=32, deg=deg) + golden_nerf_fine = NeRF(D=4, W=192, in_channels_xyz=63, in_channels_dir=32, deg=deg) + + copy_params(nerf_coarse, golden_nerf_coarse) + copy_params(nerf_fine, golden_nerf_fine) + + loss_fn = torch.nn.MSELoss() + + # Define optimizer + framework_optimizer_coarse = torch.optim.SGD(nerf_coarse.parameters(), lr=learning_rate) + framework_optimizer_fine = torch.optim.SGD(nerf_fine.parameters(), lr=learning_rate) + golden_optimizer_coarse = torch.optim.SGD(golden_nerf_coarse.parameters(), lr=learning_rate) + golden_optimizer_fine = torch.optim.SGD(golden_nerf_fine.parameters(), lr=learning_rate) + + tt_nerf_coarse = forge.compile( + nerf_coarse, + sample_inputs=[torch.rand(batch_size, 63, dtype=dtype, requires_grad=True)], + optimizer=framework_optimizer_coarse, + training=True, + ) + + tt_nerf_fine = forge.compile( + nerf_fine, + sample_inputs=[torch.rand(batch_size, 63, dtype=dtype, requires_grad=True)], + optimizer=framework_optimizer_fine, + training=True, + ) + + logger.info("Starting NeRF training loop... (logger will be disabled)") + logger.disable("") + for epoch_idx in range(num_epochs): + for batch_idx in range(num_batches): + # zero the parameter gradients + framework_optimizer_coarse.zero_grad() + framework_optimizer_fine.zero_grad() + golden_optimizer_coarse.zero_grad() + golden_optimizer_fine.zero_grad() + + # Generate random input data + input_xyz = torch.rand(batch_size, 63, dtype=dtype, requires_grad=True) + input_dirs = torch.rand(batch_size, 3, dtype=dtype, requires_grad=True) + target_data_sigma = torch.rand(batch_size, 1, dtype=dtype, requires_grad=True) + target_data_sh = torch.rand(batch_size, 3, dtype=dtype, requires_grad=True) + + # Forward pass on TT + output_coarse_sigma, output_coarse_sh = tt_nerf_coarse(input_xyz) + output_coarse_sh = output_coarse_sh[:, :27].reshape(-1, 3, (deg + 1) ** 2) + output_coarse_rgb = eval_sh(deg, output_coarse_sh, input_dirs) + + output_fine_sigma, output_fine_sh = tt_nerf_fine(input_xyz) + output_fine_sh = output_fine_sh[:, :27].reshape(-1, 3, (deg + 1) ** 2) + output_fine_rgb = eval_sh(deg, output_fine_sh, input_dirs) + + # Forward pass on PyTorch + golden_output_coarse_sigma, output_coarse_sh_pt = golden_nerf_coarse(input_xyz) + golden_output_coarse_sh_pt = output_coarse_sh_pt[:, :27].reshape(-1, 3, (deg + 1) ** 2) + golden_output_coarse_rgb_pt = eval_sh(deg, golden_output_coarse_sh_pt, input_dirs) + + golden_output_fine_sigma, golden_output_fine_sh = golden_nerf_fine(input_xyz) + golden_output_fine_sh = golden_output_fine_sh[:, :27].reshape(-1, 3, (deg + 1) ** 2) + golden_output_fine_rgb = eval_sh(deg, golden_output_fine_sh, input_dirs) + + # Compute loss for TT + loss_coarse_sigma = loss_fn(output_coarse_sigma, target_data_sigma) + loss_coarse_sh = loss_fn(output_coarse_rgb, target_data_sh) + loss_coarse = loss_coarse_sigma + loss_coarse_sh + + loss_fine_sigma = loss_fn(output_fine_sigma, target_data_sigma) + loss_fine_sh = loss_fn(output_fine_rgb, target_data_sh) + loss_fine = loss_fine_sigma + loss_fine_sh + + # Compute loss for PyTorch + golden_loss_coarse_sigma = loss_fn(golden_output_coarse_sigma, target_data_sigma) + golden_loss_coarse_sh = loss_fn(golden_output_coarse_rgb_pt, target_data_sh) + golden_loss_coarse = golden_loss_coarse_sigma + golden_loss_coarse_sh + + golden_loss_fine_sigma = loss_fn(golden_output_fine_sigma, target_data_sigma) + golden_loss_fine_sh = loss_fn(golden_output_fine_rgb, target_data_sh) + golden_loss_fine = golden_loss_fine_sigma + golden_loss_fine_sh + + # Compare TT and PyTorch losses + assert compare_with_golden(loss_coarse, golden_loss_coarse, rtol=0.05, atol=0.05), f"Loss coarse mismatch at epoch {epoch_idx}, batch {batch_idx}" + assert compare_with_golden(loss_fine, golden_loss_fine, rtol=0.05, atol=0.05), f"Loss fine mismatch at epoch {epoch_idx}, batch {batch_idx}" + + # Backward pass + loss_coarse.backward() + loss_fine.backward() + + golden_loss_coarse.backward() + golden_loss_fine.backward() + + # Update weights + framework_optimizer_coarse.step() + framework_optimizer_fine.step() + + golden_optimizer_coarse.step() + golden_optimizer_fine.step() + + logger.enable("") + logger.info("NeRF training loop completed.") + + diff --git a/forge/test/mlir/nerf/utils.py b/forge/test/mlir/nerf/utils.py new file mode 100644 index 000000000..a06e5b73a --- /dev/null +++ b/forge/test/mlir/nerf/utils.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn as nn +from test.mlir.nerf.spherical_harmonics import eval_sh + +class NeRFHead(nn.Module): + def __init__(self, W, out_dim, *args, **kwargs): + super().__init__(*args, **kwargs) + self.layer1 = nn.Linear(W, W) + self.relu1 = nn.ReLU(False) + self.layer2 = nn.Linear(W, out_dim) + + def forward(self, x): + x = self.layer1(x) + x = self.relu1(x) + x = self.layer2(x) + return x + + +class NeRFEncoding(nn.Module): + def __init__(self, in_dim, W, out_dim, *args, **kwargs): + super().__init__(*args, **kwargs) + self.layer1 = nn.Linear(in_dim, W) + self.relu1 = nn.ReLU(False) + self.layer2 = nn.Linear(W, out_dim) + + def forward(self, x): + x = self.layer1(x) + x = self.relu1(x) + x = self.layer2(x) + return x + + +class NeRF(nn.Module): + def __init__(self, D=8, W=256, in_channels_xyz=63, in_channels_dir=27, deg=2): + super(NeRF, self).__init__() + self.D = D + self.W = W + self.in_channels_xyz = in_channels_xyz + self.in_channels_dir = in_channels_dir + self.deg = deg + + for i in range(D): + if i == 0: + layer = NeRFEncoding(in_channels_xyz, W, W) + else: + layer = NeRFEncoding(W, W, W) + setattr(self, f"xyz_encoding_{i+1}", layer) + self.sigma = NeRFHead(W, 1) + self.sh = NeRFHead(W, 32) + + def forward(self, xyz): + input_xyz = xyz + + xyz_ = input_xyz + for i in range(self.D): + xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) + + sigma = self.sigma(xyz_) + sh = self.sh(xyz_) + return sigma, sh + + def postprocess(self, sigma, sh, dirs=None): + sh = sh[:, :27] + rgb = eval_sh(deg=self.deg, sh=sh.reshape(-1, 3, (self.deg + 1) ** 2), dirs=dirs) + rgb = torch.sigmoid(rgb) + return sigma, rgb, sh From 98abdd5578ac44b70c060d36d73c34b6b191b530 Mon Sep 17 00:00:00 2001 From: Pavle Glusac Date: Mon, 3 Mar 2025 13:45:52 +0000 Subject: [PATCH 2/4] Reformat --- forge/test/mlir/nerf/spherical_harmonics.py | 71 ++++++++++----------- forge/test/mlir/nerf/test_training.py | 18 +++--- forge/test/mlir/nerf/utils.py | 7 +- 3 files changed, 49 insertions(+), 47 deletions(-) diff --git a/forge/test/mlir/nerf/spherical_harmonics.py b/forge/test/mlir/nerf/spherical_harmonics.py index ac1026b0c..7f1f65ba0 100644 --- a/forge/test/mlir/nerf/spherical_harmonics.py +++ b/forge/test/mlir/nerf/spherical_harmonics.py @@ -4,13 +4,7 @@ C0 = 0.28209479177387814 C1 = 0.4886025119029199 -C2 = [ - 1.0925484305920792, - -1.0925484305920792, - 0.31539156525252005, - -1.0925484305920792, - 0.5462742152960396 -] +C2 = [1.0925484305920792, -1.0925484305920792, 0.31539156525252005, -1.0925484305920792, 0.5462742152960396] C3 = [ -0.5900435899266435, 2.890611442640554, @@ -18,7 +12,7 @@ 0.3731763325901154, -0.4570457994644658, 1.445305721320277, - -0.5900435899266435 + -0.5900435899266435, ] C4 = [ 2.5033429417967046, @@ -32,6 +26,7 @@ 0.6258357354491761, ] + def eval_sh(deg, sh, dirs): """ Evaluate spherical harmonics at unit directions @@ -43,37 +38,41 @@ def eval_sh(deg, sh, dirs): result = C0 * sh[..., 0] if deg > 0: x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] - result = (result - - C1 * y * sh[..., 1] + - C1 * z * sh[..., 2] - - C1 * x * sh[..., 3]) + result = result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] if deg > 1: xx, yy, zz = x * x, y * y, z * z xy, yz, xz = x * y, y * z, x * z - result = (result + - C2[0] * xy * sh[..., 4] + - C2[1] * yz * sh[..., 5] + - C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + - C2[3] * xz * sh[..., 7] + - C2[4] * (xx - yy) * sh[..., 8]) + result = ( + result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8] + ) if deg > 2: - result = (result + - C3[0] * y * (3 * xx - yy) * sh[..., 9] + - C3[1] * xy * z * sh[..., 10] + - C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + - C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + - C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + - C3[5] * z * (xx - yy) * sh[..., 14] + - C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + result = ( + result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15] + ) if deg > 3: - result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + - C4[1] * yz * (3 * xx - yy) * sh[..., 17] + - C4[2] * xy * (7 * zz - 1) * sh[..., 18] + - C4[3] * yz * (7 * zz - 3) * sh[..., 19] + - C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + - C4[5] * xz * (7 * zz - 3) * sh[..., 21] + - C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + - C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + - C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) - return result \ No newline at end of file + result = ( + result + + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24] + ) + return result diff --git a/forge/test/mlir/nerf/test_training.py b/forge/test/mlir/nerf/test_training.py index 7ad739af0..b04853a11 100644 --- a/forge/test/mlir/nerf/test_training.py +++ b/forge/test/mlir/nerf/test_training.py @@ -28,10 +28,10 @@ def test_nerf_training(): deg = 2 # Define models - nerf_coarse = NeRF(D=4, W=128, in_channels_xyz=63, in_channels_dir=32, deg=deg) + nerf_coarse = NeRF(D=4, W=128, in_channels_xyz=63, in_channels_dir=32, deg=deg) nerf_fine = NeRF(D=4, W=192, in_channels_xyz=63, in_channels_dir=32, deg=deg) - - golden_nerf_coarse = NeRF(D=4, W=128, in_channels_xyz=63, in_channels_dir=32, deg=deg) + + golden_nerf_coarse = NeRF(D=4, W=128, in_channels_xyz=63, in_channels_dir=32, deg=deg) golden_nerf_fine = NeRF(D=4, W=192, in_channels_xyz=63, in_channels_dir=32, deg=deg) copy_params(nerf_coarse, golden_nerf_coarse) @@ -79,7 +79,7 @@ def test_nerf_training(): output_coarse_sigma, output_coarse_sh = tt_nerf_coarse(input_xyz) output_coarse_sh = output_coarse_sh[:, :27].reshape(-1, 3, (deg + 1) ** 2) output_coarse_rgb = eval_sh(deg, output_coarse_sh, input_dirs) - + output_fine_sigma, output_fine_sh = tt_nerf_fine(input_xyz) output_fine_sh = output_fine_sh[:, :27].reshape(-1, 3, (deg + 1) ** 2) output_fine_rgb = eval_sh(deg, output_fine_sh, input_dirs) @@ -112,8 +112,12 @@ def test_nerf_training(): golden_loss_fine = golden_loss_fine_sigma + golden_loss_fine_sh # Compare TT and PyTorch losses - assert compare_with_golden(loss_coarse, golden_loss_coarse, rtol=0.05, atol=0.05), f"Loss coarse mismatch at epoch {epoch_idx}, batch {batch_idx}" - assert compare_with_golden(loss_fine, golden_loss_fine, rtol=0.05, atol=0.05), f"Loss fine mismatch at epoch {epoch_idx}, batch {batch_idx}" + assert compare_with_golden( + loss_coarse, golden_loss_coarse, rtol=0.05, atol=0.05 + ), f"Loss coarse mismatch at epoch {epoch_idx}, batch {batch_idx}" + assert compare_with_golden( + loss_fine, golden_loss_fine, rtol=0.05, atol=0.05 + ), f"Loss fine mismatch at epoch {epoch_idx}, batch {batch_idx}" # Backward pass loss_coarse.backward() @@ -131,5 +135,3 @@ def test_nerf_training(): logger.enable("") logger.info("NeRF training loop completed.") - - diff --git a/forge/test/mlir/nerf/utils.py b/forge/test/mlir/nerf/utils.py index a06e5b73a..7aacf2475 100644 --- a/forge/test/mlir/nerf/utils.py +++ b/forge/test/mlir/nerf/utils.py @@ -6,6 +6,7 @@ import torch.nn as nn from test.mlir.nerf.spherical_harmonics import eval_sh + class NeRFHead(nn.Module): def __init__(self, W, out_dim, *args, **kwargs): super().__init__(*args, **kwargs) @@ -18,7 +19,7 @@ def forward(self, x): x = self.relu1(x) x = self.layer2(x) return x - + class NeRFEncoding(nn.Module): def __init__(self, in_dim, W, out_dim, *args, **kwargs): @@ -57,12 +58,12 @@ def forward(self, xyz): xyz_ = input_xyz for i in range(self.D): - xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) + xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) sigma = self.sigma(xyz_) sh = self.sh(xyz_) return sigma, sh - + def postprocess(self, sigma, sh, dirs=None): sh = sh[:, :27] rgb = eval_sh(deg=self.deg, sh=sh.reshape(-1, 3, (self.deg + 1) ** 2), dirs=dirs) From 20734cf8759ddcde64b8088ca51ca49aae5e615f Mon Sep 17 00:00:00 2001 From: Pavle Glusac Date: Mon, 3 Mar 2025 16:34:21 +0000 Subject: [PATCH 3/4] Refactor spherical harmonics --- forge/test/mlir/nerf/spherical_harmonics.py | 119 ++++++++++++-------- 1 file changed, 73 insertions(+), 46 deletions(-) diff --git a/forge/test/mlir/nerf/spherical_harmonics.py b/forge/test/mlir/nerf/spherical_harmonics.py index 7f1f65ba0..9c2c06c46 100644 --- a/forge/test/mlir/nerf/spherical_harmonics.py +++ b/forge/test/mlir/nerf/spherical_harmonics.py @@ -4,7 +4,13 @@ C0 = 0.28209479177387814 C1 = 0.4886025119029199 -C2 = [1.0925484305920792, -1.0925484305920792, 0.31539156525252005, -1.0925484305920792, 0.5462742152960396] +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] C3 = [ -0.5900435899266435, 2.890611442640554, @@ -12,7 +18,7 @@ 0.3731763325901154, -0.4570457994644658, 1.445305721320277, - -0.5900435899266435, + -0.5900435899266435 ] C4 = [ 2.5033429417967046, @@ -26,53 +32,74 @@ 0.6258357354491761, ] - def eval_sh(deg, sh, dirs): """ - Evaluate spherical harmonics at unit directions - using hardcoded SH polynomials. + Evaluate spherical harmonics at unit directions using hardcoded SH polynomials. + + Args: + deg: Degree of spherical harmonics (0-4) + sh: Spherical harmonics coefficients + dirs: Unit direction vectors [..., 3] + + Returns: + Evaluated spherical harmonics """ - assert deg <= 4 and deg >= 0 - assert (deg + 1) ** 2 == sh.shape[-1] + assert 0 <= deg <= 4, "Degree must be between 0 and 4" + assert (deg + 1) ** 2 == sh.shape[-1], "Invalid SH coefficients shape" result = C0 * sh[..., 0] - if deg > 0: - x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] - result = result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] - if deg > 1: - xx, yy, zz = x * x, y * y, z * z - xy, yz, xz = x * y, y * z, x * z - result = ( - result - + C2[0] * xy * sh[..., 4] - + C2[1] * yz * sh[..., 5] - + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] - + C2[3] * xz * sh[..., 7] - + C2[4] * (xx - yy) * sh[..., 8] - ) + + if deg == 0: + return result - if deg > 2: - result = ( - result - + C3[0] * y * (3 * xx - yy) * sh[..., 9] - + C3[1] * xy * z * sh[..., 10] - + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] - + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] - + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] - + C3[5] * z * (xx - yy) * sh[..., 14] - + C3[6] * x * (xx - 3 * yy) * sh[..., 15] - ) - if deg > 3: - result = ( - result - + C4[0] * xy * (xx - yy) * sh[..., 16] - + C4[1] * yz * (3 * xx - yy) * sh[..., 17] - + C4[2] * xy * (7 * zz - 1) * sh[..., 18] - + C4[3] * yz * (7 * zz - 3) * sh[..., 19] - + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] - + C4[5] * xz * (7 * zz - 3) * sh[..., 21] - + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] - + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] - + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24] - ) - return result + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + + result += ( + -C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] + - C1 * x * sh[..., 3] + ) + + if deg == 1: + return result + + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + + result += ( + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8] + ) + + if deg == 2: + return result + + result += ( + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15] + ) + + if deg == 3: + return result + + result += ( + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24] + ) + + return result \ No newline at end of file From b067296e146f3173cbcd5499b1f3ff1ffbb92021 Mon Sep 17 00:00:00 2001 From: Pavle Glusac Date: Tue, 4 Mar 2025 08:53:17 +0000 Subject: [PATCH 4/4] Reformat --- forge/test/mlir/nerf/spherical_harmonics.py | 43 ++++++++------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/forge/test/mlir/nerf/spherical_harmonics.py b/forge/test/mlir/nerf/spherical_harmonics.py index 9c2c06c46..7b7244ab0 100644 --- a/forge/test/mlir/nerf/spherical_harmonics.py +++ b/forge/test/mlir/nerf/spherical_harmonics.py @@ -4,13 +4,7 @@ C0 = 0.28209479177387814 C1 = 0.4886025119029199 -C2 = [ - 1.0925484305920792, - -1.0925484305920792, - 0.31539156525252005, - -1.0925484305920792, - 0.5462742152960396 -] +C2 = [1.0925484305920792, -1.0925484305920792, 0.31539156525252005, -1.0925484305920792, 0.5462742152960396] C3 = [ -0.5900435899266435, 2.890611442640554, @@ -18,7 +12,7 @@ 0.3731763325901154, -0.4570457994644658, 1.445305721320277, - -0.5900435899266435 + -0.5900435899266435, ] C4 = [ 2.5033429417967046, @@ -32,15 +26,16 @@ 0.6258357354491761, ] + def eval_sh(deg, sh, dirs): """ Evaluate spherical harmonics at unit directions using hardcoded SH polynomials. - + Args: deg: Degree of spherical harmonics (0-4) sh: Spherical harmonics coefficients dirs: Unit direction vectors [..., 3] - + Returns: Evaluated spherical harmonics """ @@ -48,24 +43,20 @@ def eval_sh(deg, sh, dirs): assert (deg + 1) ** 2 == sh.shape[-1], "Invalid SH coefficients shape" result = C0 * sh[..., 0] - + if deg == 0: return result x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] - - result += ( - -C1 * y * sh[..., 1] - + C1 * z * sh[..., 2] - - C1 * x * sh[..., 3] - ) - + + result += -C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] + if deg == 1: return result - + xx, yy, zz = x * x, y * y, z * z xy, yz, xz = x * y, y * z, x * z - + result += ( C2[0] * xy * sh[..., 4] + C2[1] * yz * sh[..., 5] @@ -73,10 +64,10 @@ def eval_sh(deg, sh, dirs): + C2[3] * xz * sh[..., 7] + C2[4] * (xx - yy) * sh[..., 8] ) - + if deg == 2: return result - + result += ( C3[0] * y * (3 * xx - yy) * sh[..., 9] + C3[1] * xy * z * sh[..., 10] @@ -86,10 +77,10 @@ def eval_sh(deg, sh, dirs): + C3[5] * z * (xx - yy) * sh[..., 14] + C3[6] * x * (xx - 3 * yy) * sh[..., 15] ) - + if deg == 3: return result - + result += ( C4[0] * xy * (xx - yy) * sh[..., 16] + C4[1] * yz * (3 * xx - yy) * sh[..., 17] @@ -101,5 +92,5 @@ def eval_sh(deg, sh, dirs): + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24] ) - - return result \ No newline at end of file + + return result