From 75928e8fe5a2941a911c081df72e96e2327d9427 Mon Sep 17 00:00:00 2001 From: Florian Fervers Date: Mon, 22 Jul 2024 20:27:25 +0200 Subject: [PATCH] Better handling of torch<2 --- .github/workflows/run_pytest.yml | 2 ++ einx/backend/_torch.py | 1 + einx/backend/base.py | 4 ++-- einx/tracer/decorator.py | 2 +- test/test_invalid_torch_version.py | 21 +++++++++++++++++++++ 5 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 test/test_invalid_torch_version.py diff --git a/.github/workflows/run_pytest.yml b/.github/workflows/run_pytest.yml index 1d93de3..562b77a 100644 --- a/.github/workflows/run_pytest.yml +++ b/.github/workflows/run_pytest.yml @@ -65,3 +65,5 @@ jobs: pip install --upgrade "torch==2.0.0" EINX_FILTER_TRACEBACK=0 pytest EINX_FILTER_TRACEBACK=0 pytest test/test_invalid_backend.py --noconftest + pip install --upgrade "torch<2" + EINX_FILTER_TRACEBACK=0 pytest test/test_invalid_torch_version.py --noconftest diff --git a/einx/backend/_torch.py b/einx/backend/_torch.py index f3c8d89..cab0ca2 100644 --- a/einx/backend/_torch.py +++ b/einx/backend/_torch.py @@ -16,6 +16,7 @@ def create(): "torch", "einx with PyTorch requires PyTorch version >= 2, but found " f"{torch_.__version__}. einx functions are disabled for PyTorch.", + tensor_types=[torch_.Tensor], ) @einx.trace diff --git a/einx/backend/base.py b/einx/backend/base.py index d0e6409..f598b44 100644 --- a/einx/backend/base.py +++ b/einx/backend/base.py @@ -218,10 +218,10 @@ def __init__(self, message): class InvalidBackend: - def __init__(self, name, message): + def __init__(self, name, message, tensor_types=None): self.name = name self.message = message - self.tensor_types = [] + self.tensor_types = tensor_types or [] def __getattr__(self, name): raise InvalidBackendException(self.message) diff --git a/einx/tracer/decorator.py b/einx/tracer/decorator.py index 3f77f01..b0a24bd 100644 --- a/einx/tracer/decorator.py +++ b/einx/tracer/decorator.py @@ -213,7 +213,7 @@ def new_input(x): # Disable torch.compile for graph construction (if torch is imported) nonlocal has_decorated, find_backend_and_construct_graph - if not has_decorated and "torch" in sys.modules: + if not has_decorated and "torch" in sys.modules and "_dynamo" in dir(sys.modules["torch"]): import torch._dynamo as _dynamo find_backend_and_construct_graph = _dynamo.disable(find_backend_and_construct_graph) diff --git a/test/test_invalid_torch_version.py b/test/test_invalid_torch_version.py new file mode 100644 index 0000000..b01cb56 --- /dev/null +++ b/test/test_invalid_torch_version.py @@ -0,0 +1,21 @@ +import pytest +import sys +import importlib +import einx + + +@pytest.mark.skipif(importlib.find_loader("torch") is None, reason="torch is not installed") +def test_import(): + import torch + + x = torch.zeros((10,)) + + major = int(torch.__version__.split(".")[0]) + if major < 2: + with pytest.raises(einx.backend.InvalidBackendException): + einx.add("a, a", x, x) + with pytest.raises(einx.backend.InvalidBackendException): + einx.add("a, a", x, x, backend="torch") + else: + einx.add("a, a", x, x) + einx.add("a, a", x, x, backend="torch")