Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explore Objax to Tensorflow conversion #109

Closed
AlexeyKurakin opened this issue Oct 15, 2020 · 5 comments
Closed

Explore Objax to Tensorflow conversion #109

AlexeyKurakin opened this issue Oct 15, 2020 · 5 comments
Assignees
Labels
feature request New feature or request

Comments

@AlexeyKurakin
Copy link
Member

Some other JAX frameworks provide API to convert JAX models into Tensorflow:

  1. https://trax-ml.readthedocs.io/en/latest/notebooks/tf_numpy_and_keras.html#2.-Convert-Trax-to-Keras
  2. https://source.corp.google.com/piper///depot/google3/third_party/py/jax/experimental/jax2tf/examples/stax_to_tf_module.py
  3. https://github.com/google/jax/tree/master/jax/experimental/jax2tf

Such conversion might be useful because Tensorflow allow to save trained models in SavedModel format (which contains both weights and network architecture) to be later used in production settings.

@AlexeyKurakin AlexeyKurakin added the feature request New feature or request label Oct 15, 2020
@AlexeyKurakin AlexeyKurakin self-assigned this Nov 10, 2020
@AlexeyKurakin
Copy link
Member Author

I have run some preliminary experiments on conversion using jax2tf. It semi-works, however in certain cases hits JAX error.

Following example works:

wrn_width = 2            # Width of WideResNet
wrn_depth = 28           # Depth of WideResNet
batch_size = 4

# Model
model = WideResNet(nin=3, nclass=10, depth=wrn_depth, width=wrn_width)
model_vars = model.vars()

# Prediction operation
predict_op = lambda x: objax.functional.softmax(model(x, training=False))
predict_op = objax.Jit(predict_op, model_vars)

# Run prediction on random batch
x = objax.random.normal((batch_size, 3, 32, 32))
pred_y = predict_op(x)
print(pred_y)

# Convert model to Tensorflow and run it on the same batch
predict_tf = jax2tf.convert(predict_op)
print(predict_tf(np.array(x)))

However, attempt to re-run predict_tf one more time causes UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.. Which is probably related to the issue we observing in #158

@david-berthelot
Copy link
Contributor

I don't see how it could escape though, it looks like something specific to jax2tf somehow.
Otherwise the error would also show up when running predict_op twice (the JAX jitted one).

@AlexeyKurakin
Copy link
Member Author

I guess you're right. Maybe I have to add functional wrapper for jax2tf (similar to Jit, Grad, etc...). I'll look into this more

@AlexeyKurakin
Copy link
Member Author

After further investigation I made a version which seems to be working without issues:

class Objax2Tf(tf.Module):

  def __init__(self, module: objax.Module):
    assert isinstance(module, objax.Module), 'Input argument to Objax2Tf must be an Objax module.'

    module_vars = module.vars()

    def wrapped_op(tensor_list: List[JaxArray], kwargs, *args):
        original_values = module_vars.tensors()
        try:
            module_vars.assign(tensor_list)
            return module(*args, **kwargs)
        finally:
            module_vars.assign(original_values)

    tf_function = jax2tf.convert(wrapped_op)
    self._all_vars = [tf.Variable(v) for v in module_vars.tensors()]
    self._call = tf.function(
        lambda *args, **kwargs: tf_function(self._all_vars, kwargs, *args),
        autograph=False)
    
  def __call__(self, *args, **kwargs):
    return self._call(*args, **kwargs)

predict_tf = Objax2Tf(predict_op)
print(predict_tf(np.array(x)))

It also could be saved and loaded as Tensorflow SavedModel.

Still need to do more testing of various corner cases.

@AlexeyKurakin
Copy link
Member Author

Objax2Tf converter is implemented.
There are some follow up improvement which will be tracked in other issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants