Skip to content

Commit

Permalink
Fix memory leak when using scan (#1469)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Aug 14, 2022
1 parent 589b352 commit 0e63cfa
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions numpyro/ops/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def tree_flatten(self):
# set to None to avoid leaks during tracing by JAX
kwargs["rng_key"] = None
aux_trace[name][key] = kwargs
elif key == "infer":
kwargs = site["infer"].copy()
if "_scan_current_index" in kwargs:
# set to None to avoid leaks during tracing by JAX
kwargs["_scan_current_index"] = None
aux_trace[name][key] = kwargs
else:
aux_trace[name][key] = site[key]
# keep the site order information because in JAX, flatten and unflatten do not preserve
Expand Down

0 comments on commit 0e63cfa

Please sign in to comment.