Skip to content

Commit

Permalink
Merge pull request #32 from chenzizhao:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718443715
Change-Id: Ibcf69a25c204f4287f03412cb234aac738263ba9
  • Loading branch information
copybara-github committed Jan 22, 2025
2 parents 7f2783a + ea07543 commit 532ecd7
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 9 deletions.
7 changes: 5 additions & 2 deletions mujoco_playground/_src/dm_control_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@
}


ALL = list(_envs.keys())
def __getattr__(name):
if name == 'ALL':
return list(_envs.keys())
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")


def register_environment(
Expand Down Expand Up @@ -147,6 +150,6 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {ALL}")
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)
8 changes: 6 additions & 2 deletions mujoco_playground/_src/locomotion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@
"Go1Footstand": go1_randomize.domain_randomize,
}

ALL = list(_envs.keys())

def __getattr__(name):
if name == "ALL":
return list(_envs.keys())
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")


def register_environment(
Expand Down Expand Up @@ -158,7 +162,7 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {ALL}")
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)

Expand Down
7 changes: 5 additions & 2 deletions mujoco_playground/_src/manipulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@
"LeapCubeReorient": leap_cube_reorient.domain_randomize,
}

ALL = list(_envs.keys())
def __getattr__(name):
if name == "ALL":
return list(_envs.keys())
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")


def register_environment(
Expand Down Expand Up @@ -104,7 +107,7 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {ALL}")
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)

Expand Down
6 changes: 3 additions & 3 deletions mujoco_playground/_src/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
Callable[[mjx.Model, jax.Array], Tuple[mjx.Model, mjx.Model]]
]

ALL_ENVS = locomotion.ALL + manipulation.ALL + dm_control_suite.ALL


def get_default_config(env_name: str):
if env_name in manipulation.ALL:
Expand All @@ -54,7 +52,9 @@ def load(
elif env_name in dm_control_suite.ALL:
return dm_control_suite.load(env_name, config, config_overrides)

raise ValueError(f"Env '{env_name}' not found. Available envs: {ALL_ENVS}")
all_envs = manipulation.ALL + locomotion.ALL + dm_control_suite.ALL

raise ValueError(f"Env '{env_name}' not found. Available envs: {all_envs}")


def get_domain_randomizer(env_name: str) -> Optional[DomainRandomizer]:
Expand Down
58 changes: 58 additions & 0 deletions mujoco_playground/_src/registry_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the registry module."""

from absl.testing import absltest
import jax.numpy as jp
from ml_collections import config_dict

from mujoco_playground._src import registry
from mujoco_playground._src.dm_control_suite import register_environment as register_dm_control_suite
from mujoco_playground._src.locomotion import register_environment as register_locomotion
from mujoco_playground._src.manipulation import register_environment as register_manipulation


class RegistryTest(absltest.TestCase):

def test_new_env(self):
class DemoEnv:

def __init__(self, config, config_overrides):
pass

def demo_default_config():
return config_dict.ConfigDict()

register_dm_control_suite('DemoEnv', DemoEnv, demo_default_config)
env = registry.load('DemoEnv')
self.assertIsInstance(env, DemoEnv)
config = registry.get_default_config('DemoEnv')
self.assertEqual(config, config_dict.ConfigDict())

register_manipulation('DemoEnv', DemoEnv, demo_default_config)
env = registry.load('DemoEnv')
self.assertIsInstance(env, DemoEnv)
config = registry.get_default_config('DemoEnv')
self.assertEqual(config, config_dict.ConfigDict())

register_locomotion('DemoEnv', DemoEnv, demo_default_config)
env = registry.load('DemoEnv')
self.assertIsInstance(env, DemoEnv)
config = registry.get_default_config('DemoEnv')
self.assertEqual(config, config_dict.ConfigDict())


if __name__ == '__main__':
absltest.main()

0 comments on commit 532ecd7

Please sign in to comment.