Skip to content

Commit

Permalink
Add doc strings for the added functions
Browse files Browse the repository at this point in the history
Signed-off-by: Reese Wang <rewang@nvidia.com>
  • Loading branch information
zlsh80826 committed Dec 10, 2024
1 parent fea31d4 commit dfd406f
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,33 @@ class QKVLayout(Enum):
THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD

def get_qkv_format(self):
"""
Return the corresponding qkv_format (BSHD, SBHD, THD)
"""
return QKVFormat(nvte_get_qkv_format(self.value))

def is_qkvpacked(self):
"""
Return True if the query, key, value is packed
"""
return self in [QKVLayout.BS3HD, QKVLayout.T3HD]

def is_kvpacked(self):
"""
Return True if the key, value is packed
"""
return self in [QKVLayout.BSHD_BS2HD, QKVLayout.THD_T2HD]

def is_separate(self):
"""
Return True if the query, key, value are three separate tensors
"""
return self in [QKVLayout.BSHD_BSHD_BSHD, QKVLayout.THD_THD_THD]

def is_thd(self):
"""
Return True if the layout belongs to THD
"""
return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]


Expand Down

0 comments on commit dfd406f

Please sign in to comment.