diff --git a/convert.py b/convert.py index fd220eb3..e07bf16d 100755 --- a/convert.py +++ b/convert.py @@ -3,6 +3,7 @@ import argparse import concurrent.futures +import dataclasses import enum import faulthandler import functools @@ -138,6 +139,28 @@ def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType: # hparams loading # +@dataclass +class PredictorParams: + sparse_threshold: float | None = None + + @staticmethod + def loadPredictorJson(model: LazyModel, config_path: Path) -> PredictorParams: + config = json.load(open(config_path)) + return PredictorParams( + sparse_threshold = config.get("sparse_threshold"), + ) + + @staticmethod + def load(model_plus: ModelPlus) -> PredictorParams: + config_path = model_plus.paths[0].parent / "config.json" + + if config_path.exists(): + params = PredictorParams.loadPredictorJson(model_plus.model, config_path) + else: + params = PredictorParams() + + return params + @dataclass class Params: n_vocab: int @@ -160,6 +183,9 @@ class Params: # path to the directory containing the model files path_model: Path | None = None + # MLP predictor parameters + predictor_params: PredictorParams = dataclasses.field(default_factory=PredictorParams) + @staticmethod def guessed(model: LazyModel) -> Params: # try transformer naming first @@ -843,6 +869,9 @@ def add_meta_arch(self, params: Params) -> None: if params.ftype is not None: self.gguf.add_file_type(params.ftype) + if params.predictor_params.sparse_threshold is not None: + self.gguf.add_sparse_threshold(params.predictor_params.sparse_threshold) + def add_meta_vocab(self, vocab: Vocab) -> None: tokens = [] scores = [] @@ -1181,10 +1210,13 @@ def main(args_in: list[str] | None = None) -> None: if not args.vocab_only: model_plus = load_some_model(args.model) + params = Params.load(model_plus) mlp_predictor_plus = load_mlp_model(args.mlp_model) + params.predictor_params = PredictorParams.load(mlp_predictor_plus) model_plus = merge_multifile_models([model_plus, mlp_predictor_plus]) else: model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None) + params = Params.load(model_plus) if args.dump: do_dump_model(model_plus) @@ -1193,7 +1225,6 @@ def main(args_in: list[str] | None = None) -> None: if args.bigendian: endianess = gguf.GGUFEndian.BIG - params = Params.load(model_plus) if params.n_ctx == -1: if args.ctx is None: raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n" diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 8ef1d013..4ce2b912 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -108,6 +108,8 @@ // max batch size to use MMQ kernels when tensor cores are available #define MMQ_MAX_BATCH_SIZE 32 +__constant__ float dev_sparse_threshold; + #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 @@ -4483,7 +4485,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse(const void * __restrict__ // printf("row in gpu %d cols %d, value %d %d %d\n", id, ncols, *d, *(d+1), *(d+4095)); // } // int id = row; - if (idx[id] < 0.0f) { + if (idx[id] < dev_sparse_threshold) { return; } @@ -4552,12 +4554,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr return; } int id = lst[row]; - // int id = row; - // if (idx[id] < 0.0f) { - // return; - // } const int bid = blockIdx.y; - // if (bid == 0) global_lock = 0; extern __shared__ float shared_dst[]; // TODO:dynamic @@ -4578,7 +4575,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr // __syncthreads(); for (int col_id = 0; col_id < src1_ncols; col_id++) { __syncthreads(); - if (loop_idx[id] < 0.0f) { + if (loop_idx[id] < dev_sparse_threshold) { loop_dst += ncols; loop_idx += src1_ne0; loop_y += src1_ne0; @@ -4640,7 +4637,7 @@ static __global__ void dequantize_axpy_sparse(const void * __restrict__ vx, cons return; } int id = lst[row]; - if (idx[id] < 0.0f) { + if (idx[id] < dev_sparse_threshold) { return; } @@ -4689,8 +4686,7 @@ static __global__ void dequantize_mul_mat_vec_sparse(const void * __restrict__ v return; } int id = lst[row]; - // int id = row; - if (idx[id] < 0.0f) { + if (idx[id] < dev_sparse_threshold) { return; } @@ -4782,7 +4778,7 @@ static __global__ void dequantize_mul_mat_batch_sparse(const void * __restrict__ { __syncthreads(); tmp = 0.0f; - if (loop_idx[id] < 0.0f) + if (loop_idx[id] < dev_sparse_threshold) { loop_dst += dst_ne0; loop_idx += dst_ne0; @@ -9618,3 +9614,6 @@ ggml_backend_t ggml_backend_cuda_init() { return cuda_backend; } +void ggml_cuda_set_device_constants(float sparse_pred_threshold) { + CUDA_CHECK(cudaMemcpyToSymbol(dev_sparse_threshold, &sparse_pred_threshold, sizeof(float))); +} diff --git a/ggml-cuda.h b/ggml-cuda.h index 6ea73cd0..5d3a35a9 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -53,6 +53,8 @@ GGML_API int ggml_cuda_get_device_count(void); GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size); GGML_API size_t ggml_cuda_get_free_memory(int device); +GGML_API void ggml_cuda_set_device_constants(float sparse_pred_threshold); + // backend API GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use diff --git a/ggml.c b/ggml.c index 262d86e7..207bc585 100644 --- a/ggml.c +++ b/ggml.c @@ -14059,6 +14059,8 @@ static void ggml_compute_forward_mul_mat_sparse( enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; + const float threshold = sparse_pred_threshold; + GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne1 == ne11); GGML_ASSERT(ne2 == ne12); @@ -14262,7 +14264,7 @@ static void ggml_compute_forward_mul_mat_sparse( float *dst_col = (float *)((char *)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); // if (ffdata[ir0] <= 0.0f) { - if (gid[ir0] == 1 || ffdata[ir0] < -0.0f) { + if (gid[ir0] == 1 || ffdata[ir0] < threshold) { dst_col[ir0] = 0; continue; } @@ -14413,11 +14415,6 @@ static void ggml_compute_forward_mul_mat_axpy_dense( const int ir0 = atomic_fetch_add(params->aic, dr); for (int64_t ir1 = ir0; ir1 < ir0+dr; ir1++) { if (ir1 >= nr) break; - // if (gid[ir1] == 1) - // continue; - // if (idx[ir1] < 0.0f) - // continue; - // ggml_axpy_normal_f16(ne00, src0_row+nb01*ir1, vy, vy, wdata[ir1]); ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, wdata[ir1]); } if (ir0 + dr >= nr) @@ -14482,6 +14479,8 @@ static void ggml_compute_forward_mul_mat_axpy( enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; + const float threshold = sparse_pred_threshold; + // GGML_ASSERT(ne0 == ne01); // GGML_ASSERT(ne1 == ne11); // GGML_ASSERT(ne2 == ne12); @@ -14569,7 +14568,7 @@ static void ggml_compute_forward_mul_mat_axpy( if (gid[ir1] == 1) { continue; } - if (idx[ir1] < -0.0f) + if (idx[ir1] < threshold) continue; // ggml_axpy_normal_f16(ne00, src0_row+nb01*ir1, vy, vy, wdata[ir1]); ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, src1_ptr[ir1]); @@ -14632,6 +14631,8 @@ static void ggml_compute_forward_mul_mat_axpy_q4_0( enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; + const float threshold = sparse_pred_threshold; + // GGML_ASSERT(ne0 == ne01); // GGML_ASSERT(ne1 == ne11); // GGML_ASSERT(ne2 == ne12); @@ -14713,7 +14714,7 @@ static void ggml_compute_forward_mul_mat_axpy_q4_0( break; if (gid[ir1] == 1) continue; - if (idx[ir1] < 0.0f) + if (idx[ir1] < threshold) continue; int bid = ir1 / QK8_0; int qsid = ir1 % QK8_0; diff --git a/ggml.h b/ggml.h index b430c05f..94aa5c9d 100644 --- a/ggml.h +++ b/ggml.h @@ -2196,6 +2196,12 @@ extern "C" { GGML_API int ggml_cpu_has_ssse3 (void); GGML_API int ggml_cpu_has_vsx (void); + // + // global variables + // + // TODO: these should be moved to the context + extern float sparse_pred_threshold; + // // Internal types and functions exposed for tests and benchmarks // diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index cb31e527..6e90d34f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -70,6 +70,9 @@ class Tokenizer: ADD_EOS = "tokenizer.ggml.add_eos_token" HF_JSON = "tokenizer.huggingface.json" RWKV = "tokenizer.rwkv.world" + + class PowerInfer: + SPARSE_THRESHOLD = "powerinfer.sparse_threshold" # diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index c3b8c588..0483d7ba 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -399,6 +399,9 @@ def add_add_bos_token(self, value: bool) -> None: def add_add_eos_token(self, value: bool) -> None: self.add_bool(Keys.Tokenizer.ADD_EOS, value) + def add_sparse_threshold(self, value: float) -> None: + self.add_float32(Keys.PowerInfer.SPARSE_THRESHOLD, value) + def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: pack_prefix = '' if not skip_pack_prefix: diff --git a/llama.cpp b/llama.cpp index 5814a1ea..b4d63bae 100644 --- a/llama.cpp +++ b/llama.cpp @@ -93,6 +93,13 @@ #define LLAMA_MAX_NODES 4096 +// +// global variables +// + +// sparsity threshold for sparse matrix multiplication prediction +float sparse_pred_threshold = 0.; + // // logging // @@ -257,6 +264,8 @@ enum llm_kv { LLM_KV_TOKENIZER_PAD_ID, LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, + + LLM_KV_SPARSE_THRESHOLD, }; static std::map LLM_KV_NAMES = { @@ -305,6 +314,8 @@ static std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + + { LLM_KV_SPARSE_THRESHOLD, "powerinfer.sparse_threshold" }, }; struct LLM_KV { @@ -1150,6 +1161,9 @@ struct llama_hparams { float f_clamp_kqv; float f_max_alibi_bias; + + // sparse predictor threshold if sparse inference is enabled + float sparse_pred_threshold = atof(getenv("LLAMA_SPARSE_PRED_THRESHOLD") ?: "0.0"); bool operator!=(const llama_hparams & other) const { if (this->vocab_only != other.vocab_only) return true; @@ -2220,6 +2234,11 @@ static void llm_load_hparams( // gpt-j n_rot = rotary_dim } + if (gguf_get_sparse_deriv(ctx)) { + // read sparse threshold override if sparse deriv is enabled + GGUF_GET_KEY(ctx, hparams.sparse_pred_threshold, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_SPARSE_THRESHOLD)); + } + // arch-specific KVs switch (model.arch) { case LLM_ARCH_LLAMA: @@ -2607,6 +2626,9 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } + + // sparse inference + LLAMA_LOG_INFO("%s: sparse_pred_threshold = %.2f\n", __func__, hparams.sparse_pred_threshold); } @@ -2808,7 +2830,7 @@ struct llama_augmentation_model_loader { return NULL; } // allocate and copy selected weights to gpu - #ifdef GGML_USE_CUBLAS +#ifdef GGML_USE_CUBLAS int64_t row_len = src->ne[0]; int64_t gpu_rows = gpu_bucket->ne[0]; if (gpu_rows == 0) @@ -2841,10 +2863,9 @@ struct llama_augmentation_model_loader { ggml_set_no_alloc(aux_ctx, false); return gpu_dst; - #else - printf("As you do not support CUDA. Split to GPU is not allowed.\n"); +#else return NULL; - #endif +#endif } void slice_ffn_mat_to_gpu(llama_layer & layer) { @@ -2882,22 +2903,11 @@ struct llama_augmentation_model_loader { const int64_t t_start_aug_us = ggml_time_us(); std::vector work_buffer; - // transpose ffn_down to use axpy - // ggml_cgraph * tmp_transpose_gf = ggml_new_graph(aux_ctx); - // for (llama_layer &model_layer : model -> layers) { - // // gpu_w2 transpose load - // ggml_tensor * ffn_down_t = ggml_cont(aux_ctx, ggml_transpose(aux_ctx, model_layer.ffn_down)); - // ggml_build_forward_expand(tmp_transpose_gf, ffn_down_t); - // model_layer.ffn_down_t = ffn_down_t; - // LLAMA_LOG_INFO("."); - // } - // ggml_graph_compute_helper(work_buffer, tmp_transpose_gf, 2); - // for (llama_layer &model_layer : model -> layers) { - // model_layer.ffn_down_t->op = GGML_OP_NONE; - // model_layer.ffn_down_t->src[0] = NULL; - // model_layer.ffn_down_t->src[1] = NULL; - // model_layer.ffn_down_t->src[2] = NULL; - // } + // Set sparsity threshold via global virables + sparse_pred_threshold = model->hparams.sparse_pred_threshold; +#if defined (GGML_USE_CUBLAS) + ggml_cuda_set_device_constants(model->hparams.sparse_pred_threshold); +#endif // load gpu_idx and slice mat to gpu for (llama_layer &model_layer : model -> layers) { diff --git a/requirements.txt b/requirements.txt index 81c909d0..2b737d98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ numpy==1.24.4 sentencepiece==0.1.98 -gguf>=0.1.0 +-e ./gguf-py