From 07aca8cb824c8d6c22c1aaefb0c035e89d8dbd55 Mon Sep 17 00:00:00 2001 From: Florian Fervers Date: Thu, 27 Jun 2024 00:08:27 +0200 Subject: [PATCH] Add pre-commit --- .pre-commit-config.yaml | 7 ++++++ einx/backend/__init__.py | 2 +- einx/backend/_dask.py | 5 ++-- einx/backend/_jax.py | 5 ++-- einx/backend/_mlx.py | 5 ++-- einx/backend/_numpy.py | 5 ++-- einx/backend/_tensorflow.py | 5 ++-- einx/backend/_tinygrad.py | 9 ++++---- einx/backend/_torch.py | 10 ++++---- einx/backend/register.py | 8 +++---- einx/expr/solver.py | 2 +- einx/expr/stage2.py | 2 +- einx/nn/equinox.py | 6 ++--- einx/nn/flax.py | 6 ++--- einx/nn/haiku.py | 8 +++---- einx/nn/keras.py | 6 ++--- einx/nn/torch.py | 2 +- einx/op/util.py | 2 +- einx/tracer/compile.py | 46 ++++++++++++++++++++++--------------- einx/tracer/decorator.py | 2 +- einx/tracer/input.py | 4 ++-- einx/tracer/optimize.py | 4 ++-- einx/tracer/tensor.py | 13 ++++++----- einx/tracer/tracer.py | 42 +++++++++++++++++++++------------ ruff.toml | 2 +- test/conftest.py | 1 - test/test_shapes.py | 10 ++++---- test/test_values.py | 3 +-- 28 files changed, 128 insertions(+), 94 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..cb15a7a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.4 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format \ No newline at end of file diff --git a/einx/backend/__init__.py b/einx/backend/__init__.py index a10e05b..f943564 100644 --- a/einx/backend/__init__.py +++ b/einx/backend/__init__.py @@ -1,4 +1,4 @@ -from .register import register_for_module, register, get, backends, numpy +from .register import register_for_module, register, get, backends from .base import Backend, get_default from . import _numpy as numpy diff --git a/einx/backend/_dask.py b/einx/backend/_dask.py index f50eef3..adc871f 100644 --- a/einx/backend/_dask.py +++ b/einx/backend/_dask.py @@ -1,7 +1,8 @@ -from .base import * +from .base import Backend, associative_binary_to_nary import einx.tracer as tracer from einx.tracer.tensor import op -import einx, types +import einx +import types from functools import partial diff --git a/einx/backend/_jax.py b/einx/backend/_jax.py index 96a4ec0..fe591ba 100644 --- a/einx/backend/_jax.py +++ b/einx/backend/_jax.py @@ -1,7 +1,8 @@ -from .base import * +from .base import Backend, associative_binary_to_nary import einx.tracer as tracer from einx.tracer.tensor import op -import einx, types +import einx +import types from functools import partial diff --git a/einx/backend/_mlx.py b/einx/backend/_mlx.py index f8ae725..4c5973a 100644 --- a/einx/backend/_mlx.py +++ b/einx/backend/_mlx.py @@ -1,7 +1,8 @@ -from .base import * +from .base import Backend, associative_binary_to_nary import einx.tracer as tracer from einx.tracer.tensor import op -import einx, types +import einx +import types from functools import partial diff --git a/einx/backend/_numpy.py b/einx/backend/_numpy.py index 2fea5cd..2d92f02 100644 --- a/einx/backend/_numpy.py +++ b/einx/backend/_numpy.py @@ -1,8 +1,9 @@ -from .base import * +from .base import Backend, associative_binary_to_nary, vmap_forloop import einx.tracer as tracer from einx.tracer.tensor import op import numpy as np -import einx, types +import einx +import types from functools import partial diff --git a/einx/backend/_tensorflow.py b/einx/backend/_tensorflow.py index 34e8d6c..133ab10 100644 --- a/einx/backend/_tensorflow.py +++ b/einx/backend/_tensorflow.py @@ -1,7 +1,8 @@ -from .base import * +from .base import Backend, associative_binary_to_nary import einx.tracer as tracer from einx.tracer.tensor import op -import einx, types +import einx +import types from functools import partial diff --git a/einx/backend/_tinygrad.py b/einx/backend/_tinygrad.py index f3d61ee..7aeb429 100644 --- a/einx/backend/_tinygrad.py +++ b/einx/backend/_tinygrad.py @@ -1,7 +1,8 @@ -from .base import * +from .base import Backend, associative_binary_to_nary import einx.tracer as tracer from einx.tracer.tensor import op -import einx, types +import einx +import types from functools import partial import functools @@ -28,7 +29,7 @@ def outer(*args): if convert_all_to_tensor: args = [scalar_to_tensor(a) for a in args] else: - args = [a for a in args] + args = list(args) args[0] = scalar_to_tensor(args[0]) return op.elementwise(func)(*args) @@ -95,7 +96,7 @@ def einsum(backend, equation, *tensors): if len(inputs) != len(tensors): raise ValueError("Invalid equation") inputs = [x.strip().replace(" ", "") for x in inputs] - tensors = [t for t in tensors] + tensors = list(tensors) scalars = [] for i in list(range(len(inputs)))[::-1]: diff --git a/einx/backend/_torch.py b/einx/backend/_torch.py index 3bdcd97..21afa5f 100644 --- a/einx/backend/_torch.py +++ b/einx/backend/_torch.py @@ -1,8 +1,10 @@ -from .base import * +from .base import Backend, associative_binary_to_nary, ErrorBackend import einx.tracer as tracer from einx.tracer.tensor import op -import einx, types +import einx +import types from functools import partial +import functools def create(): @@ -47,8 +49,6 @@ def wrapper(*args, **kwargs): return wrapper - MARKER_DECORATED_CONSTRUCT_GRAPH = "__einx_decorated_construct_graph" - ttorch = tracer.import_("torch") import torch as torch_ @@ -98,7 +98,7 @@ class torch(Backend): @staticmethod @einx.trace def to_tensor(arg, shape): - assert False + raise NotImplementedError("to_tensor is not implemented for PyTorch") @staticmethod @einx.trace diff --git a/einx/backend/register.py b/einx/backend/register.py index 770ecd4..d1fcd0e 100644 --- a/einx/backend/register.py +++ b/einx/backend/register.py @@ -19,7 +19,7 @@ def register_for_module(module_name, backend_factory): register(backend_factory()) else: # Module is not yet imported -> register factory - if not module_name in backend_factories: + if module_name not in backend_factories: backend_factories[module_name] = [] backend_factories[module_name].append(backend_factory) @@ -68,7 +68,7 @@ def _update(): def _get1(tensor): backend = tensortype_to_backend.get(type(tensor), None) - if not backend is None: + if backend is not None: return backend _update() @@ -103,7 +103,7 @@ def get(arg): for tensor in tensors: if tensor is not None: backend2 = _get1(tensor) - if not backend2 is None: + if backend2 is not None: if ( backend is not None and backend != backend2 @@ -117,6 +117,6 @@ def get(arg): if backend is None or backend2 != numpy: backend = backend2 if backend is None: - raise ValueError(f"Could not determine the backend to use in this operation") + raise ValueError("Could not determine the backend to use in this operation") else: return backend diff --git a/einx/expr/solver.py b/einx/expr/solver.py index fd95715..5b39d5b 100644 --- a/einx/expr/solver.py +++ b/einx/expr/solver.py @@ -94,7 +94,7 @@ def __str__(self): return " + ".join(str(c) for c in self.children) def sympy(self): - return sum([c.sympy() for c in self.children]) + return sum(c.sympy() for c in self.children) class Product(Expression): diff --git a/einx/expr/stage2.py b/einx/expr/stage2.py index edb4264..8874e0b 100644 --- a/einx/expr/stage2.py +++ b/einx/expr/stage2.py @@ -934,7 +934,7 @@ def replace(expr): while i < len(expr): # Check if a subexpression starts at position i exprlist_found = None - for idx, common_expr in enumerate(common_exprs): + for idx, common_expr in enumerate(common_exprs): # noqa: B007 for exprlist in common_expr: for j in range(len(exprlist)): if i + j >= len(expr) or id(exprlist[j]) != id(expr[i + j]): diff --git a/einx/nn/equinox.py b/einx/nn/equinox.py index 548b4f3..c778a97 100644 --- a/einx/nn/equinox.py +++ b/einx/nn/equinox.py @@ -65,9 +65,9 @@ def __init__(self, name, init, dtype): self.dtype = dtype def __call__(self, shape, kwargs): - name = self.name if not self.name is None else kwargs.get("name", None) - init = self.init if not self.init is None else kwargs.get("init", None) - dtype = self.dtype if not self.dtype is None else kwargs.get("dtype", None) + name = self.name if self.name is not None else kwargs.get("name", None) + init = self.init if self.init is not None else kwargs.get("init", None) + dtype = self.dtype if self.dtype is not None else kwargs.get("dtype", None) if name is None: raise ValueError("Must specify name for tensor factory eqx.Module") diff --git a/einx/nn/flax.py b/einx/nn/flax.py index 0f45527..4f7bc16 100644 --- a/einx/nn/flax.py +++ b/einx/nn/flax.py @@ -73,9 +73,9 @@ def __init__(self, name, init, dtype, col, param_type): self.param_type = param_type def __call__(self, shape, kwargs): - name = self.name if not self.name is None else kwargs.get("name", None) - init = self.init if not self.init is None else kwargs.get("init", None) - dtype = self.dtype if not self.dtype is None else kwargs.get("dtype", None) + name = self.name if self.name is not None else kwargs.get("name", None) + init = self.init if self.init is not None else kwargs.get("init", None) + dtype = self.dtype if self.dtype is not None else kwargs.get("dtype", None) col = self.col if name is None: diff --git a/einx/nn/haiku.py b/einx/nn/haiku.py index 433980c..7159f39 100644 --- a/einx/nn/haiku.py +++ b/einx/nn/haiku.py @@ -60,9 +60,9 @@ def __init__(self, name, init, dtype, param_type, depend_on): self.depend_on = depend_on def __call__(self, shape, kwargs): - name = self.name if not self.name is None else kwargs.get("name", None) - init = self.init if not self.init is None else kwargs.get("init", None) - dtype = self.dtype if not self.dtype is None else kwargs.get("dtype", None) + name = self.name if self.name is not None else kwargs.get("name", None) + init = self.init if self.init is not None else kwargs.get("init", None) + dtype = self.dtype if self.dtype is not None else kwargs.get("dtype", None) if name is None: raise ValueError("Must specify name for tensor factory hk.get_{parameter|state}") @@ -90,7 +90,7 @@ def __call__(self, shape, kwargs): elif self.param_type == "state": func = thk.get_state else: - assert False + raise AssertionError(f"Unknown parameter type '{self.param_type}'") return einx.tracer.apply( func, diff --git a/einx/nn/keras.py b/einx/nn/keras.py index 6097129..7a1a4dd 100644 --- a/einx/nn/keras.py +++ b/einx/nn/keras.py @@ -78,9 +78,9 @@ def __init__(self, name, init, dtype, trainable): self.trainable = trainable def __call__(self, shape, kwargs): - name = self.name if not self.name is None else kwargs.get("name", None) - init = self.init if not self.init is None else kwargs.get("init", None) - dtype = self.dtype if not self.dtype is None else kwargs.get("dtype", None) + name = self.name if self.name is not None else kwargs.get("name", None) + init = self.init if self.init is not None else kwargs.get("init", None) + dtype = self.dtype if self.dtype is not None else kwargs.get("dtype", None) if name is None: raise ValueError("Must specify name for tensor factory keras.layers.Layer") diff --git a/einx/nn/torch.py b/einx/nn/torch.py index c91a6e8..6a43c0c 100644 --- a/einx/nn/torch.py +++ b/einx/nn/torch.py @@ -50,7 +50,7 @@ def __init__(self, init): self.init = init def __call__(self, shape, kwargs): - init = self.init if not self.init is None else kwargs.get("init", None) + init = self.init if self.init is not None else kwargs.get("init", None) x = self diff --git a/einx/op/util.py b/einx/op/util.py index 7071fd5..43a4480 100644 --- a/einx/op/util.py +++ b/einx/op/util.py @@ -188,7 +188,7 @@ def _unflatten(exprs_in, tensors_in, expr_out, backend): def unflatten(exprs_in, tensors_in, exprs_out, *, backend): if len(exprs_in) != len(tensors_in): raise ValueError("Got different number of input expressions and tensors") - assert not backend is None + assert backend is not None iter_exprs_in = iter(exprs_in) iter_tensors_in = iter(tensors_in) diff --git a/einx/tracer/compile.py b/einx/tracer/compile.py index eaee72a..ca060a6 100644 --- a/einx/tracer/compile.py +++ b/einx/tracer/compile.py @@ -10,19 +10,19 @@ def __init__(self, parent=None): self.variables = {} self.parent = parent self.children = [] - if not self.parent is None: + if self.parent is not None: self.parent.children.append(self) def fork(self): return Variables(parent=self) def __contains__(self, name): - return name in self.variables or (not self.parent is None and name in self.parent) + return name in self.variables or (self.parent is not None and name in self.parent) def _is_free(self, name): if name in self.variables: return False - if not self.parent is None and name in self.parent: + if self.parent is not None and name in self.parent: return False for child in self.children: if name in child: @@ -47,7 +47,7 @@ def add(self, value, prefix=None, name=None): def __getitem__(self, name): if name in self.variables: return self.variables[name] - if not self.parent is None: + if self.parent is not None: return self.parent[name] raise ValueError(f"Variable '{name}' is not set") @@ -60,7 +60,7 @@ def __init__(self, variables, parent): def is_parent_of(self, other): return other.parent is self or ( - not other.parent is None and self.is_parent_of(other.parent) + other.parent is not None and self.is_parent_of(other.parent) ) @property @@ -144,7 +144,7 @@ def name(self): def is_variable(self): if self.overwritten: raise ValueError("Trying to access overwritten definition") - return not self._code is None and self._code.isidentifier() + return self._code is not None and self._code.isidentifier() def is_pytree(self): if self.overwritten: @@ -175,7 +175,7 @@ def __init__(self, objects): for definition in self.constants: line = f"# {definition.name}: {str(type(definition.value))}" value_str = str(definition.value) - if not "\n" in value_str: + if "\n" not in value_str: line += f" = {value_str}" self.code = line + "\n" + self.code @@ -205,7 +205,7 @@ def join_blocks(self, blocks): def execute_application(self, application): assert isinstance(application, Application) - comment = f" # {application.comment}" if not application.comment is None else "" + comment = f" # {application.comment}" if application.comment is not None else "" # Find block at which to execute the application (i.e. where all dependencies are defined) in_defs = [self.get_definition_of(x) for x in application.dependencies] @@ -215,10 +215,10 @@ def execute_application(self, application): if isinstance(application.op, Import): import_str = f"import {application.op.import_}" name = application.op.import_ - if not application.op.as_ is None: + if application.op.as_ is not None: import_str = f"{import_str} as {application.op.as_}" name = application.op.as_ - if not application.op.from_ is None: + if application.op.from_ is not None: import_str = f"from {application.op.from_} {import_str}" # Import only once @@ -343,14 +343,17 @@ def slice_to_str(s): def check(output): definition = self.get_definition_of(output) if isinstance(definition.value, Tensor): - line = f"assert {definition.code}.shape == {self.get_definition_of(output.shape).code}" + line = ( + f"assert {definition.code}.shape == " + f"{self.get_definition_of(output.shape).code}" + ) block.code.append(line) einx.tree_util.tree_map(check, application.output) def _add_definition(self, definition): if id(definition.value) in self.definitions: - raise ValueError(f"Trying to add definition for existing value") + raise ValueError("Trying to add definition for existing value") self.definitions[id(definition.value)] = definition # If value is a pytree, add definition for all leaves @@ -363,7 +366,7 @@ def store(x, key): elif isinstance(k, str): code += f'["{k}"]' else: - assert False + raise AssertionError(f"Invalid key {k}") self.new_value_definition(x, definition.block, code) einx.tree_util.tree_map_with_key(store, definition.value) @@ -394,7 +397,7 @@ def get_definition_of(self, x): if id(x) in self.definitions: definition = self.definitions[id(x)] if definition.is_overwritten(): - raise ValueError(f"Trying to access overwritten variable") + raise ValueError("Trying to access overwritten variable") return definition if isinstance(x, TracableFunction): @@ -415,7 +418,7 @@ def get_definition_of(self, x): arg_defs = [ self.new_variable_definition(arg, function_block, prefix="i") for arg in x.args ] # TODO: not using kwargs - virtual_arg_defs = [ + [ self.new_empty_definition(virtual_arg, function_block) for virtual_arg in x.virtual_args ] @@ -442,7 +445,7 @@ def get_definition_of(self, x): assert id(x) in self.definitions return self.definitions[id(x)] else: - assert False, f"{type(x.origin)}" + raise AssertionError(f"Invalid origin {x.origin}") elif isinstance(x, str): return Definition(x, self.root_block, f'"{x}"') elif isinstance(x, tuple): @@ -461,9 +464,16 @@ def get_definition_of(self, x): return Definition(x, self.root_block, str(x)) elif isinstance(x, slice): if x.step is not None: - code = f"slice({self.get_definition_of(x.start).code}, {self.get_definition_of(x.stop).code}, {self.get_definition_of(x.step).code})" + code = ( + f"slice({self.get_definition_of(x.start).code}, " + f"{self.get_definition_of(x.stop).code}, " + f"{self.get_definition_of(x.step).code})" + ) elif x.stop is not None: - code = f"slice({self.get_definition_of(x.start).code}, {self.get_definition_of(x.stop).code})" + code = ( + f"slice({self.get_definition_of(x.start).code}, " + f"{self.get_definition_of(x.stop).code})" + ) else: code = f"slice({self.get_definition_of(x.start).code})" return Definition(x, self.root_block, code) diff --git a/einx/tracer/decorator.py b/einx/tracer/decorator.py index b42588b..3f77f01 100644 --- a/einx/tracer/decorator.py +++ b/einx/tracer/decorator.py @@ -205,7 +205,7 @@ def func_jit(*args, backend=None, graph=False, **kwargs): def new_input(x): value, key = einx.tracer.input.concrete_to_value_and_key(x) - if not value is None: + if value is not None: traced_input_values.append(value) return key diff --git a/einx/tracer/input.py b/einx/tracer/input.py index 2633be0..72374f8 100644 --- a/einx/tracer/input.py +++ b/einx/tracer/input.py @@ -90,7 +90,7 @@ def concrete_to_value_and_key(x): elif isinstance(x, Input): # Custom input return x.to_value_and_key() - elif not (x2 := apply_registered_tensor_factory(x)) is None: + elif (x2 := apply_registered_tensor_factory(x)) is not None: # Registered tensor factory return x2 elif callable(x): @@ -117,7 +117,7 @@ def key_to_tracer(x, backend, virtual_arg): def map(x): if isinstance(x, CacheKey): arg, x = x.to_tracer(backend, virtual_arg) - if not arg is None: + if arg is not None: args.append(arg) return x else: diff --git a/einx/tracer/optimize.py b/einx/tracer/optimize.py index 3d3f72c..fa58c88 100644 --- a/einx/tracer/optimize.py +++ b/einx/tracer/optimize.py @@ -3,7 +3,7 @@ def get_signature(node): - if not node.origin is None: + if node.origin is not None: return node.origin.signature else: return None @@ -70,7 +70,7 @@ def __call__(self, node): ) def store(new_node, node): - assert not id(node) in self.optimized_nodes + assert id(node) not in self.optimized_nodes self.optimized_nodes[id(node)] = new_node einx.tree_util.tree_map(store, new_output_nodes, node.origin.output) diff --git a/einx/tracer/tensor.py b/einx/tracer/tensor.py index 205ccfd..35a5b43 100644 --- a/einx/tracer/tensor.py +++ b/einx/tracer/tensor.py @@ -55,7 +55,8 @@ def einsum(eq, *tensors, **kwargs): if axis in values: if values[axis] != value: raise ValueError( - f"Got conflicting values for axis {axis}: {values[axis]} and {value}" + f"Got conflicting values for axis {axis}: " + f"{values[axis]} and {value}" ) else: values[axis] = value @@ -114,7 +115,7 @@ def elementwise(*args, **kwargs): while len(shape2) < len(shape): shape2 = (1,) + shape2 shape = np.maximum(shape, shape2) - assert not shape is None # TODO: can this happen? + assert shape is not None # TODO: can this happen? return apply(op, args=args, kwargs=kwargs, output=Tensor(shape)) @@ -266,7 +267,7 @@ def __init__(self, shape): try: self.shape = tuple(int(i) for i in shape) except: - raise ValueError(f"Invalid shape: {shape}") + raise ValueError(f"Invalid shape: {shape}") from None @property def ndim(self): @@ -281,7 +282,7 @@ def __getitem__(self, key): def __setitem__(self, key, value): if ( - not value.origin is None + value.origin is not None and isinstance(value.origin.op, AssignAt) and value.origin.op != "=" and value.origin.args[0] is self @@ -291,8 +292,8 @@ def __setitem__(self, key, value): # 1. x1 = __getitem__(tensor, key) # 2. x2 = __iadd__(x1, update) # 3. x3 = __setitem__(tensor, key, x2) - # The output of the second line already returns the results of the AssignAt (see below), so - # we can skip the third line. + # The output of the second line already returns the results of the AssignAt + # (see below), so we can skip the third line. return value return op.update_at(AssignAt("="), inplace=True)(self, key, value) diff --git a/einx/tracer/tracer.py b/einx/tracer/tracer.py index 3099a3b..d1b55c3 100644 --- a/einx/tracer/tracer.py +++ b/einx/tracer/tracer.py @@ -41,14 +41,22 @@ def dependencies(self): def apply( op, - args=[], - kwargs={}, + args=None, + kwargs=None, output=None, signature=None, - inplace_updates=[], + inplace_updates=None, comment=None, - depend_on=[], + depend_on=None, ): + if args is None: + args = [] + if kwargs is None: + kwargs = {} + if inplace_updates is None: + inplace_updates = [] + if depend_on is None: + depend_on = [] if isinstance(op, partial): return apply( op.func, @@ -63,7 +71,7 @@ def apply( elif isinstance(op, TracableFunction): assert len(inplace_updates) == 0 got_output = op(*args, **kwargs) - if not output is None: + if output is not None: def check(got_output, expected_output): if type(got_output) != type(expected_output): @@ -197,20 +205,24 @@ def __call__(self, *args, **kwargs): class TracableFunction(Tracer): - def __init__(self, func=None, args=None, kwargs=None, virtual_args=[], output=None, name=None): + def __init__( + self, func=None, args=None, kwargs=None, virtual_args=None, output=None, name=None + ): + if virtual_args is None: + virtual_args = [] Tracer.__init__(self) if isinstance(func, Tracer): - raise ValueError(f"func cannot be a tracer object") - if not output is None and args is None and kwargs is None: - raise ValueError(f"Cannot create a TracableFunction with an output but no input") + raise ValueError("func cannot be a tracer object") + if output is not None and args is None and kwargs is None: + raise ValueError("Cannot create a TracableFunction with an output but no input") - if args is None and not kwargs is None: + if args is None and kwargs is not None: args = [] - if not args is None and kwargs is None: + if args is not None and kwargs is None: kwargs = {} - if not func is None and output is None and (not args is None or not kwargs is None): + if func is not None and output is None and (args is not None or kwargs is not None): output = func(*args, **kwargs) self.func = func @@ -223,7 +235,7 @@ def __init__(self, func=None, args=None, kwargs=None, virtual_args=[], output=No def __call__(self, *args, **kwargs): if self.func is None: raise NotImplementedError( - f"Cannot call a TracableFunction that was created without a callable function" + "Cannot call a TracableFunction that was created without a callable function" ) return self.func(*args, **kwargs) @@ -233,7 +245,7 @@ def __init__(self, tracers): self.usages = {} # tracer-id: [using-applications] def _capture_usages(x): - if not id(x) in self.usages: + if id(x) not in self.usages: self.usages[id(x)] = [] if isinstance(x, (list, tuple)): for y in x: @@ -245,7 +257,7 @@ def _capture_usages(x): for y in x.origin.dependencies: if isinstance(y, Tracer): # Add x.origin to y's usages - if not id(y) in self.usages: + if id(y) not in self.usages: self.usages[id(y)] = [] for usage in self.usages[id(y)]: if id(usage) == id(x.origin): diff --git a/ruff.toml b/ruff.toml index f17e791..0d5b9de 100644 --- a/ruff.toml +++ b/ruff.toml @@ -4,7 +4,7 @@ target-version = "py38" [lint] select = ["B", "C", "F", "W", "YTT", "ASYNC", "E", "UP"] -ignore = ["F401", "F403", "E722", "F821", "E402", "E741", "C901", "B017", "B023", "B020"] +ignore = ["F401", "F403", "E722", "F821", "E402", "E741", "C901", "B017", "B023", "B020", "E731", "F405", "B011"] [format] preview = true diff --git a/test/conftest.py b/test/conftest.py index 9a6336b..602150b 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -106,7 +106,6 @@ def run(result, exception): # numpy is always available -import numpy as np backend = einx.backend.numpy.create() diff --git a/test/test_shapes.py b/test/test_shapes.py index c6324d0..fc32602 100644 --- a/test/test_shapes.py +++ b/test/test_shapes.py @@ -1,4 +1,3 @@ -import einx import pytest import numpy as np import conftest @@ -241,7 +240,7 @@ def w(shape): v = setup.full((2, 4, 100)) with pytest.raises(Exception): einx.dot("b t (h ck), b t (h cv) -> b h ck cv", k, v, h=32, graph=True) - + x = setup.full((10, 20)) y = setup.full((10, 24)) z = setup.full((3, 24)) @@ -255,7 +254,8 @@ def w(shape): assert einx.dot("a b c, a c, d c -> b d", x, y, z).shape == (20, 3) with pytest.raises(Exception): einx.dot("[a] b [c], a c, d [c] -> b d", x, y, z) - + + @pytest.mark.parametrize("test", conftest.tests) def test_shape_reduce(test): einx, backend, setup = test @@ -778,7 +778,7 @@ def make_coords(shape): @pytest.mark.parametrize("test", conftest.tests) def test_shape_vmap_with_axis(test): - einx, backend, setup = test + einx, _, setup = test x = setup.full((10, 10)) assert einx.flip("a [b] -> a [b]", x).shape == (10, 10) @@ -805,7 +805,7 @@ def test_shape_vmap_with_axis(test): @pytest.mark.parametrize("test", conftest.tests) def test_shape_solve(test): - einx, backend, setup = test + einx, _, setup = test x = setup.full((2, 3, 4)) assert einx.matches("a b c", x) diff --git a/test/test_values.py b/test/test_values.py index d0c9d5c..0649354 100644 --- a/test/test_values.py +++ b/test/test_values.py @@ -1,5 +1,4 @@ import conftest -import einx import pytest import numpy as np @@ -154,7 +153,7 @@ def test_values(test): @pytest.mark.parametrize("test", conftest.tests) def test_compare_backends(test): - einx, backend, setup = test + einx, _, setup = test x = np.random.uniform(size=(10, 3, 10)).astype("float32") y = setup.to_tensor(x)