Skip to content

Commit

Permalink
feat: Add state output logic (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
kberat-synnada authored Mar 8, 2025
1 parent d506aae commit 048686e
Show file tree
Hide file tree
Showing 7 changed files with 875 additions and 40 deletions.
14 changes: 14 additions & 0 deletions mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ class KeyType(Enum):
LATENT_OUTPUT = 5


class StateValue(Enum):
ZEROS = 0
ONES = 1
# RANDOM = 0 # TODO: Implement random state value


@dataclass
class StateKey:
in_key: str
out_key: str
is_exposed: bool
initial_value: MainValueInstance | StateValue | NullConnection = NOT_GIVEN


type FixedValueType = (
None
| int
Expand Down
51 changes: 51 additions & 0 deletions mithril/framework/logical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
ShapeNode,
ShapeTemplateType,
ShapeType,
StateValue,
Tensor,
ToBeDetermined,
UniadicRecord,
Expand All @@ -53,6 +54,9 @@
__all__ = ["BaseModel", "BaseKey", "ConnectionData", "ConnectionDataType"]


StateValueType = StateValue | MainValueInstance | NullConnection


class ConnectionData:
# TODO: This class updated as mutable. Update docstrings accordingly!
"""Immutable dataclass object which holds model instance, key
Expand Down Expand Up @@ -393,6 +397,9 @@ def __init__(
self.safe_shapes: dict[str, ShapeTemplateType] = {}
self.is_frozen = False
self.inter_key_count = 0
self.state_connections: dict[
ConnectionData, tuple[ConnectionData, StateValueType]
] = {}

@property
def formula_key(self) -> str | None:
Expand Down Expand Up @@ -439,6 +446,49 @@ def expose_keys(
connections.append(conn_data)
self.dependency_map.update_globals(OrderedSet(connections))

def bind_state_keys(
self,
input: ConnectionData | str,
output: ConnectionData | str,
initial_value: StateValueType = NOT_GIVEN,
) -> None:
if self.is_frozen:
raise AttributeError("Frozen model's bind_state_keys is not allowed!")
# Get connections.
in_con = self.conns.get_extracted_connection(input)
out_con = self.conns.get_extracted_connection(output)
if self.conns.get_type(in_con) not in {KeyType.INPUT, KeyType.LATENT_INPUT}:
raise KeyError("Input connection should be an input key!")
if self.conns.get_type(out_con) in {KeyType.INPUT, KeyType.LATENT_INPUT}:
raise KeyError("Output connection should be an output key!")
for _out, (_in, _) in self.state_connections.items():
if _in.metadata is in_con.metadata or _out.metadata is out_con.metadata:
raise KeyError("Binded connections could not be binded again!")

# Set connection types to latent.
self.conns.set_connection_type(in_con, KeyType.LATENT_INPUT)
self.conns.couts.discard(out_con)
if self.conns.get_type(out_con) == KeyType.OUTPUT:
self.conns.set_connection_type(out_con, KeyType.LATENT_OUTPUT)

updates = Updates()
# Set differentiability of input connection to False.
updates |= in_con.set_differentiability(False)
# Merge types.
updates |= in_con.metadata.set_type(out_con.metadata._type)
updates |= out_con.metadata.set_type(in_con.metadata._type)
if in_con.metadata.is_tensor:
# Merge shapes if connections are Tensors.
assert isinstance(in_con.metadata._value, Tensor)
assert isinstance(out_con.metadata._value, Tensor)
updates |= in_con.metadata._value.match_shapes(
out_con.metadata._value.shape
)
self.constraint_solver(updates)

# Save state connections.
self.state_connections[out_con] = (in_con, initial_value)

def _check_multi_write(
self,
local_input: bool,
Expand Down Expand Up @@ -830,6 +880,7 @@ def extend(
model._freeze()

updates = Updates()
self.state_connections |= model.state_connections

shape_info: dict[str, ShapeTemplateType] = {}
submodel_dag: dict[str, ConnectionData] = {}
Expand Down
Loading

0 comments on commit 048686e

Please sign in to comment.