From fd08d0f2d4ecf1bfa6ce262d0f414f376cb59f73 Mon Sep 17 00:00:00 2001 From: wejoncy <247153481@qq.com> Date: Wed, 25 Sep 2024 14:37:15 +0800 Subject: [PATCH] Bias (#19) Co-authored-by: Yang Wang --- csrc/common.h | 4 ++-- csrc/dequant_impl_packed.cu | 42 ++++++++++++++++++++----------------- csrc/ops.cc | 8 ++++--- csrc/utils.cuh | 8 +++++++ format.sh | 5 +++-- vptq/layers/vqlinear.py | 1 + 6 files changed, 42 insertions(+), 26 deletions(-) diff --git a/csrc/common.h b/csrc/common.h index 9725b023..b56fc14a 100644 --- a/csrc/common.h +++ b/csrc/common.h @@ -5,8 +5,8 @@ #include #include -#define gpuErrchk(ret) \ - { gpuAssert((ret), __FILE__, __LINE__); } +#define gpuErrchk(ret) gpuAssert((ret), __FILE__, __LINE__); + inline void gpuAssert(cudaError_t code, const char* file, int line) { if (code != cudaSuccess) { fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); diff --git a/csrc/dequant_impl_packed.cu b/csrc/dequant_impl_packed.cu index dea1e825..99dc573d 100644 --- a/csrc/dequant_impl_packed.cu +++ b/csrc/dequant_impl_packed.cu @@ -20,14 +20,12 @@ struct C10ToNvType { }; template -__global__ void WqA16WithOutliers_PackIndice(scalar_t* out, const scalar_t* input_data, const int32_t* q_indice, - const uint16_t* q_indice_outliers, const scalar_t* __restrict__ centroids, - const scalar_t* __restrict__ residual_centroids, - const scalar_t* outliers_centroids, const uint16_t* invert_perm, - const scalar_t* weight_scale, const scalar_t* weight_bias, - int out_features, int in_features, int outliers_infeatures, - const int index_stride_0, const int index_stride_1, - const int centroids_stride_0, const int group_nums) { +__global__ void WqA16WithOutliers_PackIndice( + scalar_t* out, const scalar_t* input_data, const int32_t* q_indice, const uint16_t* q_indice_outliers, + const scalar_t* __restrict__ centroids, const scalar_t* __restrict__ residual_centroids, + const scalar_t* outliers_centroids, const uint16_t* invert_perm, const scalar_t* weight_scale, + const scalar_t* weight_bias, const scalar_t* bias, int out_features, int in_features, int outliers_infeatures, + const int index_stride_0, const int index_stride_1, const int centroids_stride_0, const int group_nums) { int bidx = blockIdx.x; // out_features//base_groupsize int bidy = blockIdx.y; // batch int bidz = blockIdx.z; // segment in_features @@ -169,9 +167,12 @@ __global__ void WqA16WithOutliers_PackIndice(scalar_t* out, const scalar_t* inpu if constexpr (Do_Reduce > 0) { out += (in_y * GROUPSIZE) * gridDim.z + bidz; + bias += (bias == nullptr ? 0 : (in_y * GROUPSIZE) + bidz); } else { out += in_y * GROUPSIZE; + bias += (bias == nullptr ? 0 : GROUPSIZE); } + __syncthreads(); if (landid < cuda::kBlockSize / 32) { #pragma unroll @@ -180,9 +181,10 @@ __global__ void WqA16WithOutliers_PackIndice(scalar_t* out, const scalar_t* inpu reduce_out = cuda::warpReduceSum(reduce_out); if (landid == 0 && (in_y * GROUPSIZE + wid) < out_features) { if constexpr (Do_Reduce) { - out[(wid)*gridDim.z] = cuda::ConvertFromFloat(reduce_out); + out[(wid)*gridDim.z] = + cuda::ConvertFromFloat(reduce_out) + ((bidz == 0 && bias != 0) ? bias[wid] : scalar_t(0)); } else { - out[wid] = cuda::ConvertFromFloat(reduce_out); + out[wid] = cuda::ConvertFromFloat(reduce_out) + ((bias != 0) ? bias[wid] : scalar_t(0)); } } } @@ -439,7 +441,8 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel( const c10::optional& residual_centroids, //[num_c, c_size, vec_len] const c10::optional& outliers_indices, //[num_cen, c_size, ol_in_f] const c10::optional& outliers_centroids, //[num_c, c_size, out_vec_len] - const c10::optional& perm, const torch::Tensor& weight_scale, const torch::Tensor& weight_bias) { + const c10::optional& perm, const torch::Tensor& weight_scale, const torch::Tensor& weight_bias, + const c10::optional& bias) { const int base_groupsize = centroids.size(-1); int index_bits = log2(centroids.size(1)); int res_index_bits = residual_centroids.has_value() ? log2(residual_centroids.value().size(1)) : 0; @@ -464,14 +467,15 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel( const uint16_t* outliers_indices_ptr = (const uint16_t*)(outliers_indices.has_value() ? outliers_indices.value().data_ptr() : nullptr); const uint16_t* perm_ptr = perm.has_value() ? (const uint16_t*)(perm.value().data_ptr()) : nullptr; -#define CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \ - WqA16WithOutliers_PackIndice \ - <<>>( \ - out_buf.data_ptr(), input.data_ptr(), q_indice.data_ptr(), \ - outliers_indices_ptr, centroids.data_ptr(), \ - residual_centroids.has_value() ? residual_centroids.value().data_ptr() : nullptr, \ - outliers_centroids.has_value() ? outliers_centroids.value().data_ptr() : nullptr, perm_ptr, \ - weight_scale.data_ptr(), weight_bias.data_ptr(), out_features, in_features, \ + const c10::Half* bias_ptr = bias.has_value() ? (bias.value().data_ptr()) : nullptr; +#define CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \ + WqA16WithOutliers_PackIndice \ + <<>>( \ + out_buf.data_ptr(), input.data_ptr(), q_indice.data_ptr(), \ + outliers_indices_ptr, centroids.data_ptr(), \ + residual_centroids.has_value() ? residual_centroids.value().data_ptr() : nullptr, \ + outliers_centroids.has_value() ? outliers_centroids.value().data_ptr() : nullptr, perm_ptr, \ + weight_scale.data_ptr(), weight_bias.data_ptr(), bias_ptr, out_features, in_features, \ outliers_indices_size_n1, q_indice.stride(0), q_indice.stride(1), centroids.stride(0), q_indice.size(0)); #define CallWqA16kernel_dtype(out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \ if (input.dtype() == at::ScalarType::Half) { \ diff --git a/csrc/ops.cc b/csrc/ops.cc index b6a5e3dc..94da426b 100644 --- a/csrc/ops.cc +++ b/csrc/ops.cc @@ -30,7 +30,8 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel( const c10::optional& residual_centroids, //[num_c, c_size, vec_len] const c10::optional& outliers_indices, //[num_cen, c_size, ol_in_f] const c10::optional& outliers_centroids, //[num_c, c_size, out_vec_len] - const c10::optional& perm, const torch::Tensor& weight_scale, const torch::Tensor& weight_bias); + const c10::optional& perm, const torch::Tensor& weight_scale, const torch::Tensor& weight_bias, + const c10::optional& bias); torch::Tensor dequant(const torch::Tensor& q_indice, const torch::Tensor& centroids, const c10::optional& q_indice_residual, @@ -85,7 +86,8 @@ torch::Tensor wqA16Gemm(const torch::Tensor& input, const torch::Tensor& q_indic const c10::optional& q_indice_outliers, const c10::optional& outliers_centroids, const c10::optional& invperm, const torch::Tensor& weight_scale, - const torch::Tensor& weight_bias, int groupsize, int in_features, int out_features) { + const torch::Tensor& weight_bias, const c10::optional& bias, int groupsize, + int in_features, int out_features) { CHECK_INPUT(q_indice); CHECK_INPUT(input); if (q_indice_residual.has_value()) { @@ -113,7 +115,7 @@ torch::Tensor wqA16Gemm(const torch::Tensor& input, const torch::Tensor& q_indic output = lauch_gemv_outliers_cuda_packkernel(out_features, input, q_indice, centroids, q_indice_residual, residual_centroids, q_indice_outliers, outliers_centroids, invperm, - weight_scale, weight_bias); + weight_scale, weight_bias, bias); gpuErrchk(cudaPeekAtLastError()); diff --git a/csrc/utils.cuh b/csrc/utils.cuh index e7b05f98..f66154be 100644 --- a/csrc/utils.cuh +++ b/csrc/utils.cuh @@ -28,6 +28,14 @@ T __device__ __forceinline__ ConvertFromFloat(float v) { return __float2half(v); } +template +T __device__ __forceinline__ ConvertToFloat(float v) { + if constexpr (std::is_same_v) { + return __bfloat162float(v); + } + return __half2float(v); +} + template __device__ __forceinline__ float warpReduceSum(float sum) { if constexpr (WarpSize >= 32) sum += __shfl_down_sync(0xffffffff, sum, 16); // 0-16, 1-17, 2-18, etc. diff --git a/format.sh b/format.sh index 3ce1ecc5..9822b677 100644 --- a/format.sh +++ b/format.sh @@ -1,4 +1,5 @@ yapf --recursive . --style='{based_on_style: google, column_limit: 120, indent_width: 4}' -i + isort . -find csrc/ \( -name '*.h' -o -name '*.cc' -o -name '*.cu' -o -name '*.cuh' \) -print \ -| xargs clang-format -i \ No newline at end of file + +find csrc/ \( -name '*.h' -o -name '*.cc' -o -name '*.cu' -o -name '*.cuh' \) -print | xargs clang-format -i diff --git a/vptq/layers/vqlinear.py b/vptq/layers/vqlinear.py index 572b3b6f..59f42ae3 100644 --- a/vptq/layers/vqlinear.py +++ b/vptq/layers/vqlinear.py @@ -417,6 +417,7 @@ def fast_gemv(self, x): self.perm, self.weight_scale, self.weight_bias, + self.bias, self.vector_len, self.in_features, self.out_features,