Skip to content

Commit

Permalink
Modify lt case.
Browse files Browse the repository at this point in the history
  • Loading branch information
pdx1989 committed Dec 29, 2023
1 parent ac69e96 commit f027c62
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion dicp/test/op/test_lt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def forward(self, a, b):

class TestLt():
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("sizes", [[Size((5,), (5, 3)), Size((5,), (5, 3))], [Size((3, 5), (5, 3)), Size((3, 5), (5, 3))], [Size((2, 3, 4), (2, 4)), Size((2, 3, 4), (2, 4))], [Size((5, 1), (4, 1)), Size((5,), (4,))]])
@pytest.mark.parametrize("sizes", [[Size((5,), (5, 3)), Size((5,), (5, 3))], [Size((3, 5), (5, 3)), Size((3, 5), (5, 3))], [Size((2, 3, 4), (2, 4)), Size((2, 3, 4), (2, 4))], [Size((4,), (4,)), Size((4, 1), (4, 1))]])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_lt(self, sizes, dtype, compiled_model):
device = get_device()
Expand All @@ -32,6 +32,11 @@ def test_torch_lt(self, sizes, dtype, compiled_model):
input1 = torch.randn(size1, dtype=dtype)
input2 = torch.randn(size2, dtype=dtype)

# for case number 4
if size1 != size2:
input1 = torch.tensor([0, 1, 2, 3])
input2 = torch.tensor([[1], [2], [3], [4]])

dicp_input1 = input1.to(device)
dicp_input2 = input2.to(device)

Expand Down

0 comments on commit f027c62

Please sign in to comment.