Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functions that Return Unions are not Friendly with Linters #17

Open
fusedbloxxer opened this issue Jan 9, 2025 · 1 comment
Open

Functions that Return Unions are not Friendly with Linters #17

fusedbloxxer opened this issue Jan 9, 2025 · 1 comment

Comments

@fusedbloxxer
Copy link

fusedbloxxer commented Jan 9, 2025

The einx.rearrange function returns a Union[einx.Tensor, Tuple[einx.Tensor, ...]] and this is problematic when assigning the result into a variable with an einx.Tensor typehint, because the function might be returning a Tuple[einx.Tensor, ...] instead which would not be compatible with the expected type. This causes false positives when using linters such as pyright which are justified because you don't know, based on the arguments, what the function returns as it is currently written. These false positives appear in many common use cases:

  1. Performing a rearrange for a tensor transforms the type into a union and linters give errors when accessing basic operations:
single_tensor = einx.rearrange('b c h w -> b h w c', torch.randn((16, 3, 256, 256)))

# Cannot access attribute "shape" for class "Tuple[Tensor, ...]"
#   Attribute "shape" is unknown PylancereportAttributeAccessIssue
single_tensor.shape
  1. After performing a rearrange I cannot assign the result back into the same variable if I typehint it as a Tensor because the result is a union:
image: Tensor = torchvision.io.image.read_image('./resources/image.png')

# "Tuple[Tensor, ...]" is not assignable to "Tensor" PylancereportAssignmentType
image = einx.rearrange('c h w -> h w c', image)

plt.imshow(image)

A quick yet repetitive solution is to cast explicitly the Union back to Tensor:

image = cast(Tensor, einx.rearrange('c h w -> h w c', image))

Maybe the function signature could be changed to support return type inference based on the argument types (using @overload)? But this might result in type conflicts between concatenation, splitting, reshaping... I would like to use the einx.rearrange functionality and not have to fallback on einops.rearrange which always returns a Tensor but is less flexible.

@fferflo
Copy link
Owner

fferflo commented Jan 13, 2025

The specific return type depends on the value of the string expression (i.e. if there is a comma on the right side or not), so I am not sure if a type checker could determine which overload to use without running the code. I am not very familiar with typing though. The problem also appears in other functions (einx.vmap, einx.vmap_with_axis).

There is typing.Literal, which could be used to overload einx.rearrange for specific expressions with something like

@overload
def rearrange(expr: Literal['c h w -> h w c'], ...) -> Tensor: ...

However, afaik there isn't a more general option that could handle all cases of commas in literal strings like so:

def no_comma_in_output(expr: str) -> bool:
    return "," not in str.split("->")[-1]

@overload
def rearrange(expr: Satisfies[no_comma_in_output], ...) -> Tensor: ...

This seems to me like it would be the ideal solution.

Another repetitive solution would be to disable type checking for the respective lines with # type: ignore:

single_tensor.shape # type: ignore

There are also backwards-incompatible solutions like always returning a tuple from rearrange (and similar functions) or changing the return type of rearrange to typing.Any. Or one could introduce a new function like einx.rearrange1 which must return a single tensor, although this seems a little tedious as well. To be honest, I am not sure how to best address this problem at the moment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants