Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Objax2Tf enhancements #179

Open
AlexeyKurakin opened this issue Dec 15, 2020 · 0 comments
Open

Objax2Tf enhancements #179

AlexeyKurakin opened this issue Dec 15, 2020 · 0 comments

Comments

@AlexeyKurakin
Copy link
Member

This issue is tracking progress on addressing some current limitations of Objax2Tf converter:

1. Shape polymorphism.

Right now saving Objax2Tf to Tensorflow SavedModel requires specification of shape of the input (including specific value for batch dimension). Shape polymorphism will allow us to use a None for batch dimension, so generated SavedModel could be used with any batch size.

Shape polymorphism was recently added and should be working in jax2tf, however I didn't manage to make it work with Objax2Tf.
We need to investigate why it does not work. We also may need to wait until shape polymorphism support will be improved on JAX side.

2. Update of batch norm parameters (and similar things) in generated Tensorflow model

import objax

m = objax.nn.BatchNorm0D(3)
tfm = objax.util.Objax2Tf(m)

x = np.random.normal(size=(10, 3))
# This should update batch norm variables of generated Tensorflow model
y = tfm(x, training=True)

**3. Investigate if generated Objax2Tf model can be trained or fine tuned in Tensorflow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant