From 3d24d89999edc4336e9b46b236e24619c6740d3b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 10 Jan 2025 11:00:23 -0800 Subject: [PATCH] jax.numpy.clip: update use of deprecated arguments. - a is now positional-only - a_min is now min - a_max is now max The old argument names have been deprecated since JAX v0.4.27. PiperOrigin-RevId: 714108798 --- init2winit/model_lib/binarize_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/init2winit/model_lib/binarize_layers.py b/init2winit/model_lib/binarize_layers.py index 17d9285c..ca1181a3 100644 --- a/init2winit/model_lib/binarize_layers.py +++ b/init2winit/model_lib/binarize_layers.py @@ -185,7 +185,7 @@ def binarize(self, x: jnp.ndarray) -> jnp.ndarray: scale = jnp.divide(1.0, self.bound) x = jnp.multiply(x, scale) clip_bound = 1.0 - self.epsilon - x = jnp.clip(x, a_min=-clip_bound, a_max=clip_bound).astype(self.dtype) + x = jnp.clip(x, min=-clip_bound, max=clip_bound).astype(self.dtype) x = floor_with_gradient(x) + 0.5 # x is either -0.5 or +0.5 x = jnp.divide(x, scale) return x.astype(self.dtype)