-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model compiling twice when using jax==0.2.10 or later #212
Comments
Hi, I tried to run this code in colab (with
If I run the code on GPU (you need to select
So I can recommend you to try to upgrade to the latest jax and jaxlib version and see whether performance improved. If you want to experiment with the latest version of JAX in colab you can do it here: https://colab.sandbox.google.com/ |
Few other general suggestions about the code you posted:
@objax.Function(gp_model.vars() + opt.vars())
def train_op():
dE, E = energy()
opt(lr_adam, dE)
return E
train_op = objax.Jit(train_op)
class GaussianLikelihood(objax.Module):
"""
The Gaussian likelihood
"""
def __init__(self,
variance=0.1):
"""
:param variance: The observation noise variance
"""
self.variance = objax.TrainVar(np.array(variance))
def __call__(self, y, post_mean, post_cov):
"""
"""
exp_log_lik = (
-0.5 * np.log(2 * np.pi)
- 0.5 * np.log(self.variance.value)
- 0.5 * ((y - post_mean) ** 2 + post_cov) / self.variance.value
)
return exp_log_lik
class GP:
def __init__:
self.likelihood = objax.Vectorize(GaussianLikelihood())
def energy():
return np.sum(self.likelyhood( ... ))
I'm closing this issue for now, feel free to reopen it if you have more questions. |
Hi @AlexeyKurakin , https://colab.research.google.com/drive/13yKlZ1-fI_pIG3gt_J5WFuiEco1PkcYw?usp=sharing Here is a colab notebook which shows the double-compile issue. For the latest versions of jax and jaxlib, this gives
whereas running the line
Can you spot something I'm doing wrong, or do you think this is a bug? |
I should also mention that I suspect there is some double-compiling going on because when I run much larger models with a similar setup, I have observed the |
@AlexeyKurakin I also tried out your suggestions about using Vectorise rather than vmap, and now it seems like the model compiles on the first three iterations. I have checked this in google colab. Could we reopen the issue (I can't reopen myself)? import objax
import jax.numpy as np
import time
class GaussianLikelihood(objax.Module):
"""
The Gaussian likelihood
"""
def __init__(self,
variance=0.1):
"""
:param variance: The observation noise variance
"""
self.variance = objax.TrainVar(np.array(variance))
def __call__(self, y, post_mean, post_cov):
"""
"""
exp_log_lik = (
-0.5 * np.log(2 * np.pi)
- 0.5 * np.log(self.variance.value)
- 0.5 * ((y - post_mean) ** 2 + post_cov) / self.variance.value
)
return exp_log_lik
class GP(objax.Module):
"""
A GP model
"""
def __init__(self,
likelihood,
X,
Y):
self.X = np.array(X)
self.Y = np.array(Y)
self.likelihood = objax.Vectorize(likelihood, batch_axis=(0, 0, 0))
self.posterior_mean = objax.StateVar(np.zeros([X.shape[0], 1, 1]))
self.posterior_variance = objax.StateVar(np.ones([X.shape[0], 1, 1]))
def energy(self):
"""
"""
mean_f, cov_f = self.posterior_mean.value, self.posterior_variance.value
E = self.likelihood(
self.Y,
mean_f,
cov_f
)
return np.sum(E)
# generate some data
N = 1000000
x = np.linspace(-10, 100, num=N)
y = np.sin(x)
# set up the model
lik = GaussianLikelihood(variance=1.0)
gp_model = GP(likelihood=lik, X=x, Y=y)
energy = objax.GradValues(gp_model.energy, gp_model.vars())
lr_adam = 0.1
iters = 10
opt = objax.optimizer.Adam(gp_model.vars())
@objax.Function.with_vars(gp_model.vars() + opt.vars())
def train_op():
dE, E = energy() # compute energy and its gradients w.r.t. hypers
opt(lr_adam, dE)
return E
train_op = objax.Jit(train_op)
t0 = time.time()
for i in range(1, iters + 1):
t2 = time.time()
loss = train_op()
t3 = time.time()
# print('iter %2d, energy: %1.4f' % (i, loss[0]))
print('iter time: %2.2f secs' % (t3-t2))
t1 = time.time()
print('optimisation time: %2.2f secs' % (t1-t0)) gives
|
ok, let me debug it and then I get back to you |
Hi,
I recently updated JAX, and noticed that my runtime increased. I have managed to isolate the issue to be that my objax model is compiling itself twice, i.e., on the second training iteration the model seems to be recompiling for some reason. This only happens for JAX versions 0.2.10 or later.
Any idea what the cause of this may be?
I hope this toy example is clear enough. I am using objax==1.3.1 and jaxlib==0.1.60
Running this script with jax==0.2.9 gives
Running the script with jax==0.2.10 gives
As you can see, there is a significant difference in the 2nd iteration, as if the model is re-compiling itself.
The text was updated successfully, but these errors were encountered: