Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Make IOKey dataclass derived from BaseKey #117

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/flux/auto_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def attn_block(n_channels: int, name: str | None = None):
key = block.key # type: ignore[attr-defined]
value = block.value # type: ignore[attr-defined]

shape = query.shape()
shape = query.shape

query = query.transpose((0, 2, 3, 1)).reshape((shape[0], 1, -1, shape[1]))
key = key.transpose((0, 2, 3, 1)).reshape((shape[0], 1, -1, shape[1]))
Expand All @@ -92,7 +92,7 @@ def downsample(n_channels: int):
def upsample(n_channels: int, name: str | None = None):
block = Model(enforce_jit=False, name=name) # TODO: Remove enfor jit false
input = IOKey("input")
input_shape = input.shape()
input_shape = input.shape

B, C, H, W = input_shape[0], input_shape[1], input_shape[2], input_shape[3]
input = input[:, :, :, None, :, None]
Expand Down
14 changes: 7 additions & 7 deletions examples/flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def apply_rope() -> Model:
xk = IOKey("xk")
freqs_cis = IOKey("freqs_cis")

xq_shape = xq.shape()
xk_shape = xk.shape()
xq_shape = xq.shape
xk_shape = xk.shape
B, L, H = xq_shape[0], xq_shape[1], xq_shape[2]
block += Reshape()(xq, shape=(B, L, H, -1, 1, 2), output="xq_")
B, L, H = xk_shape[0], xk_shape[1], xk_shape[2]
Expand Down Expand Up @@ -96,7 +96,7 @@ def attention() -> Model:
)

# We can get named connection as model.'connection_name'
context_shape = block.context.shape() # type: ignore[attr-defined]
context_shape = block.context.shape # type: ignore[attr-defined]
block += Transpose(axes=(0, 2, 1, 3))(block.context) # type: ignore[attr-defined]
# NOTE: Reshape input is automatically connected to Transpose output
block += Reshape()(
Expand Down Expand Up @@ -137,7 +137,7 @@ def modulation(dim: int, double: bool, name: str | None = None):
def rearrange(num_heads: int):
block = Model()
input = IOKey("input")
input_shaepe = input.shape()
input_shaepe = input.shape
B, L = input_shaepe[0], input_shaepe[1]
block += Reshape()(shape=(B, L, 3, num_heads, -1))
block += Transpose(axes=(2, 0, 3, 1, 4))(output=IOKey("output"))
Expand Down Expand Up @@ -209,7 +209,7 @@ def double_stream_block(
block += Concat(axis=2, n=2)(input1=txt_v, input2=img_v, output="v_concat")

block += attention()(q="q_concat", k="k_concat", v="v_concat", pe=pe, output="attn")
# TODO: use'[:, txt.shape()[1] :]' when fixed.
# TODO: use'[:, txt.shape[1] :]' when fixed.
img_attn = block.attn[:, 256:] # type: ignore[attr-defined]

block += Linear(hidden_size, name="img_attn_proj")(img_attn, output="img_proj")
Expand All @@ -234,7 +234,7 @@ def double_stream_block(
)
img = img + block.img_mod_2[2] * block.img_mlp # type: ignore[attr-defined]

# TODO: Use txt.shape()[1]]
# TODO: Use txt.shape[1]]
txt_attn = block.attn[:, :256] # type: ignore[attr-defined]
block += Linear(hidden_size, name="txt_attn_proj")(txt_attn, output="txt_proj")

Expand Down Expand Up @@ -355,7 +355,7 @@ def rope(dim: int, theta: int) -> Model:
omega = 1.0 / (theta ** (block.arange / dim)) # type: ignore
out = input[..., None] * omega

out_shape = out.shape()
out_shape = out.shape
B, N, D = out_shape[0], out_shape[1], out_shape[2]

block += Cosine()(out, output="cos")
Expand Down
2 changes: 1 addition & 1 deletion examples/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def causal_attention(input_dim, num_heads, bias=True):
model += Linear(input_dim * 3, name="c_attn")("input", output="c_attn_out")

t_axes = (0, 2, 1, 3)
shp_con = model.input.shape() # type: ignore
shp_con = model.input.shape # type: ignore
reshape_con = (shp_con[0], shp_con[1], num_heads, -1)

model += Split(3, axis=-1)(model.c_attn_out, output="split_out") # type: ignore
Expand Down
48 changes: 27 additions & 21 deletions mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
constant_type_table,
epsilon_table,
)
from ..utils.utils import OrderedSet, PaddingType, find_dominant_type
from ..utils.utils import PaddingType, find_dominant_type
from .utils import (
NestedListType,
align_shapes,
Expand Down Expand Up @@ -1001,6 +1001,7 @@
def len(self):
return ExtendTemplate(connections=[self], model="len")

@property
def shape(self):
return ExtendTemplate(connections=[self], model="shape")

Expand Down Expand Up @@ -1100,6 +1101,14 @@
self.output_connection = None


@dataclass
class BaseKey:
value: TensorValueType | MainValueType | ToBeDetermined | str = TBD
shape: ShapeTemplateType | None = None
type: NestedListType | UnionType | type | None = None
interval: list[float | int] | None = None


class IOKey(TemplateBase):
def __init__(
self,
Expand All @@ -1109,38 +1118,35 @@
type: NestedListType | UnionType | type | None = None,
expose: bool | None = None,
interval: list[float | int] | None = None,
connections: list[Connection | str] | None = None,
connections: set[Connection | str] | None = None,
) -> None:
super().__init__()
self._name = name
self._value = value
self._shape = shape
self._type = type
self._expose = expose
self._interval = interval
self._connections: OrderedSet[ConnectionData | str] = OrderedSet()
self.name = name
self.expose = expose
if connections is None:
connections = set()
self.connections: set[Connection | str] = connections
self.data = BaseKey(value, shape, type, interval)

# TODO: Shape should not be [] also!
if self._value is not TBD and self._shape is not None and self._shape != []:
if (
self.data.value is not TBD
and self.data.shape is not None
and self.data.shape != []
):
raise ValueError(
f"Scalar values are shapeless, shape should be None or []. "
f"Got {self._shape}."
f"Got {self.data.shape}."
)

if self._value is not TBD and self._type is not None:
value_type = find_type(self._value)
if find_intersection_type(value_type, self._type) is None:
if self.data.value is not TBD and self.data.type is not None:
value_type = find_type(self.data.value)
if find_intersection_type(value_type, self.data.type) is None:

Check warning on line 1144 in mithril/framework/common.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/common.py#L1143-L1144

Added lines #L1143 - L1144 were not covered by tests
raise TypeError(
f"type of the given value and given type does not match. Given "
f"type is {self._type} while type of value is {value_type}"
f"type is {self.data.type} while type of value is {value_type}"
)

connections = connections or []
for item in connections:
conn: ConnectionData | str
conn = item.data if isinstance(item, Connection) else item
self._connections.add(conn)


class Connection(TemplateBase):
def __init__(self, key: str, metadata: IOHyperEdge, is_key_autogenerated: bool):
Expand Down
20 changes: 8 additions & 12 deletions mithril/framework/logical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,32 +95,28 @@ def __call__(self, **kwargs: ConnectionType) -> ExtendInfo:
continue
match con:
case Connection():
kwargs[key] = IOKey(value=val, connections=[con])
kwargs[key] = IOKey(value=val, connections={con})
# TODO: Maybe we could check con's value if matches with val
case item if isinstance(item, MainValueInstance) and con != val:
raise ValueError(
f"Given value {con} for local key: '{key}' "
f"has already being set to {val}!"
)
case str():
kwargs[key] = IOKey(con, value=val, expose=False)
kwargs[key] = IOKey(name=con, value=val, expose=False)
case IOKey():
if con._value is not TBD and con._value != val:
if con.data.value is not TBD and con.data.value != val:
raise ValueError(
f"Given IOKey for local key: '{key}' is not valid!"
)
else:
_conns: list[Connection | str] = [
item.conn if isinstance(item, ConnectionData) else item
for item in con._connections
]
kwargs[key] = IOKey(
name=con._name,
name=con.name,
expose=con.expose,
connections=con.connections,
type=con.data.type,
shape=con.data.shape,
value=val,
shape=con._shape,
type=con._type,
expose=con._expose,
connections=_conns,
)
case ExtendTemplate():
raise ValueError(
Expand Down
Loading
Loading