Skip to content

Commit

Permalink
review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
aturker-synnada committed Dec 20, 2024
1 parent 8305f66 commit 39db3ef
Showing 1 changed file with 30 additions and 33 deletions.
63 changes: 30 additions & 33 deletions examples/flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy

from mithril import IOKey
from mithril.models import (
Buffer,
Expand All @@ -30,7 +32,7 @@
)


def mlp_embedder(hidden_dim, name: str | None = None):
def mlp_embedder(hidden_dim: int, name: str | None = None):
block = Model(name=name)
block += Linear(hidden_dim, name="in_layer")(input="input")
block += SiLU()
Expand Down Expand Up @@ -106,9 +108,6 @@ def qk_norm(dim: int, name: str | None = None):
query_norm = rms_norm(dim, name="query_norm")
key_norm = rms_norm(dim, name="key_norm")

query_norm.name = "query_norm"
key_norm.name = "key_norm"

block += query_norm(input="q_in", output=IOKey("q_out"))
block += key_norm(input="k_in", output=IOKey("k_out"))
return block
Expand All @@ -132,6 +131,16 @@ def modulation(dim: int, double: bool, name: str | None = None):
return block


def rearrange(num_heads: int):
block = Model()
input = IOKey("input")
input_shaepe = input.shape()
B, L = input_shaepe[0], input_shaepe[1]
block += Reshape()(shape=(B, L, 3, num_heads, -1))
block += Transpose(axes=(2, 0, 3, 1, 4))(output=IOKey("output"))
return block


def double_stream_block(
hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
):
Expand All @@ -157,16 +166,13 @@ def double_stream_block(
)

# Rearrange
img_qkv_shape = block.img_qkv.shape() # type: ignore[attr-defined]
B, L = img_qkv_shape[0], img_qkv_shape[1]
block += Reshape()(shape=(B, L, 3, num_heads, -1))
block += Transpose(axes=(2, 0, 3, 1, 4))(output="transpose_out")

img_q, img_k, img_v = (
block.transpose_out[0], # type: ignore[attr-defined]
block.transpose_out[1], # type: ignore[attr-defined]
block.transpose_out[2], # type: ignore[attr-defined]
block += rearrange(num_heads=num_heads)(
input=block.img_qkv, # type: ignore[attr-defined]
output="img_rearrange_out",
)

rearrange_out = block.txt_rearrange_out # type: ignore[attr-defined]
img_q, img_k, img_v = (rearrange_out[0], rearrange_out[1], rearrange_out[2])
block += qk_norm(hidden_size // num_heads, name="img_attn_norm")(
q_in=img_q, k_in=img_k, q_out="q_out", k_out="k_out"
)
Expand All @@ -184,16 +190,12 @@ def double_stream_block(
block += Linear(hidden_size * 3, use_bias=qkv_bias, name="txt_attn_qkv")(
txt_modulated, output=IOKey("txt_qkv")
)
txt_qkv = block.txt_qkv # type: ignore[attr-defined]

# Rearrange
txt_qkv_shape = txt_qkv.shape()
B, L = txt_qkv_shape[0], txt_qkv_shape[1]
block += Reshape()(shape=(B, L, 3, num_heads, -1))
block += Transpose(axes=(2, 0, 3, 1, 4))(output="txt_transpose_out")
block += rearrange(num_heads)(input=block.txt_qkv, output="txt_rearrange_out") # type: ignore[attr-defined]

transpose_out = block.txt_transpose_out # type: ignore[attr-defined]
txt_q, txt_k, txt_v = transpose_out[0], transpose_out[1], transpose_out[2]
rearrange_out = block.txt_rearrange_out # type: ignore[attr-defined]
txt_q, txt_k, txt_v = rearrange_out[0], rearrange_out[1], rearrange_out[2]
block += qk_norm(hidden_size // num_heads, name="txt_attn_norm")(
q_in=txt_q, k_in=txt_k, q_out="txt_q_out", k_out="txt_k_out"
)
Expand All @@ -220,6 +222,9 @@ def double_stream_block(
img_mlp += Gelu(approximate=True)
img_mlp += Linear(hidden_size, name="2")(output="output")

txt_mlp = deepcopy(img_mlp)
txt_mlp.name = "txt_mlp"

block += img_mlp(
input=(1 + block.img_mod_2[1]) * img_norm_2 + block.img_mod_2[0], # type: ignore[attr-defined]
output="img_mlp",
Expand All @@ -230,11 +235,6 @@ def double_stream_block(
txt_attn = block.attn[:, :256] # type: ignore[attr-defined]
block += Linear(hidden_size, name="txt_attn_proj")(txt_attn, output="txt_proj")

txt_mlp = Model(name="txt_mlp")
txt_mlp += Linear(mlp_hidden_dim, name="0")(input="input")
txt_mlp += Gelu(approximate=True)
txt_mlp += Linear(hidden_size, name="2")(output="output")

txt = txt + block.txt_mod_1[2] * block.txt_proj # type: ignore[attr-defined]

block += LayerNorm(use_scale=False, use_bias=False, name="txt_norm2", eps=1e-6)(
Expand Down Expand Up @@ -283,14 +283,11 @@ def single_stream_block(hidden_size: int, num_heads: int, mlp_ratio: float = 4.0
mlp = block.lin1_out[..., 3 * hidden_size :] # type: ignore[attr-defined]

# Rearrange
qkv_shape = qkv.shape()
B, L = qkv_shape[0], qkv_shape[1]
block += Reshape()(input=qkv, shape=(B, L, 3, num_heads, -1))
block += Transpose(axes=(2, 0, 3, 1, 4))(output="transpose_out")

q = block.transpose_out[0] # type: ignore[attr-defined]
k = block.transpose_out[1] # type: ignore[attr-defined]
v = block.transpose_out[2] # type: ignore[attr-defined]
block += rearrange(num_heads)(input=qkv, output="rearrange_out")

q = block.rearrange_out[0] # type: ignore[attr-defined]
k = block.rearrange_out[1] # type: ignore[attr-defined]
v = block.rearrange_out[2] # type: ignore[attr-defined]

block += qk_norm(dim=head_dim, name="norm")(
q_in=q, k_in=k, q_out="q_out", k_out="k_out"
Expand Down

0 comments on commit 39db3ef

Please sign in to comment.