From ebd2bb7ab5b497b17e4609484363956309e9f94a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Tue, 3 Sep 2024 13:17:48 +0200 Subject: [PATCH] Fix empty load stage when two `GlobalStep`s are chained (#945) --- src/distilabel/pipeline/_dag.py | 12 +++++++++--- tests/unit/pipeline/test_dag.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/src/distilabel/pipeline/_dag.py b/src/distilabel/pipeline/_dag.py index 383703ccf..3ca2b569d 100644 --- a/src/distilabel/pipeline/_dag.py +++ b/src/distilabel/pipeline/_dag.py @@ -294,13 +294,19 @@ def _get_stage_last_steps(stage_steps: List[str]) -> List[str]: current_stage = [] stages_last_steps = [] - for step_name in nx.topological_sort(self.G): + steps_sorted = list(nx.topological_sort(self.G)) + for i, step_name in enumerate(steps_sorted): step: "_Step" = self.get_step(step_name)[STEP_ATTR_NAME] if not step.is_global: current_stage.append(step_name) else: - stages.append(current_stage) - stages_last_steps.append(_get_stage_last_steps(current_stage)) + previous_step = None + if i > 0: + previous_step_name = steps_sorted[i - 1] + previous_step = self.get_step(previous_step_name)[STEP_ATTR_NAME] + if not previous_step or not previous_step.is_global: + stages.append(current_stage) + stages_last_steps.append(_get_stage_last_steps(current_stage)) stages.append([step_name]) stages_last_steps.append([step_name]) current_stage = [] diff --git a/tests/unit/pipeline/test_dag.py b/tests/unit/pipeline/test_dag.py index db0a90dcc..b1e14cf83 100644 --- a/tests/unit/pipeline/test_dag.py +++ b/tests/unit/pipeline/test_dag.py @@ -276,6 +276,37 @@ def test_get_steps_load_stages(self) -> None: ], ) + def test_get_steps_load_stages_global_steps_chained(self) -> None: + with Pipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep(name="dummy_generator_step") + dummies_0 = [DummyStep1(name=f"dummy_step_0_{i}") for i in range(3)] + global_0 = DummyGlobalStep(name="global_0") + global_1 = DummyGlobalStep(name="global_1") + + generator >> dummies_0 >> global_0 >> global_1 + + assert pipeline.dag.get_steps_load_stages() == ( + [ + [ + "dummy_generator_step", + "dummy_step_0_0", + "dummy_step_0_1", + "dummy_step_0_2", + ], + ["global_0"], + ["global_1"], + ], + [ + [ + "dummy_step_0_0", + "dummy_step_0_1", + "dummy_step_0_2", + ], + ["global_0"], + ["global_1"], + ], + ) + def test_get_steps_load_stages_simple(self) -> None: with Pipeline(name="dummy") as pipeline: generator = DummyGeneratorStep(name="dummy_generator_step")