Alex Lin Wang 王帅

RL with JAX

February 26, 2026

Recently I've been working on RL written in JAX and jax.lax.scan for speedups in different policy gradient architectures. Thought it may be useful to share this list of super useful resources.

RL Resources

Advice from a Smart Guy from My Lab

  • If you haven't already, read Costa Jiang's blog post about PPO implementation details.
  • Add metrics. You should track your mean ratio and the entropy in particular. You will almost certainly find the ratio spikes and entropy is low before your crash.
  • Make sure that the ratio is 1 at the first gradient step of every new set of data.
  • Unit test your advantage estimation. Pull out the function, put in some mock values and hand calculate advantages and check they're right. Make sure you test everything including masking etc.
  • Unit test your loss function. Do the same thing as the advantages but with this more complicated one. Have the logprobs of the two policies as input.
  • Compare with a reference implementation such as CleanRL, see if there's anything different.
  • If all this happens you don't have a common bug. Try playing with hyperparams. Reduce the learning rate for example. 3e-4 is a good bet with Adam. Reducing beta2 and lowering the GAE lambda can also help.

Questions or feedback? Feel free to reach out!