You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Would like to extend the pytensor backend of Pathfinder to compile using JAX by setting compile_kwargs=dict(mode="JAX") inpmx.fit. Not yet entirely sure what the speed advantage (if any) there is. However, I think the solution to the problem below might not be too difficult.
A required fix may be to implement JAX conversion for the LogLike operator below. (The reason for having the LogLikeOp was to vectorise an existing compiled model.logp() function which takes in a flattened array of the model parameters).
Would like to extend the pytensor backend of Pathfinder to compile using JAX by setting
compile_kwargs=dict(mode="JAX")
inpmx.fit
. Not yet entirely sure what the speed advantage (if any) there is. However, I think the solution to the problem below might not be too difficult.A required fix may be to implement JAX conversion for the LogLike operator below. (The reason for having the
LogLike
Op
was to vectorise an existing compiledmodel.logp()
function which takes in a flattened array of the model parameters).pymc-extras/pymc_extras/inference/pathfinder/pathfinder.py
Lines 693 to 716 in 00a4ca3
Minimum working example:
Output:
The text was updated successfully, but these errors were encountered: