Skip to content

Commit

Permalink
merged with upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
mehmetozsoy-synnada committed Feb 18, 2025
2 parents 5f80323 + d3a1142 commit 3e634ea
Show file tree
Hide file tree
Showing 27 changed files with 496 additions and 583 deletions.
2 changes: 1 addition & 1 deletion examples/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def rms_norm(dim: int, *, name: str | None = None):
block = Model(name=name)
input = IOKey("input")
weight = IOKey(
"weight", shape=[dim], differantiable=True
"weight", shape=[dim], differentiable=True
) # TODO: weight must be initialized with ones.
rrms = input / ((input**2).mean(axis=-1, keepdim=True) + 1e-5).sqrt()
block += Multiply()(left=rrms, right=weight, output=IOKey("output"))
Expand Down
19 changes: 10 additions & 9 deletions mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,6 @@ def __init__(
expose: bool | None = None,
differentiable: bool = False,
interval: list[float | int] | None = None,
connections: set[ConnectionData | str] | None = None,
) -> None:
# If shape is provided, type should be Tensor.
if shape is not None:
Expand All @@ -1202,9 +1201,6 @@ def __init__(

self.name = name
self.expose = expose
if connections is None:
connections = set()
self.connections: set[ConnectionData | str] = connections
# TODO: Shape should not be [] also!
if (
value is not TBD
Expand Down Expand Up @@ -1242,12 +1238,13 @@ class ConnectionData:
connections.
"""

def __init__(
self, key: str, metadata: IOHyperEdge, is_key_autogenerated: bool
) -> None:
def __init__(self, key: str, metadata: IOHyperEdge) -> None:
self.key = key
self.metadata = metadata
self.is_key_autogenerated = is_key_autogenerated

@property
def is_autogenerated(self) -> bool:
return self.key.startswith("$")

def __hash__(self) -> int:
return hash(id(self))
Expand All @@ -1263,6 +1260,10 @@ def set_differentiability(self, differentiable: bool = True) -> Updates:

return updates

@property
def differentiable(self) -> bool:
return self.metadata.differentiable


ShapeResultType = Mapping[str, ShapeTemplateType | list[ShapeTemplateType] | None]

Expand Down Expand Up @@ -1369,7 +1370,7 @@ def set_connection_type(
con_type: KeyType,
safe: bool = True,
) -> None:
if safe and con_type == KeyType.OUTPUT and connection.is_key_autogenerated:
if safe and con_type == KeyType.OUTPUT and connection.is_autogenerated:
raise KeyError("Connection without a name cannot be set as output")
key = connection.key
if connection in self.couts and con_type == KeyType.INTERNAL:
Expand Down
Loading

0 comments on commit 3e634ea

Please sign in to comment.