You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The einx.rearrange function returns a Union[einx.Tensor, Tuple[einx.Tensor, ...]] and this is problematic when assigning the result into a variable with an einx.Tensor typehint, because the function might be returning a Tuple[einx.Tensor, ...] instead which would not be compatible with the expected type. This causes false positives when using linters such as pyright which are justified because you don't know, based on the arguments, what the function returns as it is currently written. These false positives appear in many common use cases:
Performing a rearrange for a tensor transforms the type into a union and linters give errors when accessing basic operations:
single_tensor=einx.rearrange('b c h w -> b h w c', torch.randn((16, 3, 256, 256)))
# Cannot access attribute "shape" for class "Tuple[Tensor, ...]"# Attribute "shape" is unknown PylancereportAttributeAccessIssuesingle_tensor.shape
After performing a rearrange I cannot assign the result back into the same variable if I typehint it as a Tensor because the result is a union:
image: Tensor=torchvision.io.image.read_image('./resources/image.png')
# "Tuple[Tensor, ...]" is not assignable to "Tensor" PylancereportAssignmentTypeimage=einx.rearrange('c h w -> h w c', image)
plt.imshow(image)
A quick yet repetitive solution is to cast explicitly the Union back to Tensor:
image=cast(Tensor, einx.rearrange('c h w -> h w c', image))
Maybe the function signature could be changed to support return type inference based on the argument types (using @overload)? But this might result in type conflicts between concatenation, splitting, reshaping... I would like to use the einx.rearrange functionality and not have to fallback on einops.rearrange which always returns a Tensor but is less flexible.
The text was updated successfully, but these errors were encountered:
The specific return type depends on the value of the string expression (i.e. if there is a comma on the right side or not), so I am not sure if a type checker could determine which overload to use without running the code. I am not very familiar with typing though. The problem also appears in other functions (einx.vmap, einx.vmap_with_axis).
There is typing.Literal, which could be used to overload einx.rearrange for specific expressions with something like
@overload
def rearrange(expr: Literal['c h w -> h w c'], ...) -> Tensor: ...
However, afaik there isn't a more general option that could handle all cases of commas in literal strings like so:
def no_comma_in_output(expr: str) -> bool:
return "," not in str.split("->")[-1]
@overload
def rearrange(expr: Satisfies[no_comma_in_output], ...) -> Tensor: ...
This seems to me like it would be the ideal solution.
Another repetitive solution would be to disable type checking for the respective lines with # type: ignore:
single_tensor.shape # type: ignore
There are also backwards-incompatible solutions like always returning a tuple from rearrange (and similar functions) or changing the return type of rearrange to typing.Any. Or one could introduce a new function like einx.rearrange1 which must return a single tensor, although this seems a little tedious as well. To be honest, I am not sure how to best address this problem at the moment.
The einx.rearrange function returns a
Union[einx.Tensor, Tuple[einx.Tensor, ...]]
and this is problematic when assigning the result into a variable with aneinx.Tensor
typehint, because the function might be returning aTuple[einx.Tensor, ...]
instead which would not be compatible with the expected type. This causes false positives when using linters such as pyright which are justified because you don't know, based on the arguments, what the function returns as it is currently written. These false positives appear in many common use cases:Tensor
because the result is a union:A quick yet repetitive solution is to
cast
explicitly theUnion
back toTensor
:Maybe the function signature could be changed to support return type inference based on the argument types (using @overload)? But this might result in type conflicts between concatenation, splitting, reshaping... I would like to use the
einx.rearrange
functionality and not have to fallback on einops.rearrange which always returns aTensor
but is less flexible.The text was updated successfully, but these errors were encountered: