Skip to content

Commit

Permalink
module level attribute and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzizhao committed Jan 22, 2025
1 parent 44f1e92 commit ea07543
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 15 deletions.
9 changes: 5 additions & 4 deletions mujoco_playground/_src/dm_control_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@
}


ALL = list(_envs.keys())
def __getattr__(name):
if name == 'ALL':
return list(_envs.keys())
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")


def register_environment(
Expand All @@ -118,8 +121,6 @@ def register_environment(
"""
_envs[env_name] = env_class
_cfgs[env_name] = cfg_class
if env_name not in ALL:
ALL.append(env_name)


def get_default_config(env_name: str) -> config_dict.ConfigDict:
Expand Down Expand Up @@ -149,6 +150,6 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {ALL}")
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)
10 changes: 6 additions & 4 deletions mujoco_playground/_src/locomotion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@
"Go1Footstand": go1_randomize.domain_randomize,
}

ALL = list(_envs.keys())

def __getattr__(name):
if name == "ALL":
return list(_envs.keys())
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")


def register_environment(
Expand All @@ -129,8 +133,6 @@ def register_environment(
"""
_envs[env_name] = env_class
_cfgs[env_name] = cfg_class
if env_name not in ALL:
ALL.append(env_name)


def get_default_config(env_name: str) -> config_dict.ConfigDict:
Expand Down Expand Up @@ -160,7 +162,7 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {ALL}")
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)

Expand Down
9 changes: 5 additions & 4 deletions mujoco_playground/_src/manipulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@
"LeapCubeReorient": leap_cube_reorient.domain_randomize,
}

ALL = list(_envs.keys())
def __getattr__(name):
if name == "ALL":
return list(_envs.keys())
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")


def register_environment(
Expand All @@ -75,8 +78,6 @@ def register_environment(
"""
_envs[env_name] = env_class
_cfgs[env_name] = cfg_class
if env_name not in ALL:
ALL.append(env_name)


def get_default_config(env_name: str) -> config_dict.ConfigDict:
Expand Down Expand Up @@ -106,7 +107,7 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {ALL}")
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)

Expand Down
6 changes: 3 additions & 3 deletions mujoco_playground/_src/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
Callable[[mjx.Model, jax.Array], Tuple[mjx.Model, mjx.Model]]
]

ALL_ENVS = locomotion.ALL + manipulation.ALL + dm_control_suite.ALL


def get_default_config(env_name: str):
if env_name in manipulation.ALL:
Expand All @@ -54,7 +52,9 @@ def load(
elif env_name in dm_control_suite.ALL:
return dm_control_suite.load(env_name, config, config_overrides)

raise ValueError(f"Env '{env_name}' not found. Available envs: {ALL_ENVS}")
all_envs = manipulation.ALL + locomotion.ALL + dm_control_suite.ALL

raise ValueError(f"Env '{env_name}' not found. Available envs: {all_envs}")


def get_domain_randomizer(env_name: str) -> Optional[DomainRandomizer]:
Expand Down
58 changes: 58 additions & 0 deletions mujoco_playground/_src/registry_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the registry module."""

from absl.testing import absltest
import jax.numpy as jp
from ml_collections import config_dict

from mujoco_playground._src import registry
from mujoco_playground._src.dm_control_suite import register_environment as register_dm_control_suite
from mujoco_playground._src.locomotion import register_environment as register_locomotion
from mujoco_playground._src.manipulation import register_environment as register_manipulation


class RegistryTest(absltest.TestCase):

def test_new_env(self):
class DemoEnv:

def __init__(self, config, config_overrides):
pass

def demo_default_config():
return config_dict.ConfigDict()

register_dm_control_suite('DemoEnv', DemoEnv, demo_default_config)
env = registry.load('DemoEnv')
self.assertIsInstance(env, DemoEnv)
config = registry.get_default_config('DemoEnv')
self.assertEqual(config, config_dict.ConfigDict())

register_manipulation('DemoEnv', DemoEnv, demo_default_config)
env = registry.load('DemoEnv')
self.assertIsInstance(env, DemoEnv)
config = registry.get_default_config('DemoEnv')
self.assertEqual(config, config_dict.ConfigDict())

register_locomotion('DemoEnv', DemoEnv, demo_default_config)
env = registry.load('DemoEnv')
self.assertIsInstance(env, DemoEnv)
config = registry.get_default_config('DemoEnv')
self.assertEqual(config, config_dict.ConfigDict())


if __name__ == '__main__':
absltest.main()

0 comments on commit ea07543

Please sign in to comment.