Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 527090610
Change-Id: Ia46e191b03384b4a5991f9c56bff36f7bab42bf8
  • Loading branch information
Brax Team authored and erikfrey committed Apr 25, 2023
1 parent 1be8ba6 commit 18392cf
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 18 deletions.
28 changes: 12 additions & 16 deletions brax/envs/wrappers/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from gym import spaces
from gym.vector import utils
import jax
from jax import numpy as jp
import numpy as np


class GymWrapper(gym.Env):
Expand All @@ -44,11 +44,11 @@ def __init__(self,
self.backend = backend
self._state = None

obs_high = jp.inf * jp.ones(self._env.observation_size, dtype='float32')
self.observation_space = spaces.Box(-obs_high, obs_high, dtype='float32')
obs = np.inf * np.ones(self._env.observation_size, dtype='float32')
self.observation_space = spaces.Box(-obs, obs, dtype='float32')

action_high = jp.ones(self._env.action_size, dtype='float32')
self.action_space = spaces.Box(-action_high, action_high, dtype='float32')
action = np.ones(self._env.action_size, dtype='float32')
self.action_space = spaces.Box(-action, action, dtype='float32')

def reset(key):
key1, key2 = jax.random.split(key)
Expand Down Expand Up @@ -111,20 +111,16 @@ def __init__(self,
self.backend = backend
self._state = None

obs_high = jp.inf * jp.ones(self._env.observation_size, dtype='float32')
self.single_observation_space = spaces.Box(
-obs_high, obs_high, dtype='float32')
self.observation_space = utils.batch_space(self.single_observation_space,
self.num_envs)
obs = np.inf * np.ones(self._env.observation_size, dtype='float32')
obs_space = spaces.Box(-obs, obs, dtype='float32')
self.observation_space = utils.batch_space(obs_space, self.num_envs)

action_high = jp.ones(self._env.action_size, dtype='float32')
self.single_action_space = spaces.Box(
-action_high, action_high, dtype='float32')
self.action_space = utils.batch_space(self.single_action_space,
self.num_envs)
action = np.ones(self._env.action_size, dtype='float32')
action_space = spaces.Box(-action, action, dtype='float32')
self.action_space = utils.batch_space(action_space, self.num_envs)

def reset(key):
key1, key2 = jp.random_split(key)
key1, key2 = jax.random.split(key)
state = self._env.reset(key2)
return state, state.obs, key1

Expand Down
2 changes: 1 addition & 1 deletion brax/training/agents/es/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def compute_delta(
Returns:
"""
# NOTE: The trick "len(weights) -> len(weights) * perturbation_std" is
# NOTE - -> len(weights) * perturbation_std" is
# equivalent to tuning the l2_coef.
weights = jnp.reshape(weights, ([population_size] + [1] * (noise.ndim - 1)))
delta = jnp.sum(noise * weights, axis=0) / population_size
Expand Down
2 changes: 1 addition & 1 deletion brax/v1/experimental/composer/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
e.g. equivalent to agent1=(..., action_agents=('agent1',), ...)
agent_groups currently defines which rewards/actions belong to which agent.
observation is the same among all agents (TODO: add optionality).
observation is the same among all agents (TODO -.
"""

from collections import OrderedDict as odict
Expand Down

0 comments on commit 18392cf

Please sign in to comment.