Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 480982351
Change-Id: I3732de7f73ecc5a73dfd8f1ba7d05d29b76e4d81
  • Loading branch information
Brax Team authored and btaba committed Oct 13, 2022
1 parent 2d10e87 commit b3e75f9
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ pip install --upgrade pip
pip install brax
```

You may also install from [Conda](https://docs.conda.io/en/latest/) or [Mamba](https://github.com/mamba-org/mamba):

```
conda install -c conda-forge brax # s/conda/mamba for mamba
```

Alternatively, to install Brax from source, clone this repo, `cd` to it, and then:

```
Expand Down
3 changes: 1 addition & 2 deletions brax/experimental/biggym/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def get_match_env_name(task_name: str, comp1: str, comp2):
return f'match_{task_name}__{comp1}__{comp2}'


def race(component: str, pos: Tuple[float, float, float] = (0, 0, 0),
**component_params):
def race(component: str, pos: Tuple[float, float, float] = (0, 0, 0), **component_params):
return dict(
components=dict(
agent1=dict(
Expand Down
2 changes: 2 additions & 0 deletions brax/physics/colliders.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,12 +999,14 @@ def get(config: config_pb2.Config, body: bodies.Body) -> List[Collider]:
if config.collider_cutoff and len(
bodies_a) > config.collider_cutoff and (
type_a, type_b) in supported_near_neighbors:
# pytype: disable=wrong-arg-types
col_a = cls_a(bodies_a, body)
col_b = cls_b(bodies_b, body)
cull = NearNeighbors(
cls_a(unique_bodies, body), cls_b(unique_bodies, body),
(col_a.body.idx, col_b.body.idx), config.collider_cutoff)
else:
# pytype: disable=wrong-arg-types
cull = Pairs(cls_a(bodies_a, body), cls_b(bodies_b, body))
if b_is_frozen:
collider = OneWayCollider(contact_fn, cull, config)
Expand Down
12 changes: 5 additions & 7 deletions brax/training/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from jax import flatten_util
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import pxla
import jax.numpy as jnp

State = TypeVar('State')
Expand Down Expand Up @@ -267,7 +266,6 @@ def psize(buffer_state):
return jax.pmap(psize, axis_name=axis_name)(buffer_state)[0]


# TODO: Make this multi-host and GDS compatible.
class PjitWrapper(ReplayBuffer[State, Sample]):
"""Wrapper to distribute the buffer on multiple devices with pjit.
Expand All @@ -284,15 +282,15 @@ def __init__(self,
buffer: ReplayBuffer[State, Sample],
mesh: maps.Mesh,
axis_name: str,
batch_partion_spec: Optional[pxla.PartitionSpec] = None):
batch_partition_spec: Optional[pjit.PartitionSpec] = None):
"""Constructor.
Args:
buffer: The buffer to replicate.
mesh: Device mesh for pjitting context.
axis_name: The axis along which the replay buffer data should be
partitionned.
batch_partion_spec: PartitionSpec of the inserted/sampled batch.
batch_partition_spec: PartitionSpec of the inserted/sampled batch.
"""
self._buffer = buffer
self._mesh = mesh
Expand Down Expand Up @@ -320,17 +318,17 @@ def sample(buffer_state: State) -> Tuple[State, Sample]:
def size(buffer_state: State) -> int:
return jnp.sum(jax.vmap(self._buffer.size)(buffer_state))

partition_spec = pxla.PartitionSpec((axis_name,))
partition_spec = pjit.PartitionSpec((axis_name,))
self._partitioned_init = pjit.pjit(
init, in_axis_resources=None, out_axis_resources=partition_spec)
self._partitioned_insert = pjit.pjit(
insert,
in_axis_resources=(partition_spec, batch_partion_spec),
in_axis_resources=(partition_spec, batch_partition_spec),
out_axis_resources=partition_spec)
self._partitioned_sample = pjit.pjit(
sample,
in_axis_resources=partition_spec,
out_axis_resources=(partition_spec, batch_partion_spec))
out_axis_resources=(partition_spec, batch_partition_spec))
# This will return the TOTAL size accross all devices.
self._partitioned_size = pjit.pjit(
size, in_axis_resources=partition_spec, out_axis_resources=None)
Expand Down

0 comments on commit b3e75f9

Please sign in to comment.