Skip to content

Commit

Permalink
memory and bf16 (#23)
Browse files Browse the repository at this point in the history
- shrink memory
- support bf16
  • Loading branch information
wejoncy authored Sep 26, 2024
1 parent 41db6e0 commit 03b4187
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 63 deletions.
122 changes: 72 additions & 50 deletions csrc/dequant_impl_packed.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <cuda_bf16.h>
#include <cmath>
#include <math_constants.h>
#include <ATen/cuda/CUDAContext.h>
Expand Down Expand Up @@ -35,15 +34,15 @@ __global__ void WqA16WithOutliers_PackIndice(
tidx += bidz * cuda::kBlockSize * Do_Reduce;
}
int in_y = bidx;
extern __shared__ scalar_t shared_memory[]; // 3xin_features, dynamic
scalar_t* shared_input = shared_memory; // in_features, dynamic
__shared__ scalar_t shared_memory[1]; // 3xin_features, dynamic
scalar_t* shared_input = shared_memory; // in_features, dynamic
// scalar_t* shared_w_scales = shared_memory+in_features;// in_features, dynamic
scalar_t* shared_w_bias = shared_memory + in_features; // in_features, dynamic
__shared__ float shared_output[GROUPSIZE][cuda::kBlockSize / 32 + 1];
scalar_t tmp_output[GROUPSIZE] = {0};
scalar_t tmp_output[GROUPSIZE];
#pragma unroll
for (int i = 0; i < GROUPSIZE; i++) {
tmp_output[i] = scalar_t(0);
tmp_output[i] = scalar_t(0.0f);
}
input_data = input_data + in_features * bidy;
out = out + out_features * bidy * gridDim.z;
Expand Down Expand Up @@ -154,11 +153,7 @@ __global__ void WqA16WithOutliers_PackIndice(
#pragma unroll
for (int gi = 0; gi < GROUPSIZE; gi++) {
float reduce_out = 0.f;
if constexpr (!std::is_same_v<scalar_t, c10::BFloat16>) {
reduce_out = __half2float(tmp_output[gi]);
} else {
reduce_out = __bfloat162float(tmp_output[gi]);
}
reduce_out = cuda::ConvertToFloat(tmp_output[gi]);
reduce_out = cuda::warpReduceSum<32>(reduce_out);
if (landid == 0) {
shared_output[gi][warpid] = reduce_out;
Expand All @@ -181,10 +176,11 @@ __global__ void WqA16WithOutliers_PackIndice(
reduce_out = cuda::warpReduceSum<cuda::kBlockSize / 32>(reduce_out);
if (landid == 0 && (in_y * GROUPSIZE + wid) < out_features) {
if constexpr (Do_Reduce) {
out[(wid)*gridDim.z] =
cuda::ConvertFromFloat<scalar_t>(reduce_out) + ((bidz == 0 && bias != 0) ? bias[wid] : scalar_t(0));
out[(wid)*gridDim.z] = cuda::ConvertFromFloat<scalar_t>(reduce_out, scalar_t(0.0f)) +
((bidz == 0 && bias != 0) ? bias[wid] : scalar_t(0.0f));
} else {
out[wid] = cuda::ConvertFromFloat<scalar_t>(reduce_out) + ((bias != 0) ? bias[wid] : scalar_t(0));
out[wid] =
cuda::ConvertFromFloat<scalar_t>(reduce_out, scalar_t(0.0f)) + ((bias != 0) ? bias[wid] : scalar_t(0.0f));
}
}
}
Expand All @@ -204,6 +200,7 @@ __global__ void DequantizeWithOutliers_PackIndice(scalar_t* out, const int32_t*
int tid = (bid * cuda::kBlockSize + threadIdx.x);
int in_x = tid % in_features;
int in_y = tid / in_features;
using VecType = typename cuda::TypeVec2<scalar_t>::type;

uint16_t mapped_index_x = invert_perm ? invert_perm[in_x] : in_x;
const scalar_t scale = weight_scale[in_x];
Expand Down Expand Up @@ -247,25 +244,25 @@ __global__ void DequantizeWithOutliers_PackIndice(scalar_t* out, const int32_t*
cuda::iterator_packed_tensor<IDXBITS + ResidualBits>((const uint32_t*)q_indice, mappped_inx_in_a_codebook);

const uint16_t base_ind = merged_ind & ((1 << IDXBITS) - 1);
__half2 base[GROUPSIZE / 2];
VecType base[GROUPSIZE / 2];
const scalar_t* centroids_start = centroids + base_ind * GROUPSIZE;
cuda::ldg_vec_x<GROUPSIZE>((uint32_t*)(base), (const uint32_t*)(centroids_start));

if constexpr (ResidualBits > 0) {
__half2 residual[GROUPSIZE / 2];
VecType residual[GROUPSIZE / 2];
merged_ind >>= IDXBITS;
const uint16_t res_ind = merged_ind & ((1 << ResidualBits) - 1);
const scalar_t* residual_centroids_start = residual_centroids + res_ind * GROUPSIZE;
cuda::ldg_vec_x<GROUPSIZE>((uint32_t*)(residual), (const uint32_t*)(residual_centroids_start));
#pragma unroll
for (int i = 0; i < GROUPSIZE / 2; i++) {
base[i] = __hadd2(*(((__half2*)base) + i), *(((__half2*)residual) + i));
base[i] = __hadd2(*(((VecType*)base) + i), *(((VecType*)residual) + i));
}
}

__half2 hres[GROUPSIZE / 2];
__half2 scale2 = __half2(scale, scale);
__half2 bias2 = __half2(bias, bias);
VecType hres[GROUPSIZE / 2];
VecType scale2 = VecType(scale, scale);
VecType bias2 = VecType(bias, bias);
#pragma unroll
for (int i = 0; i < GROUPSIZE / 2; i++) {
hres[i] = __hfma2(base[i], scale2, bias2);
Expand Down Expand Up @@ -317,46 +314,61 @@ torch::Tensor lauch_deqantize_outliers_cuda_packkernel(
}
int outliers_indices_size_n1 = outliers_indices.has_value() ? outliers_indices.value().size(-1) : 0;
int outliers_centroids_size_n1 = outliers_centroids.has_value() ? outliers_centroids.value().size(-1) : 1;
using scalar_t = at::Half;

const uint16_t* perm_ptr = perm.has_value() ? (const uint16_t*)(perm.value().data_ptr<int16_t>()) : nullptr;
const int16_t* outliers_indices_ptr =
outliers_indices.has_value() ? outliers_indices.value().data_ptr<int16_t>() : nullptr;
const scalar_t* residual_centroids_ptr =
residual_centroids.has_value() ? residual_centroids.value().data_ptr<scalar_t>() : nullptr;
const scalar_t* outliers_centroids_ptr =
outliers_centroids.has_value() ? outliers_centroids.value().data_ptr<scalar_t>() : nullptr;
auto stream = at::cuda::getCurrentCUDAStream().stream();
#define callDequantWithOutliers(IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits) \
DequantizeWithOutliers_PackIndice<scalar_t, IDXBITS, ResidualBits, BASEGROUP, OUT_OUF_INF> \
<<<blocks, threads, 0, stream>>>(output.data_ptr<scalar_t>(), q_indice.data_ptr<int32_t>(), \
outliers_indices_ptr, centroids.data_ptr<scalar_t>(), residual_centroids_ptr, \
outliers_centroids_ptr, perm_ptr, weight_scale.data_ptr<scalar_t>(), \
weight_bias.data_ptr<scalar_t>(), out_size[0], out_size[1], \
outliers_indices_size_n1, outliers_centroids_size_n1, q_indice.stride(0), \
q_indice.stride(1), centroids.stride(0), q_indice.size(0));
#define callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits) \
{ \
using nv_type = typename C10ToNvType<scalar_t>::type; \
DequantizeWithOutliers_PackIndice<nv_type, IDXBITS, ResidualBits, BASEGROUP, OUT_OUF_INF> \
<<<blocks, threads, 0, stream>>>( \
reinterpret_cast<nv_type*>(output.data_ptr<scalar_t>()), q_indice.data_ptr<int32_t>(), \
outliers_indices_ptr, reinterpret_cast<const nv_type*>(centroids.data_ptr<scalar_t>()), \
residual_centroids.has_value() \
? reinterpret_cast<const nv_type*>(residual_centroids.value().data_ptr<scalar_t>()) \
: nullptr, \
outliers_centroids.has_value() \
? reinterpret_cast<const nv_type*>(outliers_centroids.value().data_ptr<scalar_t>()) \
: nullptr, \
perm_ptr, reinterpret_cast<const nv_type*>(weight_scale.data_ptr<scalar_t>()), \
reinterpret_cast<const nv_type*>(weight_bias.data_ptr<scalar_t>()), out_size[0], out_size[1], \
outliers_indices_size_n1, outliers_centroids_size_n1, q_indice.stride(0), q_indice.stride(1), \
centroids.stride(0), q_indice.size(0)); \
}

#define callDequantWithOutliers_dtype(IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits) \
if (centroids.dtype() == at::ScalarType::Half) { \
using scalar_t = c10::Half; \
callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \
} else { \
using scalar_t = c10::BFloat16; \
callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \
}

#define callDequantWithOutliers_bits(BASEGROUP, OUT_OUF_INF, ResidualBits) \
switch (index_bits) { \
case 16: \
callDequantWithOutliers(16, BASEGROUP, OUT_OUF_INF, ResidualBits); \
callDequantWithOutliers_dtype(16, BASEGROUP, OUT_OUF_INF, ResidualBits); \
break; \
case 15: \
callDequantWithOutliers(15, BASEGROUP, OUT_OUF_INF, ResidualBits); \
callDequantWithOutliers_dtype(15, BASEGROUP, OUT_OUF_INF, ResidualBits); \
break; \
case 14: \
callDequantWithOutliers(14, BASEGROUP, OUT_OUF_INF, ResidualBits); \
callDequantWithOutliers_dtype(14, BASEGROUP, OUT_OUF_INF, ResidualBits); \
break; \
case 13: \
callDequantWithOutliers(13, BASEGROUP, OUT_OUF_INF, ResidualBits); \
callDequantWithOutliers_dtype(13, BASEGROUP, OUT_OUF_INF, ResidualBits); \
break; \
case 12: \
callDequantWithOutliers(12, BASEGROUP, OUT_OUF_INF, ResidualBits); \
callDequantWithOutliers_dtype(12, BASEGROUP, OUT_OUF_INF, ResidualBits); \
break; \
case 8: \
callDequantWithOutliers(8, BASEGROUP, OUT_OUF_INF, ResidualBits); \
callDequantWithOutliers_dtype(8, BASEGROUP, OUT_OUF_INF, ResidualBits); \
break; \
case 4: \
callDequantWithOutliers(4, BASEGROUP, OUT_OUF_INF, ResidualBits); \
callDequantWithOutliers_dtype(4, BASEGROUP, OUT_OUF_INF, ResidualBits); \
break; \
default: \
TORCH_CHECK(false, "unspportetd index_bits:" + std::to_string(index_bits)); \
Expand Down Expand Up @@ -469,22 +481,32 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel(
const uint16_t* outliers_indices_ptr =
(const uint16_t*)(outliers_indices.has_value() ? outliers_indices.value().data_ptr<int16_t>() : nullptr);
const uint16_t* perm_ptr = perm.has_value() ? (const uint16_t*)(perm.value().data_ptr<int16_t>()) : nullptr;
const c10::Half* bias_ptr = bias.has_value() ? (bias.value().data_ptr<c10::Half>()) : nullptr;
#define CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \
WqA16WithOutliers_PackIndice<scalar_t, IDXBITS, ResidualBits, BASEGROUP, 4, Do_Reduce> \
<<<blocks, threads, shared_memory_size, stream>>>( \
out_buf.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), q_indice.data_ptr<int32_t>(), \
outliers_indices_ptr, centroids.data_ptr<scalar_t>(), \
residual_centroids.has_value() ? residual_centroids.value().data_ptr<scalar_t>() : nullptr, \
outliers_centroids.has_value() ? outliers_centroids.value().data_ptr<scalar_t>() : nullptr, perm_ptr, \
weight_scale.data_ptr<scalar_t>(), weight_bias.data_ptr<scalar_t>(), bias_ptr, out_features, in_features, \
outliers_indices_size_n1, q_indice.stride(0), q_indice.stride(1), centroids.stride(0), q_indice.size(0));
#define CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \
{ \
using nv_type = typename C10ToNvType<scalar_t>::type; \
WqA16WithOutliers_PackIndice<nv_type, IDXBITS, ResidualBits, BASEGROUP, 4, Do_Reduce> \
<<<blocks, threads, shared_memory_size, stream>>>( \
reinterpret_cast<nv_type*>(out_buf.data_ptr<scalar_t>()), \
reinterpret_cast<const nv_type*>(input.data_ptr<scalar_t>()), q_indice.data_ptr<int32_t>(), \
outliers_indices_ptr, reinterpret_cast<const nv_type*>(centroids.data_ptr<scalar_t>()), \
residual_centroids.has_value() \
? reinterpret_cast<const nv_type*>(residual_centroids.value().data_ptr<scalar_t>()) \
: nullptr, \
outliers_centroids.has_value() \
? reinterpret_cast<const nv_type*>(outliers_centroids.value().data_ptr<scalar_t>()) \
: nullptr, \
perm_ptr, reinterpret_cast<const nv_type*>(weight_scale.data_ptr<scalar_t>()), \
reinterpret_cast<const nv_type*>(weight_bias.data_ptr<scalar_t>()), \
bias.has_value() ? reinterpret_cast<const nv_type*>(bias.value().data_ptr<scalar_t>()) : nullptr, \
out_features, in_features, outliers_indices_size_n1, q_indice.stride(0), q_indice.stride(1), \
centroids.stride(0), q_indice.size(0)); \
}
#define CallWqA16kernel_dtype(out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \
if (input.dtype() == at::ScalarType::Half) { \
using scalar_t = c10::Half; \
CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \
} else { \
using scalar_t = c10::Half; \
using scalar_t = c10::BFloat16; \
CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \
}
#define CallWqA16kernel_bits(out_buf, BASEGROUP, Do_Reduce, ResidualBits) \
Expand Down
20 changes: 12 additions & 8 deletions csrc/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,23 @@ struct TypeVec2<__nv_bfloat16> {
};

template <typename T>
T __device__ __forceinline__ ConvertFromFloat(float v) {
if constexpr (std::is_same_v<T, __nv_bfloat16>) {
return __float2bfloat16(v);
T __device__ __forceinline__ ConvertFromFloat(float v, T vv) {
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return vv = __float2bfloat16(v);
} else {
static_assert(std::is_same<T, __half>::value);
return vv = __float2half(v);
}
return __float2half(v);
}

template <typename T>
T __device__ __forceinline__ ConvertToFloat(float v) {
if constexpr (std::is_same_v<T, __nv_bfloat16>) {
float __device__ __forceinline__ ConvertToFloat(T v) {
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return __bfloat162float(v);
} else {
static_assert(std::is_same<T, __half>::value);
return __half2float(v);
}
return __half2float(v);
}

template <unsigned int WarpSize>
Expand Down Expand Up @@ -122,7 +126,7 @@ __device__ __forceinline__ uint32_t iterator_packed_tensor(const uint32_t* ptr,
int second = end_bits / 32;
start_bits = start_bits % 32;
end_bits = end_bits % 32;
uint32_t v = (ptr[first] >> (start_bits)) & ((1 << WBITS) - 1);
uint32_t v = (ptr[first] >> (start_bits)) & (uint32_t(1 << WBITS) - 1);
if (first == second || end_bits == 0) {
return v;
} else {
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_version():


def build_cuda_extensions():
compute_capabilities = [70, 75, 80, 86, 90]
compute_capabilities = [80, 86, 90]
arch_flags = []
for cap in compute_capabilities:
arch_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
Expand Down
3 changes: 2 additions & 1 deletion vptq/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,13 @@ def main():
args = get_valid_args(parser)
print(args)

#hf_args = {"dtype": torch.bfloat16}
hf_args = {}
token = os.getenv("HF_TOKEN", None)
if token is not None:
hf_args["token"] = token

model = VQAutoModelQuantization.from_pretrained(args.model, device_map="auto", **hf_args).half()
model = VQAutoModelQuantization.from_pretrained(args.model, device_map="auto", **hf_args)
tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer or args.model, **hf_args)

chat_loop(model, tokenizer, args)
Loading

0 comments on commit 03b4187

Please sign in to comment.