Skip to content

Commit

Permalink
chore(gpu): use same balanced decomposition code as in the CPU code
Browse files Browse the repository at this point in the history
  • Loading branch information
IceTDrinker authored and agnesLeroy committed Nov 13, 2024
1 parent b041608 commit d280403
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 57 deletions.
34 changes: 16 additions & 18 deletions backends/tfhe-cuda-backend/cuda/src/crypto/gadget.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ public:
state(state) {

mask_mod_b = (1ll << base_log) - 1ll;
int tid = threadIdx.x;
for (int i = 0; i < num_poly * params::opt; i++) {
state[tid] >>= (sizeof(T) * 8 - base_log * level_count);
tid += params::degree / params::opt;
}
synchronize_threads_in_block();
}

// Decomposes all polynomials at once
Expand All @@ -51,21 +45,25 @@ public:
// Decomposes a single polynomial
__device__ void decompose_and_compress_next_polynomial(double2 *result,
int j) {

int tid = threadIdx.x;
auto state_slice = state + j * params::degree;
uint32_t tid = threadIdx.x;
auto state_slice = &state[j * params::degree];
for (int i = 0; i < params::opt / 2; i++) {
T res_re = state_slice[tid] & mask_mod_b;
T res_im = state_slice[tid + params::degree / 2] & mask_mod_b;
state_slice[tid] >>= base_log;
state_slice[tid + params::degree / 2] >>= base_log;
T carry_re = ((res_re - 1ll) | state_slice[tid]) & res_re;
T carry_im =
((res_im - 1ll) | state_slice[tid + params::degree / 2]) & res_im;
auto input1 = &state_slice[tid];
auto input2 = &state_slice[tid + params::degree / 2];
T res_re = *input1 & mask_mod_b;
T res_im = *input2 & mask_mod_b;

*input1 >>= base_log; // Update state
*input2 >>= base_log; // Update state

T carry_re = ((res_re - 1ll) | *input1) & res_re;
T carry_im = ((res_im - 1ll) | *input2) & res_im;
carry_re >>= (base_log - 1);
carry_im >>= (base_log - 1);
state_slice[tid] += carry_re;
state_slice[tid + params::degree / 2] += carry_im;

*input1 += carry_re; // Update state
*input2 += carry_im; // Update state

res_re -= carry_re << base_log;
res_im -= carry_im << base_log;

Expand Down
8 changes: 3 additions & 5 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,

// This loop distribution seems to benefit the global mem reads
for (int i = start_i; i < end_i; i++) {
Torus a_i = round_to_closest_multiple(block_lwe_array_in[i], base_log,
level_count);
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
Torus state =
init_decomposer_state(block_lwe_array_in[i], base_log, level_count);

for (int j = 0; j < level_count; j++) {
auto ksk_block =
Expand Down Expand Up @@ -201,9 +200,8 @@ __device__ void packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext(
// Iterate through all lwe elements
for (int i = 0; i < lwe_dimension_in; i++) {
// Round and prepare decomposition
Torus a_i = round_to_closest_multiple(lwe_in[i], base_log, level_count);
Torus state = init_decomposer_state(lwe_in[i], base_log, level_count);

Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
Torus mod_b_mask = (1ll << base_log) - 1ll;

// block of key for current lwe coefficient (cur_input_lwe[i])
Expand Down
28 changes: 20 additions & 8 deletions backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ __host__ __device__ __forceinline__ constexpr double get_two_pow_torus_bits() {
return (sizeof(T) == 4) ? 4294967296.0 : 18446744073709551616.0;
}

template <typename T>
__host__ __device__ __forceinline__ constexpr T scalar_max() {
return std::numeric_limits<T>::max();
}

template <typename T>
__device__ inline void typecast_double_to_torus(double x, T &r) {
r = T(x);
Expand Down Expand Up @@ -60,14 +65,21 @@ __device__ inline void typecast_torus_to_double<uint64_t>(uint64_t x,
}

template <typename T>
__device__ inline T round_to_closest_multiple(T x, uint32_t base_log,
uint32_t level_count) {
const T non_rep_bit_count = sizeof(T) * 8 - level_count * base_log;
const T shift = non_rep_bit_count - 1;
T res = x >> shift;
res += 1;
res &= (T)(-2);
return res << shift;
__device__ inline T init_decomposer_state(T input, uint32_t base_log,
uint32_t level_count) {
const T rep_bit_count = level_count * base_log;
const T non_rep_bit_count = sizeof(T) * 8 - rep_bit_count;
T res = input >> (non_rep_bit_count - 1);
T rounding_bit = res & (T)(1);
res++;
res >>= 1;
T torus_max = scalar_max<T>();
T mod_mask = torus_max >> non_rep_bit_count;
res &= mod_mask;
T shifted_random = rounding_bit << (rep_bit_count - 1);
T need_balance =
(((res - (T)(1)) | shifted_random) & res) >> (rep_bit_count - 1);
return res - (need_balance << rep_bit_count);
}

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ __global__ void device_programmable_bootstrap_amortized(

// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
init_decomposer_state_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator_rotated, base_log, level_count, glwe_dimension + 1);

// Initialize the polynomial multiplication via FFT arrays
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ __global__ void device_programmable_bootstrap_cg(

// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
init_decomposer_state_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator_rotated, base_log, level_count);

synchronize_threads_in_block();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ __global__ void __launch_bounds__(params::degree / params::opt)
for (int i = 0; (i + lwe_offset) < lwe_dimension && i < lwe_chunk_size; i++) {
// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
init_decomposer_state_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator_rotated, base_log, level_count);

// Decompose the accumulator_rotated. Each block gets one level of the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ __global__ void __launch_bounds__(params::degree / params::opt)

// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
init_decomposer_state_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator, base_log, level_count);

synchronize_threads_in_block();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ __global__ void __launch_bounds__(params::degree / params::opt)

// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
init_decomposer_state_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator, base_log, level_count);

// Decompose the accumulator. Each block gets one level of the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ __global__ void device_programmable_bootstrap_tbc(

// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
init_decomposer_state_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator_rotated, base_log, level_count);

synchronize_threads_in_block();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ __global__ void __launch_bounds__(params::degree / params::opt)
for (int i = 0; (i + lwe_offset) < lwe_dimension && i < lwe_chunk_size; i++) {
// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
init_decomposer_state_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator_rotated, base_log, level_count);

// Decompose the accumulator. Each block gets one level of the
Expand Down
19 changes: 7 additions & 12 deletions backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,17 @@ __device__ void multiply_by_monomial_negacyclic_and_sub_polynomial(
* By default, it works on a single polynomial.
*/
template <typename T, int elems_per_thread, int block_size>
__device__ void round_to_closest_multiple_inplace(T *rotated_acc, int base_log,
int level_count,
uint32_t num_poly = 1) {
__device__ void init_decomposer_state_inplace(T *rotated_acc, int base_log,
int level_count,
uint32_t num_poly = 1) {
constexpr int degree = block_size * elems_per_thread;
for (int z = 0; z < num_poly; z++) {
T *rotated_acc_slice = (T *)rotated_acc + (ptrdiff_t)(z * degree);
int tid = threadIdx.x;
T *rotated_acc_slice = &rotated_acc[z * degree];
uint32_t tid = threadIdx.x;
for (int i = 0; i < elems_per_thread; i++) {
T x_acc = rotated_acc_slice[tid];
T shift = sizeof(T) * 8 - level_count * base_log;
T mask = 1ll << (shift - 1);
T b_acc = (x_acc & mask) >> (shift - 1);
T res_acc = x_acc >> shift;
res_acc += b_acc;
res_acc <<= shift;
rotated_acc_slice[tid] = res_acc;
rotated_acc_slice[tid] =
init_decomposer_state(x_acc, base_log, level_count);
tid = tid + block_size;
}
}
Expand Down

0 comments on commit d280403

Please sign in to comment.