Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pathfinder: allow mode="JAX" for pytensor backend compiler #425

Open
aphc14 opened this issue Feb 13, 2025 · 0 comments
Open

Pathfinder: allow mode="JAX" for pytensor backend compiler #425

aphc14 opened this issue Feb 13, 2025 · 0 comments

Comments

@aphc14
Copy link
Contributor

aphc14 commented Feb 13, 2025

Would like to extend the pytensor backend of Pathfinder to compile using JAX by setting compile_kwargs=dict(mode="JAX") inpmx.fit. Not yet entirely sure what the speed advantage (if any) there is. However, I think the solution to the problem below might not be too difficult.

A required fix may be to implement JAX conversion for the LogLike operator below. (The reason for having the LogLike Op was to vectorise an existing compiled model.logp() function which takes in a flattened array of the model parameters).

class LogLike(Op):
"""
Op that computes the densities using vectorised operations.
"""
__props__ = ("logp_func",)
def __init__(self, logp_func: Callable):
self.logp_func = logp_func
super().__init__()
def make_node(self, inputs):
inputs = pt.as_tensor(inputs)
outputs = pt.tensor(dtype="float64", shape=(None, None))
return Apply(self, [inputs], [outputs])
def perform(self, node: Apply, inputs, outputs) -> None:
phi = inputs[0]
logP = np.apply_along_axis(self.logp_func, axis=-1, arr=phi)
# replace nan with -inf since np.argmax will return the first index at nan
mask = np.isnan(logP) | np.isinf(logP)
if np.all(mask):
raise PathInvalidLogP()
outputs[0][0] = np.where(mask, -np.inf, logP)

Minimum working example:

def eight_schools_model():
    J = 8
    y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
    sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

    with pm.Model() as model:
        mu = pm.Normal("mu", mu=0.0, sigma=10.0)
        tau = pm.HalfCauchy("tau", 5.0)

        theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
        obs = pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y)

    return model

model = eight_schools_model()

with model:
    idata = pmx.fit(
        method="pathfinder",
        num_paths=20,
        jitter=12.0,
        random_seed=41,
        inference_backend="pymc",
        compile_kwargs=dict(mode="JAX"),    # <--- enable JAX mode
    )

Output:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[4], line 2
      1 with model:
----> 2     idata = pmx.fit(
      3         method="pathfinder",
      4         num_paths=20,
      5         jitter=12.0,
      6         random_seed=41,
      7         inference_backend="pymc",
      8         compile_kwargs=dict(mode="JAX"),
      9     )

File ~/projects/pymc-devs/pymc-extras/pymc_extras/inference/fit.py:35, in fit(method, **kwargs)
     32 if method == "pathfinder":
     33     from pymc_extras.inference.pathfinder import fit_pathfinder
---> 35     return fit_pathfinder(**kwargs)
     37 if method == "laplace":
     38     from pymc_extras.inference.laplace import fit_laplace

File ~/projects/pymc-devs/pymc-extras/pymc_extras/inference/pathfinder/pathfinder.py:1685, in fit_pathfinder(model, num_paths, num_draws, num_draws_per_path, maxcor, maxiter, ftol, gtol, maxls, num_elbo_draws, jitter, epsilon, importance_sampling, progressbar, concurrent, random_seed, postprocessing_backend, inference_backend, pathfinder_kwargs, compile_kwargs)
   1682     maxcor = max(maxcor, 5)
   1684 if inference_backend == "pymc":
-> 1685     mp_result = multipath_pathfinder(
   1686         model,
   1687         num_paths=num_paths,
   1688         num_draws=num_draws,
   1689         num_draws_per_path=num_draws_per_path,
   1690         maxcor=maxcor,
   1691         maxiter=maxiter,
   1692         ftol=ftol,
   1693         gtol=gtol,
   1694         maxls=maxls,
   1695         num_elbo_draws=num_elbo_draws,
   1696         jitter=jitter,
   1697         epsilon=epsilon,
   1698         importance_sampling=importance_sampling,
   1699         progressbar=progressbar,
   1700         concurrent=concurrent,
   1701         random_seed=random_seed,
   1702         pathfinder_kwargs=pathfinder_kwargs,
   1703         compile_kwargs=compile_kwargs,
   1704     )
   1705     pathfinder_samples = mp_result.samples
   1706 elif inference_backend == "blackjax":

File ~/projects/pymc-devs/pymc-extras/pymc_extras/inference/pathfinder/pathfinder.py:1506, in multipath_pathfinder(model, num_paths, num_draws, num_draws_per_path, maxcor, maxiter, ftol, gtol, maxls, num_elbo_draws, jitter, epsilon, importance_sampling, progressbar, concurrent, random_seed, pathfinder_kwargs, compile_kwargs)
   1493 pathfinder_config = PathfinderConfig(
   1494     num_draws=num_draws_per_path,
   1495     maxcor=maxcor,
   (...)
   1502     epsilon=epsilon,
   1503 )
   1505 compile_start = time.time()
-> 1506 single_pathfinder_fn = make_single_pathfinder_fn(
   1507     model,
   1508     **asdict(pathfinder_config),
   1509     pathfinder_kwargs=pathfinder_kwargs,
   1510     compile_kwargs=compile_kwargs,
   1511 )
   1512 compile_end = time.time()
   1514 # NOTE: from limited tests, no concurrency is faster than thread, and thread is faster than process. But I suspect this also depends on the model size and maxcor setting.

File ~/projects/pymc-devs/pymc-extras/pymc_extras/inference/pathfinder/pathfinder.py:939, in make_single_pathfinder_fn(model, num_draws, maxcor, maxiter, ftol, gtol, maxls, num_elbo_draws, jitter, epsilon, pathfinder_kwargs, compile_kwargs)
    936 lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls)
    938 # pathfinder body
--> 939 pathfinder_body_fn = make_pathfinder_body(
    940     logp_func, num_draws, maxcor, num_elbo_draws, epsilon, **compile_kwargs
    941 )
    942 rngs = find_rng_nodes(pathfinder_body_fn.maker.fgraph.outputs)
    944 def single_pathfinder_fn(random_seed: int) -> PathfinderResult:

File ~/projects/pymc-devs/pymc-extras/pymc_extras/inference/pathfinder/pathfinder.py:857, in make_pathfinder_body(logp_func, num_draws, maxcor, num_elbo_draws, epsilon, **compile_kwargs)
    853 logP_psi = loglike(psi)
    855 # return psi, logP_psi, logQ_psi, elbo_argmax
--> 857 pathfinder_body_fn = compile_pymc(
    858     [x_full, g_full],
    859     [psi, logP_psi, logQ_psi, elbo_argmax],
    860     **compile_kwargs,
    861 )
    862 pathfinder_body_fn.trust_input = True
    863 return pathfinder_body_fn

File ~/projects/pymc-devs/pymc/pymc/pytensorf.py:956, in compile_pymc(*args, **kwargs)
    951 def compile_pymc(*args, **kwargs):
    952     warnings.warn(
    953         "compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC",
    954         FutureWarning,
    955     )
--> 956     return compile(*args, **kwargs)

File ~/projects/pymc-devs/pymc/pymc/pytensorf.py:941, in compile(inputs, outputs, random_seed, mode, **kwargs)
    939 opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
    940 mode = Mode(linker=mode.linker, optimizer=opt_qry)
--> 941 pytensor_function = pytensor.function(
    942     inputs,
    943     outputs,
    944     updates={**rng_updates, **kwargs.pop("updates", {})},
    945     mode=mode,
    946     **kwargs,
    947 )
    948 return pytensor_function

File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/compile/function/__init__.py:318, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
    312     fn = orig_function(
    313         inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
    314     )
    315 else:
    316     # note: pfunc will also call orig_function -- orig_function is
    317     #      a choke point that all compilation must pass through
--> 318     fn = pfunc(
    319         params=inputs,
    320         outputs=outputs,
    321         mode=mode,
    322         updates=updates,
    323         givens=givens,
    324         no_default_updates=no_default_updates,
    325         accept_inplace=accept_inplace,
    326         name=name,
    327         rebuild_strict=rebuild_strict,
    328         allow_input_downcast=allow_input_downcast,
    329         on_unused_input=on_unused_input,
    330         profile=profile,
    331         output_keys=output_keys,
    332     )
    333 return fn

File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/compile/function/pfunc.py:465, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
    451     profile = ProfileStats(message=profile)
    453 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    454     params,
    455     outputs,
   (...)
    462     fgraph=fgraph,
    463 )
--> 465 return orig_function(
    466     inputs,
    467     cloned_outputs,
    468     mode,
    469     accept_inplace=accept_inplace,
    470     name=name,
    471     profile=profile,
    472     on_unused_input=on_unused_input,
    473     output_keys=output_keys,
    474     fgraph=fgraph,
    475 )

File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/compile/function/types.py:1769, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
   1757     m = Maker(
   1758         inputs,
   1759         outputs,
   (...)
   1766         fgraph=fgraph,
   1767     )
   1768     with config.change_flags(compute_test_value="off"):
-> 1769         fn = m.create(defaults)
   1770 finally:
   1771     if profile and fn:

File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/compile/function/types.py:1661, in FunctionMaker.create(self, input_storage, storage_map)
   1658 start_import_time = pytensor.link.c.cmodule.import_time
   1660 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1661     _fn, _i, _o = self.linker.make_thunk(
   1662         input_storage=input_storage_lists, storage_map=storage_map
   1663     )
   1665 end_linker = time.perf_counter()
   1667 linker_time = end_linker - start_linker

File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/link/basic.py:245, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
    238 def make_thunk(
    239     self,
    240     input_storage: Optional["InputStorageType"] = None,
   (...)
    243     **kwargs,
    244 ) -> tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 245     return self.make_all(
    246         input_storage=input_storage,
    247         output_storage=output_storage,
    248         storage_map=storage_map,
    249     )[:3]

File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/link/basic.py:695, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
    692 for k in storage_map:
    693     compute_map[k] = [k.owner is None]
--> 695 thunks, nodes, jit_fn = self.create_jitable_thunk(
    696     compute_map, nodes, input_storage, output_storage, storage_map
    697 )
    699 [fn] = thunks
    700 fn.jit_fn = jit_fn

File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
    644 # This is a bit hackish, but we only return one of the output nodes
    645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
    648     self.fgraph,
    649     order=order,
    650     input_storage=input_storage,
    651     output_storage=output_storage,
    652     storage_map=storage_map,
    653 )
    655 thunk_inputs = self.create_thunk_inputs(storage_map)
    656 thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]

File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/link/jax/linker.py:67, in JAXLinker.fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs)
     64         fgraph.inputs.remove(new_inp)
     65         fgraph.inputs.insert(old_inp_fgrap_index, new_inp)
---> 67 return jax_funcify(
     68     fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
     69 )

File ~/miniconda3/envs/python-3.10/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:54, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
     47 @jax_funcify.register(FunctionGraph)
     48 def jax_funcify_FunctionGraph(
     49     fgraph,
   (...)
     52     **kwargs,
     53 ):
---> 54     return fgraph_to_python(
     55         fgraph,
     56         jax_funcify,
     57         type_conversion_fn=jax_typify,
     58         fgraph_name=fgraph_name,
     59         **kwargs,
     60     )

File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/link/utils.py:736, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, unique_name, **kwargs)
    734 body_assigns = []
    735 for node in order:
--> 736     compiled_func = op_conversion_fn(
    737         node.op, node=node, storage_map=storage_map, **kwargs
    738     )
    740     # Create a local alias with a unique name
    741     local_compiled_func_name = unique_name(compiled_func)

File ~/miniconda3/envs/python-3.10/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:44, in jax_funcify(op, node, storage_map, **kwargs)
     41 @singledispatch
     42 def jax_funcify(op, node=None, storage_map=None, **kwargs):
     43     """Create a JAX compatible function from an PyTensor `Op`."""
---> 44     raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")

NotImplementedError: No JAX conversion for the given `Op`: LogLike{logp_func=<function make_single_pathfinder_fn.<locals>.logp_func at 0x7f7cac13f010>}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant