From 42c31ab3c4b390d20901e59ec9a931959af6783e Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Mon, 2 Dec 2024 11:08:28 +0000 Subject: [PATCH] Add doc strings for the added functions Signed-off-by: Reese Wang --- transformer_engine/jax/attention.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index efe75bbcb9..53451b6a78 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -103,18 +103,33 @@ class QKVLayout(Enum): THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD def get_qkv_format(self): + """ + Return the corresponding qkv_format (BSHD, SBHD, THD) + """ return QKVFormat(nvte_get_qkv_format(self.value)) def is_qkvpacked(self): + """ + Return True if the query, key, value is packed + """ return self in [QKVLayout.BS3HD, QKVLayout.T3HD] def is_kvpacked(self): + """ + Return True if the key, value is packed + """ return self in [QKVLayout.BSHD_BS2HD, QKVLayout.THD_T2HD] def is_separate(self): + """ + Return True if the query, key, value are three separate tensors + """ return self in [QKVLayout.BSHD_BSHD_BSHD, QKVLayout.THD_THD_THD] def is_thd(self): + """ + Return True if the layout belongs to THD + """ return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]