From 59d2378d9cae71b03ce6b7c9e35c41a868b6e645 Mon Sep 17 00:00:00 2001 From: Jordi Salvador Date: Wed, 8 May 2024 19:45:24 +0200 Subject: [PATCH] Pass number of threads in task sampler process as `nthread` --- allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py b/allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py index 2ae2b019..c92a907b 100644 --- a/allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py +++ b/allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py @@ -1644,13 +1644,14 @@ def _task_sampling_loop_worker( auto_resample_when_done: bool, should_log: bool, thread_barrier: threading.Barrier, + thread_num: int, ) -> None: """process worker for creating and interacting with the Tasks/TaskSampler.""" assert len(sampler_fn_args_list) == 1 sampler_fn_args_list = [ - {**cur_kwargs, "thread_id": worker_id, "thread_barrier": thread_barrier} + {**cur_kwargs, "thread_id": worker_id, "thread_barrier": thread_barrier, "nthread": thread_num} for cur_kwargs in sampler_fn_args_list ] @@ -1751,6 +1752,7 @@ def _start_workers( auto_resample_when_done=self._auto_resample_when_done, should_log=self.should_log, thread_barrier=barrier, + thread_num=self._num_task_samplers, ), )