Skip to content

Commit

Permalink
macos fail
Browse files Browse the repository at this point in the history
  • Loading branch information
aturker-synnada committed Jan 3, 2025
1 parent 43b5569 commit 9ef4044
Showing 1 changed file with 49 additions and 35 deletions.
84 changes: 49 additions & 35 deletions tests/scripts/test_backend_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .test_utils import get_array_device, get_array_precision

# Create instances of each backend
backends = [ml.NumpyBackend, ml.JaxBackend, ml.TorchBackend, ml.MlxBackend]
backends = [ml.NumpyBackend, ml.TorchBackend, ml.MlxBackend]


testing_fns: dict[type[ml.Backend], Callable] = {}
Expand Down Expand Up @@ -115,25 +115,33 @@ def assert_backend_results_equal(
):
ref_output_device = ref_output_device.split(":")[0]
testing_fn = testing_fns[backend.__class__]

print("fn is here")
output = fn(*fn_args, **fn_kwargs)
print("fn is after")
print("output:", output, type(output))
assert not isinstance(output, tuple | list) ^ isinstance(ref_output, tuple | list)
if not isinstance(output, tuple | list):
output = (output,)
print("new_out:", output)
if not isinstance(ref_output, tuple | list):
ref_output = (ref_output,)
print("ref:", ref_output, type(ref_output))
# for out, ref in zip(output, ref_output, strict=False):
print("get type")
assert tuple(output[0].shape) == tuple(ref_output[0].shape)
assert (
backend.backend_type == "mlx"
or get_array_device(output[0], backend.backend_type) == ref_output_device
)
assert (
get_array_precision(output[0], backend.backend_type)
== DtypeBits[ref_output_dtype.name].value
)
print("testing_fn", output[0], ref_output[0], rtol, atol, type(output[0]))
assert testing_fn(output[0], ref_output[0], rtol=rtol, atol=atol)
print("testing fn")

for out, ref in zip(output, ref_output, strict=False):
assert tuple(out.shape) == tuple(ref.shape)
assert (
backend.backend_type == "mlx"
or get_array_device(out, backend.backend_type) == ref_output_device
)
assert (
get_array_precision(out, backend.backend_type)
== DtypeBits[ref_output_dtype.name].value
)
assert testing_fn(out, ref, rtol=rtol, atol=atol)
print("test is done?")


unsupported_device_dtypes = [
Expand Down Expand Up @@ -172,36 +180,39 @@ def assert_backend_results_equal(
"backendcls, device, dtype", backends_with_device_dtype, ids=names
)
class TestArray:
def test_array_int(self, backendcls, device, dtype):
backend = backendcls(device=device, dtype=dtype)
array_fn = array_fns[backend.__class__]
fn = backend.array
fn_args = [[1, 2, 3]]
fn_kwargs: dict = {}

ref_output = array_fn(
[1, 2, 3], str(device), f"int{DtypeBits[dtype.name].value}"
)
assert_backend_results_equal(
backend,
fn,
fn_args,
fn_kwargs,
ref_output,
device,
dtype,
tolerances[dtype],
tolerances[dtype],
)
# def test_array_int(self, backendcls, device, dtype):
# print(backendcls, device, dtype.name, "array_int")
# backend = backendcls(device=device, dtype=dtype)
# array_fn = array_fns[backend.__class__]
# fn = backend.array
# fn_args = [[1, 2, 3]]
# fn_kwargs: dict = {}

# ref_output = array_fn(
# [1, 2, 3], str(device), f"int{DtypeBits[dtype.name].value}"
# )
# assert_backend_results_equal(
# backend,
# fn,
# fn_args,
# fn_kwargs,
# ref_output,
# device,
# dtype,
# tolerances[dtype],
# tolerances[dtype],
# )

def test_array_float(self, backendcls, device, dtype):
print(backendcls, device, dtype.name, "array_float")
backend = backendcls(device=device, dtype=dtype)
array_fn = array_fns[backend.__class__]
fn = backend.array
fn_args = [[1.0, 2, 3]]
fn_kwargs: dict = {}

print("before")
ref_output = array_fn([1, 2, 3], str(device), dtype.name)
print("after")
assert_backend_results_equal(
backend,
fn,
Expand All @@ -213,8 +224,10 @@ def test_array_float(self, backendcls, device, dtype):
tolerances[dtype],
tolerances[dtype],
)
print("test done")

def test_array_edge_case(self, backendcls, device, dtype):
print(backendcls, device, dtype.name, "array_edge")
backend = backendcls(device=device, dtype=dtype)
array_fn = array_fns[backend.__class__]
fn = backend.array
Expand All @@ -240,6 +253,7 @@ def test_array_edge_case(self, backendcls, device, dtype):
)
class TestZeros:
def test_zeros(self, backendcls, device, dtype):
print(backendcls, device, dtype.name, "zeros")
array_fn = array_fns[backendcls]
backend = backendcls(device=device, dtype=dtype)

Expand Down

0 comments on commit 9ef4044

Please sign in to comment.