In this repository, we present a minimal JAX implementation of the research paper NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis by Ben Mildenhall et. al. using the Keras 3 API with JAX backend.
Here we do the following:
- Port the existing NeRF Keras tutorial (in TensorFlow backend) from Keras-2 to Keras-3 ✨
- Utilise JAX as a backend in place of TensorFlow
- Achieve a 4X speed-up in training compared to the TensorFlow implementation
- Completely stateless API design
To get started you can directly open the notebooks/nerf.ipynb notebook or get started with train.py
.
If anyone is interested in going deeper into NeRF, we have built a 3-part blog series at PyImageSearch.
- NeRF repository: The official repository for NeRF.
- NeRF paper: The paper on NeRF.
- Manim Repository: We have used manim to build all the animations.
- Mathworks: Mathworks for the camera calibration article.
- Mathew's video: A great video on NeRF.