Replies: 1 comment
-
Yes - when using Keras Core with JAX backend, numpy arrays are automatically converted to JAX arrays. We can see this in two key files:
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
# Automatically converts numpy arrays to JAX arrays via jnp.asarray()
return jnp.asarray(x, dtype=dtype)
def add(x1, x2):
x1 = convert_to_tensor(x1) # numpy -> jax array conversion
x2 = convert_to_tensor(x2)
return jnp.add(x1, x2) So you can keep developing with numpy arrays and get JAX's performance benefits automatically when the arrays pass through your model. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
when you use keras core with jax backend... will the numpy arrays that is used while developing the model be eventually coverted into jax numpy array harnessing the performance and efficiency of it?
Beta Was this translation helpful? Give feedback.
All reactions