Skip to content

Commit

Permalink
Debugging segfault
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Feb 22, 2023
1 parent d2157bc commit 27c126d
Showing 1 changed file with 2 additions and 21 deletions.
23 changes: 2 additions & 21 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,6 @@ def _getkey():
return _getkey


def clear_backends():
import jax.lib
from jax._src import dispatch, pjit
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version

xb._clear_backends()
jax.lib.xla_bridge._backends = {}
dispatch.xla_callable.cache_clear() # type: ignore
dispatch.xla_primitive_callable.cache_clear()
pjit._pjit_lower_cached.cache_clear()
if xla_extension_version >= 124:
pjit._cpp_pjit_cache.clear()
xc._xla.PjitFunctionCache.clear_all()


clear_backends() # Test that it works


# Hugely hacky way of reducing memory usage in tests.
# JAX can be a little over-happy with its caching; this is especially noticable when
# performing tests and therefore doing an unusual amount of compilation etc.
Expand All @@ -50,14 +30,15 @@ def clear_backends():
def clear_caches():
process = psutil.Process()
if process.memory_info().vms > 4 * 2**30: # >4GB memory usage
clear_backends()
jax.clear_backends()
for module_name, module in sys.modules.copy().items():
if module_name.startswith("jax"):
if module_name not in ["jax.interpreters.partial_eval"]:
for obj_name in dir(module):
obj = getattr(module, obj_name)
if hasattr(obj, "cache_clear"):
try:
print(f"Clearing {obj}")
obj.cache_clear()
except Exception:
pass
Expand Down

0 comments on commit 27c126d

Please sign in to comment.