Skip to content

Commit

Permalink
Collected task infos before aggregating step results
Browse files Browse the repository at this point in the history
  • Loading branch information
jordis-ai2 committed Jul 9, 2024
1 parent 992629e commit c0797b0
Showing 1 changed file with 29 additions and 15 deletions.
44 changes: 29 additions & 15 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,34 @@ def collect_step_across_all_task_samplers(
su.action_list(action_space, flat_actions)
)

# Save after task completion metrics
if self.task_batch_size == 0:
for step_result in outputs:
if step_result.info is not None:
if COMPLETE_TASK_METRICS_KEY in step_result.info:
self.single_process_metrics.append(
step_result.info[COMPLETE_TASK_METRICS_KEY]
)
del step_result.info[COMPLETE_TASK_METRICS_KEY]
if COMPLETE_TASK_CALLBACK_KEY in step_result.info:
self.single_process_task_callback_data.append(
step_result.info[COMPLETE_TASK_CALLBACK_KEY]
)
del step_result.info[COMPLETE_TASK_CALLBACK_KEY]
else:
for batched_step_result in outputs:
for info in batched_step_result.info:
if COMPLETE_TASK_METRICS_KEY in info:
self.single_process_metrics.append(
info[COMPLETE_TASK_METRICS_KEY]
)
del info[COMPLETE_TASK_METRICS_KEY]
if COMPLETE_TASK_CALLBACK_KEY in info:
self.single_process_task_callback_data.append(
info[COMPLETE_TASK_CALLBACK_KEY]
)
del info[COMPLETE_TASK_CALLBACK_KEY]

rewards: Union[List, torch.Tensor]
observations, rewards, dones, infos = [list(x) for x in zip(*outputs)]

Expand All @@ -689,24 +717,10 @@ def collect_step_across_all_task_samplers(
observations = sum(observations, [])
rewards = sum(rewards, [])
dones = sum(dones, [])
infos = sum(infos, []) # unused
# infos = sum(infos, []) # unused
new_shape = tuple(flat_actions.shape)[:-2] + (flat_actions.shape[-2] * self.task_batch_size, flat_actions.shape[-1] // self.task_batch_size)
flat_actions = flat_actions.view(new_shape)

# Save after task completion metrics
for info in infos:
if info is not None:
if COMPLETE_TASK_METRICS_KEY in info:
self.single_process_metrics.append(
info[COMPLETE_TASK_METRICS_KEY]
)
del info[COMPLETE_TASK_METRICS_KEY]
if COMPLETE_TASK_CALLBACK_KEY in info:
self.single_process_task_callback_data.append(
info[COMPLETE_TASK_CALLBACK_KEY]
)
del info[COMPLETE_TASK_CALLBACK_KEY]

rewards = torch.tensor(
rewards, dtype=torch.float, device=self.device, # type:ignore
)
Expand Down

0 comments on commit c0797b0

Please sign in to comment.