From 18392cf0fda1d1c7c21b971fbdb5327656504c7f Mon Sep 17 00:00:00 2001 From: Brax Team Date: Tue, 25 Apr 2023 15:23:58 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 527090610 Change-Id: Ia46e191b03384b4a5991f9c56bff36f7bab42bf8 --- brax/envs/wrappers/gym.py | 28 +++++++++----------- brax/training/agents/es/train.py | 2 +- brax/v1/experimental/composer/agent_utils.py | 2 +- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/brax/envs/wrappers/gym.py b/brax/envs/wrappers/gym.py index a07972f43..c0fa7d942 100644 --- a/brax/envs/wrappers/gym.py +++ b/brax/envs/wrappers/gym.py @@ -21,7 +21,7 @@ from gym import spaces from gym.vector import utils import jax -from jax import numpy as jp +import numpy as np class GymWrapper(gym.Env): @@ -44,11 +44,11 @@ def __init__(self, self.backend = backend self._state = None - obs_high = jp.inf * jp.ones(self._env.observation_size, dtype='float32') - self.observation_space = spaces.Box(-obs_high, obs_high, dtype='float32') + obs = np.inf * np.ones(self._env.observation_size, dtype='float32') + self.observation_space = spaces.Box(-obs, obs, dtype='float32') - action_high = jp.ones(self._env.action_size, dtype='float32') - self.action_space = spaces.Box(-action_high, action_high, dtype='float32') + action = np.ones(self._env.action_size, dtype='float32') + self.action_space = spaces.Box(-action, action, dtype='float32') def reset(key): key1, key2 = jax.random.split(key) @@ -111,20 +111,16 @@ def __init__(self, self.backend = backend self._state = None - obs_high = jp.inf * jp.ones(self._env.observation_size, dtype='float32') - self.single_observation_space = spaces.Box( - -obs_high, obs_high, dtype='float32') - self.observation_space = utils.batch_space(self.single_observation_space, - self.num_envs) + obs = np.inf * np.ones(self._env.observation_size, dtype='float32') + obs_space = spaces.Box(-obs, obs, dtype='float32') + self.observation_space = utils.batch_space(obs_space, self.num_envs) - action_high = jp.ones(self._env.action_size, dtype='float32') - self.single_action_space = spaces.Box( - -action_high, action_high, dtype='float32') - self.action_space = utils.batch_space(self.single_action_space, - self.num_envs) + action = np.ones(self._env.action_size, dtype='float32') + action_space = spaces.Box(-action, action, dtype='float32') + self.action_space = utils.batch_space(action_space, self.num_envs) def reset(key): - key1, key2 = jp.random_split(key) + key1, key2 = jax.random.split(key) state = self._env.reset(key2) return state, state.obs, key1 diff --git a/brax/training/agents/es/train.py b/brax/training/agents/es/train.py index 3b9fbeb4a..ef7d4be21 100644 --- a/brax/training/agents/es/train.py +++ b/brax/training/agents/es/train.py @@ -200,7 +200,7 @@ def compute_delta( Returns: """ - # NOTE: The trick "len(weights) -> len(weights) * perturbation_std" is + # NOTE - -> len(weights) * perturbation_std" is # equivalent to tuning the l2_coef. weights = jnp.reshape(weights, ([population_size] + [1] * (noise.ndim - 1))) delta = jnp.sum(noise * weights, axis=0) / population_size diff --git a/brax/v1/experimental/composer/agent_utils.py b/brax/v1/experimental/composer/agent_utils.py index 2d4ba49bb..ebefc2713 100644 --- a/brax/v1/experimental/composer/agent_utils.py +++ b/brax/v1/experimental/composer/agent_utils.py @@ -33,7 +33,7 @@ e.g. equivalent to agent1=(..., action_agents=('agent1',), ...) agent_groups currently defines which rewards/actions belong to which agent. -observation is the same among all agents (TODO: add optionality). +observation is the same among all agents (TODO -. """ from collections import OrderedDict as odict