I compare the Flash Attention implementation with a pure pytorch implementation of the attention algorithm
- Flash Attention CUDA time total: 9.864ms
The CUDA implementation is compared with a correct, pure pytorch implementation and a official pytorch implementation of the attention mechanism in main.py. The tests show that the Flash Attention implementation is correct.
Sample results:
All close test: True Average value in torch-attention: 0.0013573728501796722 Average value in scaled_dot_prod-attention: 0.0013573728501796722 Average value in flash-attention: 0.0013573728501796722
Run tests: python3 -m test.main