diff --git a/README.md b/README.md index 0995ca528..52c416494 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,12 @@ pip install --upgrade pip pip install brax ``` +You may also install from [Conda](https://docs.conda.io/en/latest/) or [Mamba](https://github.com/mamba-org/mamba): + +``` +conda install -c conda-forge brax # s/conda/mamba for mamba +``` + Alternatively, to install Brax from source, clone this repo, `cd` to it, and then: ``` diff --git a/brax/experimental/biggym/tasks.py b/brax/experimental/biggym/tasks.py index 246fd05de..25049ea8c 100644 --- a/brax/experimental/biggym/tasks.py +++ b/brax/experimental/biggym/tasks.py @@ -27,8 +27,7 @@ def get_match_env_name(task_name: str, comp1: str, comp2): return f'match_{task_name}__{comp1}__{comp2}' -def race(component: str, pos: Tuple[float, float, float] = (0, 0, 0), - **component_params): +def race(component: str, pos: Tuple[float, float, float] = (0, 0, 0), **component_params): return dict( components=dict( agent1=dict( diff --git a/brax/physics/colliders.py b/brax/physics/colliders.py index 9b1cf5800..04e3e4b01 100644 --- a/brax/physics/colliders.py +++ b/brax/physics/colliders.py @@ -999,12 +999,14 @@ def get(config: config_pb2.Config, body: bodies.Body) -> List[Collider]: if config.collider_cutoff and len( bodies_a) > config.collider_cutoff and ( type_a, type_b) in supported_near_neighbors: + # pytype: disable=wrong-arg-types col_a = cls_a(bodies_a, body) col_b = cls_b(bodies_b, body) cull = NearNeighbors( cls_a(unique_bodies, body), cls_b(unique_bodies, body), (col_a.body.idx, col_b.body.idx), config.collider_cutoff) else: + # pytype: disable=wrong-arg-types cull = Pairs(cls_a(bodies_a, body), cls_b(bodies_b, body)) if b_is_frozen: collider = OneWayCollider(contact_fn, cull, config) diff --git a/brax/training/replay_buffers.py b/brax/training/replay_buffers.py index c63a7687e..a46bea8ef 100644 --- a/brax/training/replay_buffers.py +++ b/brax/training/replay_buffers.py @@ -23,7 +23,6 @@ from jax import flatten_util from jax.experimental import maps from jax.experimental import pjit -from jax.interpreters import pxla import jax.numpy as jnp State = TypeVar('State') @@ -267,7 +266,6 @@ def psize(buffer_state): return jax.pmap(psize, axis_name=axis_name)(buffer_state)[0] -# TODO: Make this multi-host and GDS compatible. class PjitWrapper(ReplayBuffer[State, Sample]): """Wrapper to distribute the buffer on multiple devices with pjit. @@ -284,7 +282,7 @@ def __init__(self, buffer: ReplayBuffer[State, Sample], mesh: maps.Mesh, axis_name: str, - batch_partion_spec: Optional[pxla.PartitionSpec] = None): + batch_partition_spec: Optional[pjit.PartitionSpec] = None): """Constructor. Args: @@ -292,7 +290,7 @@ def __init__(self, mesh: Device mesh for pjitting context. axis_name: The axis along which the replay buffer data should be partitionned. - batch_partion_spec: PartitionSpec of the inserted/sampled batch. + batch_partition_spec: PartitionSpec of the inserted/sampled batch. """ self._buffer = buffer self._mesh = mesh @@ -320,17 +318,17 @@ def sample(buffer_state: State) -> Tuple[State, Sample]: def size(buffer_state: State) -> int: return jnp.sum(jax.vmap(self._buffer.size)(buffer_state)) - partition_spec = pxla.PartitionSpec((axis_name,)) + partition_spec = pjit.PartitionSpec((axis_name,)) self._partitioned_init = pjit.pjit( init, in_axis_resources=None, out_axis_resources=partition_spec) self._partitioned_insert = pjit.pjit( insert, - in_axis_resources=(partition_spec, batch_partion_spec), + in_axis_resources=(partition_spec, batch_partition_spec), out_axis_resources=partition_spec) self._partitioned_sample = pjit.pjit( sample, in_axis_resources=partition_spec, - out_axis_resources=(partition_spec, batch_partion_spec)) + out_axis_resources=(partition_spec, batch_partition_spec)) # This will return the TOTAL size accross all devices. self._partitioned_size = pjit.pjit( size, in_axis_resources=partition_spec, out_axis_resources=None)