Skip to content

Commit

Permalink
Updating tests for cell and combinatorial models
Browse files Browse the repository at this point in the history
  • Loading branch information
gbg141 committed Feb 15, 2024
1 parent 8789bc6 commit 69949a3
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 12 deletions.
5 changes: 2 additions & 3 deletions test/nn/cell/test_can.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def test_forward(self):
out_channels=2,
dropout=0.5,
heads=1,
num_classes=1,
n_layers=2,
att_lift=False,
).to(device)
Expand All @@ -36,5 +35,5 @@ def test_forward(self):
adjacency_2 = adjacency_1.float().to(device)
incidence_2 = adjacency_1.float().to(device)

y = model(x_0, x_1, adjacency_1, adjacency_2, incidence_2)
assert y.shape == torch.Size([1])
x_1 = model(x_0, x_1, adjacency_1, adjacency_2, incidence_2)
assert x_1.shape == torch.Size([1, 2])
7 changes: 4 additions & 3 deletions test/nn/cell/test_ccxn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def test_forward(self):
in_channels_0=2,
in_channels_1=2,
in_channels_2=2,
num_classes=1,
n_layers=2,
att=False,
).to(device)
Expand All @@ -33,5 +32,7 @@ def test_forward(self):
adjacency_1 = adjacency_1.float().to(device)
incidence_2 = incidence_2.float().to(device)

y = model(x_0, x_1, adjacency_1, incidence_2)
assert y.shape == torch.Size([1])
x_0, x_1, x_2 = model(x_0, x_1, adjacency_1, incidence_2)
assert x_0.shape == torch.Size([2, 2])
assert x_1.shape == torch.Size([2, 2])
assert x_2.shape == torch.Size([2, 2])
7 changes: 4 additions & 3 deletions test/nn/cell/test_cwn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def test_forward(self):
in_channels_1=2,
in_channels_2=2,
hid_channels=16,
num_classes=1,
n_layers=2,
).to(device)

Expand All @@ -36,5 +35,7 @@ def test_forward(self):
incidence_2 = incidence_2.float().to(device)
incidence_1_t = incidence_1_t.float().to(device)

y = model(x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t)
assert y.shape == torch.Size([1])
x_0, x_1, x_2 = model(x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t)
assert x_0.shape == torch.Size([2, 16])
assert x_1.shape == torch.Size([2, 16])
assert x_2.shape == torch.Size([2, 16])
8 changes: 5 additions & 3 deletions test/nn/combinatorial/test_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_forward(self):
intermediate_channels = [2, 2, 2]
final_channels = [2, 2, 2]
channels_per_layer = [[in_channels, intermediate_channels, final_channels]]
model = HMC(channels_per_layer, negative_slope=0.2, num_classes=2).to(device)
model = HMC(channels_per_layer, negative_slope=0.2).to(device)

x_0 = torch.rand(2, 2)
x_1 = torch.rand(2, 2)
Expand All @@ -29,7 +29,7 @@ def test_forward(self):
)
adjacency_0 = adjacency_0.float().to(device)

y = model(
x_0, x_1, x_2 = model(
x_0,
x_1,
x_2,
Expand All @@ -39,4 +39,6 @@ def test_forward(self):
adjacency_0,
adjacency_0,
)
assert y.shape == torch.Size([2])
assert x_0.shape == torch.Size([2, 2])
assert x_1.shape == torch.Size([2, 2])
assert x_2.shape == torch.Size([2, 2])

0 comments on commit 69949a3

Please sign in to comment.