Skip to content

Commit

Permalink
[RLlib] APPO accelerate (vol 18): EnvRunner sync enhancements. (ray-p…
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Feb 28, 2025
1 parent 6692500 commit 02d4a3a
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 51 deletions.
18 changes: 9 additions & 9 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2298,16 +2298,16 @@ py_test(
)

py_test(
name = "examples/evaluation/custom_evaluation_parallel_to_training",
name = "examples/evaluation/custom_evaluation_parallel_to_training_10_episodes",
main = "examples/evaluation/custom_evaluation.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/evaluation/custom_evaluation.py"],
args = ["--enable-new-api-stack", "--as-test", "--framework=torch", "--stop-reward=0.75", "--evaluation-parallel-to-training", "--num-cpus=5"]
args = ["--enable-new-api-stack", "--as-test", "--stop-reward=0.75", "--evaluation-parallel-to-training", "--num-cpus=5", "--evaluation-duration=10", "--evaluation-duration-unit=episodes"]
)

py_test(
name = "examples/evaluation/evaluation_parallel_to_training_duration_auto_torch",
name = "examples/evaluation/evaluation_parallel_to_training_duration_auto",
main = "examples/evaluation/evaluation_parallel_to_training.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
Expand All @@ -2316,7 +2316,7 @@ py_test(
)

py_test(
name = "examples/evaluation/evaluation_parallel_to_training_multi_agent_duration_auto_torch",
name = "examples/evaluation/evaluation_parallel_to_training_multi_agent_duration_auto",
main = "examples/evaluation/evaluation_parallel_to_training.py",
tags = ["team:rllib", "exclusive", "examples", "examples_use_all_core"],
size = "large",
Expand All @@ -2343,7 +2343,7 @@ py_test(
)

py_test(
name = "examples/evaluation/evaluation_parallel_to_training_13_episodes_torch",
name = "examples/evaluation/evaluation_parallel_to_training_13_episodes",
main = "examples/evaluation/evaluation_parallel_to_training.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
Expand All @@ -2352,7 +2352,7 @@ py_test(
)

py_test(
name = "examples/evaluation/evaluation_parallel_to_training_multi_agent_10_episodes_torch",
name = "examples/evaluation/evaluation_parallel_to_training_multi_agent_10_episodes",
main = "examples/evaluation/evaluation_parallel_to_training.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
Expand All @@ -2362,17 +2362,17 @@ py_test(

# @OldAPIStack
py_test(
name = "examples/evaluation/evaluation_parallel_to_training_duration_auto_torch_old_api_stack",
name = "examples/evaluation/evaluation_parallel_to_training_duration_auto_old_api_stack",
main = "examples/evaluation/evaluation_parallel_to_training.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/evaluation/evaluation_parallel_to_training.py"],
args = ["--as-test", "--evaluation-parallel-to-training", "--stop-reward=50.0", "--num-cpus=6", "--evaluation-duration=auto"]
args = ["--as-test", "--evaluation-parallel-to-training", "--stop-reward=50.0", "--num-cpus=6", "--evaluation-duration=auto", "--evaluation-duration-unit=timesteps"]
)

# @OldAPIStack
py_test(
name = "examples/evaluation/evaluation_parallel_to_training_211_ts_torch_old_api_stack",
name = "examples/evaluation/evaluation_parallel_to_training_211_ts_old_api_stack",
main = "examples/evaluation/evaluation_parallel_to_training.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
Expand Down
29 changes: 23 additions & 6 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,10 +1041,23 @@ def evaluate(

# Sync weights to the evaluation EnvRunners.
if self.eval_env_runner_group is not None:
if self.config.enable_env_runner_and_connector_v2:
if (
self.env_runner_group is not None
and self.env_runner_group.healthy_env_runner_ids()
):
# TODO (sven): Replace this with a new ActorManager API:
# try_remote_request_till_success("get_state") -> tuple(int,
# remoteresult)
weights_src = self.env_runner_group._worker_manager._actors[
self.env_runner_group.healthy_env_runner_ids()[0]
]
else:
weights_src = self.learner_group
else:
weights_src = self.env_runner_group.local_env_runner
self.eval_env_runner_group.sync_weights(
from_worker_or_learner_group=self.learner_group
if self.config.enable_env_runner_and_connector_v2
else self.env_runner_group.local_env_runner,
from_worker_or_learner_group=weights_src,
inference_only=True,
)

Expand Down Expand Up @@ -1444,7 +1457,7 @@ def _evaluate_with_fixed_duration(self):

# Remote function used on healthy EnvRunners to sample, get metrics, and
# step counts.
def _env_runner_remote(worker, num, round, iter):
def _env_runner_remote(worker, num, round, iter, _force_reset):
# Sample AND get_metrics, but only return metrics (and steps actually taken)
# to save time. Also return the iteration to check, whether we should
# discard and outdated result (from a slow worker).
Expand All @@ -1453,7 +1466,7 @@ def _env_runner_remote(worker, num, round, iter):
num[worker.worker_index] if unit == "timesteps" else None
),
num_episodes=(num[worker.worker_index] if unit == "episodes" else None),
force_reset=force_reset and round == 0,
force_reset=_force_reset and round == 0,
)
metrics = worker.get_metrics()
env_steps = sum(e.env_steps() for e in episodes)
Expand Down Expand Up @@ -1491,7 +1504,11 @@ def _env_runner_remote(worker, num, round, iter):
]
self.eval_env_runner_group.foreach_env_runner_async(
func=functools.partial(
_env_runner_remote, num=_num, round=_round, iter=algo_iteration
_env_runner_remote,
num=_num,
round=_round,
iter=algo_iteration,
_force_reset=force_reset,
),
)
results = self.eval_env_runner_group.fetch_ready_async_reqs(
Expand Down
111 changes: 76 additions & 35 deletions rllib/env/env_runner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,13 +495,9 @@ def sync_env_runner_states(
config.num_env_runners or 1
)

# Update the rl_module component of the EnvRunner states, if necessary:
if rl_module_state:
env_runner_states.update(rl_module_state)

# If we do NOT want remote EnvRunners to get their Connector states updated,
# only update the local worker here (with all state components) and then remove
# the connector components.
# only update the local worker here (with all state components, except the model
# weights) and then remove the connector components.
if not config.update_worker_filter_stats:
self.local_env_runner.set_state(env_runner_states)
env_runner_states.pop(COMPONENT_ENV_TO_MODULE_CONNECTOR, None)
Expand All @@ -510,18 +506,24 @@ def sync_env_runner_states(
# If there are components in the state left -> Update remote workers with these
# state components (and maybe the local worker, if it hasn't been updated yet).
if env_runner_states:
# Put the state dictionary into Ray's object store to avoid having to make n
# pickled copies of the state dict.
ref_env_runner_states = ray.put(env_runner_states)

def _update(_env_runner: EnvRunner) -> None:
_env_runner.set_state(ray.get(ref_env_runner_states))
# Update the local EnvRunner, but NOT with the weights. If used at all for
# evaluation (through the user calling `self.evaluate`), RLlib would update
# the weights up front either way.
if config.update_worker_filter_stats:
self.local_env_runner.set_state(env_runner_states)

# Send the model weights only to remote EnvRunners.
# In case the local EnvRunner is ever needed for evaluation,
# RLlib updates its weight right before such an eval step.
if rl_module_state:
env_runner_states.update(rl_module_state)

# Broadcast updated states back to all workers.
self.foreach_env_runner(
_update,
"set_state", # Call the `set_state()` remote method.
kwargs=dict(state=env_runner_states),
remote_worker_ids=env_runner_indices_to_update,
local_env_runner=config.update_worker_filter_stats,
local_env_runner=False,
timeout_seconds=0.0, # This is a state update -> Fire-and-forget.
)

Expand Down Expand Up @@ -581,18 +583,35 @@ def sync_weights(
if policies is not None
else [COMPONENT_RL_MODULE]
)
# LearnerGroup has-a Learner has-a RLModule.
# LearnerGroup has-a Learner, which has-a RLModule.
if isinstance(weights_src, LearnerGroup):
rl_module_state = weights_src.get_state(
components=[COMPONENT_LEARNER + "/" + m for m in modules],
inference_only=inference_only,
)[COMPONENT_LEARNER]
# EnvRunner has-a RLModule.
# EnvRunner (new API stack).
elif self._remote_config.enable_env_runner_and_connector_v2:
rl_module_state = weights_src.get_state(
components=modules,
inference_only=inference_only,
)
# EnvRunner (remote) has-a RLModule.
# TODO (sven): Replace this with a new ActorManager API:
# try_remote_request_till_success("get_state") -> tuple(int,
# remoteresult)
# `weights_src` could be the ActorManager, then. Then RLlib would know
# that it has to ping the manager to try all healthy actors until the
# first returns something.
if isinstance(weights_src, ray.actor.ActorHandle):
rl_module_state = ray.get(
weights_src.get_state.remote(
components=modules,
inference_only=inference_only,
)
)
# EnvRunner (local) has-a RLModule.
else:
rl_module_state = weights_src.get_state(
components=modules,
inference_only=inference_only,
)
# RolloutWorker (old API stack).
else:
rl_module_state = weights_src.get_weights(
policies=policies,
Expand All @@ -613,22 +632,28 @@ def sync_weights(
# copies of the weights dict for each worker.
rl_module_state_ref = ray.put(rl_module_state)

def _set_weights(env_runner):
env_runner.set_state(ray.get(rl_module_state_ref))
# Sync to specified remote workers in this EnvRunnerGroup.
self.foreach_env_runner(
func="set_state",
kwargs=dict(state=rl_module_state_ref),
local_env_runner=False, # Do not sync back to local worker.
remote_worker_ids=to_worker_indices,
timeout_seconds=timeout_seconds,
)

else:
rl_module_state_ref = ray.put(rl_module_state)

def _set_weights(env_runner):
env_runner.set_weights(ray.get(rl_module_state_ref), global_vars)

# Sync to specified remote workers in this EnvRunnerGroup.
self.foreach_env_runner(
func=_set_weights,
local_env_runner=False, # Do not sync back to local worker.
remote_worker_ids=to_worker_indices,
timeout_seconds=timeout_seconds,
)
# Sync to specified remote workers in this EnvRunnerGroup.
self.foreach_env_runner(
func=_set_weights,
local_env_runner=False, # Do not sync back to local worker.
remote_worker_ids=to_worker_indices,
timeout_seconds=timeout_seconds,
)

# If `from_worker_or_learner_group` is provided, also sync to this
# EnvRunnerGroup's local worker.
Expand Down Expand Up @@ -716,8 +741,11 @@ def stop(self) -> None:

def foreach_env_runner(
self,
func: Callable[[EnvRunner], T],
func: Union[
Callable[[EnvRunner], T], List[Callable[[EnvRunner], T]], str, List[str]
],
*,
kwargs=None,
local_env_runner: bool = True,
healthy_only: bool = True,
remote_worker_ids: List[int] = None,
Expand Down Expand Up @@ -759,13 +787,18 @@ def foreach_env_runner(

local_result = []
if local_env_runner and self.local_env_runner is not None:
local_result = [func(self.local_env_runner)]
assert kwargs is None
if isinstance(func, str):
local_result = [getattr(self.local_env_runner, func)]
else:
local_result = [func(self.local_env_runner)]

if not self._worker_manager.actor_ids():
return local_result

remote_results = self._worker_manager.foreach_actor(
func,
kwargs=kwargs,
healthy_only=healthy_only,
remote_actor_ids=remote_worker_ids,
timeout_seconds=timeout_seconds,
Expand All @@ -782,18 +815,24 @@ def foreach_env_runner(

return local_result + remote_results

# TODO (sven): Deprecate this API. Users can lookup the "worker index" from the
# EnvRunner object directly through `self.worker_index` (besides many other useful
# properties, like `in_evaluation`, `num_env_runners`, etc..).
def foreach_env_runner_with_id(
self,
func: Callable[[int, EnvRunner], T],
func: Union[
Callable[[int, EnvRunner], T],
List[Callable[[int, EnvRunner], T]],
str,
List[str],
],
*,
local_env_runner: bool = True,
healthy_only: bool = True,
remote_worker_ids: List[int] = None,
timeout_seconds: Optional[float] = None,
return_obj_refs: bool = False,
mark_healthy: bool = False,
# Deprecated args.
local_worker=DEPRECATED_VALUE,
) -> List[T]:
"""Calls the given function with each EnvRunner and its ID as its arguments.
Expand Down Expand Up @@ -850,7 +889,9 @@ def foreach_env_runner_with_id(

def foreach_env_runner_async(
self,
func: Union[Callable[[EnvRunner], T], str],
func: Union[
Callable[[EnvRunner], T], List[Callable[[EnvRunner], T]], str, List[str]
],
*,
healthy_only: bool = True,
remote_worker_ids: List[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
parser.set_defaults(
evaluation_num_env_runners=2,
evaluation_interval=1,
evaluation_duration_unit="timesteps",
)


Expand Down

0 comments on commit 02d4a3a

Please sign in to comment.