diff --git a/gokart/task.py b/gokart/task.py index 4f0f6de5..c5360380 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -1,3 +1,4 @@ +import functools import hashlib import inspect import os @@ -118,6 +119,11 @@ def __init__(self, *args, **kwargs): self._rerun_state = self.rerun self._lock_at_dump = True + # Cache to_str_params to avoid slow task creation in a deep task tree. + # For example, gokart.build(RecursiveTask(dep=RecursiveTask(dep=RecursiveTask(dep=HelloWorldTask())))) results in O(n^2) calls to to_str_params. + # However, @lru_cache cannot be used as a decorator because luigi.Task employs metaclass tricks. + self.to_str_params = functools.lru_cache(maxsize=None)(self.to_str_params) # type: ignore[method-assign] + if self.complete_check_at_run: self.run = task_complete_check_wrapper(run_func=self.run, complete_check_func=self.complete) # type: ignore diff --git a/test/test_task_on_kart.py b/test/test_task_on_kart.py index 8e8ef9dd..924c036e 100644 --- a/test/test_task_on_kart.py +++ b/test/test_task_on_kart.py @@ -584,6 +584,17 @@ def test_serialize_and_deserialize_default_values(self): deserialized: gokart.TaskOnKart = luigi.task_register.load_task(None, task.get_task_family(), task.to_str_params()) self.assertDictEqual(task.to_str_params(), deserialized.to_str_params()) + def test_to_str_params_changes_on_values_and_flags(self): + class _DummyTaskWithParams(gokart.TaskOnKart): + task_namespace = __name__ + param: str = luigi.Parameter() + + t1 = _DummyTaskWithParams(param='a') + self.assertEqual(t1.to_str_params(), t1.to_str_params()) # cache + self.assertEqual(t1.to_str_params(), _DummyTaskWithParams(param='a').to_str_params()) # same value + self.assertNotEqual(t1.to_str_params(), _DummyTaskWithParams(param='b').to_str_params()) # different value + self.assertNotEqual(t1.to_str_params(), t1.to_str_params(only_significant=True)) + def test_should_lock_run_when_set(self): class _DummyTaskWithLock(gokart.TaskOnKart): def run(self):