diff --git a/sml/preprocessing/preprocessing.py b/sml/preprocessing/preprocessing.py index 6acc7dc0..58a0e158 100644 --- a/sml/preprocessing/preprocessing.py +++ b/sml/preprocessing/preprocessing.py @@ -42,8 +42,8 @@ def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1): ndarray of shape (n_samples, n_classes) Shape will be (n_samples, 1) for binary problems. """ - eq_func = lambda x: jnp.where(classes == x, 1, 0) - result = jax.vmap(eq_func)(y) + eq_func = lambda x: classes == x + result = jax.vmap(eq_func)(y).astype(jnp.int_) if neg_label != 0 or pos_label != 1: result = jnp.where(result, pos_label, neg_label) @@ -203,7 +203,7 @@ def binarize(X, *, threshold=0.0): Feature values below or equal to this are replaced by 0, above it by 1. """ - return jnp.where(X > threshold, 1, 0) + return (X > threshold).astype(jnp.int_) class Binarizer: @@ -626,7 +626,7 @@ def _weighted_percentile(x, q, w): adjusted_percentile = q / 100 * weight_cdf[-1] def searchsorted_element(x_inner): - encoding = jnp.where(x_inner >= weight_cdf[0:-1, 0], 1, 0) + encoding = x_inner >= weight_cdf[0:-1, 0] return jnp.sum(encoding) percentile_idx = jax.vmap(searchsorted_element)(adjusted_percentile) @@ -1112,7 +1112,7 @@ def transform(self, X): def compute_row(bin, x, c): def compute_element(x): - encoding = jnp.where(x >= bin[1:-1], 1, 0) + encoding = x >= bin[1:-1] return jnp.clip(jnp.sum(encoding), 0, c - 2) return jax.vmap(compute_element)(x) @@ -1125,7 +1125,7 @@ def compute_element(x): def compute_row(bin, x): def compute_element(x): - encoding = jnp.where(x >= bin[1:-1], 1, 0) + encoding = x >= bin[1:-1] return jnp.sum(encoding) return jax.vmap(compute_element)(x)