diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_classic.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_classic.cuh index 87660a5977..682f073ca7 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_classic.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_classic.cuh @@ -67,12 +67,12 @@ __global__ void device_programmable_bootstrap_cg( // We always compute the pointer with most restrictive alignment to avoid // alignment issues - double2 *accumulator_fft = (double2 *)selected_memory; - Torus *accumulator = - (Torus *)accumulator_fft + - (ptrdiff_t)(sizeof(double2) * polynomial_size / 2 / sizeof(Torus)); + Torus *accumulator = (Torus *)selected_memory; Torus *accumulator_rotated = - (Torus *)accumulator + (ptrdiff_t)polynomial_size; + (Torus *)accumulator + (ptrdiff_t)(polynomial_size); + double2 *accumulator_fft = + (double2 *)(accumulator_rotated) + + (ptrdiff_t)(polynomial_size * sizeof(Torus) / sizeof(double2)); if constexpr (SMD == PARTIALSM) accumulator_fft = (double2 *)sharedmem; diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh index 7a2267b813..8a534b2e83 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh @@ -162,10 +162,10 @@ __global__ void __launch_bounds__(params::degree / params::opt) // We always compute the pointer with most restrictive alignment to avoid // alignment issues - double2 *accumulator_fft = (double2 *)selected_memory; - Torus *accumulator = - (Torus *)accumulator_fft + - (ptrdiff_t)(sizeof(double2) * params::degree / 2 / sizeof(Torus)); + Torus *accumulator = (Torus *)selected_memory; + double2 *accumulator_fft = + (double2 *)accumulator + + (ptrdiff_t)(sizeof(Torus) * params::degree / sizeof(double2)); if constexpr (SMD == PARTIALSM) accumulator_fft = (double2 *)sharedmem;