Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wandb support, callback func for PipelineStage, and cache handling #382

Merged
merged 26 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 149 additions & 84 deletions allenact/algorithms/onpolicy_sync/engine.py

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions allenact/algorithms/onpolicy_sync/losses/a2cacktr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implementation of A2C and ACKTR losses."""

from typing import cast, Tuple, Dict, Optional

import torch
Expand Down Expand Up @@ -99,7 +100,9 @@ def loss( # type: ignore
**kwargs,
):
losses_per_step = self.loss_per_step(
step_count=step_count, batch=batch, actor_critic_output=actor_critic_output,
step_count=step_count,
batch=batch,
actor_critic_output=actor_critic_output,
)
losses = {
key: (loss.mean(), weight)
Expand Down Expand Up @@ -169,4 +172,7 @@ def __init__(
)


A2CConfig = dict(value_loss_coef=0.5, entropy_coef=0.01,)
A2CConfig = dict(
value_loss_coef=0.5,
entropy_coef=0.01,
)
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,6 @@ def loss( # type: ignore
torch.log((probs_tensor * expert_group_actions_mask).sum(-1))
).mean()

return total_loss, {"grouped_action_cross_entropy": total_loss.item(),}
return total_loss, {
"grouped_action_cross_entropy": total_loss.item(),
}
12 changes: 8 additions & 4 deletions allenact/algorithms/onpolicy_sync/losses/imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def loss( # type: ignore
ready_actions[group_name] = expert_action

current_loss, expert_successes = self.group_loss(
cd, expert_action, expert_action_masks,
cd,
expert_action,
expert_action_masks,
)

should_report_loss = (
Expand Down Expand Up @@ -204,7 +206,9 @@ def loss( # type: ignore
)
return (
total_loss,
{"expert_cross_entropy": total_loss.item(), **losses}
if should_report_loss
else {},
(
{"expert_cross_entropy": total_loss.item(), **losses}
if should_report_loss
else {}
),
)
28 changes: 17 additions & 11 deletions allenact/algorithms/onpolicy_sync/losses/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,17 @@ def add_trailing_dims(t: torch.Tensor):
"action": (action_loss, None),
"entropy": (dist_entropy.mul_(-1.0), self.entropy_coef), # type: ignore
},
{
"ratio": ratio,
"ratio_clamped": clamped_ratio,
"ratio_used": torch.where(
cast(torch.Tensor, use_clamped), clamped_ratio, ratio
),
}
if self.show_ratios
else {},
(
{
"ratio": ratio,
"ratio_clamped": clamped_ratio,
"ratio_used": torch.where(
cast(torch.Tensor, use_clamped), clamped_ratio, ratio
),
}
if self.show_ratios
else {}
),
)

def loss( # type: ignore
Expand All @@ -135,7 +137,9 @@ def loss( # type: ignore
**kwargs
):
losses_per_step, ratio_info = self.loss_per_step(
step_count=step_count, batch=batch, actor_critic_output=actor_critic_output,
step_count=step_count,
batch=batch,
actor_critic_output=actor_critic_output,
)
losses = {
key: (loss.mean(), weight)
Expand Down Expand Up @@ -210,7 +214,9 @@ def loss( # type: ignore

return (
value_loss,
{"value": value_loss.item(),},
{
"value": value_loss.item(),
},
)


Expand Down
98 changes: 62 additions & 36 deletions allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Defines the reinforcement learning `OnPolicyRunner`."""

import copy
import enum
import glob
Expand Down Expand Up @@ -44,6 +45,7 @@
ScalarMeanTracker,
set_deterministic_cudnn,
set_seed,
download_checkpoint_from_wandb,
)
from allenact.utils.misc_utils import (
NumpyJSONEncoder,
Expand Down Expand Up @@ -542,9 +544,9 @@ def start_train(
config=self.config,
callback_sensors=self._get_callback_sensors,
results_queue=self.queues["results"],
checkpoints_queue=self.queues["checkpoints"]
if self.running_validation
else None,
checkpoints_queue=(
self.queues["checkpoints"] if self.running_validation else None
),
checkpoints_dir=self.checkpoint_dir(),
seed=self.seed,
deterministic_cudnn=self.deterministic_cudnn,
Expand All @@ -555,9 +557,9 @@ def start_train(
distributed_port=distributed_port,
max_sampler_processes_per_worker=max_sampler_processes_per_worker,
save_ckpt_after_every_pipeline_stage=save_ckpt_after_every_pipeline_stage,
initial_model_state_dict=initial_model_state_dict
if model_hash is None
else model_hash,
initial_model_state_dict=(
initial_model_state_dict if model_hash is None else model_hash
),
first_local_worker_id=worker_ids[0],
distributed_preemption_threshold=self.distributed_preemption_threshold,
valid_on_initial_weights=valid_on_initial_weights,
Expand Down Expand Up @@ -782,9 +784,11 @@ def checkpoint_dir(
self, start_time_str: Optional[str] = None, create_if_none: bool = True
):
path_parts = [
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag),
(
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag)
),
start_time_str or self.local_start_time_str,
]
if self.save_dir_fmt == SaveDirFormat.NESTED:
Expand Down Expand Up @@ -816,9 +820,11 @@ def log_writer_path(self, start_time_str: str) -> str:
)
path = os.path.join(
self.output_dir,
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag),
(
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag)
),
start_time_str,
"train_tb",
)
Expand All @@ -827,9 +833,11 @@ def log_writer_path(self, start_time_str: str) -> str:
path = os.path.join(
self.output_dir,
"tb",
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag),
(
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag)
),
start_time_str,
)
if self.mode == TEST_MODE_STR:
Expand All @@ -850,19 +858,23 @@ def metric_path(self, start_time_str: str) -> str:
return os.path.join(
self.output_dir,
"metrics",
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag),
(
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag)
),
start_time_str,
)
else:
raise NotImplementedError

def save_project_state(self):
path_parts = [
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag),
(
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag)
),
self.local_start_time_str,
]
if self.save_dir_fmt == SaveDirFormat.NESTED:
Expand Down Expand Up @@ -1091,16 +1103,23 @@ def update_keys_metric(
f" AllenAct, please report this issue at https://github.com/allenai/allenact/issues."
)
else:
scalar_name_to_total_storage_experience[
scalar_name
] = total_exp_for_storage
scalar_name_to_total_experiences_key[
scalar_name
] = storage_uuid_to_total_experiences_key[storage_uuid]
scalar_name_to_total_storage_experience[scalar_name] = (
total_exp_for_storage
)
scalar_name_to_total_experiences_key[scalar_name] = (
storage_uuid_to_total_experiences_key[storage_uuid]
)

assert all_equal(
checkpoint_file_name
), f"All {mode} logging packages must have the same checkpoint_file_name."
if any(checkpoint_file_name):
ckpt_to_store = None
for ckpt in checkpoint_file_name:
if ckpt is not None:
ckpt_to_store = ckpt
assert ckpt_to_store is not None
checkpoint_file_name = [ckpt_to_store]
# assert all_equal(
# checkpoint_file_name
# ), f"All {mode} logging packages must have the same checkpoint_file_name."

message = [
f"{mode.upper()}: {training_steps} rollout steps ({pkgs[0].storage_uuid_to_total_experiences})"
Expand Down Expand Up @@ -1156,9 +1175,9 @@ def update_keys_metric(
stage_component_uuid,
)
callback_metric_means[approx_eps_key] = eps
scalar_name_to_total_experiences_key[
approx_eps_key
] = storage_uuid_to_total_experiences_key[storage_uuid]
scalar_name_to_total_experiences_key[approx_eps_key] = (
storage_uuid_to_total_experiences_key[storage_uuid]
)

if log_writer is not None:
log_writer.add_scalar(
Expand Down Expand Up @@ -1194,6 +1213,7 @@ def update_keys_metric(
metrics=metric_dicts_list,
metric_means=callback_metric_means,
step=training_steps,
checkpoint_file_name=checkpoint_file_name[0],
tasks_data=tasks_callback_data,
scalar_name_to_total_experiences_key=scalar_name_to_total_experiences_key,
)
Expand Down Expand Up @@ -1358,9 +1378,11 @@ def log_and_close(
self.process_valid_package(
log_writer=log_writer,
pkg=package,
all_results=eval_results
if self._collect_valid_results
else None,
all_results=(
eval_results
if self._collect_valid_results
else None
),
)

if metrics_file is not None:
Expand Down Expand Up @@ -1479,6 +1501,10 @@ def get_checkpoint_files(
checkpoint_path_dir_or_pattern: str,
approx_ckpt_step_interval: Optional[int] = None,
):
if "wandb://" == checkpoint_path_dir_or_pattern[:8]:
eval_dir = "wandb_ckpts_to_eval/{}".format(self.local_start_time_str)
os.makedirs(eval_dir, exist_ok=True)
return download_checkpoint_from_wandb(checkpoint_path_dir_or_pattern, eval_dir, only_allow_one_ckpt=False)

if os.path.isdir(checkpoint_path_dir_or_pattern):
# The fragment is a path to a directory, lets use this directory
Expand Down
27 changes: 20 additions & 7 deletions allenact/algorithms/onpolicy_sync/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def empty(self) -> bool:
class MiniBatchStorageMixin(abc.ABC):
@abc.abstractmethod
def batched_experience_generator(
self, num_mini_batch: int,
self,
num_mini_batch: int,
) -> Generator[Dict[str, Any], None, None]:
raise NotImplementedError

Expand Down Expand Up @@ -183,7 +184,8 @@ def initialize(
self.action_space = action_space

self.memory_first_last: Memory = self.create_memory(
spec=self.memory_specification, num_samplers=num_samplers,
spec=self.memory_specification,
num_samplers=num_samplers,
).to(self.device)
for key in self.memory_specification:
self.flattened_to_unflattened["memory"][key] = [key]
Expand Down Expand Up @@ -249,7 +251,10 @@ def observations(self) -> Memory:
return self._observations_full.slice(dim=0, start=0, stop=self.step + 1)

@staticmethod
def create_memory(spec: Optional[FullMemorySpecType], num_samplers: int,) -> Memory:
def create_memory(
spec: Optional[FullMemorySpecType],
num_samplers: int,
) -> Memory:
if spec is None:
return Memory()

Expand Down Expand Up @@ -290,7 +295,9 @@ def to(self, device: torch.device):
self.device = device

def insert_observations(
self, observations: ObservationType, time_step: int,
self,
observations: ObservationType,
time_step: int,
):
self.insert_tensors(
storage=self._observations_full,
Expand All @@ -300,7 +307,9 @@ def insert_observations(
)

def insert_memory(
self, memory: Optional[Memory], time_step: int,
self,
memory: Optional[Memory],
time_step: int,
):
if memory is None:
assert len(self.memory_first_last) == 0
Expand Down Expand Up @@ -519,7 +528,10 @@ def before_updates(
):
assert len(kwargs) == 0
self.compute_returns(
next_value=next_value, use_gae=use_gae, gamma=gamma, tau=tau,
next_value=next_value,
use_gae=use_gae,
gamma=gamma,
tau=tau,
)

self._advantages = self.returns[:-1] - self.value_preds[:-1]
Expand Down Expand Up @@ -587,7 +599,8 @@ def compute_returns(
)

def batched_experience_generator(
self, num_mini_batch: int,
self,
num_mini_batch: int,
):
assert self._before_update_called, (
"self._before_update_called() must be called before"
Expand Down
Loading
Loading