Skip to content

Commit a1f3c6d

Browse files
Investigating if we can drop the typeguard dependency.
1 parent 44154e1 commit a1f3c6d

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

diffrax/_integrate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
198198
try:
199199
with jax.numpy_dtype_promotion("standard"):
200200
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
201-
except Exception as e:
201+
except ValueError as e:
202202
# ValueError may also arise from mismatched tree structures
203203
pretty_term = wl.pformat(terms)
204204
pretty_expected = wl.pformat(term_structure)

diffrax/_typing.py

+25-18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import inspect
2-
import sys
32
import types
43
from typing import (
54
Annotated,
@@ -14,33 +13,41 @@
1413
)
1514
from typing_extensions import TypeAlias
1615

17-
import typeguard
18-
1916

2017
# We don't actually care what people have subscripted with.
2118
# In practice this should be thought of as TypeLike = Union[type, types.UnionType]. Plus
2219
# maybe type(Literal) and so on?
2320
TypeLike: TypeAlias = Any
2421

2522

26-
def better_isinstance(x, annotation) -> bool:
27-
"""As `isinstance`, but supports general type hints."""
23+
_T = TypeVar("_T")
2824

29-
@typeguard.typechecked
30-
def f(y: annotation):
31-
pass
3225

33-
try:
34-
f(x)
35-
except TypeError:
36-
return False
37-
else:
38-
return True
26+
class _Foo(Generic[_T]):
27+
pass
28+
3929

30+
_generic_alias_types = (types.GenericAlias, type(_Foo[int]))
31+
_union_origins = (Union, types.UnionType)
32+
del _Foo, _T
4033

41-
_union_types: list = [Union]
42-
if sys.version_info >= (3, 10):
43-
_union_types.append(types.UnionType)
34+
35+
def better_isinstance(x, annotation) -> bool:
36+
"""As `isinstance`, but supports a few other types that are useful to us."""
37+
origin = get_origin(annotation)
38+
if origin in _union_origins:
39+
return any(better_isinstance(x, arg) for arg in get_args(annotation))
40+
elif isinstance(annotation, _generic_alias_types):
41+
assert origin is not None
42+
return better_isinstance(x, origin)
43+
elif annotation is Any:
44+
return True
45+
elif isinstance(annotation, type):
46+
return isinstance(x, annotation)
47+
else:
48+
raise NotImplementedError(
49+
f"Do not know how to check whether `{x}` is an instance of `{annotation}`."
50+
)
4451

4552

4653
def get_origin_no_specials(x, error_msg: str) -> Optional[type]:
@@ -59,7 +66,7 @@ def get_origin_no_specials(x, error_msg: str) -> Optional[type]:
5966
As `get_origin`, specifically either `None` or a class.
6067
"""
6168
origin = get_origin(x)
62-
if origin in _union_types:
69+
if origin in _union_origins:
6370
raise NotImplementedError(f"Cannot use unions in `{error_msg}`.")
6471
elif origin is Annotated:
6572
# We do allow Annotated, just because it's easy to handle.

0 commit comments

Comments
 (0)