diff --git a/infer.py b/infer.py index a98b3a2..2df553c 100644 --- a/infer.py +++ b/infer.py @@ -41,6 +41,8 @@ def setup_parser(): parser.add_argument("--gen-length", type=int, default=32, help="Batch Size") parser.add_argument("--seed", type=int, default=1234, help="Seed") parser.add_argument("--use-optimized-code", action='store_true', default=False) + parser.add_argument("--warmup-iters", type=int, default=5) + parser.add_argument("--total-iters", type=int, default=10) return parser @@ -48,8 +50,12 @@ def load_prompts(tokenizer, batch_size, prompt_length): dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt") total_tokens = encodings.input_ids.shape[1] - start_index = min(random.randint(0, total_tokens), total_tokens - batch_size * prompt_length) - input_ids = encodings.input_ids[:, start_index : start_index + batch_size * prompt_length].reshape(batch_size, prompt_length) + input_ids = [] + for _ in range(batch_size): + start_index = min(random.randint(0, total_tokens), total_tokens - prompt_length) + tokens = encodings.input_ids[:, start_index : start_index + prompt_length].reshape(1, prompt_length) + input_ids.append(tokens) + input_ids = torch.cat(input_ids, dim=0) return input_ids if __name__ == "__main__": @@ -62,8 +68,8 @@ def load_prompts(tokenizer, batch_size, prompt_length): set_seed(args.seed) if args.method == "pca-topk": - args.top_k = args.prompt_length - args.top_r = 128 + args.top_k = int(0.25 * args.prompt_length) + args.top_r = 32 args.rotary_type = "postrotary" if args.use_optimized_code: @@ -84,20 +90,26 @@ def load_prompts(tokenizer, batch_size, prompt_length): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - generations = [] - input_ids = tokenized_prompts.cuda() - with torch.autocast(device_type='cuda', dtype=dtype): - outputs = model.generate(input_ids, do_sample=True, max_new_tokens=args.gen_length, num_beams=4) + + # warmup iters + for _ in range(args.warmup_iters): + with torch.autocast(device_type='cuda', dtype=dtype): + outputs = model.generate(input_ids, do_sample=True, max_new_tokens=args.gen_length) + # timed iters + start_event.record() + for _ in range(args.total_iters - args.warmup_iters): + with torch.autocast(device_type='cuda', dtype=dtype): + outputs = model.generate(input_ids, do_sample=True, max_new_tokens=args.gen_length) end_event.record() + + generated_tokens = outputs.numel() - input_ids.numel() total_generated_tokens += generated_tokens torch.cuda.synchronize() - total_time = start_event.elapsed_time(end_event) + total_time = start_event.elapsed_time(end_event) / (args.total_iters - args.warmup_iters) tput = total_generated_tokens * 1000 / total_time output_ids = outputs[:, args.prompt_length:] diff --git a/methods/pca_topk/cache_utils.py b/methods/pca_topk/cache_utils.py index 12d20bb..f6a9bad 100644 --- a/methods/pca_topk/cache_utils.py +++ b/methods/pca_topk/cache_utils.py @@ -5,8 +5,8 @@ import time import torch -import external.gather_matmul as G - +#import external.gather_matmul as G +import kernel.pca_topk as G topk_time = 0 @@ -16,6 +16,8 @@ # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = False +from timers import Timers +import json # Work In Progress class PcaTopKCache(Cache): # Not used anymore @@ -120,6 +122,10 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) - return max_length - new_seq_length return previous_seq_length + def reset(self): + self.key_cache: List[torch.Tensor] = [] # Stores the reduced keys for each layer + self.value_cache: List[torch.Tensor] = [] + def test_pcatopk_cache(): cache = PcaTopKCache(2, 4) @@ -139,57 +145,91 @@ def test_pcatopk_cache(): -def micro_benchmark_pca_topk(cache, prompt_keys, top_r, top_k, num_gen_steps=2000, use_optimised_gather=False): +def micro_benchmark_pca_topk(cache, prompt_keys, top_r, top_k, num_layers, timers, + num_gen_steps=2000, use_optimised_gather=False): import time torch.set_float32_matmul_precision("highest") - head_dim = prompt_keys.shape[-1] - bs = prompt_keys.shape[0] - num_heads = prompt_keys.shape[1] - - generative_query = torch.rand(bs, num_heads, 1, head_dim).to("cuda") - generative_key = torch.rand(bs, num_heads, 1, head_dim).to("cuda") + head_dim = prompt_keys[0].shape[-1] + bs = prompt_keys[0].shape[0] + num_heads = prompt_keys[0].shape[1] + dtype = prompt_keys[0].dtype + prompt_seq_length = prompt_keys[0].shape[2] - print ("Starting microbenchmark") matmul_time = 0 top_keys = torch.zeros(bs, num_heads, top_k, head_dim).to("cuda") top_vals = torch.zeros(bs, num_heads, top_k, head_dim).to("cuda") + pca_projection_mat = torch.randn(num_heads, head_dim, head_dim, dtype=dtype, device='cuda') + + assert use_optimised_gather if use_optimised_gather: + timers.start('total') for i in range(num_gen_steps): - keys, vals = cache.update(generative_key, generative_key, generative_query, 0, False) - torch.cuda.synchronize() - - start = time.time() - attn_weights = torch.matmul(generative_query[:,:,:,:top_r], keys.transpose(2, 3)[:,:,:top_r,:]) / math.sqrt(head_dim) - - # Get top-k keys and top-k values based on the attention scores - key_states_topk_indices = torch.topk(attn_weights, top_k, dim=-1).indices.to("cuda") - key_states_topk_indices,_ = torch.sort(key_states_topk_indices, dim=-1) - key_states_topk_indices= key_states_topk_indices.reshape(-1, key_states_topk_indices.shape[-1]) - - keys = keys.reshape(-1, keys.shape[-2] , keys.shape[-1]) - vals = vals.reshape(-1, vals.shape[-2] , vals.shape[-1]) - - attn_weights = G.gather_outer_bmv( - generative_query.reshape(-1, 1, head_dim), - keys.transpose(-1, -2), - key_states_topk_indices, - #.squeeze(0).squeeze(-1), - chunk=256 - #chunk=min(k2, 65536 // Q.shape[-1]), - ) / math.sqrt(head_dim) - attn_weights = torch.softmax(attn_weights, dim=-1) - - attn_output = G.gather_inner_matrix_only_bmv( - attn_weights, vals, key_states_topk_indices, chunk=64 - ) - - torch.cuda.synchronize() - end = time.time() - - if i > 5: - matmul_time += end - start + for layer in range(num_layers): + timers.start('qk-gen') + generative_query = torch.rand(bs, num_heads, 1, head_dim, device='cuda', dtype=dtype) + generative_key = torch.rand(bs, num_heads, 1, head_dim, device='cuda', dtype=dtype) + timers.stop('qk-gen') + + timers.start('project') + generative_key = generative_key.squeeze().transpose(0, 1).bmm(pca_projection_mat).unsqueeze(2) + generative_query = generative_query.squeeze().transpose(0, 1).bmm(pca_projection_mat).unsqueeze(2) + timers.stop('project') + + timers.start('cache-update') + keys, vals = cache.update(generative_key, generative_key, generative_query, layer, False) + timers.stop('cache-update') + + timers.start('qk-matmul-1') + #attn_weights = torch.matmul(generative_query[:,:,:,:top_r], keys.transpose(2, 3)[:,:,:top_r,:]) / math.sqrt(head_dim) + nh, bs, s, r = keys.shape + attn_weights = G.topr_bmv_optimized(A=generative_query.view(nh*bs, 1, r), B=keys.view(nh*bs, s, r).transpose(-1,-2), + r=top_r) + attn_weights = attn_weights.view(nh, bs, 1, s) + timers.stop('qk-matmul-1') + + # Get top-k keys and top-k values based on the attention scores + timers.start('top-k') + #key_states_topk_indices = torch.topk(attn_weights, top_k, dim=-1, sorted=False).indices.to("cuda") + #key_states_topk_indices,_ = torch.sort(key_states_topk_indices, dim=-1) + key_states_topk_indices = torch.argsort(attn_weights, dim=-1, descending=True)[:,:,:,:top_k] + timers.stop('top-k') + + + timers.start('reshape-0') + key_states_topk_indices= key_states_topk_indices.reshape(-1, key_states_topk_indices.shape[-1]) + timers.stop('reshape-0') + + timers.start('reshape-1') + keys = keys.view(-1, keys.shape[-2] , keys.shape[-1]) + vals = vals.view(-1, vals.shape[-2] , vals.shape[-1]) + timers.stop('reshape-1') + + timers.start('qk-matmul-2') + attn_weights = G.gather_outer_bmv_optimized( + generative_query.reshape(-1, 1, head_dim), + keys.transpose(-1, -2), + key_states_topk_indices, + #.squeeze(0).squeeze(-1), + #chunk=256 + #chunk=min(k2, 65536 // Q.shape[-1]), + ) / math.sqrt(head_dim) + timers.stop('qk-matmul-2') + + timers.start('softmax') + attn_weights = torch.softmax(attn_weights.float(), dim=-1).to(dtype) + timers.stop('softmax') + + timers.start('sv-matmul') + attn_output = G.gather_inner_matrix_only_bmv_optimized( + attn_weights, vals, key_states_topk_indices) + timers.stop('sv-matmul') + + timers.start('reshape-output') + attn_output = attn_output.view(num_heads, bs, 1, head_dim).transpose(0,1).transpose(1,2).contiguous() + timers.stop('reshape-output') + timers.stop('total') else: for i in range(num_gen_steps): keys, vals = cache.update(generative_key, generative_key, generative_query, 0, False) @@ -197,7 +237,6 @@ def micro_benchmark_pca_topk(cache, prompt_keys, top_r, top_k, num_gen_steps=200 start = time.time() attn_weights = torch.matmul(generative_query[:,:,:,:top_r], keys.transpose(2, 3)[:,:,:top_r,:]) / math.sqrt(head_dim) - # Get top-k keys and top-k values based on the attention scores key_states_topk_indices = torch.topk(attn_weights, top_k, dim=-1).indices.to("cuda") key_states_topk_indices,_ = torch.sort(key_states_topk_indices, dim=-1) @@ -212,70 +251,119 @@ def micro_benchmark_pca_topk(cache, prompt_keys, top_r, top_k, num_gen_steps=200 torch.cuda.synchronize() end = time.time() - if i > 5: - matmul_time += end - start - print (f"Matmul Time: {matmul_time}") -def micro_bench_actual_attention(cache, prompt_keys, num_gen_steps=2000): +def micro_bench_actual_attention(cache, prompt_keys, num_layers, timers, num_gen_steps=2000): import time torch.set_float32_matmul_precision("highest") - head_dim = prompt_keys.shape[-1] - bs = prompt_keys.shape[0] - num_heads = prompt_keys.shape[1] - - generative_query = torch.rand(bs, num_heads, 1, head_dim).to("cuda") - generative_key = torch.rand(bs, num_heads, 1, head_dim).to("cuda") + head_dim = prompt_keys[0].shape[-1] + bs = prompt_keys[0].shape[0] + num_heads = prompt_keys[0].shape[1] + dtype = prompt_keys[0].dtype - print ("Starting microbenchmark") matmul_time = 0 - for i in range(num_gen_steps): - keys, vals = cache.update(generative_key, generative_key, generative_query, 0, False) - torch.cuda.synchronize() - start = time.time() - attn_weights = torch.matmul(generative_query, keys.transpose(2, 3)) / math.sqrt(head_dim) - attn_weights = torch.softmax(attn_weights, dim=-1) - attn_output = torch.matmul(attn_weights, vals) - torch.cuda.synchronize() - end = time.time() + timers.start('total') + for i in range(num_gen_steps): + for layer in range(num_layers): + timers.start('qk-gen') + generative_query = torch.rand(bs, num_heads, 1, head_dim, dtype=dtype, device='cuda') + generative_key = torch.rand(bs, num_heads, 1, head_dim, dtype=dtype, device='cuda') + timers.stop('qk-gen') + + timers.start('cache-update') + keys, vals = cache.update(generative_key, generative_key, generative_query, layer, False) + timers.stop('cache-update') + + timers.start('qk-matmul-1') + attn_weights = torch.matmul(generative_query, keys.transpose(2, 3)) / math.sqrt(head_dim) + timers.stop('qk-matmul-1') + + timers.start('softmax') + attn_weights = torch.softmax(attn_weights.float(), dim=-1).to(dtype) + timers.stop('softmax') + + timers.start('sv-matmul') + attn_output = torch.matmul(attn_weights, vals) + timers.stop('sv-matmul') + + timers.start('reshape-output') + attn_output = attn_output.transpose(1, 2).contiguous() + timers.stop('reshape-output') + - if i > 5: - matmul_time += end - start - print (f"Matmul Time: {matmul_time}") + timers.stop('total') +@torch.no_grad() def benchmark_attention(batch_size=1, num_heads=32, num_gen_steps=128, prompt_length=3072, - topk=256): + topk=256, + num_layers=32, + dtype=torch.float16): head_dim=128 # Change this to change batch size, etc. - prompt_keys = torch.rand(batch_size, num_heads, prompt_length, head_dim).to("cuda") + prompt_keys = [torch.rand(batch_size, num_heads, prompt_length, head_dim, device='cuda', dtype=dtype) for _ in range(num_layers)] - print("PCA TOPK Unoptimized") - cache1 = PcaTopKCache() - cache1.update(prompt_keys, prompt_keys, prompt_keys, 0) - micro_benchmark_pca_topk(cache1, prompt_keys, 32, topk, num_gen_steps=num_gen_steps) - del cache1 + #print("PCA TOPK Unoptimized") + #cache1 = [PcaTopKCache() for _ in range(num_layers)] + #for i in range(num_layers): + # cache1[i].update(prompt_keys[i], prompt_keys[i], prompt_keys[i], 0) + #micro_benchmark_pca_topk(cache1, prompt_keys, 32, topk, num_gen_steps=num_gen_steps) + #del cache1 + print("PCA TOPK Optimized") - cache2 = PcaTopKCache() - cache2.update(prompt_keys, prompt_keys, prompt_keys, 0) - micro_benchmark_pca_topk(cache2, prompt_keys, 32, topk, num_gen_steps=num_gen_steps, use_optimised_gather=True) - del cache2 + for _ in range(10): + cache2 = PcaTopKCache() + for i in range(num_layers): + cache2.update(prompt_keys[i].transpose(0,1).contiguous(), + prompt_keys[i].transpose(0,1).contiguous(), + prompt_keys[i].transpose(0,1).contiguous(), i) + timers = Timers() + micro_benchmark_pca_topk(cache2, prompt_keys, 32, topk, + num_gen_steps=num_gen_steps, num_layers=num_layers, + use_optimised_gather=True, timers=timers) + del cache2 + times = timers.get_times() + print(times) + + print("Average time (minus cache updates) is - ") + print(times['total'] - times['cache-update'], " s") + print("==================================") + times_pca_topk = times print("Actual Attention") - cache3= PcaTopKCache() - cache3.update(prompt_keys, prompt_keys, prompt_keys, 0) - micro_bench_actual_attention(cache3, prompt_keys, num_gen_steps=num_gen_steps) - del cache3 + for _ in range(10): + cache3= PcaTopKCache() + for i in range(num_layers): + cache3.update(prompt_keys[i], prompt_keys[i], prompt_keys[i], i) + timers = Timers() + micro_bench_actual_attention(cache3, prompt_keys, num_layers=num_layers, + num_gen_steps=num_gen_steps, timers=timers) + del cache3 + times = timers.get_times() + print("Average time (minus cache updates) is - ") + print(times['total'] - times['cache-update'], " s") + print(times) + print("==================================") + times_vanilla = times + return times_pca_topk, times_vanilla if __name__ == "__main__": #test_pcatopk_cache() with torch.no_grad(): - benchmark_attention(prompt_length=4096, num_gen_steps=2000, batch_size=16, topk=1024) + prompt_length = 2000 + for num_gen_steps in [1000]: + print(f"prompt length = {prompt_length}, gen length = {num_gen_steps}, batch_size={16}, topk and top r are 25%") + times_pca_topk, times_vanilla = benchmark_attention(prompt_length=prompt_length, num_gen_steps=num_gen_steps, batch_size=16, topk=prompt_length // 4, num_layers=1) + with open(f"prompt_{prompt_length}_gen_{num_gen_steps}_pca_topk_opt_first_matmul.json", "w") as f: + json.dump(times_pca_topk, f, indent=2) + + with open(f"prompt_{prompt_length}_gen_{num_gen_steps}_vanilla.json", "w") as f: + json.dump(times_vanilla, f, indent=2) diff --git a/methods/pca_topk/kernel/__init__.py b/methods/pca_topk/kernel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/methods/pca_topk/kernel/benchmark.py b/methods/pca_topk/kernel/benchmark.py new file mode 100644 index 0000000..75d35dd --- /dev/null +++ b/methods/pca_topk/kernel/benchmark.py @@ -0,0 +1,88 @@ +import torch +import triton +import numpy as np +from pca_topk import gather_outer_bmv_optimized, gather_inner_matrix_only_bmv_optimized +from sparq import gather_outer_bmv, gather_inner_matrix_only_bmv + +B = 4 +NH = 32 +S = 800 +D = 128 +dtype = torch.float16 + + + +configs = [ + triton.testing.Benchmark( + x_names=["sparsity"], # Argument names to use as an x-axis for the plot + x_vals=[0.125, 0.25, 0.5, 0.75, 1.0], # Different possible values for `x_name` + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. + line_vals=["torch", "triton-optimized"], # Label name for the lines + line_names=["torch (full keys and values)", "Triton (Optimized)"], # Line styles + styles=[("black", "-"), ("blue", "-")], + ylabel="TFLOPS", # Label name for the y-axis + plot_name="matmul-performance-" + ("fp16 (time in ms)" ), # Name for the plot, used also as a file name for saving the plot. + args = {"B": B, "NH" : NH, "S": S, "D": D} + ) + ] + + +@triton.testing.perf_report(configs) +def benchmark_bmm1(sparsity, B, NH, S, D, provider): + q = torch.randn((B*NH, 1, D), device='cuda', dtype=dtype) + k = torch.randn((B*NH, S, D), device='cuda', dtype=dtype) + choice = np.concatenate([np.sort(np.random.choice(S, size=(int(S*sparsity)), replace=False)) for _ in range(B*NH)]).reshape(B*NH, -1) + token_mask = torch.tensor(choice, device="cuda") + + quantiles = [0.5, 0.2, 0.8] + if provider == 'torch': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.bmm(q, k.transpose(1,2)), quantiles=quantiles) + if provider == 'triton-optimized': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gather_outer_bmv_optimized(q, k.transpose(1, 2), token_mask), quantiles=quantiles) + if provider == 'triton-sparq': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gather_outer_bmv(q, k.transpose(1, 2), token_mask, chunk=256), quantiles=quantiles) + #perf = lambda ms: 2 * (B*NH) * S * D * 1e-12 / (ms * 1e-3) + #return perf(ms), perf(max_ms), perf(min_ms) + return ms, max_ms, min_ms + +@triton.testing.perf_report(configs) +def benchmark_bmm2(sparsity, B, NH, S, D, provider): + k_seq_len = S + scores_sampled = torch.randn( (B*NH, 1, int(k_seq_len*sparsity)), device='cuda', dtype=dtype) + scores = torch.randn( (B*NH, 1, k_seq_len), device='cuda', dtype=dtype) + + v = torch.randn((B*NH, k_seq_len, D), device='cuda', dtype=dtype) + choice = np.concatenate([np.sort(np.random.choice(S, size=(int(k_seq_len*sparsity)), replace=False)) for _ in range(B*NH)]).reshape(B*NH, -1) + token_mask = torch.tensor(choice, device="cuda") + + quantiles = [0.5, 0.2, 0.8] + if provider == 'torch': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.bmm(scores, v), quantiles=quantiles) + if provider == 'triton-optimized': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gather_inner_matrix_only_bmv_optimized(scores_sampled, v, token_mask), quantiles=quantiles) + if provider == 'triton-sparq': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gather_inner_matrix_only_bmv(scores_sampled, v, token_mask, chunk=256), quantiles=quantiles) + + return ms, max_ms, min_ms + + + +print("===== BENCHMARKING q@k.t() with various sparsities =======") +print("Batch Size : ", B) +print("Number of Heads : ", NH) +print("Number of Key Tokens (or sequence length) : ", S) +print("Hidden dimension per head : ", D) +result = benchmark_bmm1.run(print_data=True) + + + +print("===== BENCHMARKING s@v with various sparsities =======") +print("Batch Size : ", B) +print("Number of Heads : ", NH) +print("Number of Key Tokens (or sequence length) : ", S) +print("Hidden dimension per head : ", D) +result = benchmark_bmm2.run(print_data=True) + + diff --git a/methods/pca_topk/kernel/pca_topk.py b/methods/pca_topk/kernel/pca_topk.py new file mode 100644 index 0000000..c641d8b --- /dev/null +++ b/methods/pca_topk/kernel/pca_topk.py @@ -0,0 +1,340 @@ +import math +import warnings + +import torch +import triton +import triton.language as tl +from torch import Tensor + +def get_autotune_config_outer(): + return [ + triton.Config({"n_chunk": 4}), + triton.Config({"n_chunk": 8}), + triton.Config({"n_chunk": 16}), + triton.Config({"n_chunk": 32}), + triton.Config({"n_chunk": 64}), + triton.Config({"n_chunk": 128}), + triton.Config({"n_chunk": 256}), + triton.Config({"n_chunk": 512}), + triton.Config({"n_chunk": 1024}) + ] + +@triton.autotune( + configs=get_autotune_config_outer(), + key=['b', 'n', 'k'], +) +@triton.jit +def _kernel_gather_outer_bmv( + A_ptr, + B_ptr, + I_ptr, + Y_ptr, + k: tl.constexpr, + b: int, + n: int, + n_chunk: tl.constexpr, + A_s0: int, + A_s2: int, + B_s0: int, + B_s1: int, + B_s2: int, + I_s0: int, + I_s1: int, +): + pid_b = tl.program_id(axis=0).to(tl.int64) + pid_n = tl.program_id(axis=1).to(tl.int64) + a = tl.load(A_ptr + pid_b * A_s0 + tl.arange(0, k) * A_s2) # (k) + chunk_idx = pid_n * n_chunk + tl.arange(0, n_chunk) + i = tl.load(I_ptr + pid_b * I_s0 + chunk_idx * I_s1, mask=(chunk_idx < n))# (n_chunk) + b = tl.load( # (k x n_chunk) + B_ptr + + pid_b * B_s0 + + (tl.arange(0, k) * B_s1)[:, None] + + (i * B_s2)[None, :], + mask=(chunk_idx < n)[None, :] + ) + # # As tl.dot() is unavailable for matrix-vector + y = tl.sum((a[:, None] * b).to(tl.float32), 0).to(a.dtype) # (n_chunk) + tl.store(Y_ptr + pid_b * n + chunk_idx, y, mask=(chunk_idx < n)) + +def gather_outer_bmv_optimized(A: Tensor, B: Tensor, I: Tensor) -> Tensor: + """Batched vector-matrix multiplication, with a gather on the matrix outer dimension. + + Dimensions: + b -- batch + k -- inner dimension, must be a power of two + n* -- (pre-gather) outer dimension + n -- (post-gather) outer dimension (n <= n*) + + A -- (b, 1, k) batch of vectors + B -- (b, k, n*) batch of matrices + I -- int(b, n) indices, in [0, n*) + chunk -- int size of chunks of `B` (along dimension `n`) to be processed at a time + + returns -- (b, 1, n) the inner product of `A` and `B`, after gathering the outer dimension + according to `I` + """ + if A.ndim > 3: + assert B.ndim == A.ndim and I.ndim == A.ndim - 1 + return gather_outer_bmv_custom( + A.flatten(end_dim=-3), + B.flatten(end_dim=-3), + I.flatten(end_dim=-2), + ).unflatten(0, A.shape[:-2]) + assert A.ndim == 3 and B.ndim == 3 and A.shape[1] == 1 and A.shape[2] == B.shape[1] + assert I.ndim == 2 and I.shape[0] == A.shape[0] + + b, k, n = A.shape[0], A.shape[2], I.shape[1] + Y = torch.empty((b, 1, n), dtype=A.dtype, device=A.device) + assert Y.stride(0) == n and Y.stride(2) == 1 + + grid = lambda META: (b, triton.cdiv(n, META["n_chunk"])) + _kernel_gather_outer_bmv[grid]( + A_ptr=A, + B_ptr=B, + I_ptr=I, + Y_ptr=Y, + b=b, + k=k, + n=n, + A_s0=A.stride(0), + A_s2=A.stride(2), + B_s0=B.stride(0), + B_s1=B.stride(1), + B_s2=B.stride(2), + I_s0=I.stride(0), + I_s1=I.stride(1), + ) + + + return Y + +def get_autotune_config_topr(): + return [ + triton.Config({"n_chunk": 32}), + triton.Config({"n_chunk": 128}), + triton.Config({"n_chunk": 512}), + triton.Config({"n_chunk": 1024}) + ] + +@triton.autotune( + configs=get_autotune_config_topr(), + key=['b', 'r'], +) +@triton.jit +def _kernel_topr_bmv( + A_ptr, + B_ptr, + Y_ptr, + k: tl.constexpr, + b: int, + n: int, + r: tl.constexpr, + n_chunk: tl.constexpr, + A_s0: int, + A_s2: int, + B_s0: int, + B_s1: int, + B_s2: int +): + pid_b = tl.program_id(axis=0).to(tl.int64) + pid_n = tl.program_id(axis=1).to(tl.int64) + a = tl.load(A_ptr + pid_b * A_s0 + tl.arange(0, r) * A_s2) # (r) + chunk_idx = pid_n * n_chunk + tl.arange(0, n_chunk) + #i = tl.load(I_ptr + pid_b * I_s0 + chunk_idx * I_s1, mask=(chunk_idx < n))# (n_chunk) + b = tl.load( # (r x n_chunk) + B_ptr + + pid_b * B_s0 + + (tl.arange(0, r) * B_s1)[:, None] + + (chunk_idx * B_s2)[None, :], + mask=(chunk_idx < n)[None, :] + ) + # # As tl.dot() is unavailable for matrix-vector + y = tl.sum((a[:, None] * b).to(tl.float32), 0).to(a.dtype) # (n_chunk) + tl.store(Y_ptr + pid_b * n + chunk_idx, y, mask=(chunk_idx < n)) + +def topr_bmv_optimized(A: Tensor, B: Tensor, r: int) -> Tensor: + """Batched vector-matrix multiplication, with a gather on the matrix outer dimension. + + Dimensions: + b -- batch + k -- inner dimension + n -- outer dimension (n) + + A -- (b, 1, k) batch of vectors + B -- (b, k, n) batch of matrices + r -- int only r out of k dimensions are used for the inner product + + returns -- (b, 1, n) the inner product of `A` and `B`, but only using r out of k inner dimensions + """ + if A.ndim > 3: + assert B.ndim == A.ndim + return gather_outer_bmv_custom( + A.flatten(end_dim=-3), + B.flatten(end_dim=-3), + r + ).unflatten(0, A.shape[:-2]) + assert A.ndim == 3 and B.ndim == 3 and A.shape[1] == 1 and A.shape[2] == B.shape[1] + + b, k, n = A.shape[0], A.shape[2], B.shape[-1] + Y = torch.empty((b, 1, n), dtype=A.dtype, device=A.device) + assert Y.stride(0) == n and Y.stride(2) == 1 + assert r <= k + + grid = lambda META: (b, triton.cdiv(n, META["n_chunk"])) + _kernel_topr_bmv[grid]( + A_ptr=A, + B_ptr=B, + Y_ptr=Y, + b=b, + k=k, + n=n, + r=r, + A_s0=A.stride(0), + A_s2=A.stride(2), + B_s0=B.stride(0), + B_s1=B.stride(1), + B_s2=B.stride(2), + ) + + + return Y + + + +def get_autotune_config_inner(): + return [ + triton.Config({"n_chunk": 4}), + triton.Config({"n_chunk": 8}), + triton.Config({"n_chunk": 16}), + triton.Config({"n_chunk": 32}), + triton.Config({"n_chunk": 64}), + triton.Config({"n_chunk": 128}), + ] + +@triton.autotune( + configs=get_autotune_config_inner(), + key=['b', 'n', 'k'], +) +@triton.jit +def _kernel_gather_inner_bmv( + A_ptr, + B_ptr, + I_ptr, + Y_ptr, + b: int, + k: tl.constexpr, # int + k_next_pow_2: tl.constexpr, + n: int, + n_chunk: tl.constexpr, # int + A_s0: int, + A_s2: int, + B_s0: int, + B_s1: int, + B_s2: int, + I_s0: int, + I_s1: int, + gather_A: tl.constexpr, # bool +): + pid_b = tl.program_id(axis=0).to(tl.int64) + pid_n = tl.program_id(axis=1).to(tl.int64) + i = tl.load(I_ptr + pid_b * I_s0 + tl.arange(0,k_next_pow_2) * I_s1, mask = tl.arange(0,k_next_pow_2) < k) # (k) + a = tl.load(A_ptr + pid_b * A_s0 + (i if gather_A else tl.arange(0,k_next_pow_2)) * A_s2, mask = tl.arange(0,k_next_pow_2) < k) # (k) + chunk_idx = pid_n * n_chunk + tl.arange(0, n_chunk) + b = tl.load( # (k x n_chunk) + B_ptr + pid_b * B_s0 + (i * B_s1)[:, None] + (chunk_idx * B_s2)[None, :], + mask=(chunk_idx < n)[None, :] + ) + # As tl.dot() is unavailable for matrix-vector + y = tl.sum((a[:, None] * b).to(tl.float32), 0).to(a.dtype) # (n_chunk) + tl.store(Y_ptr + pid_b * n + chunk_idx, y, mask=(chunk_idx < n)) + + +def gather_inner_bmv_optimized( + A: Tensor, B: Tensor, I: Tensor, _matrix_only: bool = False +) -> Tensor: + """Batched vector-matrix multiplication, with a gather on the inner dimension. + + Dimensions: + b -- batch + k* -- (pre-gather) inner dimension + k -- (post-gather) inner dimension (k <= k*), must be a power of two + n -- outer dimension + + A -- (b, 1, k*) batch of vectors + B -- (b, k*, n) batch of matrices + I -- int(b, k) indices, in [0, k*) + chunk -- int size of chunks of `B` (along dimension `n`) to be processed at a time + _matrix_only -- bool don't use (see `gather_inner_matrix_only_bmv`) + + returns -- (b, 1, n) the inner product of `A` and `B`, after gathering the inner dimension + according to `I` + """ + if A.ndim > 3: + assert B.ndim == A.ndim and I.ndim == A.ndim - 1 + return gather_inner_bmv( + A.flatten(end_dim=-3), + B.flatten(end_dim=-3), + I.flatten(end_dim=-2), + chunk=chunk, + _matrix_only=_matrix_only, + ).unflatten(0, A.shape[:-2]) + assert A.ndim == 3 and B.ndim == 3 and A.shape[1] == 1 + assert ( + I.ndim == 2 + and I.shape[0] == A.shape[0] + ) + assert A.shape[2] == (I.shape[1] if _matrix_only else B.shape[1]) + if B.stride(2) != 1: + warnings.warn( + "gather_inner_bmv(A, B, ...) `B` should be contiguous in the last dimension" + ", otherwise it is very slow" + ) + + b, k, n = A.shape[0], I.shape[1], B.shape[2] + k_next_pow_2 = triton.next_power_of_2(k) + Y = torch.empty((b, 1, n), dtype=A.dtype, device=A.device) + assert Y.stride(0) == n and Y.stride(2) == 1 + + grid = lambda META: (b, triton.cdiv(n, META["n_chunk"])) + _kernel_gather_inner_bmv[grid]( + A_ptr=A, + B_ptr=B, + I_ptr=I, + Y_ptr=Y, + b=b, + k=k, + k_next_pow_2=k_next_pow_2, + n=n, + A_s0=A.stride(0), + A_s2=A.stride(2), + B_s0=B.stride(0), + B_s1=B.stride(1), + B_s2=B.stride(2), + I_s0=I.stride(0), + I_s1=I.stride(1), + gather_A=not _matrix_only, + ) + return Y + + +def gather_inner_matrix_only_bmv_optimized(A: Tensor, B: Tensor, I: Tensor) -> Tensor: + """Batched vector-matrix multiplication, with a gather on the inner dimension of the matrix. + + Dimensions: + b -- batch + k* -- (pre-gather) inner dimension + k -- (post-gather) inner dimension (k <= k*), must be a power of two + n -- outer dimension + + A -- (b, 1, k) batch of vectors + B -- (b, k*, n) batch of matrices + I -- int(b, k) indices, in [0, k*) + chunk -- int size of chunks of `B` (along dimension `n`) to be processed at a time + + returns -- (b, 1, n) the inner product of `A` and `B`, after gathering the inner dimension + of `B` according to `I` + """ + return gather_inner_bmv_optimized(A, B, I, _matrix_only=True) + + diff --git a/methods/pca_topk/kernel/sparq.py b/methods/pca_topk/kernel/sparq.py new file mode 100644 index 0000000..3cf85b4 --- /dev/null +++ b/methods/pca_topk/kernel/sparq.py @@ -0,0 +1,212 @@ +import math +import warnings + +import torch +import triton +import triton.language as tl +from torch import Tensor + +@triton.jit +def _kernel_gather_outer_bmv( + A_ptr, + B_ptr, + I_ptr, + Y_ptr, + k: tl.constexpr, + n: int, + n_chunk: tl.constexpr, + A_s0: int, + A_s2: int, + B_s0: int, + B_s1: int, + B_s2: int, + I_s0: int, + I_s1: int, +): + pid = tl.program_id(axis=0).to(tl.int64) + a = tl.load(A_ptr + pid * A_s0 + tl.arange(0, k) * A_s2) # (k) + for chunk in range(0, tl.cdiv(n, n_chunk)): + chunk_idx = chunk * n_chunk + tl.arange(0, n_chunk) + i = tl.load(I_ptr + pid * I_s0 + chunk_idx * I_s1) # (n_chunk) + b = tl.load( # (k x n_chunk) + B_ptr + + pid * B_s0 + + (tl.arange(0, k) * B_s1)[:, None] + + (i * B_s2)[None, :], + mask=(chunk_idx < n)[None, :], + ) + # # As tl.dot() is unavailable for matrix-vector + y = tl.sum(a[:, None] * b, 0) # (n_chunk) + tl.store(Y_ptr + pid * n + chunk_idx, y, mask=(chunk_idx < n)) + + + +def gather_outer_bmv(A: Tensor, B: Tensor, I: Tensor, chunk: int) -> Tensor: + """Batched vector-matrix multiplication, with a gather on the matrix outer dimension. + + Dimensions: + b -- batch + k -- inner dimension, must be a power of two + n* -- (pre-gather) outer dimension + n -- (post-gather) outer dimension (n <= n*) + + A -- (b, 1, k) batch of vectors + B -- (b, k, n*) batch of matrices + I -- int(b, n) indices, in [0, n*) + chunk -- int size of chunks of `B` (along dimension `n`) to be processed at a time + + returns -- (b, 1, n) the inner product of `A` and `B`, after gathering the outer dimension + according to `I` + """ + if A.ndim > 3: + assert B.ndim == A.ndim and I.ndim == A.ndim - 1 + return gather_outer_bmv( + A.flatten(end_dim=-3), + B.flatten(end_dim=-3), + I.flatten(end_dim=-2), + chunk=chunk, + ).unflatten(0, A.shape[:-2]) + assert A.ndim == 3 and B.ndim == 3 and A.shape[1] == 1 and A.shape[2] == B.shape[1] + assert I.ndim == 2 and I.shape[0] == A.shape[0] + + b, k, n = A.shape[0], A.shape[2], I.shape[1] + Y = torch.empty((b, 1, n), dtype=A.dtype, device=A.device) + assert Y.stride(0) == n and Y.stride(2) == 1 + + _kernel_gather_outer_bmv[(b,)]( + A_ptr=A, + B_ptr=B, + I_ptr=I, + Y_ptr=Y, + k=k, + n=n, + n_chunk=chunk, + A_s0=A.stride(0), + A_s2=A.stride(2), + B_s0=B.stride(0), + B_s1=B.stride(1), + B_s2=B.stride(2), + I_s0=I.stride(0), + I_s1=I.stride(1), + ) + return Y + + + +@triton.jit +def _kernel_gather_inner_bmv( + A_ptr, + B_ptr, + I_ptr, + Y_ptr, + k: tl.constexpr, # int + n: int, + n_chunk: tl.constexpr, # int + A_s0: int, + A_s2: int, + B_s0: int, + B_s1: int, + B_s2: int, + I_s0: int, + I_s1: int, + gather_A: tl.constexpr, # bool +): + pid = tl.program_id(axis=0).to(tl.int64) + i = tl.load(I_ptr + pid * I_s0 + tl.arange(0, k) * I_s1) # (k) + a = tl.load(A_ptr + pid * A_s0 + (i if gather_A else tl.arange(0, k)) * A_s2) # (k) + for chunk in range(0, tl.cdiv(n, n_chunk)): + chunk_idx = chunk * n_chunk + tl.arange(0, n_chunk) + b = tl.load( # (k x n_chunk) + B_ptr + pid * B_s0 + (i * B_s1)[:, None] + (chunk_idx * B_s2)[None, :] + ) + # As tl.dot() is unavailable for matrix-vector + y = tl.sum(a[:, None] * b, 0) # (n_chunk) + tl.store(Y_ptr + pid * n + chunk_idx, y, mask=(chunk_idx < n)) + + +def gather_inner_bmv( + A: Tensor, B: Tensor, I: Tensor, chunk: int, _matrix_only: bool = False +) -> Tensor: + """Batched vector-matrix multiplication, with a gather on the inner dimension. + + Dimensions: + b -- batch + k* -- (pre-gather) inner dimension + k -- (post-gather) inner dimension (k <= k*), must be a power of two + n -- outer dimension + + A -- (b, 1, k*) batch of vectors + B -- (b, k*, n) batch of matrices + I -- int(b, k) indices, in [0, k*) + chunk -- int size of chunks of `B` (along dimension `n`) to be processed at a time + _matrix_only -- bool don't use (see `gather_inner_matrix_only_bmv`) + + returns -- (b, 1, n) the inner product of `A` and `B`, after gathering the inner dimension + according to `I` + """ + if A.ndim > 3: + assert B.ndim == A.ndim and I.ndim == A.ndim - 1 + return gather_inner_bmv( + A.flatten(end_dim=-3), + B.flatten(end_dim=-3), + I.flatten(end_dim=-2), + chunk=chunk, + _matrix_only=_matrix_only, + ).unflatten(0, A.shape[:-2]) + assert A.ndim == 3 and B.ndim == 3 and A.shape[1] == 1 + assert ( + I.ndim == 2 + and I.shape[0] == A.shape[0] + and 2 ** int(math.log2(I.shape[1])) == I.shape[1] + ) + assert A.shape[2] == (I.shape[1] if _matrix_only else B.shape[1]) + if B.stride(2) != 1: + warnings.warn( + "gather_inner_bmv(A, B, ...) `B` should be contiguous in the last dimension" + ", otherwise it is very slow" + ) + + b, k, n = A.shape[0], I.shape[1], B.shape[2] + Y = torch.empty((b, 1, n), dtype=A.dtype, device=A.device) + assert Y.stride(0) == n and Y.stride(2) == 1 + + _kernel_gather_inner_bmv[(b,)]( + A_ptr=A, + B_ptr=B, + I_ptr=I, + Y_ptr=Y, + k=k, + n=n, + n_chunk=chunk, + A_s0=A.stride(0), + A_s2=A.stride(2), + B_s0=B.stride(0), + B_s1=B.stride(1), + B_s2=B.stride(2), + I_s0=I.stride(0), + I_s1=I.stride(1), + gather_A=not _matrix_only, + ) + return Y + + +def gather_inner_matrix_only_bmv(A: Tensor, B: Tensor, I: Tensor, chunk: int) -> Tensor: + """Batched vector-matrix multiplication, with a gather on the inner dimension of the matrix. + + Dimensions: + b -- batch + k* -- (pre-gather) inner dimension + k -- (post-gather) inner dimension (k <= k*), must be a power of two + n -- outer dimension + + A -- (b, 1, k) batch of vectors + B -- (b, k*, n) batch of matrices + I -- int(b, k) indices, in [0, k*) + chunk -- int size of chunks of `B` (along dimension `n`) to be processed at a time + + returns -- (b, 1, n) the inner product of `A` and `B`, after gathering the inner dimension + of `B` according to `I` + """ + return gather_inner_bmv(A, B, I, chunk=chunk, _matrix_only=True) + + diff --git a/methods/pca_topk/kernel/test.py b/methods/pca_topk/kernel/test.py new file mode 100644 index 0000000..c82516c --- /dev/null +++ b/methods/pca_topk/kernel/test.py @@ -0,0 +1,97 @@ +from pca_topk import gather_outer_bmv_optimized, gather_inner_matrix_only_bmv_optimized, topr_bmv_optimized +import torch +import numpy as np +import pytest + +@pytest.mark.parametrize("B", [2, 4, 8, 16]) +@pytest.mark.parametrize("NH", [8, 16, 32]) +@pytest.mark.parametrize("S", [32, 33, 37, 64, 69, 73, 128, 255, 259, 1024, 1028, 2048, 2500]) +@pytest.mark.parametrize("D", [128]) +@pytest.mark.parametrize("sparsity", [0.125, 0.25, 0.5]) +def test_first_bmm(B, NH, S, D, sparsity, dtype=torch.float32): + """ + Test the correctness of the first bmm (q @ k.t()) + B - batch size + NH - number of heads + S - key sequence length + D - hidden dimension per head + sparsity - topk in [0,1) + """ + + k_seq_len = S + q = torch.randn((B*NH, 1, D), device='cuda', dtype=dtype) + k = torch.randn((B*NH, k_seq_len, D), device='cuda', dtype=dtype) + + choice = np.concatenate([np.sort(np.random.choice(S, size=(int(k_seq_len*sparsity)), replace=False)) for _ in range(B*NH)]).reshape(B*NH, -1) + token_mask = torch.tensor(choice, device="cuda") + + y_optimized = gather_outer_bmv_optimized(q, k.transpose(1,2), token_mask) + + for i in range(B*NH): + token_mask[i] += k_seq_len*i + + k_reshaped = k.view(-1, D) + k_sampled = torch.index_select(k_reshaped, dim=0, index=token_mask.view(-1)).reshape(B*NH, -1, D) + y_torch = torch.bmm(q, k_sampled.transpose(1,2)) + + assert torch.allclose(y_optimized, y_torch, rtol=1e-2, atol=1e-2) + + + +@pytest.mark.parametrize("B", [2, 4, 8, 16]) +@pytest.mark.parametrize("NH", [8, 16, 32]) +@pytest.mark.parametrize("S", [32, 33, 37, 64, 69, 73, 128, 255, 259, 1024, 1028, 2048, 2500]) +@pytest.mark.parametrize("D", [128]) +@pytest.mark.parametrize("sparsity", [0.125, 0.25, 0.5]) +def test_second_bmm(B, NH, S, D, sparsity, dtype=torch.float32): + """ + Test the correctness of the first bmm (q @ k.t()) + B - batch size + NH - number of heads + S - key sequence length + D - hidden dimension per head + sparsity - topk in [0,1) + """ + + k_seq_len = S + scores = torch.randn( (B*NH, 1, int(k_seq_len*sparsity)), device='cuda', dtype=dtype) + v = torch.randn((B*NH, k_seq_len, D), device='cuda', dtype=dtype) + + choice = np.concatenate([np.sort(np.random.choice(S, size=(int(k_seq_len*sparsity)), replace=False)) for _ in range(B*NH)]).reshape(B*NH, -1) + token_mask = torch.tensor(choice, device="cuda") + + y_optimized = gather_inner_matrix_only_bmv_optimized(scores, v, token_mask) + + for i in range(B*NH): + token_mask[i] += k_seq_len*i + + v_reshaped = v.view(-1, D) + v_sampled = torch.index_select(v_reshaped, dim=0, index=token_mask.view(-1)).reshape(B*NH, -1, D) + y_torch = torch.bmm(scores, v_sampled) + + assert torch.allclose(y_optimized, y_torch, rtol=1e-2, atol=1e-2) + +@pytest.mark.parametrize("B", [2, 4, 8, 16]) +@pytest.mark.parametrize("NH", [8, 16, 32]) +@pytest.mark.parametrize("S", [32, 33, 37, 64, 69, 73, 128, 255, 259, 1024, 1028, 2048, 2500]) +@pytest.mark.parametrize("r", [32, 64, 128]) +def test_topr_bmm(B, NH, S, r, dtype=torch.float32): + k_seq_len = S + D = 128 + q = torch.randn((B*NH, 1, D), device='cuda', dtype=dtype) + k = torch.randn((B*NH, k_seq_len, D), device='cuda', dtype=dtype) + + y_optimized = topr_bmv_optimized(q, k.transpose(1,2), r) + + y_torch = torch.bmm(q[:, :, :r], k[: , :, :r].transpose(1,2)) + + assert torch.allclose(y_optimized, y_torch, rtol=1e-2, atol=1e-2) + +if __name__ == "__main__": + test_topr_bmm( + B=2, + NH=32, + S=1024, + D=128, + sparsity=0.25 + ) diff --git a/methods/pca_topk/modify_llama_optimized.py b/methods/pca_topk/modify_llama_optimized.py index 073b8f0..dce1ca7 100644 --- a/methods/pca_topk/modify_llama_optimized.py +++ b/methods/pca_topk/modify_llama_optimized.py @@ -12,6 +12,7 @@ from .utils import mask_attn_pca_topk, get_pca_components import methods.pca_topk.external.gather_matmul as G +import methods.pca_topk.kernel.pca_topk as G import methods @@ -103,23 +104,20 @@ def modified_forward( key_states = key_states.reshape(-1, key_states.shape[-2], key_states.shape[-1]) query_states = query_states.reshape(-1, query_states.shape[-2], query_states.shape[-1]) - attn_weights = G.gather_outer_bmv( - query_states.contiguous(), - key_states.transpose(-1, -2).contiguous(), + attn_weights = G.gather_outer_bmv_optimized( + query_states, + key_states.transpose(-1, -2), key_states_topk_indices, - chunk=256 # Varying this changes performance - #chunk=min(k2, 65536 // Q.shape[-1]), ) / math.sqrt(self.head_dim) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) value_states = value_states.reshape(-1, value_states.shape[-2], value_states.shape[-1]) - attn_output = G.gather_inner_matrix_only_bmv( - attn_weights.contiguous(), - value_states.contiguous(), + attn_output = G.gather_inner_matrix_only_bmv_optimized( + attn_weights, + value_states, key_states_topk_indices, - chunk=64 ) attn_output = attn_output.reshape(bsz, self.num_heads, q_len, self.head_dim) else: diff --git a/methods/pca_topk/timers.py b/methods/pca_topk/timers.py new file mode 100644 index 0000000..dff0923 --- /dev/null +++ b/methods/pca_topk/timers.py @@ -0,0 +1,31 @@ +import torch +from collections import defaultdict + +class Timers(): + def __init__(self): + self.timers = defaultdict(list) + self.curr_index = defaultdict(int) + + def start(self, key): + index = self.curr_index[key] + timers = self.timers[key] + assert index == len(timers) or index < len(timers) + if index == len(timers): + self.timers[key].append([torch.cuda.Event(enable_timing=True) for _ in range(2)]) + self.timers[key][index][0].record() + + + def stop(self, key): + index = self.curr_index[key] + self.timers[key][index][1].record() + self.curr_index[key] += 1 + + def get_times(self): + torch.cuda.synchronize() + total_times = defaultdict(float) + for key in self.timers: + for events in self.timers[key]: + start_event, end_event = events + total_times[key] += start_event.elapsed_time(end_event) / 1000 + self.curr_index[key] = 0 + return total_times