You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This issue is tracking progress on addressing some current limitations of Objax2Tf converter:
1. Shape polymorphism.
Right now saving Objax2Tf to Tensorflow SavedModel requires specification of shape of the input (including specific value for batch dimension). Shape polymorphism will allow us to use a None for batch dimension, so generated SavedModel could be used with any batch size.
Shape polymorphism was recently added and should be working in jax2tf, however I didn't manage to make it work with Objax2Tf.
We need to investigate why it does not work. We also may need to wait until shape polymorphism support will be improved on JAX side.
2. Update of batch norm parameters (and similar things) in generated Tensorflow model
import objax
m = objax.nn.BatchNorm0D(3)
tfm = objax.util.Objax2Tf(m)
x = np.random.normal(size=(10, 3))
# This should update batch norm variables of generated Tensorflow model
y = tfm(x, training=True)
**3. Investigate if generated Objax2Tf model can be trained or fine tuned in Tensorflow.
The text was updated successfully, but these errors were encountered:
This issue is tracking progress on addressing some current limitations of
Objax2Tf
converter:1. Shape polymorphism.
Right now saving
Objax2Tf
to Tensorflow SavedModel requires specification of shape of the input (including specific value for batch dimension). Shape polymorphism will allow us to use aNone
for batch dimension, so generated SavedModel could be used with any batch size.Shape polymorphism was recently added and should be working in
jax2tf
, however I didn't manage to make it work withObjax2Tf
.We need to investigate why it does not work. We also may need to wait until shape polymorphism support will be improved on JAX side.
2. Update of batch norm parameters (and similar things) in generated Tensorflow model
**3. Investigate if generated
Objax2Tf
model can be trained or fine tuned in Tensorflow.The text was updated successfully, but these errors were encountered: