Skip to content

Commit

Permalink
Merge pull request #233 from patrick-kidger/experimental-test-fix
Browse files Browse the repository at this point in the history
Testing a possible upstream JAX fix
  • Loading branch information
patrick-kidger authored Feb 22, 2023
2 parents 9b89e1c + 45da0eb commit d2157bc
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,26 @@ def _getkey():
return _getkey


def clear_backends():
import jax.lib
from jax._src import dispatch, pjit
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version

xb._clear_backends()
jax.lib.xla_bridge._backends = {}
dispatch.xla_callable.cache_clear() # type: ignore
dispatch.xla_primitive_callable.cache_clear()
pjit._pjit_lower_cached.cache_clear()
if xla_extension_version >= 124:
pjit._cpp_pjit_cache.clear()
xc._xla.PjitFunctionCache.clear_all()


clear_backends() # Test that it works


# Hugely hacky way of reducing memory usage in tests.
# JAX can be a little over-happy with its caching; this is especially noticable when
# performing tests and therefore doing an unusual amount of compilation etc.
Expand All @@ -30,7 +50,7 @@ def _getkey():
def clear_caches():
process = psutil.Process()
if process.memory_info().vms > 4 * 2**30: # >4GB memory usage
jax.clear_backends()
clear_backends()
for module_name, module in sys.modules.copy().items():
if module_name.startswith("jax"):
if module_name not in ["jax.interpreters.partial_eval"]:
Expand Down

0 comments on commit d2157bc

Please sign in to comment.