Skip to content

Commit

Permalink
Bias (#19)
Browse files Browse the repository at this point in the history
Co-authored-by: Yang Wang <wyatuestc@gmail.com>
  • Loading branch information
wejoncy and YangWang92 authored Sep 25, 2024
1 parent 5b52fb5 commit fd08d0f
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 26 deletions.
4 changes: 2 additions & 2 deletions csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#define gpuErrchk(ret) \
{ gpuAssert((ret), __FILE__, __LINE__); }
#define gpuErrchk(ret) gpuAssert((ret), __FILE__, __LINE__);

inline void gpuAssert(cudaError_t code, const char* file, int line) {
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
Expand Down
42 changes: 23 additions & 19 deletions csrc/dequant_impl_packed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@ struct C10ToNvType<c10::Half> {
};

template <typename scalar_t, int IDXBITS, int ResidualBits, int GROUPSIZE, int OL_GroupSize, int Do_Reduce>
__global__ void WqA16WithOutliers_PackIndice(scalar_t* out, const scalar_t* input_data, const int32_t* q_indice,
const uint16_t* q_indice_outliers, const scalar_t* __restrict__ centroids,
const scalar_t* __restrict__ residual_centroids,
const scalar_t* outliers_centroids, const uint16_t* invert_perm,
const scalar_t* weight_scale, const scalar_t* weight_bias,
int out_features, int in_features, int outliers_infeatures,
const int index_stride_0, const int index_stride_1,
const int centroids_stride_0, const int group_nums) {
__global__ void WqA16WithOutliers_PackIndice(
scalar_t* out, const scalar_t* input_data, const int32_t* q_indice, const uint16_t* q_indice_outliers,
const scalar_t* __restrict__ centroids, const scalar_t* __restrict__ residual_centroids,
const scalar_t* outliers_centroids, const uint16_t* invert_perm, const scalar_t* weight_scale,
const scalar_t* weight_bias, const scalar_t* bias, int out_features, int in_features, int outliers_infeatures,
const int index_stride_0, const int index_stride_1, const int centroids_stride_0, const int group_nums) {
int bidx = blockIdx.x; // out_features//base_groupsize
int bidy = blockIdx.y; // batch
int bidz = blockIdx.z; // segment in_features
Expand Down Expand Up @@ -169,9 +167,12 @@ __global__ void WqA16WithOutliers_PackIndice(scalar_t* out, const scalar_t* inpu

if constexpr (Do_Reduce > 0) {
out += (in_y * GROUPSIZE) * gridDim.z + bidz;
bias += (bias == nullptr ? 0 : (in_y * GROUPSIZE) + bidz);
} else {
out += in_y * GROUPSIZE;
bias += (bias == nullptr ? 0 : GROUPSIZE);
}

__syncthreads();
if (landid < cuda::kBlockSize / 32) {
#pragma unroll
Expand All @@ -180,9 +181,10 @@ __global__ void WqA16WithOutliers_PackIndice(scalar_t* out, const scalar_t* inpu
reduce_out = cuda::warpReduceSum<cuda::kBlockSize / 32>(reduce_out);
if (landid == 0 && (in_y * GROUPSIZE + wid) < out_features) {
if constexpr (Do_Reduce) {
out[(wid)*gridDim.z] = cuda::ConvertFromFloat<scalar_t>(reduce_out);
out[(wid)*gridDim.z] =
cuda::ConvertFromFloat<scalar_t>(reduce_out) + ((bidz == 0 && bias != 0) ? bias[wid] : scalar_t(0));
} else {
out[wid] = cuda::ConvertFromFloat<scalar_t>(reduce_out);
out[wid] = cuda::ConvertFromFloat<scalar_t>(reduce_out) + ((bias != 0) ? bias[wid] : scalar_t(0));
}
}
}
Expand Down Expand Up @@ -439,7 +441,8 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel(
const c10::optional<torch::Tensor>& residual_centroids, //[num_c, c_size, vec_len]
const c10::optional<torch::Tensor>& outliers_indices, //[num_cen, c_size, ol_in_f]
const c10::optional<torch::Tensor>& outliers_centroids, //[num_c, c_size, out_vec_len]
const c10::optional<torch::Tensor>& perm, const torch::Tensor& weight_scale, const torch::Tensor& weight_bias) {
const c10::optional<torch::Tensor>& perm, const torch::Tensor& weight_scale, const torch::Tensor& weight_bias,
const c10::optional<torch::Tensor>& bias) {
const int base_groupsize = centroids.size(-1);
int index_bits = log2(centroids.size(1));
int res_index_bits = residual_centroids.has_value() ? log2(residual_centroids.value().size(1)) : 0;
Expand All @@ -464,14 +467,15 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel(
const uint16_t* outliers_indices_ptr =
(const uint16_t*)(outliers_indices.has_value() ? outliers_indices.value().data_ptr<int16_t>() : nullptr);
const uint16_t* perm_ptr = perm.has_value() ? (const uint16_t*)(perm.value().data_ptr<int16_t>()) : nullptr;
#define CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \
WqA16WithOutliers_PackIndice<scalar_t, IDXBITS, ResidualBits, BASEGROUP, 4, Do_Reduce> \
<<<blocks, threads, shared_memory_size, stream>>>( \
out_buf.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), q_indice.data_ptr<int32_t>(), \
outliers_indices_ptr, centroids.data_ptr<scalar_t>(), \
residual_centroids.has_value() ? residual_centroids.value().data_ptr<scalar_t>() : nullptr, \
outliers_centroids.has_value() ? outliers_centroids.value().data_ptr<scalar_t>() : nullptr, perm_ptr, \
weight_scale.data_ptr<scalar_t>(), weight_bias.data_ptr<scalar_t>(), out_features, in_features, \
const c10::Half* bias_ptr = bias.has_value() ? (bias.value().data_ptr<c10::Half>()) : nullptr;
#define CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \
WqA16WithOutliers_PackIndice<scalar_t, IDXBITS, ResidualBits, BASEGROUP, 4, Do_Reduce> \
<<<blocks, threads, shared_memory_size, stream>>>( \
out_buf.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), q_indice.data_ptr<int32_t>(), \
outliers_indices_ptr, centroids.data_ptr<scalar_t>(), \
residual_centroids.has_value() ? residual_centroids.value().data_ptr<scalar_t>() : nullptr, \
outliers_centroids.has_value() ? outliers_centroids.value().data_ptr<scalar_t>() : nullptr, perm_ptr, \
weight_scale.data_ptr<scalar_t>(), weight_bias.data_ptr<scalar_t>(), bias_ptr, out_features, in_features, \
outliers_indices_size_n1, q_indice.stride(0), q_indice.stride(1), centroids.stride(0), q_indice.size(0));
#define CallWqA16kernel_dtype(out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \
if (input.dtype() == at::ScalarType::Half) { \
Expand Down
8 changes: 5 additions & 3 deletions csrc/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel(
const c10::optional<torch::Tensor>& residual_centroids, //[num_c, c_size, vec_len]
const c10::optional<torch::Tensor>& outliers_indices, //[num_cen, c_size, ol_in_f]
const c10::optional<torch::Tensor>& outliers_centroids, //[num_c, c_size, out_vec_len]
const c10::optional<torch::Tensor>& perm, const torch::Tensor& weight_scale, const torch::Tensor& weight_bias);
const c10::optional<torch::Tensor>& perm, const torch::Tensor& weight_scale, const torch::Tensor& weight_bias,
const c10::optional<torch::Tensor>& bias);

torch::Tensor dequant(const torch::Tensor& q_indice, const torch::Tensor& centroids,
const c10::optional<torch::Tensor>& q_indice_residual,
Expand Down Expand Up @@ -85,7 +86,8 @@ torch::Tensor wqA16Gemm(const torch::Tensor& input, const torch::Tensor& q_indic
const c10::optional<torch::Tensor>& q_indice_outliers,
const c10::optional<torch::Tensor>& outliers_centroids,
const c10::optional<torch::Tensor>& invperm, const torch::Tensor& weight_scale,
const torch::Tensor& weight_bias, int groupsize, int in_features, int out_features) {
const torch::Tensor& weight_bias, const c10::optional<torch::Tensor>& bias, int groupsize,
int in_features, int out_features) {
CHECK_INPUT(q_indice);
CHECK_INPUT(input);
if (q_indice_residual.has_value()) {
Expand Down Expand Up @@ -113,7 +115,7 @@ torch::Tensor wqA16Gemm(const torch::Tensor& input, const torch::Tensor& q_indic

output = lauch_gemv_outliers_cuda_packkernel(out_features, input, q_indice, centroids, q_indice_residual,
residual_centroids, q_indice_outliers, outliers_centroids, invperm,
weight_scale, weight_bias);
weight_scale, weight_bias, bias);

gpuErrchk(cudaPeekAtLastError());

Expand Down
8 changes: 8 additions & 0 deletions csrc/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ T __device__ __forceinline__ ConvertFromFloat(float v) {
return __float2half(v);
}

template <typename T>
T __device__ __forceinline__ ConvertToFloat(float v) {
if constexpr (std::is_same_v<T, __nv_bfloat16>) {
return __bfloat162float(v);
}
return __half2float(v);
}

template <unsigned int WarpSize>
__device__ __forceinline__ float warpReduceSum(float sum) {
if constexpr (WarpSize >= 32) sum += __shfl_down_sync(0xffffffff, sum, 16); // 0-16, 1-17, 2-18, etc.
Expand Down
5 changes: 3 additions & 2 deletions format.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
yapf --recursive . --style='{based_on_style: google, column_limit: 120, indent_width: 4}' -i

isort .
find csrc/ \( -name '*.h' -o -name '*.cc' -o -name '*.cu' -o -name '*.cuh' \) -print \
| xargs clang-format -i

find csrc/ \( -name '*.h' -o -name '*.cc' -o -name '*.cu' -o -name '*.cuh' \) -print | xargs clang-format -i
1 change: 1 addition & 0 deletions vptq/layers/vqlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def fast_gemv(self, x):
self.perm,
self.weight_scale,
self.weight_bias,
self.bias,
self.vector_len,
self.in_features,
self.out_features,
Expand Down

0 comments on commit fd08d0f

Please sign in to comment.