Skip to content

Commit

Permalink
SingleProcess decoupled from thread ids
Browse files Browse the repository at this point in the history
  • Loading branch information
jordis-ai2 committed Apr 30, 2024
1 parent e435565 commit 21c6182
Showing 1 changed file with 15 additions and 49 deletions.
64 changes: 15 additions & 49 deletions allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,17 @@ def mp_ctx(self):
"""
return self._mp_ctx

# def next_task(self, **kwargs):
# if "force_advance_scene" in kwargs:
# # TODO just send it to the first thread in each process if in multithreading
# self.command_at()
# # TODO just send it to all the rest of subprocesses in each process (this sucks big time)
# return self.command(
# commands=NEXT_TASK_COMMAND, data_list=[kwargs] * self.num_unpaused_tasks
# )
#
# return super().next_task(**kwargs)

@staticmethod
def _task_sampling_loop_worker(
worker_id: Union[int, str],
Expand Down Expand Up @@ -942,7 +953,7 @@ def command(
) -> List[Any]:
""""""
return super().command(
commands, data_list, partition_fn=self._partition_to_processes
commands, data_list, partition_fn=kwargs.get("partition_function") or self._partition_to_processes
)

def call(
Expand Down Expand Up @@ -1006,7 +1017,6 @@ def __init__(
callback_sensor_suite: Optional[SensorSuite] = None,
auto_resample_when_done: bool = True,
should_log: bool = True,
local_worker_id: Optional[int] = None,
**kwargs: Any,
) -> None:

Expand All @@ -1021,8 +1031,6 @@ def __init__(

self.should_log = should_log

self.local_worker_id = local_worker_id

self._vector_task_generators: List[Generator] = self._create_generators(
make_sampler_fn=make_sampler_fn,
sampler_fn_args=[{"mp_ctx": None, **args} for args in sampler_fn_args_list],
Expand Down Expand Up @@ -1069,15 +1077,10 @@ def _task_sampling_loop_generator_fn(
callback_sensor_suite: Optional[SensorSuite],
auto_resample_when_done: bool,
should_log: bool,
local_worker_id: Optional[int],
) -> Generator:
"""Generator for working with Tasks/TaskSampler."""

task_sampler_args = {**sampler_fn_args}
if local_worker_id is not None:
task_sampler_args["thread_id"] = local_worker_id

task_sampler = make_sampler_fn(**task_sampler_args)
task_sampler = make_sampler_fn(**sampler_fn_args)
current_task = task_sampler.next_task()

if current_task is None:
Expand Down Expand Up @@ -1230,7 +1233,6 @@ def _create_generators(
callback_sensor_suite=callback_sensor_suite,
auto_resample_when_done=self._auto_resample_when_done,
should_log=self.should_log,
local_worker_id=self.local_worker_id,
)
)

Expand Down Expand Up @@ -1653,13 +1655,14 @@ def _task_sampling_loop_worker(
"""process worker for creating and interacting with the
Tasks/TaskSampler."""

sampler_fn_args_list = [{**cur_kwargs, "thread_id": worker_id} for cur_kwargs in sampler_fn_args_list]

sp_vector_sampled_tasks = SingleProcessVectorSampledTasks(
make_sampler_fn=make_sampler_fn,
sampler_fn_args_list=sampler_fn_args_list,
callback_sensor_suite=callback_sensor_suite,
auto_resample_when_done=auto_resample_when_done,
should_log=should_log,
local_worker_id=worker_id,
)

try:
Expand All @@ -1670,29 +1673,8 @@ def _task_sampling_loop_worker(
raise NotImplementedError(
f"got {read_input}, but only one sampler per thread is implemented"
)
# sampler_index, command, data = read_input
#
# assert command != CLOSE_COMMAND, "Must close all threads at once."
# assert (
# command != RESUME_COMMAND
# ), "Must resume all task samplers at once."
#
# if command == PAUSE_COMMAND:
# sp_vector_sampled_tasks.pause_at(sampler_index=sampler_index)
# connection_write_fn("done")
# else:
# connection_write_fn(
# sp_vector_sampled_tasks.command_at(
# sampler_index=sampler_index, command=command, data=data,
# )
# )
else:
commands, data_list = read_input
# print(f"worker {worker_id} commands {commands} data_list {data_list}")

# assert (
# commands != PAUSE_COMMAND
# ), "Cannot pause all task samplers at once."

if commands == PAUSE_COMMAND:
get_logger().info(
Expand All @@ -1709,22 +1691,6 @@ def _task_sampling_loop_worker(
sp_vector_sampled_tasks.resume_all()
connection_write_fn("done")

# elif commands == CALL_COMMAND:
# if isinstance(data_list[0], str):
# commands = [
# data_list[0]
# ] * sp_vector_sampled_tasks.num_unpaused_tasks
# data_list = [data_list[1]] * sp_vector_sampled_tasks.num_unpaused_tasks
# else:
# commands = data_list[0]
# data_list = data_list[1]
#
# connection_write_fn(
# sp_vector_sampled_tasks.command(
# commands=commands, data_list=data_list
# )
# )

else:
if isinstance(commands, str):
commands = [
Expand Down

0 comments on commit 21c6182

Please sign in to comment.