diff --git a/einx/backend/tracer.py b/einx/backend/tracer.py index 66bbf36..d7fe70a 100644 --- a/einx/backend/tracer.py +++ b/einx/backend/tracer.py @@ -352,10 +352,10 @@ def reduce(tensor, axis, *, op=None, **kwargs): else: for a in reversed(sorted(axes)): del shape[a] - return Op(op, args=[tensor], kwargs=kwargs | {"axis": axis}, output_shapes=np.asarray(shape)).output_tracers + return Op(op, args=[tensor], kwargs={**kwargs, **{"axis": axis}}, output_shapes=np.asarray(shape)).output_tracers def map(tensor, axis, op, *args, **kwargs): - return Op(op, args=[tensor], kwargs=kwargs | {"axis": axis}, output_shapes=np.asarray(tensor.shape)).output_tracers + return Op(op, args=[tensor], kwargs={**kwargs, **{"axis": axis}}, output_shapes=np.asarray(tensor.shape)).output_tracers def index(tensor, coordinates, update=None, op=None): return Op(op, args=[tensor, coordinates, update], output_shapes=np.asarray(coordinates[0].shape)).output_tracers diff --git a/einx/lru_cache.py b/einx/lru_cache.py index 74b3f24..a460524 100644 --- a/einx/lru_cache.py +++ b/einx/lru_cache.py @@ -34,7 +34,7 @@ def lru_cache(func=None, trace=None): if max_cache_size == 0: inner = func elif max_cache_size < 0: - inner = freeze(functools.cache(func)) # No cache limit + inner = freeze(functools.lru_cache(maxsize=None)(func)) # No cache limit else: inner = freeze(functools.lru_cache(maxsize=max_cache_size)(func)) else: diff --git a/einx/param.py b/einx/param.py index 1b01565..415c397 100644 --- a/einx/param.py +++ b/einx/param.py @@ -30,7 +30,7 @@ def instantiate(x, shape, backend, **kwargs): raise TypeError("instantiate cannot be called on None") if backend == einx.backend.tracer: if is_tensor_factory(x): - return einx.backend.tracer.Op(instantiate, [x], {"shape": shape} | kwargs, output_shapes=np.asarray(shape), pass_backend=True).output_tracers + return einx.backend.tracer.Op(instantiate, [x], {**{"shape": shape}, **kwargs}, output_shapes=np.asarray(shape), pass_backend=True).output_tracers else: return einx.backend.tracer.Op("to_tensor", [x], output_shapes=np.asarray(shape)).output_tracers else: