diff --git a/backends/tfhe-cuda-backend/cuda/src/crypto/gadget.cuh b/backends/tfhe-cuda-backend/cuda/src/crypto/gadget.cuh index 7b4653af3b..01209356a9 100644 --- a/backends/tfhe-cuda-backend/cuda/src/crypto/gadget.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/crypto/gadget.cuh @@ -32,12 +32,6 @@ public: state(state) { mask_mod_b = (1ll << base_log) - 1ll; - int tid = threadIdx.x; - for (int i = 0; i < num_poly * params::opt; i++) { - state[tid] >>= (sizeof(T) * 8 - base_log * level_count); - tid += params::degree / params::opt; - } - synchronize_threads_in_block(); } // Decomposes all polynomials at once @@ -51,21 +45,25 @@ public: // Decomposes a single polynomial __device__ void decompose_and_compress_next_polynomial(double2 *result, int j) { - - int tid = threadIdx.x; - auto state_slice = state + j * params::degree; + uint32_t tid = threadIdx.x; + auto state_slice = &state[j * params::degree]; for (int i = 0; i < params::opt / 2; i++) { - T res_re = state_slice[tid] & mask_mod_b; - T res_im = state_slice[tid + params::degree / 2] & mask_mod_b; - state_slice[tid] >>= base_log; - state_slice[tid + params::degree / 2] >>= base_log; - T carry_re = ((res_re - 1ll) | state_slice[tid]) & res_re; - T carry_im = - ((res_im - 1ll) | state_slice[tid + params::degree / 2]) & res_im; + auto input1 = &state_slice[tid]; + auto input2 = &state_slice[tid + params::degree / 2]; + T res_re = *input1 & mask_mod_b; + T res_im = *input2 & mask_mod_b; + + *input1 >>= base_log; // Update state + *input2 >>= base_log; // Update state + + T carry_re = ((res_re - 1ll) | *input1) & res_re; + T carry_im = ((res_im - 1ll) | *input2) & res_im; carry_re >>= (base_log - 1); carry_im >>= (base_log - 1); - state_slice[tid] += carry_re; - state_slice[tid + params::degree / 2] += carry_im; + + *input1 += carry_re; // Update state + *input2 += carry_im; // Update state + res_re -= carry_re << base_log; res_im -= carry_im << base_log; diff --git a/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh b/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh index f84fee7030..d4c77fb69c 100644 --- a/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh @@ -71,9 +71,8 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes, // This loop distribution seems to benefit the global mem reads for (int i = start_i; i < end_i; i++) { - Torus a_i = round_to_closest_multiple(block_lwe_array_in[i], base_log, - level_count); - Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count); + Torus state = + init_decomposer_state(block_lwe_array_in[i], base_log, level_count); for (int j = 0; j < level_count; j++) { auto ksk_block = @@ -201,9 +200,8 @@ __device__ void packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext( // Iterate through all lwe elements for (int i = 0; i < lwe_dimension_in; i++) { // Round and prepare decomposition - Torus a_i = round_to_closest_multiple(lwe_in[i], base_log, level_count); + Torus state = init_decomposer_state(lwe_in[i], base_log, level_count); - Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count); Torus mod_b_mask = (1ll << base_log) - 1ll; // block of key for current lwe coefficient (cur_input_lwe[i]) diff --git a/backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh b/backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh index f9875b107a..10c59bf31e 100644 --- a/backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh @@ -12,6 +12,11 @@ __host__ __device__ __forceinline__ constexpr double get_two_pow_torus_bits() { return (sizeof(T) == 4) ? 4294967296.0 : 18446744073709551616.0; } +template +__host__ __device__ __forceinline__ constexpr T scalar_max() { + return std::numeric_limits::max(); +} + template __device__ inline void typecast_double_to_torus(double x, T &r) { r = T(x); @@ -60,14 +65,21 @@ __device__ inline void typecast_torus_to_double(uint64_t x, } template -__device__ inline T round_to_closest_multiple(T x, uint32_t base_log, - uint32_t level_count) { - const T non_rep_bit_count = sizeof(T) * 8 - level_count * base_log; - const T shift = non_rep_bit_count - 1; - T res = x >> shift; - res += 1; - res &= (T)(-2); - return res << shift; +__device__ inline T init_decomposer_state(T input, uint32_t base_log, + uint32_t level_count) { + const T rep_bit_count = level_count * base_log; + const T non_rep_bit_count = sizeof(T) * 8 - rep_bit_count; + T res = input >> (non_rep_bit_count - 1); + T rounding_bit = res & (T)(1); + res++; + res >>= 1; + T torus_max = scalar_max(); + T mod_mask = torus_max >> non_rep_bit_count; + res &= mod_mask; + T shifted_random = rounding_bit << (rep_bit_count - 1); + T need_balance = + (((res - (T)(1)) | shifted_random) & res) >> (rep_bit_count - 1); + return res - (need_balance << rep_bit_count); } template diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_amortized.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_amortized.cuh index 10cf9c27b0..374bbce727 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_amortized.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_amortized.cuh @@ -117,8 +117,8 @@ __global__ void device_programmable_bootstrap_amortized( // Perform a rounding to increase the accuracy of the // bootstrapped ciphertext - round_to_closest_multiple_inplace( + init_decomposer_state_inplace( accumulator_rotated, base_log, level_count, glwe_dimension + 1); // Initialize the polynomial multiplication via FFT arrays diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_classic.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_classic.cuh index 0cfd95efb7..2e5f83d45b 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_classic.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_classic.cuh @@ -117,8 +117,8 @@ __global__ void device_programmable_bootstrap_cg( // Perform a rounding to increase the accuracy of the // bootstrapped ciphertext - round_to_closest_multiple_inplace( + init_decomposer_state_inplace( accumulator_rotated, base_log, level_count); synchronize_threads_in_block(); diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh index 1e89164235..d736534e48 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh @@ -96,8 +96,8 @@ __global__ void __launch_bounds__(params::degree / params::opt) for (int i = 0; (i + lwe_offset) < lwe_dimension && i < lwe_chunk_size; i++) { // Perform a rounding to increase the accuracy of the // bootstrapped ciphertext - round_to_closest_multiple_inplace( + init_decomposer_state_inplace( accumulator_rotated, base_log, level_count); // Decompose the accumulator_rotated. Each block gets one level of the diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh index 831d3a6478..31f1e9487f 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh @@ -106,8 +106,8 @@ __global__ void __launch_bounds__(params::degree / params::opt) // Perform a rounding to increase the accuracy of the // bootstrapped ciphertext - round_to_closest_multiple_inplace( + init_decomposer_state_inplace( accumulator, base_log, level_count); synchronize_threads_in_block(); diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh index b57c61cf90..a58647185b 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh @@ -224,8 +224,8 @@ __global__ void __launch_bounds__(params::degree / params::opt) // Perform a rounding to increase the accuracy of the // bootstrapped ciphertext - round_to_closest_multiple_inplace( + init_decomposer_state_inplace( accumulator, base_log, level_count); // Decompose the accumulator. Each block gets one level of the diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_classic.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_classic.cuh index bbdf1ab43e..b7dc557e3a 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_classic.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_classic.cuh @@ -121,8 +121,8 @@ __global__ void device_programmable_bootstrap_tbc( // Perform a rounding to increase the accuracy of the // bootstrapped ciphertext - round_to_closest_multiple_inplace( + init_decomposer_state_inplace( accumulator_rotated, base_log, level_count); synchronize_threads_in_block(); diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh index ade8d1f423..22b6f4e196 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh @@ -104,8 +104,8 @@ __global__ void __launch_bounds__(params::degree / params::opt) for (int i = 0; (i + lwe_offset) < lwe_dimension && i < lwe_chunk_size; i++) { // Perform a rounding to increase the accuracy of the // bootstrapped ciphertext - round_to_closest_multiple_inplace( + init_decomposer_state_inplace( accumulator_rotated, base_log, level_count); // Decompose the accumulator. Each block gets one level of the diff --git a/backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh b/backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh index 6b19f08d76..5bf3af711a 100644 --- a/backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh @@ -141,22 +141,17 @@ __device__ void multiply_by_monomial_negacyclic_and_sub_polynomial( * By default, it works on a single polynomial. */ template -__device__ void round_to_closest_multiple_inplace(T *rotated_acc, int base_log, - int level_count, - uint32_t num_poly = 1) { +__device__ void init_decomposer_state_inplace(T *rotated_acc, int base_log, + int level_count, + uint32_t num_poly = 1) { constexpr int degree = block_size * elems_per_thread; for (int z = 0; z < num_poly; z++) { - T *rotated_acc_slice = (T *)rotated_acc + (ptrdiff_t)(z * degree); - int tid = threadIdx.x; + T *rotated_acc_slice = &rotated_acc[z * degree]; + uint32_t tid = threadIdx.x; for (int i = 0; i < elems_per_thread; i++) { T x_acc = rotated_acc_slice[tid]; - T shift = sizeof(T) * 8 - level_count * base_log; - T mask = 1ll << (shift - 1); - T b_acc = (x_acc & mask) >> (shift - 1); - T res_acc = x_acc >> shift; - res_acc += b_acc; - res_acc <<= shift; - rotated_acc_slice[tid] = res_acc; + rotated_acc_slice[tid] = + init_decomposer_state(x_acc, base_log, level_count); tid = tid + block_size; } }