From 0f3adda84f2a2ab55e9d9aaf7311c917518ec25c Mon Sep 17 00:00:00 2001 From: Baruch Tabanpour Date: Wed, 29 Jan 2025 13:02:54 -0800 Subject: [PATCH] Revert back to old observation_size API. PiperOrigin-RevId: 721080636 Change-Id: I09f74527dab8281fea0aa349936130055e13756b --- .../_src/dm_control_suite/dm_control_suite_test.py | 2 +- mujoco_playground/_src/locomotion/locomotion_test.py | 1 + mujoco_playground/_src/manipulation/manipulation_test.py | 1 + mujoco_playground/_src/mjx_env.py | 7 +++++-- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mujoco_playground/_src/dm_control_suite/dm_control_suite_test.py b/mujoco_playground/_src/dm_control_suite/dm_control_suite_test.py index f8ec3ab..124838e 100644 --- a/mujoco_playground/_src/dm_control_suite/dm_control_suite_test.py +++ b/mujoco_playground/_src/dm_control_suite/dm_control_suite_test.py @@ -33,7 +33,7 @@ def test_can_create_all_environments(self, env_name: str) -> None: state = jax.jit(env.reset)(jax.random.PRNGKey(42)) state = jax.jit(env.step)(state, jp.zeros(env.action_size)) self.assertIsNotNone(state) - self.assertEqual(state.obs.shape, env.observation_size) + self.assertEqual(state.obs.shape[0], env.observation_size) self.assertFalse(jp.isnan(state.data.qpos).any()) diff --git a/mujoco_playground/_src/locomotion/locomotion_test.py b/mujoco_playground/_src/locomotion/locomotion_test.py index 971a849..6afd8f8 100644 --- a/mujoco_playground/_src/locomotion/locomotion_test.py +++ b/mujoco_playground/_src/locomotion/locomotion_test.py @@ -36,6 +36,7 @@ def test_can_create_all_environments(self, env_name: str) -> None: state = jax.jit(env.step)(state, jp.zeros(env.action_size)) self.assertIsNotNone(state) obs_shape = jax.tree_util.tree_map(lambda x: x.shape, state.obs) + obs_shape = obs_shape[0] if isinstance(obs_shape, tuple) else obs_shape self.assertEqual(obs_shape, env.observation_size) self.assertFalse(jp.isnan(state.data.qpos).any()) diff --git a/mujoco_playground/_src/manipulation/manipulation_test.py b/mujoco_playground/_src/manipulation/manipulation_test.py index 971aa19..0b52d2e 100644 --- a/mujoco_playground/_src/manipulation/manipulation_test.py +++ b/mujoco_playground/_src/manipulation/manipulation_test.py @@ -36,6 +36,7 @@ def test_can_create_all_environments(self, env_name: str) -> None: state = jax.jit(env.step)(state, jp.zeros(env.action_size)) self.assertIsNotNone(state) obs_shape = jax.tree_util.tree_map(lambda x: x.shape, state.obs) + obs_shape = obs_shape[0] if isinstance(obs_shape, tuple) else obs_shape self.assertEqual(obs_shape, env.observation_size) self.assertFalse(jp.isnan(state.data.qpos).any()) diff --git a/mujoco_playground/_src/mjx_env.py b/mujoco_playground/_src/mjx_env.py index 4871554..4e558df 100644 --- a/mujoco_playground/_src/mjx_env.py +++ b/mujoco_playground/_src/mjx_env.py @@ -270,8 +270,11 @@ def n_substeps(self) -> int: @property def observation_size(self) -> ObservationSize: - out = jax.eval_shape(self.reset, jax.random.PRNGKey(0)) - return jax.tree_util.tree_map(lambda x: x.shape, out.obs) + abstract_state = jax.eval_shape(self.reset, jax.random.PRNGKey(0)) + obs = abstract_state.obs + if isinstance(obs, Mapping): + return jax.tree_util.tree_map(lambda x: x.shape, obs) + return obs.shape[-1] def render( self,