Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] fast compute kernels for the compute benchmark #11

Open
wants to merge 10 commits into
base: pca_attn
Choose a base branch
from
34 changes: 23 additions & 11 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,21 @@ def setup_parser():
parser.add_argument("--gen-length", type=int, default=32, help="Batch Size")
parser.add_argument("--seed", type=int, default=1234, help="Seed")
parser.add_argument("--use-optimized-code", action='store_true', default=False)
parser.add_argument("--warmup-iters", type=int, default=5)
parser.add_argument("--total-iters", type=int, default=10)

return parser

def load_prompts(tokenizer, batch_size, prompt_length):
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt")
total_tokens = encodings.input_ids.shape[1]
start_index = min(random.randint(0, total_tokens), total_tokens - batch_size * prompt_length)
input_ids = encodings.input_ids[:, start_index : start_index + batch_size * prompt_length].reshape(batch_size, prompt_length)
input_ids = []
for _ in range(batch_size):
start_index = min(random.randint(0, total_tokens), total_tokens - prompt_length)
tokens = encodings.input_ids[:, start_index : start_index + prompt_length].reshape(1, prompt_length)
input_ids.append(tokens)
input_ids = torch.cat(input_ids, dim=0)
return input_ids

if __name__ == "__main__":
Expand All @@ -62,8 +68,8 @@ def load_prompts(tokenizer, batch_size, prompt_length):
set_seed(args.seed)

if args.method == "pca-topk":
args.top_k = args.prompt_length
args.top_r = 128
args.top_k = int(0.25 * args.prompt_length)
args.top_r = 32
args.rotary_type = "postrotary"

if args.use_optimized_code:
Expand All @@ -84,20 +90,26 @@ def load_prompts(tokenizer, batch_size, prompt_length):

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
generations = []

input_ids = tokenized_prompts.cuda()
with torch.autocast(device_type='cuda', dtype=dtype):
outputs = model.generate(input_ids, do_sample=True, max_new_tokens=args.gen_length, num_beams=4)

# warmup iters
for _ in range(args.warmup_iters):
with torch.autocast(device_type='cuda', dtype=dtype):
outputs = model.generate(input_ids, do_sample=True, max_new_tokens=args.gen_length)

# timed iters
start_event.record()
for _ in range(args.total_iters - args.warmup_iters):
with torch.autocast(device_type='cuda', dtype=dtype):
outputs = model.generate(input_ids, do_sample=True, max_new_tokens=args.gen_length)
end_event.record()


generated_tokens = outputs.numel() - input_ids.numel()
total_generated_tokens += generated_tokens

torch.cuda.synchronize()
total_time = start_event.elapsed_time(end_event)
total_time = start_event.elapsed_time(end_event) / (args.total_iters - args.warmup_iters)
tput = total_generated_tokens * 1000 / total_time

output_ids = outputs[:, args.prompt_length:]
Expand Down
254 changes: 171 additions & 83 deletions methods/pca_topk/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import time

import torch
import external.gather_matmul as G

#import external.gather_matmul as G
import kernel.pca_topk as G


topk_time = 0
Expand All @@ -16,6 +16,8 @@
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = False

from timers import Timers
import json

# Work In Progress
class PcaTopKCache(Cache): # Not used anymore
Expand Down Expand Up @@ -120,6 +122,10 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
return max_length - new_seq_length
return previous_seq_length

def reset(self):
self.key_cache: List[torch.Tensor] = [] # Stores the reduced keys for each layer
self.value_cache: List[torch.Tensor] = []

def test_pcatopk_cache():
cache = PcaTopKCache(2, 4)

Expand All @@ -139,65 +145,98 @@ def test_pcatopk_cache():



def micro_benchmark_pca_topk(cache, prompt_keys, top_r, top_k, num_gen_steps=2000, use_optimised_gather=False):
def micro_benchmark_pca_topk(cache, prompt_keys, top_r, top_k, num_layers, timers,
num_gen_steps=2000, use_optimised_gather=False):
import time
torch.set_float32_matmul_precision("highest")

head_dim = prompt_keys.shape[-1]
bs = prompt_keys.shape[0]
num_heads = prompt_keys.shape[1]

generative_query = torch.rand(bs, num_heads, 1, head_dim).to("cuda")
generative_key = torch.rand(bs, num_heads, 1, head_dim).to("cuda")
head_dim = prompt_keys[0].shape[-1]
bs = prompt_keys[0].shape[0]
num_heads = prompt_keys[0].shape[1]
dtype = prompt_keys[0].dtype
prompt_seq_length = prompt_keys[0].shape[2]

print ("Starting microbenchmark")
matmul_time = 0
top_keys = torch.zeros(bs, num_heads, top_k, head_dim).to("cuda")
top_vals = torch.zeros(bs, num_heads, top_k, head_dim).to("cuda")
pca_projection_mat = torch.randn(num_heads, head_dim, head_dim, dtype=dtype, device='cuda')


assert use_optimised_gather
if use_optimised_gather:
timers.start('total')
for i in range(num_gen_steps):
keys, vals = cache.update(generative_key, generative_key, generative_query, 0, False)
torch.cuda.synchronize()

start = time.time()
attn_weights = torch.matmul(generative_query[:,:,:,:top_r], keys.transpose(2, 3)[:,:,:top_r,:]) / math.sqrt(head_dim)

# Get top-k keys and top-k values based on the attention scores
key_states_topk_indices = torch.topk(attn_weights, top_k, dim=-1).indices.to("cuda")
key_states_topk_indices,_ = torch.sort(key_states_topk_indices, dim=-1)
key_states_topk_indices= key_states_topk_indices.reshape(-1, key_states_topk_indices.shape[-1])

keys = keys.reshape(-1, keys.shape[-2] , keys.shape[-1])
vals = vals.reshape(-1, vals.shape[-2] , vals.shape[-1])

attn_weights = G.gather_outer_bmv(
generative_query.reshape(-1, 1, head_dim),
keys.transpose(-1, -2),
key_states_topk_indices,
#.squeeze(0).squeeze(-1),
chunk=256
#chunk=min(k2, 65536 // Q.shape[-1]),
) / math.sqrt(head_dim)
attn_weights = torch.softmax(attn_weights, dim=-1)

attn_output = G.gather_inner_matrix_only_bmv(
attn_weights, vals, key_states_topk_indices, chunk=64
)

torch.cuda.synchronize()
end = time.time()

if i > 5:
matmul_time += end - start
for layer in range(num_layers):
timers.start('qk-gen')
generative_query = torch.rand(bs, num_heads, 1, head_dim, device='cuda', dtype=dtype)
generative_key = torch.rand(bs, num_heads, 1, head_dim, device='cuda', dtype=dtype)
timers.stop('qk-gen')

timers.start('project')
generative_key = generative_key.squeeze().transpose(0, 1).bmm(pca_projection_mat).unsqueeze(2)
generative_query = generative_query.squeeze().transpose(0, 1).bmm(pca_projection_mat).unsqueeze(2)
timers.stop('project')

timers.start('cache-update')
keys, vals = cache.update(generative_key, generative_key, generative_query, layer, False)
timers.stop('cache-update')

timers.start('qk-matmul-1')
#attn_weights = torch.matmul(generative_query[:,:,:,:top_r], keys.transpose(2, 3)[:,:,:top_r,:]) / math.sqrt(head_dim)
nh, bs, s, r = keys.shape
attn_weights = G.topr_bmv_optimized(A=generative_query.view(nh*bs, 1, r), B=keys.view(nh*bs, s, r).transpose(-1,-2),
r=top_r)
attn_weights = attn_weights.view(nh, bs, 1, s)
timers.stop('qk-matmul-1')

# Get top-k keys and top-k values based on the attention scores
timers.start('top-k')
#key_states_topk_indices = torch.topk(attn_weights, top_k, dim=-1, sorted=False).indices.to("cuda")
#key_states_topk_indices,_ = torch.sort(key_states_topk_indices, dim=-1)
key_states_topk_indices = torch.argsort(attn_weights, dim=-1, descending=True)[:,:,:,:top_k]
timers.stop('top-k')


timers.start('reshape-0')
key_states_topk_indices= key_states_topk_indices.reshape(-1, key_states_topk_indices.shape[-1])
timers.stop('reshape-0')

timers.start('reshape-1')
keys = keys.view(-1, keys.shape[-2] , keys.shape[-1])
vals = vals.view(-1, vals.shape[-2] , vals.shape[-1])
timers.stop('reshape-1')

timers.start('qk-matmul-2')
attn_weights = G.gather_outer_bmv_optimized(
generative_query.reshape(-1, 1, head_dim),
keys.transpose(-1, -2),
key_states_topk_indices,
#.squeeze(0).squeeze(-1),
#chunk=256
#chunk=min(k2, 65536 // Q.shape[-1]),
) / math.sqrt(head_dim)
timers.stop('qk-matmul-2')

timers.start('softmax')
attn_weights = torch.softmax(attn_weights.float(), dim=-1).to(dtype)
timers.stop('softmax')

timers.start('sv-matmul')
attn_output = G.gather_inner_matrix_only_bmv_optimized(
attn_weights, vals, key_states_topk_indices)
timers.stop('sv-matmul')

timers.start('reshape-output')
attn_output = attn_output.view(num_heads, bs, 1, head_dim).transpose(0,1).transpose(1,2).contiguous()
timers.stop('reshape-output')
timers.stop('total')
else:
for i in range(num_gen_steps):
keys, vals = cache.update(generative_key, generative_key, generative_query, 0, False)
torch.cuda.synchronize()

start = time.time()
attn_weights = torch.matmul(generative_query[:,:,:,:top_r], keys.transpose(2, 3)[:,:,:top_r,:]) / math.sqrt(head_dim)

# Get top-k keys and top-k values based on the attention scores
key_states_topk_indices = torch.topk(attn_weights, top_k, dim=-1).indices.to("cuda")
key_states_topk_indices,_ = torch.sort(key_states_topk_indices, dim=-1)
Expand All @@ -212,70 +251,119 @@ def micro_benchmark_pca_topk(cache, prompt_keys, top_r, top_k, num_gen_steps=200
torch.cuda.synchronize()
end = time.time()

if i > 5:
matmul_time += end - start
print (f"Matmul Time: {matmul_time}")

def micro_bench_actual_attention(cache, prompt_keys, num_gen_steps=2000):
def micro_bench_actual_attention(cache, prompt_keys, num_layers, timers, num_gen_steps=2000):
import time
torch.set_float32_matmul_precision("highest")

head_dim = prompt_keys.shape[-1]
bs = prompt_keys.shape[0]
num_heads = prompt_keys.shape[1]

generative_query = torch.rand(bs, num_heads, 1, head_dim).to("cuda")
generative_key = torch.rand(bs, num_heads, 1, head_dim).to("cuda")
head_dim = prompt_keys[0].shape[-1]
bs = prompt_keys[0].shape[0]
num_heads = prompt_keys[0].shape[1]
dtype = prompt_keys[0].dtype

print ("Starting microbenchmark")
matmul_time = 0
for i in range(num_gen_steps):
keys, vals = cache.update(generative_key, generative_key, generative_query, 0, False)
torch.cuda.synchronize()

start = time.time()
attn_weights = torch.matmul(generative_query, keys.transpose(2, 3)) / math.sqrt(head_dim)
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, vals)
torch.cuda.synchronize()
end = time.time()
timers.start('total')
for i in range(num_gen_steps):
for layer in range(num_layers):
timers.start('qk-gen')
generative_query = torch.rand(bs, num_heads, 1, head_dim, dtype=dtype, device='cuda')
generative_key = torch.rand(bs, num_heads, 1, head_dim, dtype=dtype, device='cuda')
timers.stop('qk-gen')

timers.start('cache-update')
keys, vals = cache.update(generative_key, generative_key, generative_query, layer, False)
timers.stop('cache-update')

timers.start('qk-matmul-1')
attn_weights = torch.matmul(generative_query, keys.transpose(2, 3)) / math.sqrt(head_dim)
timers.stop('qk-matmul-1')

timers.start('softmax')
attn_weights = torch.softmax(attn_weights.float(), dim=-1).to(dtype)
timers.stop('softmax')

timers.start('sv-matmul')
attn_output = torch.matmul(attn_weights, vals)
timers.stop('sv-matmul')

timers.start('reshape-output')
attn_output = attn_output.transpose(1, 2).contiguous()
timers.stop('reshape-output')


if i > 5:
matmul_time += end - start
print (f"Matmul Time: {matmul_time}")
timers.stop('total')

@torch.no_grad()
def benchmark_attention(batch_size=1,
num_heads=32,
num_gen_steps=128,
prompt_length=3072,
topk=256):
topk=256,
num_layers=32,
dtype=torch.float16):

head_dim=128
# Change this to change batch size, etc.
prompt_keys = torch.rand(batch_size, num_heads, prompt_length, head_dim).to("cuda")
prompt_keys = [torch.rand(batch_size, num_heads, prompt_length, head_dim, device='cuda', dtype=dtype) for _ in range(num_layers)]


print("PCA TOPK Unoptimized")
cache1 = PcaTopKCache()
cache1.update(prompt_keys, prompt_keys, prompt_keys, 0)
micro_benchmark_pca_topk(cache1, prompt_keys, 32, topk, num_gen_steps=num_gen_steps)
del cache1
#print("PCA TOPK Unoptimized")
#cache1 = [PcaTopKCache() for _ in range(num_layers)]
#for i in range(num_layers):
# cache1[i].update(prompt_keys[i], prompt_keys[i], prompt_keys[i], 0)
#micro_benchmark_pca_topk(cache1, prompt_keys, 32, topk, num_gen_steps=num_gen_steps)
#del cache1


print("PCA TOPK Optimized")
cache2 = PcaTopKCache()
cache2.update(prompt_keys, prompt_keys, prompt_keys, 0)
micro_benchmark_pca_topk(cache2, prompt_keys, 32, topk, num_gen_steps=num_gen_steps, use_optimised_gather=True)
del cache2
for _ in range(10):
cache2 = PcaTopKCache()
for i in range(num_layers):
cache2.update(prompt_keys[i].transpose(0,1).contiguous(),
prompt_keys[i].transpose(0,1).contiguous(),
prompt_keys[i].transpose(0,1).contiguous(), i)
timers = Timers()
micro_benchmark_pca_topk(cache2, prompt_keys, 32, topk,
num_gen_steps=num_gen_steps, num_layers=num_layers,
use_optimised_gather=True, timers=timers)
del cache2
times = timers.get_times()
print(times)

print("Average time (minus cache updates) is - ")
print(times['total'] - times['cache-update'], " s")
print("==================================")
times_pca_topk = times

print("Actual Attention")
cache3= PcaTopKCache()
cache3.update(prompt_keys, prompt_keys, prompt_keys, 0)
micro_bench_actual_attention(cache3, prompt_keys, num_gen_steps=num_gen_steps)
del cache3
for _ in range(10):
cache3= PcaTopKCache()
for i in range(num_layers):
cache3.update(prompt_keys[i], prompt_keys[i], prompt_keys[i], i)
timers = Timers()
micro_bench_actual_attention(cache3, prompt_keys, num_layers=num_layers,
num_gen_steps=num_gen_steps, timers=timers)
del cache3
times = timers.get_times()
print("Average time (minus cache updates) is - ")
print(times['total'] - times['cache-update'], " s")
print(times)
print("==================================")
times_vanilla = times
return times_pca_topk, times_vanilla

if __name__ == "__main__":
#test_pcatopk_cache()
with torch.no_grad():
benchmark_attention(prompt_length=4096, num_gen_steps=2000, batch_size=16, topk=1024)
prompt_length = 2000
for num_gen_steps in [1000]:
print(f"prompt length = {prompt_length}, gen length = {num_gen_steps}, batch_size={16}, topk and top r are 25%")
times_pca_topk, times_vanilla = benchmark_attention(prompt_length=prompt_length, num_gen_steps=num_gen_steps, batch_size=16, topk=prompt_length // 4, num_layers=1)
with open(f"prompt_{prompt_length}_gen_{num_gen_steps}_pca_topk_opt_first_matmul.json", "w") as f:
json.dump(times_pca_topk, f, indent=2)

with open(f"prompt_{prompt_length}_gen_{num_gen_steps}_vanilla.json", "w") as f:
json.dump(times_vanilla, f, indent=2)


Empty file.
Loading