RL with JAX
February 26, 2026
Recently I've been working on RL written in JAX and
jax.lax.scan for speedups in different policy gradient
architectures. Thought it may be useful to share this list of super useful
resources.
- RL Debugging Guide
- PPO Implementation Details (ICLR Blog Track)
- Why is Machine Learning Hard?
- An Opinionated Guide to ML Research
- Spinning Up in Deep RL (OpenAI)
- A Very Visual RL Guide
- Lighter alternative to Sutton & Barto
RL Resources
- Trivial reference envs for tests
- Test examples
- Hyperparams give crazy variance
- Hyperparam optimization (small)
- Hyperparam optimization (big)
- Super important reward shaping paper
- All common RL algos
- RL implementations in JAX (PureJaxRL)
- MARL implementations in JAX (JaxMARL)
- Interesting applications:
Advice from a Smart Guy from My Lab
- If you haven't already, read Costa Jiang's blog post about PPO implementation details.
- Add metrics. You should track your mean ratio and the entropy in particular. You will almost certainly find the ratio spikes and entropy is low before your crash.
- Make sure that the ratio is 1 at the first gradient step of every new set of data.
- Unit test your advantage estimation. Pull out the function, put in some mock values and hand calculate advantages and check they're right. Make sure you test everything including masking etc.
- Unit test your loss function. Do the same thing as the advantages but with this more complicated one. Have the logprobs of the two policies as input.
- Compare with a reference implementation such as CleanRL, see if there's anything different.
- If all this happens you don't have a common bug. Try playing with
hyperparams. Reduce the learning rate for example.
3e-4is a good bet with Adam. Reducingbeta2and lowering the GAE lambda can also help.
Questions or feedback? Feel free to reach out!