Skip to content

Commit

Permalink
chore(gpu): add back async entry points
Browse files Browse the repository at this point in the history
We need to have async entry points for asynchronous execution between
the CPU & GPU at the HL API level later. We can't remove it.
  • Loading branch information
agnesLeroy committed Jan 28, 2025
1 parent db4592d commit f47fe52
Show file tree
Hide file tree
Showing 24 changed files with 252 additions and 218 deletions.
9 changes: 3 additions & 6 deletions tfhe/src/integer/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ pub unsafe fn decompress_integer_radix_async<T: UnsignedInteger, B: Numeric>(
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization
/// is required
pub unsafe fn unchecked_add_integer_radix_assign(
pub unsafe fn unchecked_add_integer_radix_assign_async(
streams: &CudaStreams,
radix_lwe_left: &mut CudaRadixCiphertext,
radix_lwe_right: &CudaRadixCiphertext,
Expand Down Expand Up @@ -589,7 +589,6 @@ pub unsafe fn unchecked_add_integer_radix_assign(
&radix_lwe_right_data,
);
update_noise_degree(radix_lwe_left, &radix_lwe_left_data);
streams.synchronize();
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -2163,7 +2162,7 @@ pub unsafe fn unchecked_rotate_left_integer_radix_kb_assign_async<
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization
/// is required
pub unsafe fn unchecked_cmux_integer_radix_kb<T: UnsignedInteger, B: Numeric>(
pub unsafe fn unchecked_cmux_integer_radix_kb_async<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
radix_lwe_out: &mut CudaRadixCiphertext,
radix_lwe_condition: &CudaBooleanBlock,
Expand Down Expand Up @@ -2347,7 +2346,6 @@ pub unsafe fn unchecked_cmux_integer_radix_kb<T: UnsignedInteger, B: Numeric>(
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
streams.synchronize()
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -3302,7 +3300,7 @@ pub(crate) unsafe fn unchecked_unsigned_overflowing_sub_integer_radix_kb_assign_
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization
/// is required
pub unsafe fn unchecked_signed_abs_radix_kb_assign<T: UnsignedInteger, B: Numeric>(
pub unsafe fn unchecked_signed_abs_radix_kb_assign_async<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
ct: &mut CudaRadixCiphertext,
bootstrapping_key: &CudaVec<B>,
Expand Down Expand Up @@ -3382,7 +3380,6 @@ pub unsafe fn unchecked_signed_abs_radix_kb_assign<T: UnsignedInteger, B: Numeri
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
streams.synchronize()
}

#[allow(clippy::too_many_arguments)]
Expand Down
20 changes: 11 additions & 9 deletions tfhe/src/integer/gpu/server_key/radix/abs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::{unchecked_signed_abs_radix_kb_assign, PBSType};
use crate::integer::gpu::{unchecked_signed_abs_radix_kb_assign_async, PBSType};

impl CudaServerKey {
/// # Safety
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
/// synchronization is required
pub fn unchecked_abs_assign<T>(&self, ct: &mut T, streams: &CudaStreams)
pub unsafe fn unchecked_abs_assign_async<T>(&self, ct: &mut T, streams: &CudaStreams)
where
T: CudaIntegerRadixCiphertext,
{
Expand All @@ -18,7 +18,7 @@ impl CudaServerKey {
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
unchecked_signed_abs_radix_kb_assign(
unchecked_signed_abs_radix_kb_assign_async(
streams,
ct.as_mut(),
&d_bsk.d_vec,
Expand All @@ -43,7 +43,7 @@ impl CudaServerKey {
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
unchecked_signed_abs_radix_kb_assign(
unchecked_signed_abs_radix_kb_assign_async(
streams,
ct.as_mut(),
&d_multibit_bsk.d_vec,
Expand Down Expand Up @@ -74,10 +74,11 @@ impl CudaServerKey {
where
T: CudaIntegerRadixCiphertext,
{
let mut res = ct.duplicate(streams);
let mut res = unsafe { ct.duplicate_async(streams) };
if T::IS_SIGNED {
self.unchecked_abs_assign(&mut res, streams);
unsafe { self.unchecked_abs_assign_async(&mut res, streams) };
}
streams.synchronize();
res
}

Expand Down Expand Up @@ -131,13 +132,14 @@ impl CudaServerKey {
where
T: CudaIntegerRadixCiphertext,
{
let mut res = ct.duplicate(streams);
let mut res = unsafe { ct.duplicate_async(streams) };
if !ct.block_carries_are_empty() {
self.full_propagate_assign(&mut res, streams);
unsafe { self.full_propagate_assign_async(&mut res, streams) };
};
if T::IS_SIGNED {
self.unchecked_abs_assign(&mut res, streams);
unsafe { self.unchecked_abs_assign_async(&mut res, streams) };
}
streams.synchronize();
res
}
}
45 changes: 28 additions & 17 deletions tfhe/src/integer/gpu/server_key/radix/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::integer::gpu::ciphertext::{
};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::{
unchecked_add_integer_radix_assign,
unchecked_add_integer_radix_assign_async,
unchecked_partial_sum_ciphertexts_integer_radix_kb_assign_async, PBSType,
};
use crate::integer::server_key::radix_parallel::OutputFlag;
Expand Down Expand Up @@ -70,7 +70,7 @@ impl CudaServerKey {
ct_right: &T,
streams: &CudaStreams,
) -> T {
let mut result = ct_left.duplicate(streams);
let mut result = unsafe { ct_left.duplicate_async(streams) };
self.add_assign(&mut result, ct_right, streams);
result
}
Expand All @@ -94,18 +94,18 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign_async(ct_left, streams);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(ct_left, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
};
Expand Down Expand Up @@ -179,7 +179,7 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub fn unchecked_add_assign<T: CudaIntegerRadixCiphertext>(
pub unsafe fn unchecked_add_assign_async<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
Expand All @@ -204,10 +204,21 @@ impl CudaServerKey {
);

unsafe {
unchecked_add_integer_radix_assign(streams, ciphertext_left, ciphertext_right);
unchecked_add_integer_radix_assign_async(streams, ciphertext_left, ciphertext_right);
}
}

pub fn unchecked_add_assign<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
streams: &CudaStreams,
) {
unsafe {
self.unchecked_add_assign_async(ct_left, ct_right, streams);
}
streams.synchronize();
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
Expand Down Expand Up @@ -396,7 +407,7 @@ impl CudaServerKey {
.iter_mut()
.filter(|ct| !ct.block_carries_are_empty())
.for_each(|ct| {
self.full_propagate_assign(&mut *ct, streams);
self.full_propagate_assign_async(&mut *ct, streams);
});

Some(self.unchecked_sum_ciphertexts_async(&ciphertexts, streams))
Expand Down Expand Up @@ -456,14 +467,14 @@ impl CudaServerKey {
(true, false) => {
unsafe {
tmp_rhs = ct_right.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}
(ct_left, &tmp_rhs)
}
(false, true) => {
unsafe {
tmp_lhs = ct_left.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
}
(&tmp_lhs, ct_right)
}
Expand All @@ -472,8 +483,8 @@ impl CudaServerKey {
tmp_lhs = ct_left.duplicate_async(stream);
tmp_rhs = ct_right.duplicate_async(stream);

self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}

(&tmp_lhs, &tmp_rhs)
Expand Down Expand Up @@ -643,14 +654,14 @@ impl CudaServerKey {
(true, false) => {
unsafe {
tmp_rhs = ct_right.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}
(ct_left, &tmp_rhs)
}
(false, true) => {
unsafe {
tmp_lhs = ct_left.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
}
(&tmp_lhs, ct_right)
}
Expand All @@ -659,8 +670,8 @@ impl CudaServerKey {
tmp_lhs = ct_left.duplicate_async(stream);
tmp_rhs = ct_right.duplicate_async(stream);

self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}

(&tmp_lhs, &tmp_rhs)
Expand Down
26 changes: 13 additions & 13 deletions tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,18 +465,18 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign_async(ct_left, streams);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(ct_left, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
}
Expand Down Expand Up @@ -570,18 +570,18 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign_async(ct_left, streams);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(ct_left, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
};
Expand Down Expand Up @@ -675,18 +675,18 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign_async(ct_left, streams);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(ct_left, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
};
Expand Down Expand Up @@ -764,7 +764,7 @@ impl CudaServerKey {
streams: &CudaStreams,
) {
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
self.full_propagate_assign_async(ct, streams);
}

self.unchecked_bitnot_assign_async(ct, streams);
Expand Down
Loading

0 comments on commit f47fe52

Please sign in to comment.