[Feature Request] Let ttnn.transformer.scaled_dot_product_attention
support dropout_p
(dropout probability)
#16022
Labels
feature-request
External feature request
Is your feature request related to a problem? Please describe.
I am trying to lower
aten._scaled_dot_product_flash_attention
tottnn.transformer.scaled_dot_product_attention
. The related issues are:ttnn.transformer.scaled_dot_product_attention
infers 0 batch size from 1 #16021aten._scaled_dot_product_flash_attention
pytorch2.0_ttnn#569Describe the solution you'd like
Add a floating-point parameter
dropout_p
tottnn.transformer.scaled_dot_product_attention
. Its behavior should matchtorch.nn.functional.scaled_dot_product_attention
.Describe alternatives you've considered
Split out the final step of matrix multiplication, so we can insert a dropout op there?
The text was updated successfully, but these errors were encountered: