Skip to content

Commit

Permalink
Rename Tracer -> Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Feb 18, 2024
1 parent 537b6ee commit 73c0c43
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions einx/backend/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def to_shape(shape):
else:
return shape

class Tracer:
class Tensor:
def __init__(self, shape):
self.shape = to_shape(shape)

Expand Down Expand Up @@ -93,7 +93,7 @@ def __eq__(self, other):
def __ne__(self, other):
return elementwise(self, other, op="not_equal")

class Input(Tracer):
class Input(Tensor):
def __init__(self, shape, index, original_type=None):
super().__init__(shape)
self.index = index
Expand Down Expand Up @@ -148,7 +148,7 @@ def __init__(self, op, args=[], kwargs={}, output_shapes=None):
self.output_tracers = einx.tree_util.tree_map_with_key(lambda shape, key: OpOutput(self, shape, key), self.output_shapes, is_leaf=is_leaf)
assert not "backend" in self.kwargs

class OpOutput(Tracer):
class OpOutput(Tensor):
def __init__(self, op, shape, key):
super().__init__(shape)
self.op = op
Expand Down Expand Up @@ -196,7 +196,7 @@ def inner_output_shapes(self):
def __call__(self, *args, **kwargs):
return OpApplication(self, args=args, kwargs=kwargs, output_shapes=self.output_shapes).output_tracers

class VmappedOpOutput(Tracer):
class VmappedOpOutput(Tensor):
def __init__(self, vmapped_op, index):
super().__init__(vmapped_op.output_shapes[index])
self.vmapped_op = vmapped_op
Expand Down Expand Up @@ -408,7 +408,7 @@ def assertion(tracer, shape):

class Graph:
def __init__(self, output, args, kwargs, backend):
assert any(isinstance(x, Tracer) for x in einx.tree_util.tree_flatten(output)), f"Expected at least one tracer in output, got {output}"
assert any(isinstance(x, Tensor) for x in einx.tree_util.tree_flatten(output)), f"Expected at least one tracer in output, got {output}"
self.output = output
self.args = args
self.kwargs = kwargs
Expand Down Expand Up @@ -537,7 +537,7 @@ def broadcast(dims):
input_shape = input_shape[1:]
elif s is None:
output_shape.append(1)
elif isinstance(s, Tracer) and s.ndim == 0:
elif isinstance(s, Tensor) and s.ndim == 0:
input_shape = input_shape[1:]
else:
raise TypeError(f"Invalid coordinate type: {type(s)}")
Expand All @@ -556,12 +556,12 @@ def to_tensor(tensor):
if isinstance(tensor, OpOutput) and tensor.op.op.op == "to_tensor":
# Merge consecutive to_tensor ops
return OpApplication("to_tensor", args=[tensor.op.args[0]], output_shapes=tensor.shape).output_tracers
if isinstance(tensor, Tracer):
if isinstance(tensor, Tensor):
return OpApplication("to_tensor", args=[tensor], output_shapes=tensor.shape).output_tracers
else:
return OpApplication("to_tensor", args=[tensor], output_shapes=einx.param.get_shape(tensor)).output_tracers

tensor = Tracer
tensor = Tensor
name = "tracer"

def op(op, tracable=False):
Expand Down

0 comments on commit 73c0c43

Please sign in to comment.