Skip to content

Commit

Permalink
Better handling of torch<2
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Jul 22, 2024
1 parent 9e1dafd commit 75928e8
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/run_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,5 @@ jobs:
pip install --upgrade "torch==2.0.0"
EINX_FILTER_TRACEBACK=0 pytest
EINX_FILTER_TRACEBACK=0 pytest test/test_invalid_backend.py --noconftest
pip install --upgrade "torch<2"
EINX_FILTER_TRACEBACK=0 pytest test/test_invalid_torch_version.py --noconftest
1 change: 1 addition & 0 deletions einx/backend/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def create():
"torch",
"einx with PyTorch requires PyTorch version >= 2, but found "
f"{torch_.__version__}. einx functions are disabled for PyTorch.",
tensor_types=[torch_.Tensor],
)

@einx.trace
Expand Down
4 changes: 2 additions & 2 deletions einx/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ def __init__(self, message):


class InvalidBackend:
def __init__(self, name, message):
def __init__(self, name, message, tensor_types=None):
self.name = name
self.message = message
self.tensor_types = []
self.tensor_types = tensor_types or []

def __getattr__(self, name):
raise InvalidBackendException(self.message)
2 changes: 1 addition & 1 deletion einx/tracer/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def new_input(x):

# Disable torch.compile for graph construction (if torch is imported)
nonlocal has_decorated, find_backend_and_construct_graph
if not has_decorated and "torch" in sys.modules:
if not has_decorated and "torch" in sys.modules and "_dynamo" in dir(sys.modules["torch"]):
import torch._dynamo as _dynamo

find_backend_and_construct_graph = _dynamo.disable(find_backend_and_construct_graph)
Expand Down
21 changes: 21 additions & 0 deletions test/test_invalid_torch_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
import sys
import importlib
import einx


@pytest.mark.skipif(importlib.find_loader("torch") is None, reason="torch is not installed")
def test_import():
import torch

x = torch.zeros((10,))

major = int(torch.__version__.split(".")[0])
if major < 2:
with pytest.raises(einx.backend.InvalidBackendException):
einx.add("a, a", x, x)
with pytest.raises(einx.backend.InvalidBackendException):
einx.add("a, a", x, x, backend="torch")
else:
einx.add("a, a", x, x)
einx.add("a, a", x, x, backend="torch")

0 comments on commit 75928e8

Please sign in to comment.