Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Range type support #167

Merged
merged 4 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/gpt/run_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def run_sample(
):
# TODO: This recursion limit is minimum we can have for now.
# We may further improve this limit in the future.
sys.setrecursionlimit(734)
# NOTE: Recursion limit is set according to the pytest function call
# limit. If you run this script directly, limit can be set to 692.
sys.setrecursionlimit(716)
# Model Configuration
block_size = 100
gpt = create_gpt(
Expand Down
142 changes: 96 additions & 46 deletions mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
Dtype,
constant_type_table,
)
from ..utils.utils import PaddingType, find_dominant_type
from ..utils.utils import PaddingType
from .utils import (
align_shapes,
find_intersection_type,
Expand Down Expand Up @@ -241,33 +241,30 @@ class KeyType(Enum):
)
# Availale types for Tensor type ("_type" attribute of Tensor class).
_TensorTypes = type[int] | type[float] | type[bool] | UnionType
# Nested Sequence type values for Tensor class.
_TensorValueType = (
int
| float
| bool
| Sequence[int | float | bool]
| Sequence[Sequence[int | float | bool]]
| Sequence[Sequence[Sequence[int | float | bool]]]
| Sequence[Sequence[Sequence[Sequence[int | float | bool]]]]
| Sequence[Sequence[Sequence[Sequence[Sequence[int | float | bool]]]]]
ValType = int | float | bool
SequenceValType = (
Sequence[ValType]
| Sequence[Sequence[ValType]]
| Sequence[Sequence[Sequence[ValType]]]
| Sequence[Sequence[Sequence[Sequence[ValType]]]]
| Sequence[Sequence[Sequence[Sequence[Sequence[ValType]]]]]
)
ListValType = (
list[ValType]
| list[list[ValType]]
| list[list[list[ValType]]]
| list[list[list[list[ValType]]]]
| list[list[list[list[list[ValType]]]]]
)
# Nested Sequence type values for Tensor class.
_TensorValueType = ValType | SequenceValType
# Logical value types for Tensor class (i.e. "value" attribute of
# Tensor class).
TensorValueType = _TensorValueType | Constant

# TODO: This kind of type definitions will be updated as recursive
# definitions when mypy supports recursive types.
TensorToListType = (
int
| float
| bool
| list[int | float | bool]
| list[list[int | float | bool]]
| list[list[list[int | float | bool]]]
| list[list[list[list[int | float | bool]]]]
| list[list[list[list[list[int | float | bool]]]]]
)
TensorToListType = ValType | ListValType

MaxNestedListDepth = 5

Expand Down Expand Up @@ -637,9 +634,19 @@ def get_shapes(
AllValueType = TensorValueType | ScalarValueType | ToBeDetermined


@overload
def _find_type(value: Constant) -> type[int] | type[float] | type[bool]: ...
@overload
def _find_type(value: Tensor[Any]) -> type[Tensor[Any]]: ...
@overload
def _find_type(value: range) -> list[int]: ...
@overload
def _find_type(value: ScalarValueType) -> ScalarType: ...


def _find_type(
value: Tensor[Any] | ScalarValueType,
) -> type[Tensor[Any]] | ScalarType:
value: Tensor[Any] | ScalarValueType | range,
) -> type[Tensor[Any]] | ScalarType | list[int]:
typ: type
if isinstance(value, Tensor):
typ = Tensor[value.type] # type: ignore
Expand All @@ -650,23 +657,54 @@ def _find_type(
return typ


def list_shape(ndarray: TensorValueType) -> list[int]:
# TODO: Handle TOBeDetermined case.
if isinstance(ndarray, list | tuple):
# More dimensions, so make a recursive call
outermost_size = len(ndarray)
row_shape = list_shape(ndarray[0])
for item in ndarray[1:]:
shape = list_shape(item)
if row_shape != shape:
raise ValueError(
f"Shape mismatch: expected {row_shape}, but got {shape}. The list "
"should not be ragged."
)
return [outermost_size, *row_shape]
else:
# No more dimensions, so we're done
return []
def process_value(
value: TensorValueType,
) -> tuple[list[int], TensorValueType, type[int] | type[float] | type[bool]]:
def check_uniformity(sublist: _TensorValueType) -> None:
"""Check if all sublists have the same length."""
if isinstance(sublist, Sequence):
lengths = [
len(item) if isinstance(item, Sequence) else -1 for item in sublist
]
if len(set(lengths)) > 1:
raise ValueError("Inconsistent dimensions found in the list.")

# If value is not a sequence, directly return empty shape, value and
# its type directly.
if not isinstance(value, tuple | list | range):
return (
[],
value,
type(value) if not isinstance(value, Constant) else _find_type(value),
) # type: ignore

# Convert range types into list.
if isinstance(value, range):
value = list(value)

# Check for incompatible dimensions.
check_uniformity(value)

# Initialize result as an empty sequence of same type as value.
result: list[Any] | tuple[Any, ...] = list() if isinstance(value, list) else tuple()

dominant_type: type[bool] | type[int] | type[float] = bool
for item in value:
# Recursively determine the shape, value and type of sublists.
sub_shape, sub_val, sub_type = process_value(item)
assert not isinstance(sub_val, Constant)

if isinstance(result, list):
result.append(sub_val)
else:
result += (sub_val,)

if sub_type is float:
dominant_type = float
elif sub_type is int and dominant_type is bool:
dominant_type = int

return [len(result)] + sub_shape, result, dominant_type


class Tensor(Generic[TypeVarTensorType]):
Expand Down Expand Up @@ -709,17 +747,18 @@ def set_value(self, value: TensorValueType) -> Updates:
f"Value is set before as {self.value}. A value can not be reset."
)
updates = Updates()
# Find and set type.
updates |= self.set_type(find_dominant_type(value))
# Set value.
if self.value is TBD:
# Infer shape, final_value and type from the value.
shape, val, typ = process_value(value)
# Set type.
updates |= self.set_type(typ)
# Set shape.
updates |= self.shape.set_values(shape)
# Add all referee edges into the updates.
for edge in self.referees:
updates.add(edge)
self.value = value
# Infer shape from the value and set.
shape = list_shape(value)
updates |= self.shape.set_values(shape)
self.value = val
return updates

def match(self, other: Tensor[Any]) -> Updates:
Expand Down Expand Up @@ -1276,6 +1315,17 @@ class BaseKey:
type: UnionType | type | type[Tensor[Any]] | ScalarType | None = None
interval: list[float | int] | None = None

# TODO: Add __post_init__ to check types and values
# def __post_init__(self) -> None:
# if not isinstance(self.value, ToBeDetermined):
# value_type = _find_type(self.value)
# if self.type is not None and
# find_intersection_type(value_type, self.type) is None:
# raise TypeError(
# f"type of the given value and given type does not match. Given "
# f"type is {self.type} while type of value is {value_type}"
# )


class IOKey(TemplateBase):
def __init__(
Expand Down
13 changes: 6 additions & 7 deletions mithril/framework/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
is_tuple_int_or_none,
is_tuple_of_two_ints,
)
from ..utils.utils import PaddingType, find_dominant_type
from ..utils.utils import PaddingType
from .common import (
DNF,
TBD,
Expand All @@ -49,7 +49,7 @@
UpdateType,
Variadic,
_TensorTypes,
list_shape,
process_value,
)
from .utils import (
find_intersection_type,
Expand Down Expand Up @@ -3404,11 +3404,9 @@ def to_tensor_constraints(
), "Invalid input value!"

if not isinstance(input_val, ToBeDetermined):
shape = []
shape: list[int] = []
if isinstance(input_val, list | tuple):
# list_shape function takes only list type input.
shape = [idx for idx in list_shape(list(input_val))]
typ = find_dominant_type(input_val)
shape, _, typ = process_value(input_val)
updates |= output.set_type(Tensor[typ]) # type: ignore
updates.add(output, update_type=UpdateType.TYPE)
elif isinstance(input_val, float | int):
Expand Down Expand Up @@ -3453,7 +3451,8 @@ def tensor_to_list_constraints(
if not isinstance(output_val, ToBeDetermined):
shape: list[Uniadic] = []
if isinstance(output_value, list | tuple):
shape = [Uniadic(idx) for idx in list_shape(list(output_val))]
shp, *_ = process_value(output_val)
shape = [Uniadic(idx) for idx in shp]

updates |= input_shape.inner_match(prefix=shape)
status = True
Expand Down
17 changes: 11 additions & 6 deletions mithril/framework/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def find_intersection_type(
subtypes_2 = set(type_2.__args__) if type(type_2) is UnionType else {type_2}
intersect = subtypes_1 & subtypes_2

# Any (typing.Any) type can be coerced to all types, handle it.
# Handle coercion of Any (typing.Any) type to all other types.
if Any in subtypes_1:
intersect.update(subtypes_2)
subtypes_1.remove(Any)
Expand All @@ -254,17 +254,20 @@ def find_intersection_type(

for s_types in (subtypes_1, subtypes_2):
other_set = subtypes_2 if s_types == subtypes_1 else subtypes_1
for orig_type in (list, tuple):
for orig_type in (list, tuple, range):
if orig_type in s_types:
for typ in other_set:
# if isinstance(typ, GenericAlias) and typ.__origin__ == orig_type:
if isinstance(typ, GenericAlias):
if typ.__origin__ == orig_type:
intersect.add(typ)
elif typ.__origin__ == Sequence:
intersect.add(
orig_type[reduce(lambda x, y: x | y, typ.__args__)]
)
if orig_type is range:
if find_intersection_type(int, typ.__args__[0]):
intersect.add(range)
else:
intersect.add(
orig_type[reduce(lambda x, y: x | y, typ.__args__)] # type: ignore
)

# Take tuple types from remaining sets and find intesection types
# of all consistent pairs of cartesian product.
Expand Down Expand Up @@ -385,6 +388,8 @@ def find_type[T](connection: T) -> type[T]:
else:
result: UnionType | type = reduce(lambda x, y: x | y, element_types)
return list[result] # type: ignore
# elif isinstance(connection, range):
# return list[int]
else:
return type(connection)

Expand Down
5 changes: 1 addition & 4 deletions tests/scripts/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5580,10 +5580,7 @@ def test_tensor_to_list_backward_2():
assert_constraint_results(
shapes, {}, {}, {}, tensor_to_list_constraints, False, set(), scalar_info
)
assert (
str(err_info.value)
== "Shape mismatch: expected [3], but got [2]. The list should not be ragged."
)
assert str(err_info.value) == "Inconsistent dimensions found in the list."


def test_item_constraints_1():
Expand Down
Loading
Loading