Skip to content


support bf16 (#51)
Browse files Browse the repository at this point in the history
Gracefully support bf16 for all nvidia-GPUs.
For thoce old GPU like V100/T4, we just do the simulation and the
computation actually happens in float32.
This would be slower than half then.
  • Loading branch information
wejoncy authored Oct 8, 2024
1 parent 9d81547 commit 2d6d252
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 80 deletions.
124 changes: 51 additions & 73 deletions csrc/
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@ struct C10ToNvType<c10::Half> {
typedef __half type;

template <>
struct C10ToNvType<float> {
typedef float type;

template <typename scalar_t, int IDXBITS, int ResidualBits, int GROUPSIZE, int OL_GroupSize, int Do_Reduce>
__global__ void WqA16WithOutliers_PackIndice(
scalar_t* out, const scalar_t* input_data, const int32_t* q_indice, const uint16_t* q_indice_outliers,
const scalar_t* __restrict__ centroids, const scalar_t* __restrict__ residual_centroids,
const scalar_t* outliers_centroids, const uint16_t* invert_perm, const scalar_t* weight_scale,
const scalar_t* weight_bias, const scalar_t* bias, int out_features, int in_features, int outliers_infeatures,
const int index_stride_0, const int index_stride_1, const int centroids_stride_0, const int group_nums) {
static_assert((GROUPSIZE & 1) == 0, "GROUPSIZE must be even ");
int bidx = blockIdx.x; // out_features//base_groupsize
int bidy = blockIdx.y; // batch
int bidz = blockIdx.z; // segment in_features
Expand All @@ -34,10 +40,6 @@ __global__ void WqA16WithOutliers_PackIndice(
tidx += bidz * cuda::kBlockSize * Do_Reduce;
int in_y = bidx;
__shared__ scalar_t shared_memory[1]; // 3xin_features, dynamic
scalar_t* shared_input = shared_memory; // in_features, dynamic
// scalar_t* shared_w_scales = shared_memory+in_features;// in_features, dynamic
scalar_t* shared_w_bias = shared_memory + in_features; // in_features, dynamic
__shared__ float shared_output[GROUPSIZE][cuda::kBlockSize / 32 + 1];
scalar_t tmp_output[GROUPSIZE];
#pragma unroll
Expand All @@ -46,14 +48,6 @@ __global__ void WqA16WithOutliers_PackIndice(
input_data = input_data + in_features * bidy;
out = out + out_features * bidy * gridDim.z;
if constexpr (Do_Reduce == 0) {
for (int i = tidx; i < in_features; i += cuda::kBlockSize) {
int w_col = invert_perm ? invert_perm[i] : i;
shared_input[i] = input_data[w_col] * weight_scale[w_col];
shared_w_bias[i] = input_data[w_col] * weight_bias[w_col];
if (tidx >= in_features) {
Expand All @@ -64,10 +58,10 @@ __global__ void WqA16WithOutliers_PackIndice(
// const scalar_t scale = shared_w_scales[col];
const int w_col = Do_Reduce ? (invert_perm ? invert_perm[col] : col) : 0;
const scalar_t input_col_v = input_data[w_col];
const scalar_t bias = Do_Reduce ? input_col_v * weight_bias[w_col] : shared_w_bias[col];
scalar_t input_v = Do_Reduce ? input_col_v * weight_scale[w_col] : shared_input[col];
VecType input_v2 = VecType(input_v, input_v);
VecType bias2 = VecType(bias, bias);
const scalar_t bias = input_col_v * weight_bias[w_col];
scalar_t input_v = input_col_v * weight_scale[w_col];
VecType input_v2 = VecType{input_v, input_v};
VecType bias2 = VecType{bias, bias};

int32_t mapped_index_x = col;
if (mapped_index_x < outliers_infeatures) {
Expand All @@ -84,14 +78,13 @@ __global__ void WqA16WithOutliers_PackIndice(
scalar_t* tmp_output_off_p = tmp_output + gi;
scalar_t scalar_weight[OL_GroupSize];
if (out_y < out_features) {
(const uint32_t*)outliers_centroids_start);
cuda::ldg_vec_x<OL_GroupSize>((scalar_weight), (const uint32_t*)outliers_centroids_start);
VecType* weight_h2 = (VecType*)scalar_weight;
VecType* tmp_output_off_h2 = (VecType*)tmp_output_off_p;
tmp_output_off_h2[0] = __hfma2(weight_h2[0], input_v2, tmp_output_off_h2[0]);
tmp_output_off_h2[1] = __hfma2(weight_h2[1], input_v2, tmp_output_off_h2[1]);
tmp_output_off_h2[0] = __hadd2(tmp_output_off_h2[0], bias2);
tmp_output_off_h2[1] = __hadd2(tmp_output_off_h2[1], bias2);
tmp_output_off_h2[0] = FMA2(weight_h2[0], input_v2, tmp_output_off_h2[0]);
tmp_output_off_h2[1] = FMA2(weight_h2[1], input_v2, tmp_output_off_h2[1]);
tmp_output_off_h2[0] = ADD2(tmp_output_off_h2[0], bias2);
tmp_output_off_h2[1] = ADD2(tmp_output_off_h2[1], bias2);
} else {
Expand All @@ -113,21 +106,21 @@ __global__ void WqA16WithOutliers_PackIndice(
const uint32_t base_ind = merged_ind & ((1 << IDXBITS) - 1);

const scalar_t* centroids_start = (centroids_cb) + base_ind * GROUPSIZE;
cuda::ldg_vec_x<GROUPSIZE>(reinterpret_cast<uint32_t*>(base), (const uint32_t*)(centroids_start));
cuda::ldg_vec_x<GROUPSIZE>((base), (const uint32_t*)(centroids_start));

VecType* hres_ptr = nullptr;
if constexpr (ResidualBits > 0) {
scalar_t residual[GROUPSIZE];
const uint32_t res_ind = (merged_ind >> IDXBITS) & ((1 << ResidualBits) - 1);
const scalar_t* residual_centroids_start = (residual_centroids_cb) + res_ind * GROUPSIZE;
cuda::ldg_vec_x<GROUPSIZE>(reinterpret_cast<uint32_t*>(residual), (const uint32_t*)(residual_centroids_start));
cuda::ldg_vec_x<GROUPSIZE>((residual), (const uint32_t*)(residual_centroids_start));

VecType hres[GROUPSIZE / 2];
hres_ptr = hres;
#pragma unroll
for (int i = 0; i < GROUPSIZE / 2; i++) {
hres[i] = __hadd2(*(((VecType*)base) + i), *(((VecType*)residual) + i));
// hres[i] = __hfma2(hres[i], scale2, bias2);
hres[i] = ADD2(*(((VecType*)base) + i), *(((VecType*)residual) + i));
// hres[i] = FMA2(hres[i], scale2, bias2);
} else {
hres_ptr = (VecType*)base;
Expand All @@ -141,8 +134,8 @@ __global__ void WqA16WithOutliers_PackIndice(
VecType* h2_tmp_output = (VecType*)tmp_output;
#pragma unroll
for (int gi = 0; gi < GROUPSIZE / 2; gi++) {
h2_tmp_output[gi] = __hfma2(hres_ptr[gi], input_v2, h2_tmp_output[gi]);
h2_tmp_output[gi] = __hadd2(h2_tmp_output[gi], bias2);
h2_tmp_output[gi] = FMA2(hres_ptr[gi], input_v2, h2_tmp_output[gi]);
h2_tmp_output[gi] = ADD2(h2_tmp_output[gi], bias2);
Expand Down Expand Up @@ -246,26 +239,26 @@ __global__ void DequantizeWithOutliers_PackIndice(scalar_t* out, const int32_t*
const uint16_t base_ind = merged_ind & ((1 << IDXBITS) - 1);
VecType base[GROUPSIZE / 2];
const scalar_t* centroids_start = centroids + base_ind * GROUPSIZE;
cuda::ldg_vec_x<GROUPSIZE>((uint32_t*)(base), (const uint32_t*)(centroids_start));
cuda::ldg_vec_x<GROUPSIZE>((base), (const uint32_t*)(centroids_start));

if constexpr (ResidualBits > 0) {
VecType residual[GROUPSIZE / 2];
merged_ind >>= IDXBITS;
const uint16_t res_ind = merged_ind & ((1 << ResidualBits) - 1);
const scalar_t* residual_centroids_start = residual_centroids + res_ind * GROUPSIZE;
cuda::ldg_vec_x<GROUPSIZE>((uint32_t*)(residual), (const uint32_t*)(residual_centroids_start));
cuda::ldg_vec_x<GROUPSIZE>((residual), (const uint32_t*)(residual_centroids_start));
#pragma unroll
for (int i = 0; i < GROUPSIZE / 2; i++) {
base[i] = __hadd2(*(((VecType*)base) + i), *(((VecType*)residual) + i));
base[i] = ADD2(*(((VecType*)base) + i), *(((VecType*)residual) + i));

VecType hres[GROUPSIZE / 2];
VecType scale2 = VecType(scale, scale);
VecType bias2 = VecType(bias, bias);
VecType scale2 = VecType{scale, scale};
VecType bias2 = VecType{bias, bias};
#pragma unroll
for (int i = 0; i < GROUPSIZE / 2; i++) {
hres[i] = __hfma2(base[i], scale2, bias2);
hres[i] = FMA2(base[i], scale2, bias2);
scalar_t* res = (scalar_t*)hres;
const int group_step = in_y * GROUPSIZE;
Expand Down Expand Up @@ -298,8 +291,8 @@ torch::Tensor lauch_deqantize_outliers_cuda_packkernel(
OptionalCUDAGuard cudaguard(q_indice.device().index());
int base_groupsize = centroids.size(-1); // how many elements in a vector
int res_groupsize = residual_centroids.has_value() ? residual_centroids.value().size(-1) : 0;
// TORCH_CHECK((res_groupsize===base_groupsize||res_groupsize==0), "res_groupsize===base_groupsize is false, must be
// true");
TORCH_CHECK(((res_groupsize == base_groupsize) || (res_groupsize == 0)),
"res_groupsize==base_groupsize is false, must be true");
int index_bits = log2(centroids.size(1)); // how many bits to index quantization vector
int res_index_bits = residual_centroids.has_value() ? log2(residual_centroids.value().size(1)) : 0;
auto out_size = outf_x_inf;
Expand Down Expand Up @@ -337,26 +330,18 @@ torch::Tensor lauch_deqantize_outliers_cuda_packkernel(
outliers_indices_size_n1, outliers_centroids_size_n1, q_indice.stride(0), q_indice.stride(1), \
centroids.stride(0), q_indice.size(0)); \
#if __CUDA_ARCH__ < 800
#define callDequantWithOutliers_dtype(IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits) \
if (centroids.dtype() == at::ScalarType::Half) { \
using scalar_t = c10::Half; \
callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \
} else { \
TORCH_CHECK(false, "un-supported dtype: bfloat16"); \

#define callDequantWithOutliers_dtype(IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits) \
if (centroids.dtype() == at::ScalarType::Half) { \
using scalar_t = c10::Half; \
callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \
} else { \
using scalar_t = c10::BFloat16; \
callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \

#define callDequantWithOutliers_dtype(IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits) \
if (centroids.dtype() == at::ScalarType::Half) { \
using scalar_t = c10::Half; \
callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \
} else if (centroids.dtype() == at::ScalarType::Float) { \
using scalar_t = float; \
callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \
} else { \
using scalar_t = c10::BFloat16; \
callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \

#define callDequantWithOutliers_bits(BASEGROUP, OUT_OUF_INF, ResidualBits) \
switch (index_bits) { \
Expand Down Expand Up @@ -516,24 +501,17 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel(
centroids.stride(0), q_indice.size(0)); \

#if __CUDA_ARCH__ < 800
#define CallWqA16kernel_dtype(out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \
if (input.dtype() == at::ScalarType::Half) { \
using scalar_t = c10::Half; \
CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \
} else { \
TORCH_CHECK(false, "un-supported dtype: bfloat16"); \
#define CallWqA16kernel_dtype(out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \
if (input.dtype() == at::ScalarType::Half) { \
using scalar_t = c10::Half; \
CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \
} else { \
using scalar_t = c10::BFloat16; \
CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \
#define CallWqA16kernel_dtype(out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \
if (input.dtype() == at::ScalarType::Half) { \
using scalar_t = c10::Half; \
CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \
} else if (input.dtype() == at::ScalarType::Float) { \
using scalar_t = float; \
CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \
} else { \
using scalar_t = c10::BFloat16; \
CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \

#define CallWqA16kernel_bits(out_buf, BASEGROUP, Do_Reduce, ResidualBits) \
switch (index_bits) { \
Expand Down
79 changes: 74 additions & 5 deletions csrc/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,18 @@ struct TypeVec2<__nv_bfloat16> {
typedef __nv_bfloat162 type;

template <>
struct TypeVec2<float> {
typedef float2 type;

template <typename T>
T __device__ __forceinline__ ConvertFromFloat(float v, T vv) {
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return vv = __float2bfloat16(v);
} else if constexpr (std::is_same<T, float>::value) {
return vv = v;
} else {
static_assert(std::is_same<T, __half>::value);
return vv = __float2half(v);
Expand All @@ -34,6 +42,8 @@ template <typename T>
float __device__ __forceinline__ ConvertToFloat(T v) {
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return __bfloat162float(v);
} else if constexpr (std::is_same<T, float>::value) {
return v;
} else {
static_assert(std::is_same<T, __half>::value);
return __half2float(v);
Expand All @@ -50,8 +60,12 @@ __device__ __forceinline__ float warpReduceSum(float sum) {
return sum;

template <int GROUPSIZE>
__device__ __forceinline__ void ldg_vec_x(uint32_t* __restrict__ dst_u32, const uint32_t* __restrict__ src_u32) {
template <int GROUPSIZE, typename T>
__device__ __forceinline__ void ldg_vec_x(T* __restrict__ dst_t32, const uint32_t* __restrict__ src_u32) {
uint32_t* dst_u32 = (uint32_t*)dst_t32;
if constexpr (std::is_same<T, float>::value || std::is_same<T, float2>::value) {
return ldg_vec_x<GROUPSIZE * 2>(dst_u32, src_u32);
int2* dst = (int2*)dst_u32;
const int2* src = (const int2*)src_u32;
if constexpr (GROUPSIZE == 2) {
Expand Down Expand Up @@ -88,9 +102,9 @@ __device__ __forceinline__ void ldg_vec_x(uint32_t* __restrict__ dst_u32, const
} else if constexpr (GROUPSIZE == 12) {
if (uint64_t(src) % 16) {
dst[0] = __ldg(src);
int4 b = __ldg((int4*)(src + 1));
dst[1] = *((int2*)&b);
dst[2] = *((int2*)&b + 1);
int4 b = __ldg((const int4*)(src + 1));
dst[1] = *((const int2*)&b);
dst[2] = *((const int2*)&b + 1);
} else {
*(int4*)dst = __ldg((int4*)(src));
dst[2] = __ldg((src + 2));
Expand All @@ -110,6 +124,25 @@ __device__ __forceinline__ void ldg_vec_x(uint32_t* __restrict__ dst_u32, const
// : "=r"(dec[4]), "=r"(dec[5])
// : "l"((const void*)src)
// );
} else if constexpr (GROUPSIZE == 24) {
*((int4*)(dst)) = __ldg((const int4*)(src));
*(((int4*)(dst)) + 1) = __ldg(((const int4*)(src)) + 1);
*(((int4*)(dst)) + 2) = __ldg(((const int4*)(src)) + 2);
} else if constexpr (GROUPSIZE == 32) {
asm volatile(" {%0, %1, %2, %3}, [%4];"
: "=r"(dst_u32[0]), "=r"(dst_u32[1]), "=r"(dst_u32[2]), "=r"(dst_u32[3])
: "l"((const void*)src_u32));
asm volatile(" {%0, %1, %2, %3}, [%4];"
: "=r"(dst_u32[4]), "=r"(dst_u32[5]), "=r"(dst_u32[6]), "=r"(dst_u32[7])
: "l"((const void*)(src_u32 + 4)));
asm volatile(" {%0, %1, %2, %3}, [%4];"
: "=r"(dst_u32[8]), "=r"(dst_u32[9]), "=r"(dst_u32[10]), "=r"(dst_u32[11])
: "l"((const void*)(src_u32 + 8)));
asm volatile(" {%0, %1, %2, %3}, [%4];"
: "=r"(dst_u32[12]), "=r"(dst_u32[13]), "=r"(dst_u32[14]), "=r"(dst_u32[15])
: "l"((const void*)(src_u32 + 12)));
} else {

Expand Down Expand Up @@ -144,3 +177,39 @@ __forceinline__ T ceil_div(T a, T b) {

} // namespace cuda

template <typename T>
T __device__ __forceinline__ FMA2(T a, T b, T c) {
if constexpr (std::is_same<T, __nv_bfloat162>::value) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float x = __bfloat162float(a.x) * __bfloat162float(b.x) + __bfloat162float(c.x);
float y = __bfloat162float(a.y) * __bfloat162float(b.y) + __bfloat162float(c.y);
return __nv_bfloat162{__float2bfloat16(x), __float2bfloat16(y)};
return __hfma2(a, b, c);
} else if constexpr (std::is_same<T, float2>::value) {
return float2{a.x * b.x + c.x, a.y * b.y + c.y};
} else {
return __hfma2(a, b, c);
__builtin_unreachable(); // Suppress missing return statement warning

template <typename T>
T __device__ __forceinline__ ADD2(T a, T b) {
if constexpr (std::is_same<T, __nv_bfloat162>::value) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float x = __bfloat162float(a.x) + __bfloat162float(b.x);
float y = __bfloat162float(a.y) + __bfloat162float(b.y);
return __nv_bfloat162{__float2bfloat16(x), __float2bfloat16(y)};
return __hadd2(a, b);
} else if constexpr (std::is_same<T, float2>::value) {
return float2{a.x + b.x, a.y + b.y};
} else {
return __hadd2(a, b);
__builtin_unreachable(); // Suppress missing return statement warning
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
2 changes: 2 additions & 0 deletions
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def build_cuda_extensions():
delimiter = ' ' if ';' not in TORCH_CUDA_ARCH_LIST else ' '
compute_capabilities = [int(10 * float(arch)) for arch in TORCH_CUDA_ARCH_LIST if '+' not in arch]

print(" build for compute capabilities: ==============", compute_capabilities)
for cap in compute_capabilities:
arch_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
extra_compile_args = {
Expand Down
2 changes: 1 addition & 1 deletion vptq/
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

__version__ = "0.0.2"
__version__ = "0.0.2.post1"
from .layers import AutoModelForCausalLM as AutoModelForCausalLM

0 comments on commit 2d6d252

Please sign in to comment.