diff --git a/tests/learner/test_AL.py b/tests/learner/test_AL.py index 051a230..75219f8 100644 --- a/tests/learner/test_AL.py +++ b/tests/learner/test_AL.py @@ -6,6 +6,7 @@ path as os_path, listdir as os_listdir ) +import pandas as pd from tempfile import TemporaryDirectory # Import your main function from main.py @@ -104,16 +105,25 @@ def test_basic(self): "--data_folder", os_path.join(TestCLI.input_folder, "top50"), "--parameter_file", TestCLI.parameter_file, "--output_folder", temp_dir, - "--save_plot", + # "--save_plot", "--seed", f"{TestCLI.seed}" ] with patch.object(sys, 'argv', test_args): main() - results = compare_folders(temp_dir, TestCLI.reference_output_folder) - for file, details in results.items(): - # Check if all files are identical - self.assertTrue(details[0] == "Identical", f"{file}: {details}") + # Compare dfs read from next samples files + next_samples_expected_df = pd.read_csv( + os_path.join(TestCLI.reference_output_folder, "next_sampling_ei50.csv") + ) + next_samples_output_df = pd.read_csv( + os_path.join(temp_dir, "next_sampling_ei50.csv") + ) + pd.testing.assert_frame_equal(next_samples_expected_df, next_samples_output_df) + + # results = compare_folders(temp_dir, TestCLI.reference_output_folder) + # for file, details in results.items(): + # # Check if all files are identical + # self.assertTrue(details[0] == "Identical", f"{file}: {details}")