From 7c07c5d1a4b699e94550844195e94b78afcc75c9 Mon Sep 17 00:00:00 2001 From: KuoHaoZeng Date: Tue, 2 Jul 2024 17:12:20 -0700 Subject: [PATCH] update task --- allenact/base_abstractions/task.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/allenact/base_abstractions/task.py b/allenact/base_abstractions/task.py index 96ec651b..e91dfce3 100644 --- a/allenact/base_abstractions/task.py +++ b/allenact/base_abstractions/task.py @@ -392,6 +392,7 @@ def __init__( max_steps: int, task_sampler: TaskSampler, task_classes: List[type(Task)], + state_views: List, callback_sensor_suite: Optional[SensorSuite], **kwargs, ) -> None: @@ -410,6 +411,7 @@ def __init__( self.frames = None self.depths = None self.segs = None + self.state_views = state_views # If task_batch_size greater than 1, instantiate the rest of tasks (with task_batch_size set to 1) if self.task_sampler.task_batch_size > 1: @@ -443,6 +445,12 @@ def get_observations(self, **kwargs) -> List[Any]: #-> Dict[str, Any]: seg=self.segs[idx], ) for idx, task in enumerate(self.tasks)] + def update_state_views(self): + for idx, state_view in enumerate(self.state_views): + updated_state_view = self.tasks[0].env.call(state_view) + for idy, task in enumerate(self.tasks): + task.state_views[idx] = updated_state_view[idy] + @property @abc.abstractmethod def action_space(self) -> gym.Space: