Skip to content

Commit

Permalink
chore(gpu): pass over all cuda bind
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Sep 6, 2024
1 parent 019548d commit 1d549df
Show file tree
Hide file tree
Showing 4 changed files with 504 additions and 662 deletions.
10 changes: 5 additions & 5 deletions backends/tfhe-cuda-backend/cuda/include/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ void *cuda_malloc_async(uint64_t size, cudaStream_t stream, uint32_t gpu_index);

void cuda_check_valid_malloc(uint64_t size, uint32_t gpu_index);

bool cuda_check_support_cooperative_groups();

bool cuda_check_support_thread_block_clusters();

void cuda_memcpy_async_to_gpu(void *dest, void *src, uint64_t size,
cudaStream_t stream, uint32_t gpu_index);

Expand All @@ -62,9 +58,13 @@ void cuda_synchronize_device(uint32_t gpu_index);
void cuda_drop(void *ptr, uint32_t gpu_index);

void cuda_drop_async(void *ptr, cudaStream_t stream, uint32_t gpu_index);
}

int cuda_get_max_shared_memory(uint32_t gpu_index);
}

bool cuda_check_support_cooperative_groups();

bool cuda_check_support_thread_block_clusters();

template <typename Torus>
void cuda_set_value_async(cudaStream_t stream, uint32_t gpu_index,
Expand Down
2 changes: 1 addition & 1 deletion backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ extern std::mutex m;
extern bool p2p_enabled;

extern "C" {
int cuda_setup_multi_gpu();
int32_t cuda_setup_multi_gpu();
}

// Define a variant type that can be either a vector or a single pointer
Expand Down
4 changes: 2 additions & 2 deletions backends/tfhe-cuda-backend/cuda/src/utils/helper_multi_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
std::mutex m;
bool p2p_enabled = false;

int cuda_setup_multi_gpu() {
int32_t cuda_setup_multi_gpu() {
int num_gpus = cuda_get_number_of_gpus();
if (num_gpus == 0)
PANIC("GPU error: the number of GPUs should be > 0.")
Expand All @@ -32,7 +32,7 @@ int cuda_setup_multi_gpu() {
}
m.unlock();
}
return num_used_gpus;
return (int32_t)(num_used_gpus);
}

int get_active_gpu_count(int num_inputs, int gpu_count) {
Expand Down
Loading

0 comments on commit 1d549df

Please sign in to comment.