-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explore Objax to Tensorflow conversion #109
Comments
I have run some preliminary experiments on conversion using jax2tf. It semi-works, however in certain cases hits JAX error. Following example works: wrn_width = 2 # Width of WideResNet
wrn_depth = 28 # Depth of WideResNet
batch_size = 4
# Model
model = WideResNet(nin=3, nclass=10, depth=wrn_depth, width=wrn_width)
model_vars = model.vars()
# Prediction operation
predict_op = lambda x: objax.functional.softmax(model(x, training=False))
predict_op = objax.Jit(predict_op, model_vars)
# Run prediction on random batch
x = objax.random.normal((batch_size, 3, 32, 32))
pred_y = predict_op(x)
print(pred_y)
# Convert model to Tensorflow and run it on the same batch
predict_tf = jax2tf.convert(predict_op)
print(predict_tf(np.array(x))) However, attempt to re-run |
I don't see how it could escape though, it looks like something specific to jax2tf somehow. |
I guess you're right. Maybe I have to add functional wrapper for jax2tf (similar to Jit, Grad, etc...). I'll look into this more |
After further investigation I made a version which seems to be working without issues: class Objax2Tf(tf.Module):
def __init__(self, module: objax.Module):
assert isinstance(module, objax.Module), 'Input argument to Objax2Tf must be an Objax module.'
module_vars = module.vars()
def wrapped_op(tensor_list: List[JaxArray], kwargs, *args):
original_values = module_vars.tensors()
try:
module_vars.assign(tensor_list)
return module(*args, **kwargs)
finally:
module_vars.assign(original_values)
tf_function = jax2tf.convert(wrapped_op)
self._all_vars = [tf.Variable(v) for v in module_vars.tensors()]
self._call = tf.function(
lambda *args, **kwargs: tf_function(self._all_vars, kwargs, *args),
autograph=False)
def __call__(self, *args, **kwargs):
return self._call(*args, **kwargs)
predict_tf = Objax2Tf(predict_op)
print(predict_tf(np.array(x))) It also could be saved and loaded as Tensorflow SavedModel. Still need to do more testing of various corner cases. |
|
Some other JAX frameworks provide API to convert JAX models into Tensorflow:
Such conversion might be useful because Tensorflow allow to save trained models in SavedModel format (which contains both weights and network architecture) to be later used in production settings.
The text was updated successfully, but these errors were encountered: