JAX port of lenstronomy, for parallelized, GPU accelerated, and differentiable gravitational lensing and image simulations.
Disclaimer: This project is still in an early development phase and serves as a skeleton for someone taking the lead on it :)
The goal of this library is to reimplement lenstronomy functionalities in pure JAX to allow for automatic differentiation, GPU acceleration, and batched computations.
Guiding Principles:
- Strive to be a drop-in replacement for lenstronomy, i.e. provide a close match to the lenstronomy API.
- Each function/feature will be tested against the reference lenstronomy implementation.
- This package will aim to be a subset of lenstronomy (i.e. only contains functions with a reference lenstronomy implementation).
- Implementations should be easy to read and understand.
- Code should be pip installable on any machine, no compilation required.
- Any notable differences between the JAX and reference implementations will be clearly documented.
The following lensing software packages do use JAX-accelerated computing that in part were inspired or made use of lenstronomy functions: