Skip to content

Commit

Permalink
edge_type_constraints is removed
Browse files Browse the repository at this point in the history
  • Loading branch information
mehmetozsoy-synnada committed Feb 17, 2025
1 parent e4cecd5 commit 107d667
Showing 1 changed file with 0 additions and 33 deletions.
33 changes: 0 additions & 33 deletions mithril/framework/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
IOHyperEdge,
MaxNestedListDepth,
PossibleValues,
ScalarValueType,
ShapeRepr,
Tensor,
ToBeDetermined,
Expand All @@ -58,7 +57,6 @@
from .utils import find_list_base_type, is_union

__all__ = [
"edge_type_constraint",
"general_tensor_type_constraint",
"scalar_slice_type_constraint",
"indexer_initial_type_constraint",
Expand Down Expand Up @@ -184,36 +182,6 @@ def set_edge_type(edge: IOHyperEdge, new_type: Any) -> Updates:
return edge.set_type(type)


def edge_type_constraint(
output: IOHyperEdge, *inputs: IOHyperEdge
) -> ConstrainResultType:
updates = Updates()

### Forward Inference ###
if any(input.is_tensor for input in inputs):
# if any Tensor input exists, set output type to Tensor.
updates |= output.set_type(Tensor[int | float | bool])

elif all(input.is_scalar for input in inputs):
# if all types of input is scalar, set output type to scalar.
updates |= output.set_type(ScalarValueType)

### Reverse Inference ###
elif output.is_tensor:
# If there is only one untyped input, set it as Tensor.
untyped_inputs = {input for input in inputs if input.is_polymorphic}
if len(untyped_inputs) == 1:
updates |= untyped_inputs.pop().set_type(Tensor[int | float | bool])

elif output.is_scalar:
# Scalar output means all inputs are scalar.
for input in inputs:
updates |= input.set_type(ScalarValueType)

# If no polymorphic edge_type exists, return True.
return not any(input.is_polymorphic for input in inputs), updates


def general_forward_constraint(
*keys: IOHyperEdge, callable: Callable[..., Any]
) -> ConstrainResultType:
Expand Down Expand Up @@ -4159,7 +4127,6 @@ def polynomial_kernel_constraint(
constrain_fn_dict = {key: fn for key, fn in globals().items() if callable(fn)}

constraint_type_map: dict[ConstraintFunctionType, list[UpdateType]] = {
edge_type_constraint: [UpdateType.TYPE],
general_tensor_type_constraint: [UpdateType.TYPE],
scalar_slice_type_constraint: [UpdateType.TYPE],
indexer_initial_type_constraint: [UpdateType.TYPE],
Expand Down

0 comments on commit 107d667

Please sign in to comment.