Skip to content

CUDA + PyTorch implementation of Flash Attention for transformers

Notifications You must be signed in to change notification settings

ulrikisdahl/Flash-Attention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Flash attention (v1) implementation


Benchmark:

I compare the Flash Attention implementation with a pure pytorch implementation of the attention algorithm

- Pure torch attention CUDA time total: 64.838ms
- Flash Attention CUDA time total: 9.864ms

Validity of imlementation:

The CUDA implementation is compared with a correct, pure pytorch implementation and a official pytorch implementation of the attention mechanism in main.py. The tests show that the Flash Attention implementation is correct.

Sample results:

All close test: True
Average value in torch-attention: 0.0013573728501796722
Average value in scaled_dot_prod-attention: 0.0013573728501796722
Average value in flash-attention: 0.0013573728501796722

Run tests: python3 -m test.main

About

CUDA + PyTorch implementation of Flash Attention for transformers

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published