Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
aturker-synnada committed Jan 20, 2025
1 parent 4f19b6e commit e53a922
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 26 deletions.
43 changes: 18 additions & 25 deletions tests/scripts/test_constant_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,13 @@ def assert_all_backends_device_dtype(model: Model, inference: bool = False):
params=randomized_inputs,
)

# Check if gradients have correct device and dtype
for grad in grads.values():
assert (
backend.backend_type == "mlx" or get_array_device(grad, _type) == device
)
assert get_array_precision(grad, _type) == DtypeBits[dtype.name].value
# Check if gradients have correct device and dtype
for grad in grads.values():
assert (
backend.backend_type == "mlx"
or get_array_device(grad, _type) == device
)
assert get_array_precision(grad, _type) == DtypeBits[dtype.name].value

# In final step. we compare used inputs (used inputs are given as input to the
# either to comp_model.evaluate() or comp_model.evaluate_gradients()) with their
Expand Down Expand Up @@ -1068,7 +1069,7 @@ def test_bool_tensor_numpy_64():
left=Tensor([7.0, 8.0]), right=not_1.output, output=IOKey(name="output")
)
comp_model = ml.compile(
model=model, backend=NumpyBackend(precision=64), inference=True
model=model, backend=NumpyBackend(dtype=ml.float64), inference=True
)
output = comp_model.evaluate()["output"]
assert isinstance(output, np.ndarray)
Expand All @@ -1085,9 +1086,7 @@ def test_bool_tensor_torch_32():
model += add_1(
left=Tensor([7.0, 8.0]), right=not_1.output, output=IOKey(name="output")
)
comp_model = ml.compile(
model=model, backend=TorchBackend(precision=32), inference=True
)
comp_model = ml.compile(model=model, backend=TorchBackend(), inference=True)
output = comp_model.evaluate()["output"]
assert isinstance(output, torch.Tensor)
out = output.numpy()
Expand All @@ -1105,7 +1104,7 @@ def test_bool_tensor_torch_64():
left=Tensor([7.0, 8.0]), right=not_1.output, output=IOKey(name="output")
)
comp_model = ml.compile(
model=model, backend=TorchBackend(precision=64), inference=True
model=model, backend=TorchBackend(dtype=ml.float64), inference=True
)
output = comp_model.evaluate()["output"]
assert isinstance(output, torch.Tensor)
Expand All @@ -1123,9 +1122,7 @@ def test_bool_tensor_jax_32():
model += add_1(
left=Tensor([7.0, 8.0]), right=not_1.output, output=IOKey(name="output")
)
comp_model = ml.compile(
model=model, backend=JaxBackend(precision=32), inference=True
)
comp_model = ml.compile(model=model, backend=JaxBackend(), inference=True)
output = np.array(comp_model.evaluate()["output"])
np.testing.assert_allclose(output, ref)
assert output.dtype == np.float32
Expand All @@ -1141,7 +1138,7 @@ def test_bool_tensor_jax_64():
left=Tensor([7.0, 8.0]), right=not_1.output, output=IOKey(name="output")
)
comp_model = ml.compile(
model=model, backend=JaxBackend(precision=64), inference=True
model=model, backend=JaxBackend(dtype=ml.float64), inference=True
)
output = np.array(comp_model.evaluate()["output"])
np.testing.assert_allclose(output, ref)
Expand All @@ -1157,9 +1154,7 @@ def test_bool_tensor_mlx_32():
model += add_1(
left=Tensor([7.0, 8.0]), right=not_1.output, output=IOKey(name="output")
)
comp_model = ml.compile(
model=model, backend=JaxBackend(precision=32), inference=True
)
comp_model = ml.compile(model=model, backend=JaxBackend(), inference=True)
output = np.array(comp_model.evaluate()["output"])
np.testing.assert_allclose(output, ref)
assert output.dtype == np.float32
Expand All @@ -1175,7 +1170,7 @@ def test_bool_tensor_mlx_64():
left=Tensor([7.0, 8.0]), right=not_1.output, output=IOKey(name="output")
)
comp_model = ml.compile(
model=model, backend=JaxBackend(precision=64), inference=True
model=model, backend=JaxBackend(dtype=ml.float64), inference=True
)
output = np.array(comp_model.evaluate()["output"])
np.testing.assert_allclose(output, ref)
Expand Down Expand Up @@ -1358,9 +1353,7 @@ def test_static_input_6():
model_2 += model_1(left=add_3.left, right=add_3.right, out2=IOKey(name="output_1"))

backend = JaxBackend()
comp_model = ml.compile(
model=model_2, backend=backend, jit=False, inference=True
)
comp_model = ml.compile(model=model_2, backend=backend, jit=False, inference=True)
output = comp_model.evaluate()

assert model_1.left.metadata.value == 3.0 # type: ignore # It is Tensor type.
Expand Down Expand Up @@ -1604,7 +1597,7 @@ def test_composite_5():
model += add_model_1(left=IOKey(value=list1, name="left1"), right=list1)
model += add_model_2(left=add_model_1.output, right=list2)
model += add_model_3(left=add_model_2.output, right=list3)

assert_all_backends_device_dtype(model, inference=True)


Expand Down Expand Up @@ -1669,7 +1662,7 @@ def test_composite_7():
model += add_model_1(left=IOKey(name="left1", value=Tensor([[1]])), right=list1)
model += add_model_2(left=add_model_1.output, right=list2)
model += add_model_3(left=add_model_2.output, right=list3)

assert_all_backends_device_dtype(model, inference=True)


Expand All @@ -1687,7 +1680,7 @@ def test_composite_7_set_values():
model.set_values({add_model_2.right: list2})
model += add_model_3(left=add_model_2.output)
model.set_values({add_model_3.right: list3})

assert_all_backends_device_dtype(model, inference=True)


Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/test_flatmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def test_integration_with_all_defined():
add = Add()
add.set_types(left=Tensor, right=Tensor)
model += add(left="a", right="b", output="c")
backend = JaxBackend(dtype=ml.float64)
backend = JaxBackend(dtype=ml.float64)

pm_short = ml.compile(model, backend)
pm_long = ml.compile(model, backend, use_short_namings=False)
Expand Down

0 comments on commit e53a922

Please sign in to comment.