From ea0754306b389f79c67a52b1ddea988a137ae257 Mon Sep 17 00:00:00 2001 From: Zizhao Chen Date: Tue, 21 Jan 2025 22:58:55 -0500 Subject: [PATCH] module level attribute and add tests --- .../_src/dm_control_suite/__init__.py | 9 +-- mujoco_playground/_src/locomotion/__init__.py | 10 ++-- .../_src/manipulation/__init__.py | 9 +-- mujoco_playground/_src/registry.py | 6 +- mujoco_playground/_src/registry_test.py | 58 +++++++++++++++++++ 5 files changed, 77 insertions(+), 15 deletions(-) create mode 100644 mujoco_playground/_src/registry_test.py diff --git a/mujoco_playground/_src/dm_control_suite/__init__.py b/mujoco_playground/_src/dm_control_suite/__init__.py index e624a24..136fdca 100644 --- a/mujoco_playground/_src/dm_control_suite/__init__.py +++ b/mujoco_playground/_src/dm_control_suite/__init__.py @@ -101,7 +101,10 @@ } -ALL = list(_envs.keys()) +def __getattr__(name): + if name == 'ALL': + return list(_envs.keys()) + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") def register_environment( @@ -118,8 +121,6 @@ def register_environment( """ _envs[env_name] = env_class _cfgs[env_name] = cfg_class - if env_name not in ALL: - ALL.append(env_name) def get_default_config(env_name: str) -> config_dict.ConfigDict: @@ -149,6 +150,6 @@ def load( An instance of the environment. """ if env_name not in _envs: - raise ValueError(f"Env '{env_name}' not found. Available envs: {ALL}") + raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}") config = config or get_default_config(env_name) return _envs[env_name](config=config, config_overrides=config_overrides) diff --git a/mujoco_playground/_src/locomotion/__init__.py b/mujoco_playground/_src/locomotion/__init__.py index 476a483..ea08a0e 100644 --- a/mujoco_playground/_src/locomotion/__init__.py +++ b/mujoco_playground/_src/locomotion/__init__.py @@ -112,7 +112,11 @@ "Go1Footstand": go1_randomize.domain_randomize, } -ALL = list(_envs.keys()) + +def __getattr__(name): + if name == "ALL": + return list(_envs.keys()) + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") def register_environment( @@ -129,8 +133,6 @@ def register_environment( """ _envs[env_name] = env_class _cfgs[env_name] = cfg_class - if env_name not in ALL: - ALL.append(env_name) def get_default_config(env_name: str) -> config_dict.ConfigDict: @@ -160,7 +162,7 @@ def load( An instance of the environment. """ if env_name not in _envs: - raise ValueError(f"Env '{env_name}' not found. Available envs: {ALL}") + raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}") config = config or get_default_config(env_name) return _envs[env_name](config=config, config_overrides=config_overrides) diff --git a/mujoco_playground/_src/manipulation/__init__.py b/mujoco_playground/_src/manipulation/__init__.py index a01356a..d60cb01 100644 --- a/mujoco_playground/_src/manipulation/__init__.py +++ b/mujoco_playground/_src/manipulation/__init__.py @@ -58,7 +58,10 @@ "LeapCubeReorient": leap_cube_reorient.domain_randomize, } -ALL = list(_envs.keys()) +def __getattr__(name): + if name == "ALL": + return list(_envs.keys()) + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") def register_environment( @@ -75,8 +78,6 @@ def register_environment( """ _envs[env_name] = env_class _cfgs[env_name] = cfg_class - if env_name not in ALL: - ALL.append(env_name) def get_default_config(env_name: str) -> config_dict.ConfigDict: @@ -106,7 +107,7 @@ def load( An instance of the environment. """ if env_name not in _envs: - raise ValueError(f"Env '{env_name}' not found. Available envs: {ALL}") + raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}") config = config or get_default_config(env_name) return _envs[env_name](config=config, config_overrides=config_overrides) diff --git a/mujoco_playground/_src/registry.py b/mujoco_playground/_src/registry.py index 8b2b740..af06b32 100644 --- a/mujoco_playground/_src/registry.py +++ b/mujoco_playground/_src/registry.py @@ -28,8 +28,6 @@ Callable[[mjx.Model, jax.Array], Tuple[mjx.Model, mjx.Model]] ] -ALL_ENVS = locomotion.ALL + manipulation.ALL + dm_control_suite.ALL - def get_default_config(env_name: str): if env_name in manipulation.ALL: @@ -54,7 +52,9 @@ def load( elif env_name in dm_control_suite.ALL: return dm_control_suite.load(env_name, config, config_overrides) - raise ValueError(f"Env '{env_name}' not found. Available envs: {ALL_ENVS}") + all_envs = manipulation.ALL + locomotion.ALL + dm_control_suite.ALL + + raise ValueError(f"Env '{env_name}' not found. Available envs: {all_envs}") def get_domain_randomizer(env_name: str) -> Optional[DomainRandomizer]: diff --git a/mujoco_playground/_src/registry_test.py b/mujoco_playground/_src/registry_test.py new file mode 100644 index 0000000..4016564 --- /dev/null +++ b/mujoco_playground/_src/registry_test.py @@ -0,0 +1,58 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the registry module.""" + +from absl.testing import absltest +import jax.numpy as jp +from ml_collections import config_dict + +from mujoco_playground._src import registry +from mujoco_playground._src.dm_control_suite import register_environment as register_dm_control_suite +from mujoco_playground._src.locomotion import register_environment as register_locomotion +from mujoco_playground._src.manipulation import register_environment as register_manipulation + + +class RegistryTest(absltest.TestCase): + + def test_new_env(self): + class DemoEnv: + + def __init__(self, config, config_overrides): + pass + + def demo_default_config(): + return config_dict.ConfigDict() + + register_dm_control_suite('DemoEnv', DemoEnv, demo_default_config) + env = registry.load('DemoEnv') + self.assertIsInstance(env, DemoEnv) + config = registry.get_default_config('DemoEnv') + self.assertEqual(config, config_dict.ConfigDict()) + + register_manipulation('DemoEnv', DemoEnv, demo_default_config) + env = registry.load('DemoEnv') + self.assertIsInstance(env, DemoEnv) + config = registry.get_default_config('DemoEnv') + self.assertEqual(config, config_dict.ConfigDict()) + + register_locomotion('DemoEnv', DemoEnv, demo_default_config) + env = registry.load('DemoEnv') + self.assertIsInstance(env, DemoEnv) + config = registry.get_default_config('DemoEnv') + self.assertEqual(config, config_dict.ConfigDict()) + + +if __name__ == '__main__': + absltest.main()