Skip to content

Commit

Permalink
Add depth backward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
ingra14m committed Jul 26, 2023
1 parent 6787328 commit 4c6d250
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 1 deletion.
23 changes: 23 additions & 0 deletions cuda_rasterizer/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,11 @@ renderCUDA(
const float2* __restrict__ points_xy_image,
const float4* __restrict__ conic_opacity,
const float* __restrict__ colors,
const float* __restrict__ depths,
const float* __restrict__ final_Ts,
const uint32_t* __restrict__ n_contrib,
const float* __restrict__ dL_dpixels,
const float* __restrict__ dL_depths,
float3* __restrict__ dL_dmean2D,
float4* __restrict__ dL_dconic2D,
float* __restrict__ dL_dopacity,
Expand All @@ -435,6 +437,7 @@ renderCUDA(
__shared__ float2 collected_xy[BLOCK_SIZE];
__shared__ float4 collected_conic_opacity[BLOCK_SIZE];
__shared__ float collected_colors[C * BLOCK_SIZE];
__shared__ float collected_depths[BLOCK_SIZE];

// In the forward, we stored the final value for T, the
// product of all (1 - alpha) factors.
Expand All @@ -448,12 +451,16 @@ renderCUDA(

float accum_rec[C] = { 0 };
float dL_dpixel[C];
float dL_depth;
float accum_depth_rec = 0;
if (inside)
for (int i = 0; i < C; i++)
dL_dpixel[i] = dL_dpixels[i * H * W + pix_id];
dL_depth = dL_depths[pix_id];

float last_alpha = 0;
float last_color[C] = { 0 };
float last_depth = 0;

// Gradient of pixel coordinate w.r.t. normalized
// screen-space viewport corrdinates (-1 to 1)
Expand All @@ -475,6 +482,7 @@ renderCUDA(
collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
for (int i = 0; i < C; i++)
collected_colors[i * BLOCK_SIZE + block.thread_rank()] = colors[coll_id * C + i];
collected_depths[block.thread_rank()] = depths[coll_id];
}
block.sync();

Expand Down Expand Up @@ -522,6 +530,17 @@ renderCUDA(
// many that were affected by this Gaussian.
atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel);
}

// Propagate gradients to per-Gaussian depths
const float c_d = collected_depths[j];
accum_depth_rec = last_alpha * last_depth + (1.f - last_alpha) * accum_depth_rec;
last_depth = c_d;
dL_dalpha += (c_d - accum_depth_rec) * dL_depth;
// for (int ch = 0; ch < C; ch++)
// {
// atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_depth);
// }

dL_dalpha *= T;
// Update last alpha (to be used in the next iteration)
last_alpha = alpha;
Expand Down Expand Up @@ -630,9 +649,11 @@ void BACKWARD::render(
const float2* means2D,
const float4* conic_opacity,
const float* colors,
const float* depths,
const float* final_Ts,
const uint32_t* n_contrib,
const float* dL_dpixels,
const float* dL_depths,
float3* dL_dmean2D,
float4* dL_dconic2D,
float* dL_dopacity,
Expand All @@ -646,9 +667,11 @@ void BACKWARD::render(
means2D,
conic_opacity,
colors,
depths,
final_Ts,
n_contrib,
dL_dpixels,
dL_depths,
dL_dmean2D,
dL_dconic2D,
dL_dopacity,
Expand Down
2 changes: 2 additions & 0 deletions cuda_rasterizer/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ namespace BACKWARD
const float2* means2D,
const float4* conic_opacity,
const float* colors,
const float* depths,
const float* final_Ts,
const uint32_t* n_contrib,
const float* dL_dpixels,
const float* dL_depths,
float3* dL_dmean2D,
float4* dL_dconic2D,
float* dL_dopacity,
Expand Down
1 change: 1 addition & 0 deletions cuda_rasterizer/rasterizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ namespace CudaRasterizer
char* binning_buffer,
char* image_buffer,
const float* dL_dpix,
const float* dL_depths,
float* dL_dmean2D,
float* dL_dconic,
float* dL_dopacity,
Expand Down
4 changes: 4 additions & 0 deletions cuda_rasterizer/rasterizer_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ void CudaRasterizer::Rasterizer::backward(
char* binning_buffer,
char* img_buffer,
const float* dL_dpix,
const float* dL_depths,
float* dL_dmean2D,
float* dL_dconic,
float* dL_dopacity,
Expand Down Expand Up @@ -389,6 +390,7 @@ void CudaRasterizer::Rasterizer::backward(
// opacity and RGB of Gaussians from per-pixel loss gradients.
// If we were given precomputed colors and not SHs, use them.
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb;
const float* depth_ptr = geomState.depths;
BACKWARD::render(
tile_grid,
block,
Expand All @@ -399,9 +401,11 @@ void CudaRasterizer::Rasterizer::backward(
geomState.means2D,
geomState.conic_opacity,
color_ptr,
depth_ptr,
imgState.accum_alpha,
imgState.n_contrib,
dL_dpix,
dL_depths,
(float3*)dL_dmean2D,
(float4*)dL_dconic,
dL_dopacity,
Expand Down
3 changes: 2 additions & 1 deletion diff_gaussian_rasterization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def backward(ctx, grad_out_color, grad_radii, grad_depth):
raster_settings.projmatrix,
raster_settings.tanfovx,
raster_settings.tanfovy,
grad_out_color,
grad_out_color,
grad_depth,
sh,
raster_settings.sh_degree,
raster_settings.campos,
Expand Down
2 changes: 2 additions & 0 deletions rasterize_points.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const float tan_fovx,
const float tan_fovy,
const torch::Tensor& dL_dout_color,
const torch::Tensor& dL_dout_depth,
const torch::Tensor& sh,
const int degree,
const torch::Tensor& campos,
Expand Down Expand Up @@ -179,6 +180,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
dL_dout_color.contiguous().data<float>(),
dL_dout_depth.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(),
dL_dopacity.contiguous().data<float>(),
Expand Down
1 change: 1 addition & 0 deletions rasterize_points.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const float tan_fovx,
const float tan_fovy,
const torch::Tensor& dL_dout_color,
const torch::Tensor& dL_dout_depth,
const torch::Tensor& sh,
const int degree,
const torch::Tensor& campos,
Expand Down

0 comments on commit 4c6d250

Please sign in to comment.