[JAX] Fixes for CI failures with the latest JAX #6911
lint.yml
on: pull_request
PyTorch C++
19s
PyTorch Python
1m 59s
JAX C++
20s
JAX Python
20s