Skip to content

Commit

Permalink
Fix remaining ruff errors
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Mar 8, 2024
1 parent 456edf4 commit a6e8326
Show file tree
Hide file tree
Showing 35 changed files with 442 additions and 219 deletions.
3 changes: 2 additions & 1 deletion einx/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def get(arg):
if backend2 != numpy:
if backend is not None and backend != backend2:
raise ValueError(
f"Got tensors with conflicting backends: {backend.__name__} and {backend2.__name__}"
"Got tensors with conflicting backends: "
f"{backend.__name__} and {backend2.__name__}"
)
backend = backend2
if backend is None:
Expand Down
3 changes: 2 additions & 1 deletion einx/backend/_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,6 @@ def inner(*args):
)
return xs

inner.__name__ = f"vmap({op.__name__ if '__name__' in dir(op) else str(op)}, in_axes={in_axes}, out_axes={out_axes})"
inner.__name__ = f"vmap({op.__name__ if '__name__' in dir(op) else str(op)}, "
f"in_axes={in_axes}, out_axes={out_axes})"
return inner
12 changes: 8 additions & 4 deletions einx/backend/_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def softmax(x, axis):
if isinstance(axis, (list, tuple)):
if len(axis) != 1:
raise ValueError(
f"Tensorflow only supports softmax along a single axis, got {len(axis)} axes"
"Tensorflow only supports softmax along a single axis, "
f"got {len(axis)} axes"
)
axis = axis[0]
return tf.nn.softmax(x, axis=axis)
Expand All @@ -122,7 +123,8 @@ def log_softmax(x, axis):
if isinstance(axis, (list, tuple)):
if len(axis) != 1:
raise ValueError(
f"Tensorflow only supports log_softmax along a single axis, got {len(axis)} axes"
"Tensorflow only supports log_softmax along a single axis, "
f"got {len(axis)} axes"
)
axis = axis[0]
return tf.nn.log_softmax(x, axis=axis)
Expand All @@ -135,7 +137,8 @@ def log_softmax(x, axis):

def vmap(op, in_axes, out_axes, input_shapes=None, output_shapes=None):
def inner(*args):
# TODO: suboptimal (?) implementation of vmap in tensorflow that transposes the vmapped axis to the front and calls tf.vectorized_map
# TODO: suboptimal (?) implementation of vmap in tensorflow that transposes the
# vmapped axis to the front and calls tf.vectorized_map
if len(args) != len(in_axes):
raise ValueError(f"Expected {len(in_axes)} arguments, got {len(args)}")
value = {arg.shape[axis] for arg, axis in zip(args, in_axes) if axis is not None}
Expand Down Expand Up @@ -176,7 +179,8 @@ def inner(*args):

return tuple(xs)

inner.__name__ = f"vmap({op.__name__ if '__name__' in dir(op) else str(op)}, in_axes={in_axes}, out_axes={out_axes})"
inner.__name__ = f"vmap({op.__name__ if '__name__' in dir(op) else str(op)}, "
f"in_axes={in_axes}, out_axes={out_axes})"
return inner

class random:
Expand Down
15 changes: 10 additions & 5 deletions einx/backend/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def make_torch_backend():

version = tuple(int(i) for i in torch_.__version__.split(".")[:2])
if version < (2, 0):
message = f"einx with PyTorch requires PyTorch version >= 2, but found {torch_.__version__}. einx functions are disabled for PyTorch."
message = "einx with PyTorch requires PyTorch version >= 2, but found "
f"{torch_.__version__}. einx functions are disabled for PyTorch."
print(f"WARNING: {message}")
return ErrorBackend(message)

Expand Down Expand Up @@ -115,8 +116,10 @@ def get_at(tensor, coordinates):
else:
# Fix for https://github.com/pytorch/functorch/issues/747
# Scalar coordinates cause problems with torch.vmap and throw an error:
# "RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor ..."
# As a workaround, we add a dummy dimension and remove it after the indexing operation.
# "RuntimeError: vmap: It looks like you're calling .item() on a Tensor.
# We don't support vmap over calling .item() on a Tensor ..."
# As a workaround, we add a dummy dimension and remove it after the indexing
# operation.
return tensor[tuple(c[None] for c in coordinates)][0]
else:
if isinstance(coordinates, (slice, int)) or coordinates.ndim > 0:
Expand Down Expand Up @@ -151,7 +154,8 @@ def softmax(tensor, axis):
if isinstance(axis, (list, tuple)):
if len(axis) != 1:
raise ValueError(
f"PyTorch only supports softmax along a single axis, got {len(axis)} axes"
"PyTorch only supports softmax along a single axis, "
f"got {len(axis)} axes"
)
axis = axis[0]
return torch_.softmax(tensor, axis)
Expand All @@ -160,7 +164,8 @@ def log_softmax(tensor, axis):
if isinstance(axis, (list, tuple)):
if len(axis) != 1:
raise ValueError(
f"PyTorch only supports log_softmax along a single axis, got {len(axis)} axes"
"PyTorch only supports log_softmax along a single axis, "
f"got {len(axis)} axes"
)
axis = axis[0]
return torch_.nn.functional.log_softmax(tensor, axis)
Expand Down
18 changes: 13 additions & 5 deletions einx/backend/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,12 @@ def define_function(self, graph):

lines = []
output = function_scope.eval(graph.output)
params = [function_scope.eval(x) for x in graph.input_tracers] + [
op + "=" + op for op in function_scope.used_functions
]
lines.insert(
0,
f"def {name}({', '.join([function_scope.eval(x) for x in graph.input_tracers] + [op + '=' + op for op in function_scope.used_functions])}):",
f"def {name}({', '.join(params)}):",
)
for line in function_scope.lines:
lines.append(f" {line}")
Expand Down Expand Up @@ -375,7 +378,8 @@ def slice_to_str(s):
kwargs = self.eval(kwargs)

self.lines.append(
f"{name} = backend.apply({op}, args={args}, kwargs={kwargs}, output_shapes={self.eval(x.output_shapes)})"
f"{name} = backend.apply({op}, args={args}, kwargs={kwargs}, "
f"output_shapes={self.eval(x.output_shapes)})"
)
else:
op = self.eval(x.op)
Expand Down Expand Up @@ -404,11 +408,14 @@ def assertion(tracer, shape):
name = self.declare_local_name_for(None, prefix="op")
if self.backend == einx.backend.tracer:
self.lines.append(
f"{name} = backend.vmap({old_name}, in_axes={self.eval(x.in_axes)}, out_axes={self.eval(x.out_axes)}, input_shapes={self.eval(x.input_shapes)}, output_shapes={self.eval(x.output_shapes)})"
f"{name} = backend.vmap({old_name}, in_axes={self.eval(x.in_axes)}, "
f"out_axes={self.eval(x.out_axes)}, input_shapes={self.eval(x.input_shapes)}, "
f"output_shapes={self.eval(x.output_shapes)})"
)
else:
self.lines.append(
f"{name} = backend.vmap({old_name}, in_axes={self.eval(x.in_axes)}, out_axes={self.eval(x.out_axes)})"
f"{name} = backend.vmap({old_name}, in_axes={self.eval(x.in_axes)}, "
f"out_axes={self.eval(x.out_axes)})"
)
elif isinstance(x, str):
name = f'"{x}"'
Expand Down Expand Up @@ -747,7 +754,8 @@ def einsum(eq, *tensors):
expr = expr.strip().replace(" ", "")
if len(expr) != len(tensor.shape):
raise ValueError(
f"Expected {len(expr)} axes, got {len(tensor.shape)} for {i}-th (zero-based) input tensor"
f"Expected {len(expr)} axes, got {len(tensor.shape)} for {i}-th "
"(zero-based) input tensor"
)
for axis, value in zip(expr, tensor.shape):
if axis in values:
Expand Down
15 changes: 10 additions & 5 deletions einx/expr/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,15 @@ def solve(equations):
if isinstance(t1, Variable) and isinstance(t2, Constant):
if constants.get(t1.id, t2.value) != t2.value:
raise SolveException(
f"Found contradictory values { {constants[t1.id], t2.value} } for expression '{t1.name}'"
f"Found contradictory values { {constants[t1.id], t2.value} } for "
f"expression '{t1.name}'"
)
constants[t1.id] = t2.value
elif isinstance(t1, Constant) and isinstance(t2, Variable):
if constants.get(t2.id, t1.value) != t1.value:
raise SolveException(
f"Found contradictory values { {constants[t2.id], t1.value} } for expression '{t2.name}'"
f"Found contradictory values { {constants[t2.id], t1.value} } for "
f"expression '{t2.name}'"
)
constants[t2.id] = t1.value
elif isinstance(t1, Constant) and isinstance(t2, Constant):
Expand Down Expand Up @@ -195,11 +197,13 @@ def solve(equations):
names = {variables[a].name for a in eclass}
if len(names) == 1:
raise SolveException(
f"Found contradictory values {class_constants} for expression '{next(iter(names))}'"
f"Found contradictory values {class_constants} for expression "
f"'{next(iter(names))}'"
)
else:
raise SolveException(
f"Found contradictory values {class_constants} for equivalent expressions {names}"
f"Found contradictory values {class_constants} for equivalent "
f"expressions {names}"
)
v = Constant(next(iter(class_constants)))
else:
Expand Down Expand Up @@ -232,7 +236,8 @@ def replace(t):
if isinstance(t1, Constant) and isinstance(t2, Constant):
if t1.value != t2.value:
raise SolveException(
f"Found contradictory values {t1.value} != {t2.value} for same equivalence class"
f"Found contradictory values {t1.value} != {t2.value} "
"for same equivalence class"
)
elif t1 != t2:
equations2.append((t1, t2))
Expand Down
27 changes: 18 additions & 9 deletions einx/expr/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def __init__(self, children, ellipsis_indices):
for c in children:
if len(c) != 1:
raise ValueError(
f"Concatenation can only be used on expressions of length 1, but got expression '{c}'"
"Concatenation can only be used on expressions of length 1, "
f"but got expression '{c}'"
)
self.children = children
for c in children:
Expand Down Expand Up @@ -204,9 +205,11 @@ def __init__(self, exprs1, exprs2, expansions1, expansions2, depths1, depths2, m
exprs1, exprs2, expansions1, expansions2, depths1, depths2
):
self.message += " "
self.message += f"{einx.expr.util._to_str(expr1)} (expansion={einx.expr.util._to_str(expansion1)} at depth={depth1})"
self.message += f"{einx.expr.util._to_str(expr1)} (expansion="
f"{einx.expr.util._to_str(expansion1)} at depth={depth1})"
self.message += " = "
self.message += f"{einx.expr.util._to_str(expr2)} (expansion={einx.expr.util._to_str(expansion2)} at depth={depth2})"
self.message += f"{einx.expr.util._to_str(expr2)} (expansion="
f"{einx.expr.util._to_str(expansion2)} at depth={depth2})"
self.message += "\n"
super().__init__(self.message)

Expand Down Expand Up @@ -337,7 +340,7 @@ def solve(exprs1, exprs2, expansions1, expansions2, depths1, depths2):
except solver.SolveException as e:
raise SolveDepthException(
exprs1, exprs2, expansions1, expansions2, depths1, depths2, str(e)
)
) from e
expr_depths = {}
for k, v in solutions.items():
if k.startswith("symbolic_expr_depths["):
Expand Down Expand Up @@ -412,7 +415,8 @@ def solve(exprs1, exprs2, expansions1, expansions2, depths1, depths2):
f"symbolic_expr_expansions[{id(expr)},{depth}]", f"{expr} at depth {depth}"
)

# Add equations: Expansion of an expression at depth d (less than own depth) is equal to the expansion of each child at depth d
# Add equations: Expansion of an expression at depth d (less than own depth)
# is equal to the expansion of each child at depth d
for root in exprs1 + exprs2:
if root is not None:
for expr in root.all():
Expand Down Expand Up @@ -460,7 +464,8 @@ def solve(exprs1, exprs2, expansions1, expansions2, depths1, depths2):
expansions2,
depths1,
depths2,
f"Expansion '{expansion1}' of expression '{expr1}' does not match expansion '{expansion2}' of expression '{expr2}'",
f"Expansion '{expansion1}' of expression '{expr1}' does not match expansion "
f"'{expansion2}' of expression '{expr2}'",
)

if expansion1 is not None and expansion2 is not None:
Expand Down Expand Up @@ -532,7 +537,7 @@ def solve(exprs1, exprs2, expansions1, expansions2, depths1, depths2):
except solver.SolveException as e:
raise SolveExpansionException(
exprs1, exprs2, expansions1, expansions2, depths1, depths2, str(e)
)
) from e

def to_key(k):
return int(id_expr), int(depth)
Expand Down Expand Up @@ -599,7 +604,10 @@ def get_unnamed_value(expr):
value = get_unnamed_value(expr.inner)
if value != 1: # TODO: implement this
raise NotImplementedError(
f"Found unnamed and unexpanded ellipsis '{expr}'. We currently disallow this case, since it could can take on multiple values ('2...' could have values 2, 4, ...) that should be resolved in the solver and then checked to be consistent with these constraints."
f"Found unnamed and unexpanded ellipsis '{expr}'. We currently disallow this "
"case, since it could can take on multiple values ('2...' could have values "
"2, 4, ...) that should be resolved in the solver and then checked to be "
"consistent with these constraints."
)
return 1
else:
Expand Down Expand Up @@ -779,7 +787,8 @@ def remove_duplicates(common_expr):
print("CSE: Removed duplicates")
for v in common_exprs:
print(
f" {[' '.join([str(y) for y in x]) for x in v]} {[[id(y) for y in x] for x in v]}"
f" {[' '.join([str(y) for y in x]) for x in v]} "
f"{[[id(y) for y in x] for x in v]}"
)

# Remove singletons
Expand Down
5 changes: 3 additions & 2 deletions einx/expr/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def __init__(self, children):
for c in children:
if len(c) != 1:
raise ValueError(
f"Concatenation can only be used on expressions of length 1, but got expression '{c}'"
"Concatenation can only be used on expressions of length 1, but"
f"got expression '{c}'"
)
c.parent = self

Expand Down Expand Up @@ -313,7 +314,7 @@ def solve(exprs1, exprs2):
try:
solutions = solver.solve(equations)
except solver.SolveException as e:
raise SolveValueException(exprs1, exprs2, str(e))
raise SolveValueException(exprs1, exprs2, str(e)) from e
axis_values = {}
for k, v in solutions.items():
if k.startswith("symbolic_expr_values["):
Expand Down
13 changes: 8 additions & 5 deletions einx/expr/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def _input_expr(expr):
else:
try:
expr = np.asarray(expr)
except e:
raise ValueError(f"Invalid expression '{expr}'")
except Exception as e:
raise ValueError(f"Invalid expression '{expr}'") from e
if not np.issubdtype(expr.dtype, np.integer):
raise ValueError(f"Invalid expression '{expr}', must be integers")
expr = " ".join([str(i) for i in expr.flatten()])
Expand All @@ -45,7 +45,8 @@ def __init__(self, expr1, expr2=None, depth1=0, depth2=0):
self.depth2 = None if expr2 is None else depth2

def __repr__(self):
return f"{self.expr} = {self.value.tolist()} (expansion={self.expansion} at depth={self.depth})"
return f"{self.expr} = {self.value.tolist()} (expansion={self.expansion} at "
f"depth={self.depth})"


def _to_str(l): # Print numpy arrays in a single line rather than with line breaks
Expand Down Expand Up @@ -78,7 +79,8 @@ def solve(
exprs1, exprs2, expansions1, expansions2, depths1, depths2
):
print(
f" {_to_str(expr1)} (expansion={_to_str(expansion1)} at depth={depth1}) = {_to_str(expr2)} (expansion={_to_str(expansion2)} at depth={depth2})"
f" {_to_str(expr1)} (expansion={_to_str(expansion1)} at depth={depth1}) = "
f"{_to_str(expr2)} (expansion={_to_str(expansion2)} at depth={depth2})"
)

exprs1 = [(stage1.parse(expr) if isinstance(expr, str) else expr) for expr in exprs1]
Expand All @@ -99,7 +101,8 @@ def solve(
exprs1, exprs2, expansions1, expansions2, depths1, depths2
):
print(
f" {_to_str(expr1)} (expansion={_to_str(expansion1)} at depth={depth1}) = {_to_str(expr2)} (expansion={_to_str(expansion2)} at depth={depth2})"
f" {_to_str(expr1)} (expansion={_to_str(expansion1)} at depth={depth1}) = "
f"{_to_str(expr2)} (expansion={_to_str(expansion2)} at depth={depth2})"
)

exprs1, exprs2 = stage2.solve(exprs1, exprs2, expansions1, expansions2, depths1, depths2)
Expand Down
9 changes: 6 additions & 3 deletions einx/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ def func_with_warn(*args, **kwargs):
# Print warning
has_warned = True
print(
f"WARNING (einx): The following call stack has resulted in {warn_on_retrace_num} retraces of an einx function.\n"
f"A retrace happens when the function is called with different signatures of input arguments.\n"
f"WARNING (einx): The following call stack has resulted in "
f"{warn_on_retrace_num} retraces of an einx function.\n"
f"A retrace happens when the function is called with "
"different signatures of input arguments.\n"
f"Call stack (most recent call last):\n"
f"{trace}"
)
Expand Down Expand Up @@ -107,7 +109,8 @@ def lru_cache(func=None, trace=None):
else:
inner = freeze(functools.lru_cache(maxsize=max_cache_size)(func))
else:
# Arguments are traced: Create cache for graph, then wrap cache in a function that executes graph
# Arguments are traced: Create cache for graph, then wrap
# cache in a function that executes graph
@lru_cache
def construct_graph(*args, backend, **kwargs):
output_tracers = func(*args, **kwargs, backend=einx.backend.tracer)
Expand Down
Loading

0 comments on commit a6e8326

Please sign in to comment.