-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
switch to mujoco env #381
switch to mujoco env #381
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -392,6 +392,7 @@ def __init__( | |
max_steps: int, | ||
task_sampler: TaskSampler, | ||
task_classes: List[type(Task)], | ||
state_views: List, | ||
callback_sensor_suite: Optional[SensorSuite], | ||
**kwargs, | ||
) -> None: | ||
|
@@ -407,6 +408,11 @@ def __init__( | |
self.tasks = [task_classes[0](env=env, sensors=sensors, task_info=task_info, max_steps=max_steps, batch_index=0, **kwargs)] | ||
self.tasks[0].batch_index = 0 | ||
|
||
self.frames = None | ||
self.depths = None | ||
self.segs = None | ||
self.state_views = state_views | ||
|
||
# If task_batch_size greater than 1, instantiate the rest of tasks (with task_batch_size set to 1) | ||
if self.task_sampler.task_batch_size > 1: | ||
for it in range(1, self.task_sampler.task_batch_size): | ||
|
@@ -428,10 +434,22 @@ def observation_space(self): | |
|
||
def get_observations(self, **kwargs) -> List[Any]: #-> Dict[str, Any]: | ||
# Render all tasks in batch | ||
self.tasks[0].env.render() # assume this is stored locally in the env class | ||
self.frames = self.render("rgb") | ||
self.depths = self.render("depth") | ||
self.segs = self.render("seg") | ||
|
||
# return {"batch_observations": [task.get_observations() for task in self.tasks]} | ||
return [task.get_observations() for task in self.tasks] | ||
return [task.get_observations( | ||
frame=self.frames[idx], | ||
depth=self.depths[idx], | ||
seg=self.segs[idx], | ||
) for idx, task in enumerate(self.tasks)] | ||
|
||
def update_state_views(self): | ||
for idx, state_view in enumerate(self.state_views): | ||
updated_state_view = self.tasks[0].env.call(state_view) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know I introduced this notation, but we might prefer to use the env reference in the task sampler rather than that in task 0? I was not sure of what seems cleaner, but I'm currently inclined for the task sampler reference as the better option. Again, happy to discuss what seems best :) |
||
for idy, task in enumerate(self.tasks): | ||
task.state_views[idx] = updated_state_view[idy] | ||
|
||
@property | ||
@abc.abstractmethod | ||
|
@@ -456,7 +474,7 @@ def render(self, mode: str = "rgb", *args, **kwargs) -> np.ndarray: | |
|
||
An numpy array corresponding to the requested render. | ||
""" | ||
raise NotImplementedError() | ||
return self.tasks[0].env.render(mode=mode, *args, **kwargs) | ||
|
||
def step(self, action: Any) -> RLStepResult: | ||
srs = self._step(action=action) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see why you're doing this, but I guess I had a thought on overloading the render() command, so that it produces all the observations we might want to extract from the environment (including, for example, also expert actions). I'm not sure we should be explicit at this level about which type of data are rendered, but I'm happy to discuss it.