Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add VALUE type to constraints #176

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5f9193c
feat: Now People will be assigned based on applied labels, github act…
mehmetozsoy-synnada Nov 14, 2024
8309838
feat: Now actions bot requests a change in the case of tests failed
mehmetozsoy-synnada Nov 14, 2024
aea5265
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Nov 19, 2024
5916d8f
fix: bug fix attempt 1
mehmetozsoy-synnada Nov 21, 2024
f3e9083
fix: resolve conflicts
mehmetozsoy-synnada Nov 22, 2024
6365adb
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Nov 27, 2024
78ccbfa
merged with upstream
mehmetozsoy-synnada Dec 2, 2024
7c2ce69
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 4, 2024
74ce570
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 5, 2024
7835062
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 12, 2024
6533572
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 17, 2024
202366e
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 20, 2024
fea39af
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 25, 2024
4ebe35a
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 26, 2024
4bc57d2
merged with upstream
mehmetozsoy-synnada Dec 26, 2024
796a447
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 27, 2024
7293b6e
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 27, 2024
408cddf
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 27, 2024
b32eecc
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 31, 2024
d9b3b5c
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Jan 2, 2025
91b447a
merged with upstream
mehmetozsoy-synnada Jan 22, 2025
5618042
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Jan 23, 2025
eb5e320
merged with upstream
mehmetozsoy-synnada Jan 23, 2025
f3dff98
constraint graph structure is implemented
mehmetozsoy-synnada Jan 26, 2025
c3ed35e
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Jan 26, 2025
1f4f379
Merge branch 'main' of https://github.com/mehmetozsoy-synnada/mithril…
mehmetozsoy-synnada Jan 26, 2025
c83f970
set_constraint name changed to add_constraint
mehmetozsoy-synnada Jan 26, 2025
9102b59
comments are added
mehmetozsoy-synnada Jan 26, 2025
f34bc9f
'VALUE' UpdateType is introduced. all constraints are seperated into …
mehmetozsoy-synnada Jan 28, 2025
8c2f298
reviews are applied
mehmetozsoy-synnada Jan 29, 2025
bbeb3cc
conflicts are resolved partially
mehmetozsoy-synnada Jan 30, 2025
1edf5c6
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Jan 30, 2025
f4c4c1d
small bug in argmax test is fixed, merged with main
mehmetozsoy-synnada Jan 30, 2025
532b788
reviews are applied
mehmetozsoy-synnada Feb 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 33 additions & 45 deletions mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class ToBeDetermined(SingletonObject):
class UpdateType(Enum):
SHAPE = 1
TYPE = 2
VALUE = 3


class KeyType(Enum):
Expand Down Expand Up @@ -338,10 +339,9 @@ def solver_loop(self, constraints: set[Constraint]) -> Updates:
while constraints:
constr = constraints.pop()
if (not constr.parents) and (constr in self.constraint_map):
constraint_type = constr.type
hyper_edges = self.constraint_map[constr]
status, newly_added_symbols = constr(hyper_edges)
if constraint_type is UpdateType.SHAPE:
if UpdateType.SHAPE in constr.types:
self.update_shapes(newly_added_symbols)
updates |= newly_added_symbols
new_constraints = newly_added_symbols.constraints
Expand Down Expand Up @@ -533,41 +533,27 @@ def add(
symbol: IOHyperEdge | Uniadic | Variadic,
update_type: UpdateType = UpdateType.SHAPE,
) -> None:
# TODO: Use match case here
if update_type == UpdateType.SHAPE:
if isinstance(symbol, Uniadic):
match symbol:
case Uniadic():
self._add_uniadic(symbol)
elif isinstance(symbol, Variadic):
case Variadic():
self._add_variadic(symbol)
else:
self._add_edge(symbol)

# TODO: Fill here after type_updates added to class
elif update_type == UpdateType.TYPE:
assert isinstance(symbol, IOHyperEdge)
self._add_type_update(symbol)

def _add_edge(self, symbol: IOHyperEdge) -> None:
self.value_updates.add(symbol)
self.constraints |= symbol.shape_constraints
case IOHyperEdge():
self.constraints |= symbol.constraints[update_type]

def _add_uniadic(self, symbol: Uniadic) -> None:
self.uniadic_updates.add(symbol)
for repr in symbol.metadata.reprs_dict:
for edge in repr.node.referees:
self.shape_updates.add(edge)
self.constraints |= edge.shape_constraints
self.constraints |= edge.constraints[UpdateType.SHAPE]

def _add_variadic(self, symbol: Variadic) -> None:
# self.symbol_updates.add(symbol)
for repr in symbol.reprs:
self.node_updates.add(repr.node)
for edge in repr.node.referees:
self.shape_updates.add(edge)
self.constraints |= edge.shape_constraints

def _add_type_update(self, symbol: IOHyperEdge) -> None:
self.constraints |= symbol.type_constraints
self.constraints |= edge.constraints[UpdateType.SHAPE]

def __ior__(self, other: Updates) -> Updates:
self.constraints |= other.constraints
Expand Down Expand Up @@ -896,7 +882,8 @@ def set_value(self, value: TensorValueType) -> Updates:
updates |= self.shape.set_values(shape)
# Add all referee edges into the updates.
for edge in self.referees:
updates.add(edge)
updates.add(edge, update_type=UpdateType.VALUE)
updates.add(edge, update_type=UpdateType.SHAPE)
self.value = val
return updates

Expand Down Expand Up @@ -947,8 +934,9 @@ def __init__(
interval: list[float | int] | None = None,
) -> None:
self.key_origin = key_origin
self.shape_constraints: set[Constraint] = set()
self.type_constraints: set[Constraint] = set()
self.constraints: dict[UpdateType, set[Constraint]] = {
type: set() for type in UpdateType
}
self._temp_shape: ShapeRepr | None = None # set random repr
self.differentiable: bool = False
self.interval: list[float | int] | None = interval
Expand Down Expand Up @@ -985,7 +973,8 @@ def is_valued(self) -> bool:

@property
def all_constraints(self) -> set[Constraint]:
return self.shape_constraints | self.type_constraints
result: set[Constraint] = set().union(*self.constraints.values())
return result

@property
def value(self) -> _TensorValueType | ScalarValueType | ToBeDetermined:
Expand Down Expand Up @@ -1103,6 +1092,7 @@ def set_value(
self._value.shape.referees.add(self)
# Add self as a type update since type has just updated to Tensor.
updates.add(self, UpdateType.TYPE)
updates.add(self, UpdateType.SHAPE)
# TODO: When two edges set to the same tensor value using
# different Tensor objects, we need to merge their nodes into
# a single node. In order to track this, we need to add all
Expand All @@ -1113,8 +1103,9 @@ def set_value(
else:
updates |= self.set_type(_find_type(value))
self._value = value
updates.add(self, UpdateType.VALUE)
updates.value_updates.add(self)
# Add self to updates.
updates.add(self)
self.differentiable = self.value is TBD
return updates

Expand All @@ -1139,11 +1130,11 @@ def match(self, other: IOHyperEdge) -> Updates:
updates.value_updates.discard(other)
updates.shape_updates.discard(other)
# After modifications done, propagate other constraints into self.
self.shape_constraints |= other.shape_constraints
self.type_constraints |= other.type_constraints
# Set other's constraints to empty.
other.shape_constraints = set()
other.type_constraints = set()

for type in UpdateType:
self.constraints[type] |= other.constraints[type]
other.constraints[type] = set()

# Update differentiability.
if isinstance(self._value, Tensor) and self._value.value is TBD:
is_diff = self.differentiable | other.differentiable
Expand All @@ -1152,15 +1143,12 @@ def match(self, other: IOHyperEdge) -> Updates:
return updates

def add_constraint(self, constraint: Constraint) -> None:
if constraint.type == UpdateType.SHAPE:
self.shape_constraints.add(constraint)
elif constraint.type == UpdateType.TYPE:
self.type_constraints.add(constraint)
for type in constraint.types:
self.constraints[type].add(constraint)

def remove_constraint(self, constraint: Constraint) -> None:
# TODO: check why pop raises!
self.shape_constraints.discard(constraint)
self.type_constraints.discard(constraint)
for type in constraint.types:
self.constraints[type].discard(constraint)


class TemplateBase:
Expand Down Expand Up @@ -2651,7 +2639,7 @@ def merge(self, other: ShapeNode) -> Updates:

if add_constraint:
for tensor in self.referees:
updates.constraints |= tensor.shape_constraints
updates.constraints |= tensor.constraints[UpdateType.SHAPE]

for repr in resolved_reprs:
# remove_repr_from_symbols(repr)
Expand Down Expand Up @@ -3103,25 +3091,25 @@ def clear(self) -> None:
@dataclass
class Constraint:
fn: ConstraintFunctionType
type: UpdateType = UpdateType.SHAPE
types: list[UpdateType] = field(default_factory=lambda: [UpdateType.SHAPE])
call_counter: int = 0
parents: set[Constraint] = field(default_factory=lambda: set())
children: set[Constraint] = field(default_factory=lambda: set())

def __call__(self, keys: list[IOHyperEdge]) -> ConstrainResultType:
status = False
updates = Updates()
if self.type == UpdateType.SHAPE:
if UpdateType.SHAPE in self.types:
tensor_keys = [key for key in keys if key.is_tensor]
for reprs in product(*[key.shape.reprs for key in tensor_keys]): # type: ignore
for idx, repr in enumerate(reprs):
tensor_keys[idx]._temp_shape = repr
status, newly_added_symbols = self.fn(*keys)
updates |= newly_added_symbols
# Clear temp_shape.
for idx, _ in enumerate(reprs):
for idx in range(len(reprs)):
tensor_keys[idx]._temp_shape = None
elif self.type == UpdateType.TYPE:
else:
status, newly_added_symbols = self.fn(*keys)
updates |= newly_added_symbols
if status:
Expand Down
52 changes: 32 additions & 20 deletions mithril/framework/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,13 @@
"tensor_to_list_constraints",
"to_list_constraints",
"where_constrains",
"validate_bcast",
"eye_constraints",
"item_constraints",
"indexer_constraints",
"to_tuple_constraints",
"tensor_to_list_type_constraint",
"reduce_type_constraint",
"type_constraints",
"constraint_type_map",
"padding_1d_constraint",
"padding_2d_constraint",
"stride_constraint",
Expand Down Expand Up @@ -3234,17 +3233,18 @@
assert isinstance(N.value, int)
n_uni.set_value(N.value)
updates.add(n_uni)

elif n_uni_valued and not n_valued:
N.set_value(n_uni.value)
updates.add(N)
updates |= N.set_value(n_uni.value)

Check warning on line 3238 in mithril/framework/constraints.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/constraints.py#L3238

Added line #L3238 was not covered by tests

if m_valued and not m_uni_valued:
assert isinstance(M.value, int | NoneType)
m_uni.set_value(M.value)
updates.add(m_uni)

elif m_uni_valued and not m_valued:
M.set_value(m_uni.value)
updates.add(M)
updates |= M.set_value(m_uni.value)

Check warning on line 3246 in mithril/framework/constraints.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/constraints.py#L3246

Added line #L3246 was not covered by tests

all_items: list[IOHyperEdge | Uniadic] = [N, M, n_uni, m_uni]
return all(isinstance(s.value, int) for s in all_items), updates

Expand Down Expand Up @@ -3515,8 +3515,7 @@
# output value appears only once in input sequence, write its
# index as the value of index argument.
if input.value.count(output.value) == 1:
index.set_value(input.value.index(output.value))
updates._add_edge(index)
updates |= index.set_value(input.value.index(output.value))
status = True
return status, updates

Expand Down Expand Up @@ -3998,16 +3997,29 @@

constrain_fn_dict = {key: fn for key, fn in globals().items() if callable(fn)}

type_constraints: set[ConstraintFunctionType] = {
edge_type_constraint,
general_tensor_type_constraint,
floor_divide_type_constraint,
scalar_slice_type_constraint,
indexer_initial_type_constraint,
indexer_type_constraint,
tensor_to_list_type_constraint,
reduce_type_constraint,
relational_operator_type_constraint,
divide_type_constraint,
buffer_constraint,
constraint_type_map: dict[ConstraintFunctionType, list[UpdateType]] = {
edge_type_constraint: [UpdateType.TYPE],
general_tensor_type_constraint: [UpdateType.TYPE],
floor_divide_type_constraint: [UpdateType.TYPE],
scalar_slice_type_constraint: [UpdateType.TYPE],
indexer_initial_type_constraint: [UpdateType.TYPE],
indexer_type_constraint: [UpdateType.TYPE],
slice_constraints: [UpdateType.VALUE],
bcast: [UpdateType.SHAPE],
bcast_matrix_mult: [UpdateType.SHAPE],
to_tensor_constraints: [UpdateType.SHAPE, UpdateType.TYPE],
tensor_to_list_constraints: [UpdateType.SHAPE, UpdateType.TYPE],
to_list_constraints: [UpdateType.VALUE],
where_constrains: [UpdateType.SHAPE],
item_constraints: [UpdateType.SHAPE],
to_tuple_constraints: [UpdateType.VALUE],
tensor_to_list_type_constraint: [UpdateType.TYPE],
reduce_type_constraint: [UpdateType.TYPE],
padding_1d_constraint: [UpdateType.VALUE],
padding_2d_constraint: [UpdateType.VALUE],
stride_constraint: [UpdateType.VALUE],
tuple_converter_constraint: [UpdateType.VALUE],
buffer_constraint: [UpdateType.TYPE, UpdateType.VALUE],
relational_operator_type_constraint: [UpdateType.TYPE],
divide_type_constraint: [UpdateType.TYPE],
}
16 changes: 6 additions & 10 deletions mithril/framework/logical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
create_shape_repr,
get_shapes,
)
from ..constraints import type_constraints
from ..constraints import constraint_type_map

__all__ = ["BaseModel", "ExtendInfo"]

Expand Down Expand Up @@ -421,7 +421,7 @@ def _add_constraint(
self,
fn: ConstraintFunctionType,
keys: list[str],
type: UpdateType | None = None,
types: list[UpdateType] | None = None,
dependencies: set[Constraint] | None = None,
) -> Constraint:
all_conns = self.conns.all
Expand All @@ -432,13 +432,9 @@ def _add_constraint(
unresolved_dependencies = (
dependencies & self.constraint_solver.constraint_map.keys()
)
if type is None:
# TODO: separate type_constraints and shape constraints into two files under
# constraints folder. Then, check if fn is not in any of those types set
# _type to None. If _type and type are both None or one is UpdateType.SHAPE
# while other one is UpdateType.Type, raise Exception!
type = UpdateType.TYPE if fn in type_constraints else UpdateType.SHAPE
constr = Constraint(fn=fn, type=type)
if types is None:
types = constraint_type_map.get(fn, [UpdateType.SHAPE, UpdateType.VALUE])
constr = Constraint(fn=fn, types=types)
constr.add_dependencies(*unresolved_dependencies)

self.constraint_solver.constraint_map[constr] = hyper_edges
Expand All @@ -452,7 +448,7 @@ def add_constraint(
self,
fn: ConstraintFunctionType,
keys: list[str],
type: UpdateType = UpdateType.SHAPE,
type: list[UpdateType] | None = None,
dependencies: set[Constraint] | None = None,
) -> Constraint:
self.assigned_constraints.append({"fn": fn.__name__, "keys": keys})
Expand Down
17 changes: 13 additions & 4 deletions mithril/framework/physical/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Tensor,
ToBeDetermined,
Updates,
UpdateType,
is_type_adjustment_required,
)
from .flat_graph import FlatGraph
Expand Down Expand Up @@ -126,12 +127,20 @@ def _clear_constraints(self, key: str) -> None:
if key not in self._all_data:
return

shape_constraints = self._all_data[key].shape_constraints
type_constraints = self._all_data[key].type_constraints
shape_constraints = self._all_data[key].constraints[UpdateType.SHAPE]
type_constraints = self._all_data[key].constraints[UpdateType.TYPE]
value_constraints = self._all_data[key].constraints[UpdateType.VALUE]
for source_key in self.graph.get_source_keys(key):
if source_key in self._all_data:
self._all_data[source_key].shape_constraints -= shape_constraints
self._all_data[source_key].type_constraints -= type_constraints
self._all_data[source_key].constraints[UpdateType.SHAPE] -= (
shape_constraints
)
self._all_data[source_key].constraints[UpdateType.TYPE] -= (
type_constraints
)
self._all_data[source_key].constraints[UpdateType.VALUE] -= (
value_constraints
)

def update_cached_data(self, updated_data: Updates) -> set[str]:
# If any data value is found by shape inference algorithms
Expand Down
3 changes: 2 additions & 1 deletion mithril/framework/physical/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
ToBeDetermined,
UniadicRecord,
Updates,
UpdateType,
Variadic,
create_shape_map,
find_intersection_type,
Expand Down Expand Up @@ -552,7 +553,7 @@
# there can exist some inferred intermediate scalar keys in logical model.
# find those keys and add to cached datas
if not value.is_tensor and (value.value is not TBD):
updates.add(value)
updates.add(value, update_type=UpdateType.VALUE)

Check warning on line 556 in mithril/framework/physical/model.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/physical/model.py#L556

Added line #L556 was not covered by tests

self.data_store.update_cached_data(updates)

Expand Down
3 changes: 2 additions & 1 deletion tests/scripts/test_backend_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,8 @@ def test_argmax(self, backendcls, device, dtype):
]
fn_kwargs: dict = {}

ref_output = array_fn([1, 1], device, "int64")
precision = "int32" if backendcls is MlxBackend else "int64"
ref_output = array_fn([1, 1], device, precision)

assert_backend_results_equal(
backend,
Expand Down
Loading