Skip to content

Prefer "own" type vars (owned by current function) when building "any of" constraints #18986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def infer_constraints_for_callable(
param_spec_arg_names = []
param_spec_arg_kinds = []

own_vars = {t.id for t in callee.variables}
type_state.constraints_targets.append(own_vars)

incomplete_star_mapping = False
for i, actuals in enumerate(formal_to_actual): # TODO: isn't this `enumerate(arg_types)`?
for actual in actuals:
Expand Down Expand Up @@ -273,6 +276,7 @@ def infer_constraints_for_callable(
if any(isinstance(v, ParamSpecType) for v in callee.variables):
# As a perf optimization filter imprecise constraints only when we can have them.
constraints = filter_imprecise_kinds(constraints)
type_state.constraints_targets.pop()
return constraints


Expand Down Expand Up @@ -512,7 +516,7 @@ def handle_recursive_union(template: UnionType, actual: Type, direction: int) ->
) or infer_constraints(UnionType.make_union(type_var_items), actual, direction)


def any_constraints(options: list[list[Constraint] | None], eager: bool) -> list[Constraint]:
def any_constraints(options: list[list[Constraint] | None], *, eager: bool) -> list[Constraint]:
"""Deduce what we can from a collection of constraint lists.

It's a given that at least one of the lists must be satisfied. A
Expand Down Expand Up @@ -547,14 +551,19 @@ def any_constraints(options: list[list[Constraint] | None], eager: bool) -> list
if option in trivial_options:
continue
merged_options.append([merge_with_any(c) for c in option])
return any_constraints(list(merged_options), eager)
return any_constraints(list(merged_options), eager=eager)

# If normal logic didn't work, try excluding trivially unsatisfiable constraint (due to
# upper bounds) from each option, and comparing them again.
filtered_options = [filter_satisfiable(o) for o in options]
if filtered_options != options:
return any_constraints(filtered_options, eager=eager)

# Try harder: if that didn't work, try to strip typevars not owned by current function.
filtered_options = [filter_own(o) for o in options]
if filtered_options != options:
return any_constraints(filtered_options, eager=eager)

# Otherwise, there are either no valid options or multiple, inconsistent valid
# options. Give up and deduce nothing.
return []
Expand All @@ -569,6 +578,7 @@ def filter_satisfiable(option: list[Constraint] | None) -> list[Constraint] | No
"""
if not option:
return option

satisfiable = []
for c in option:
if isinstance(c.origin_type_var, TypeVarType) and c.origin_type_var.values:
Expand All @@ -583,6 +593,16 @@ def filter_satisfiable(option: list[Constraint] | None) -> list[Constraint] | No
return satisfiable


def filter_own(option: list[Constraint] | None) -> list[Constraint] | None:
"""Keep only constraints that reference type vars local to current function, if any."""

if not option or not type_state.constraints_targets:
return option
own_vars = type_state.constraints_targets[-1]

return [c for c in option if c.type_var in own_vars] or None


def is_same_constraints(x: list[Constraint], y: list[Constraint]) -> bool:
for c1 in x:
if not any(is_same_constraint(c1, c2) for c2 in y):
Expand Down
4 changes: 4 additions & 0 deletions mypy/typestate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class TypeState:
_assuming_proper: Final[list[tuple[Type, Type]]]
# Ditto for inference of generic constraints against recursive type aliases.
inferring: Final[list[tuple[Type, Type]]]
# When building constraints for a callable, prefer these variables when we encounter
# ambiguous set in `any_constraints`
constraints_targets: Final[list[set[TypeVarId]]]
# Whether to use joins or unions when solving constraints, see checkexpr.py for details.
infer_unions: bool
# Whether to use new type inference algorithm that can infer polymorphic types.
Expand All @@ -112,6 +115,7 @@ def __init__(self) -> None:
self._assuming = []
self._assuming_proper = []
self.inferring = []
self.constraints_targets = []
self.infer_unions = False
self.infer_polymorphic = False

Expand Down
15 changes: 15 additions & 0 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -3963,3 +3963,18 @@ def f() -> None:

# The type below should not be Any.
reveal_type(x) # N: Revealed type is "builtins.int"

[case testInferenceMappingTypeVarGet]
from typing import Generic, TypeVar, Union

_T = TypeVar("_T")
_K = TypeVar("_K")
_V = TypeVar("_V")

class Mapping(Generic[_K, _V]):
def get(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ...

def check(mapping: Mapping[str, _T]) -> None:
ok1 = mapping.get("", "")
ok2: Union[_T, str] = mapping.get("", "")
[builtins fixtures/tuple.pyi]