Skip to content

Commit

Permalink
Update layers.py
Browse files Browse the repository at this point in the history
  • Loading branch information
aturker-synnada authored Dec 20, 2024
1 parent b26995b commit 62d7cbe
Showing 1 changed file with 42 additions and 32 deletions.
74 changes: 42 additions & 32 deletions examples/flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@
# limitations under the License.

from mithril import IOKey
from mithril.models import Buffer, Model, Reshape, ScaledDotProduct, Transpose, Arange, Sine, Cosine, Concat
from mithril.models import (
Arange,
Buffer,
Concat,
Cosine,
Model,
Reshape,
ScaledDotProduct,
Sine,
Transpose,
)


def rms_norm(dim: int) -> Model:
Expand Down Expand Up @@ -78,45 +88,45 @@ def attention() -> Model:
return block


def embed_nd(theta: int, axes_dim: list[int]) -> Model:
block = Model()
input = IOKey("input")

for i in range(len(axes_dim)):
rope_B = rope(axes_dim[i], theta)
block += rope_B(input=input[..., i], output=f"out{i}")

block += Concat(n=len(axes_dim), axis=-3)(
**{f"input{i+1}": f"out{i}" for i in range(len(axes_dim))}, output="concat_out"
)

block += Buffer()(block.concat_out[:, None], output=IOKey("output"))

return block


def rope(dim: int, theta: int) -> Model:
assert dim % 2 == 0
block = Model()
input = IOKey("input")

block += Arange(0, dim, 2)(output="arange")
omega = 1 / (theta**(block.arange / dim))

input = input[..., None]
out = input * omega
block += Arange(start=0, stop=dim, step=2)(output="arange")

omega = 1.0 / (theta ** (block.arange / dim)) # type: ignore
out = input[..., None] * omega

out_shape = out.shape()
B, N, D = out_shape[0], out_shape[1], out_shape[2]

block += Cosine()(out, output="cos")
block += Sine()(out, output="sin")

block += Concat(n=4, axis=-1)(input1=block.cos.reshape(shape=(B, N, D, 1)),
input2=-block.sin.reshape(shape=(B, N, D, 1)),
input3=block.sin.reshape(shape=(B, N, D, 1)),
input4=block.cos.reshape(shape=(B, N, D, 1)))

block += Concat(n=4, axis=-1)(
input1=block.cos[..., None], # type: ignore
input2=-block.sin[..., None], # type: ignore
input3=block.sin[..., None], # type: ignore
input4=block.cos[..., None], # type: ignore
)
rope_shape = (B, N, D, 2, 2)
block += Reshape()(shape=rope_shape, output=IOKey("output"))
return block


def EmbedND(dim: int, theta: int, axes_dim: list[int]) -> Model:
block = Model()
input = IOKey("input")
# in original implementation range equal to range(n_axes) but
# n_axes=input.shape()[-1] can't be interpreted as an integer so i can't solve this how can i use it as an integer
n_axes=3

for i in range(len(axes_dim)): # in original implementation range equal to range(n_axes) but
rope_B = rope(axes_dim[i], theta)
block += rope_B(input=input[..., i], output=f"out{i}")

block += Concat(n=len(axes_dim), axis=-3)(**{f"input{i+1}": f"out{i}"
for i in range(len(axes_dim))}, output="concat_out")

block += Buffer()(block.concat_out[:, None], output=IOKey("output"))

block.set_canonical_input("input")
return block

0 comments on commit 62d7cbe

Please sign in to comment.