Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch to mujoco env #381

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions allenact/base_abstractions/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def __init__(
max_steps: int,
task_sampler: TaskSampler,
task_classes: List[type(Task)],
state_views: List,
callback_sensor_suite: Optional[SensorSuite],
**kwargs,
) -> None:
Expand All @@ -407,6 +408,11 @@ def __init__(
self.tasks = [task_classes[0](env=env, sensors=sensors, task_info=task_info, max_steps=max_steps, batch_index=0, **kwargs)]
self.tasks[0].batch_index = 0

self.frames = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see why you're doing this, but I guess I had a thought on overloading the render() command, so that it produces all the observations we might want to extract from the environment (including, for example, also expert actions). I'm not sure we should be explicit at this level about which type of data are rendered, but I'm happy to discuss it.

self.depths = None
self.segs = None
self.state_views = state_views

# If task_batch_size greater than 1, instantiate the rest of tasks (with task_batch_size set to 1)
if self.task_sampler.task_batch_size > 1:
for it in range(1, self.task_sampler.task_batch_size):
Expand All @@ -428,10 +434,22 @@ def observation_space(self):

def get_observations(self, **kwargs) -> List[Any]: #-> Dict[str, Any]:
# Render all tasks in batch
self.tasks[0].env.render() # assume this is stored locally in the env class
self.frames = self.render("rgb")
self.depths = self.render("depth")
self.segs = self.render("seg")

# return {"batch_observations": [task.get_observations() for task in self.tasks]}
return [task.get_observations() for task in self.tasks]
return [task.get_observations(
frame=self.frames[idx],
depth=self.depths[idx],
seg=self.segs[idx],
) for idx, task in enumerate(self.tasks)]

def update_state_views(self):
for idx, state_view in enumerate(self.state_views):
updated_state_view = self.tasks[0].env.call(state_view)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I introduced this notation, but we might prefer to use the env reference in the task sampler rather than that in task 0? I was not sure of what seems cleaner, but I'm currently inclined for the task sampler reference as the better option. Again, happy to discuss what seems best :)

for idy, task in enumerate(self.tasks):
task.state_views[idx] = updated_state_view[idy]

@property
@abc.abstractmethod
Expand All @@ -456,7 +474,7 @@ def render(self, mode: str = "rgb", *args, **kwargs) -> np.ndarray:

An numpy array corresponding to the requested render.
"""
raise NotImplementedError()
return self.tasks[0].env.render(mode=mode, *args, **kwargs)

def step(self, action: Any) -> RLStepResult:
srs = self._step(action=action)
Expand Down
31 changes: 18 additions & 13 deletions tests/make_it_batch/batch_ai2thor_controller.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
from typing import List
from ai2thor.controller import Controller
from allenact_plugins.ithor_plugin.ithor_environment import IThorEnvironment


class BatchController:
def __init__(
self,
task_batch_size: int,
**kwargs,
self,
task_batch_size: int,
**kwargs,
):
self.task_batch_size = task_batch_size
self.controllers = [Controller(**kwargs) for _ in range(task_batch_size)]
self.controllers = [IThorEnvironment(**kwargs) for _ in range(task_batch_size)]
self._frames = []

def step(self, actions: List[str]):
assert len(actions) == self.task_batch_size
for controller, action in zip(self.controllers, actions):
controller.step(action)
controller.step(action=action if action != "End" else "Pass")
self._frames = []
return self.batch_last_event()

def get_agent_location(self):
return None

def reset(
self,
idx: int,
scene: str,
self,
idx: int,
scene: str,
):
self.controllers[idx].reset(scene)

def batch_reset(
self,
scenes: List[str],
self,
scenes: List[str],
):
for controller, scene in zip(self.controllers, scenes):
controller.reset(scene)
Expand All @@ -42,7 +48,6 @@ def batch_last_event(self):
return [controller.last_event for controller in self.controllers]

def render(self):
frames = []
assert len(self._frames) == 0
for controller in self.controllers:
frames.append(controller.last_event.frame)
return frames
self._frames.append(controller.last_event.frame)
49 changes: 1 addition & 48 deletions tests/make_it_batch/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,55 +31,8 @@
from allenact_plugins.ithor_plugin.ithor_environment import IThorEnvironment
from allenact.base_abstractions.misc import RLStepResult

from tests.make_it_batch.batch_ai2thor_controller import BatchController

class BatchController:
def __init__(
self,
task_batch_size: int,
**kwargs,
):
self.task_batch_size = task_batch_size
self.controllers = [IThorEnvironment(**kwargs) for _ in range(task_batch_size)]
self._frames = []

def step(self, actions: List[str]):
assert len(actions) == self.task_batch_size
for controller, action in zip(self.controllers, actions):
controller.step(action=action if action != "End" else "Pass")
self._frames = []
return self.batch_last_event()

def get_agent_location(self):
return None

def reset(
self,
idx: int,
scene: str,
):
self.controllers[idx].reset(scene)

def batch_reset(
self,
scenes: List[str],
):
for controller, scene in zip(self.controllers, scenes):
controller.reset(scene)

def stop(self):
for controller in self.controllers:
controller.stop()

def last_event(self, idx: int):
return self.controllers[idx].last_event

def batch_last_event(self):
return [controller.last_event for controller in self.controllers]

def render(self):
assert len(self._frames) == 0
for controller in self.controllers:
self._frames.append(controller.last_event.frame)


class BatchableObjectNaviThorGridTask(ObjectNaviThorGridTask):
Expand Down
Loading