From 69949a35c0666b8b58880747bd2deadba7f46bba Mon Sep 17 00:00:00 2001 From: gbg141 Date: Thu, 15 Feb 2024 21:54:41 +0100 Subject: [PATCH] Updating tests for cell and combinatorial models --- test/nn/cell/test_can.py | 5 ++--- test/nn/cell/test_ccxn.py | 7 ++++--- test/nn/cell/test_cwn.py | 7 ++++--- test/nn/combinatorial/test_hmc.py | 8 +++++--- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/test/nn/cell/test_can.py b/test/nn/cell/test_can.py index c4c19eb0..245d4f02 100644 --- a/test/nn/cell/test_can.py +++ b/test/nn/cell/test_can.py @@ -18,7 +18,6 @@ def test_forward(self): out_channels=2, dropout=0.5, heads=1, - num_classes=1, n_layers=2, att_lift=False, ).to(device) @@ -36,5 +35,5 @@ def test_forward(self): adjacency_2 = adjacency_1.float().to(device) incidence_2 = adjacency_1.float().to(device) - y = model(x_0, x_1, adjacency_1, adjacency_2, incidence_2) - assert y.shape == torch.Size([1]) + x_1 = model(x_0, x_1, adjacency_1, adjacency_2, incidence_2) + assert x_1.shape == torch.Size([1, 2]) diff --git a/test/nn/cell/test_ccxn.py b/test/nn/cell/test_ccxn.py index a2fc4d92..9394864f 100644 --- a/test/nn/cell/test_ccxn.py +++ b/test/nn/cell/test_ccxn.py @@ -15,7 +15,6 @@ def test_forward(self): in_channels_0=2, in_channels_1=2, in_channels_2=2, - num_classes=1, n_layers=2, att=False, ).to(device) @@ -33,5 +32,7 @@ def test_forward(self): adjacency_1 = adjacency_1.float().to(device) incidence_2 = incidence_2.float().to(device) - y = model(x_0, x_1, adjacency_1, incidence_2) - assert y.shape == torch.Size([1]) + x_0, x_1, x_2 = model(x_0, x_1, adjacency_1, incidence_2) + assert x_0.shape == torch.Size([2, 2]) + assert x_1.shape == torch.Size([2, 2]) + assert x_2.shape == torch.Size([2, 2]) diff --git a/test/nn/cell/test_cwn.py b/test/nn/cell/test_cwn.py index e7e4f3ab..cb0aafd1 100644 --- a/test/nn/cell/test_cwn.py +++ b/test/nn/cell/test_cwn.py @@ -16,7 +16,6 @@ def test_forward(self): in_channels_1=2, in_channels_2=2, hid_channels=16, - num_classes=1, n_layers=2, ).to(device) @@ -36,5 +35,7 @@ def test_forward(self): incidence_2 = incidence_2.float().to(device) incidence_1_t = incidence_1_t.float().to(device) - y = model(x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t) - assert y.shape == torch.Size([1]) + x_0, x_1, x_2 = model(x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t) + assert x_0.shape == torch.Size([2, 16]) + assert x_1.shape == torch.Size([2, 16]) + assert x_2.shape == torch.Size([2, 16]) diff --git a/test/nn/combinatorial/test_hmc.py b/test/nn/combinatorial/test_hmc.py index a4371bc3..8270a7b1 100644 --- a/test/nn/combinatorial/test_hmc.py +++ b/test/nn/combinatorial/test_hmc.py @@ -15,7 +15,7 @@ def test_forward(self): intermediate_channels = [2, 2, 2] final_channels = [2, 2, 2] channels_per_layer = [[in_channels, intermediate_channels, final_channels]] - model = HMC(channels_per_layer, negative_slope=0.2, num_classes=2).to(device) + model = HMC(channels_per_layer, negative_slope=0.2).to(device) x_0 = torch.rand(2, 2) x_1 = torch.rand(2, 2) @@ -29,7 +29,7 @@ def test_forward(self): ) adjacency_0 = adjacency_0.float().to(device) - y = model( + x_0, x_1, x_2 = model( x_0, x_1, x_2, @@ -39,4 +39,6 @@ def test_forward(self): adjacency_0, adjacency_0, ) - assert y.shape == torch.Size([2]) + assert x_0.shape == torch.Size([2, 2]) + assert x_1.shape == torch.Size([2, 2]) + assert x_2.shape == torch.Size([2, 2])