jax==0.5.0
optax==0.2.4
numpy
matplotlib
