Replies: 10 comments
-
Hi @imoneoi. You are correct, capsule-heightmap is not implemented (feel free to send over an impl), so no complex terrain aside from capsule-mesh, capsule-box, capsule-plane assuming your end-effector is a capsule. But randomizing over other params should get some form of sim2real working. Check out the domain randomization wrapper pushed to experimental https://github.com/google/brax/blob/main/brax/experimental/tracing/wrappers.py, @cdfreeman-google can provide insight if you have questions about the domain randomization |
Beta Was this translation helpful? Give feedback.
-
Thanks! I have some additional questions about sim2real and domain randomization:
|
Beta Was this translation helpful? Give feedback.
-
|
Beta Was this translation helpful? Give feedback.
-
One extra comment on point number 2: the current domain randomization code is written this way, yes, where the randomness is fixed at environment initialization time. It's fairly straightforward to extend this to have the randomness be part of the environment seed (i.e., so that environment randomness is refreshed after every reset). You would simply need to modify the
|
Beta Was this translation helpful? Give feedback.
-
Thanks! @btaba, as I see in the code, there may be I'm considering writing an impl of BTW, I think things like BVH are very hard to implement in JAX. Maybe currently we can only have fast collision detection on heightmaps, not meshes? |
Beta Was this translation helpful? Give feedback.
-
@imoneoi yes for roughly the reasons you describe capsule-mesh is slow, which is why we'd want capsule-heightmap to check a subset of the heightmap rather than the whole mesh. I think capsule-heightmap would be of higher priority compared to box-heightmap, but feel free to fix box-heightmap as well! What's likely needed in both is a jittable collision check between a line segment and a subset of triangles in the heightmap. RE the viewer: what issue are you seeing with mesh and heightmap in the viz? These are supported here https://github.com/google/brax/blob/main/js/system.js#L251-L257 RE BVH: branching logic is difficult to implement in jax. We started looking into mesh-mesh collisions and how to make those fast in jax, but don't have a fleshed out solution. We recently added jax jittable box-box collisions |
Beta Was this translation helpful? Give feedback.
-
@btaba I'm figuring out a fast jittable capsule-heightmap or box-heightmap collision impl. I think we should identify the local heightmap points and test collision with these. But the fixed shape of JAX may be annoying for finding the local points. Also, can we avoid vmapping the data of heightmap with domain randomization? (10240 envs * 1024 * 1024 really consumes a lot VRAM) For the viewer, sorry I missed that line of code, and thought it wasn't supported before. |
Beta Was this translation helpful? Give feedback.
-
Hey @imoneoi, sorry for the late reply! You could try creating a bounding box for the capsule with worst-case static size, and check the heightmap against that fixed bounding box at each step If you're randomizing the heightmap (so not just one heightmap), I'm not sure how to avoid storing those, unless the heightmap heights are generated on the fly using some function, say f(x, y, key) = jax.random.uniform(key+x+y) – so each env gets a different key and the height is generated on the fly |
Beta Was this translation helpful? Give feedback.
-
Using static size is a good idea. I may also try something like |
Beta Was this translation helpful? Give feedback.
-
Yeah using the same heightmap and randomizing start positions may get rid of that memory blow-up, and likely get you similar domain randomized behavior. But I'm not really in the weeds here so deferring to you; a working capsule-heightmap collision func would already be a great win here! With the implicit function approach, that could take up memory for the worst-case bounding box projected onto the heightmap (depending on the impl and size of things), so likely smaller than 1024x1024 Thanks for taking a look at this this! |
Beta Was this translation helpful? Give feedback.
-
Is Brax suitable for sim2real RL training, and are there any examples?
Also, can we use complex terrain in simulation? There is heightmap in geometries, however, only collision between heightmap and box is implemented. Does that mean we may use only boxes to compose the agent, and a height map for terrain?
Beta Was this translation helpful? Give feedback.
All reactions