diff --git a/test/conftest.py b/test/conftest.py index 595cc63a..b01fe2f9 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -21,26 +21,6 @@ def _getkey(): return _getkey -def clear_backends(): - import jax.lib - from jax._src import dispatch, pjit - from jax._src.lib import xla_bridge as xb - from jax._src.lib import xla_client as xc - from jax._src.lib import xla_extension_version - - xb._clear_backends() - jax.lib.xla_bridge._backends = {} - dispatch.xla_callable.cache_clear() # type: ignore - dispatch.xla_primitive_callable.cache_clear() - pjit._pjit_lower_cached.cache_clear() - if xla_extension_version >= 124: - pjit._cpp_pjit_cache.clear() - xc._xla.PjitFunctionCache.clear_all() - - -clear_backends() # Test that it works - - # Hugely hacky way of reducing memory usage in tests. # JAX can be a little over-happy with its caching; this is especially noticable when # performing tests and therefore doing an unusual amount of compilation etc. @@ -50,7 +30,7 @@ def clear_backends(): def clear_caches(): process = psutil.Process() if process.memory_info().vms > 4 * 2**30: # >4GB memory usage - clear_backends() + jax.clear_backends() for module_name, module in sys.modules.copy().items(): if module_name.startswith("jax"): if module_name not in ["jax.interpreters.partial_eval"]: @@ -58,6 +38,7 @@ def clear_caches(): obj = getattr(module, obj_name) if hasattr(obj, "cache_clear"): try: + print(f"Clearing {obj}") obj.cache_clear() except Exception: pass