Skip to content

Commit

Permalink
[sml] optimize preprocessing by eliminating unnecessary where function (
Browse files Browse the repository at this point in the history
#608)

# Pull Request

## What problem does this PR solve?
Small optimization in sml/preprocessing
Optimize where(expression, 1, 0) to expression, which eliminates
unnecessary where function.
  • Loading branch information
winnylyc authored Mar 15, 2024
1 parent 60780e5 commit e8fe5e2
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions sml/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1):
ndarray of shape (n_samples, n_classes)
Shape will be (n_samples, 1) for binary problems.
"""
eq_func = lambda x: jnp.where(classes == x, 1, 0)
result = jax.vmap(eq_func)(y)
eq_func = lambda x: classes == x
result = jax.vmap(eq_func)(y).astype(jnp.int_)

if neg_label != 0 or pos_label != 1:
result = jnp.where(result, pos_label, neg_label)
Expand Down Expand Up @@ -203,7 +203,7 @@ def binarize(X, *, threshold=0.0):
Feature values below or equal to this are replaced by 0, above it by 1.
"""
return jnp.where(X > threshold, 1, 0)
return (X > threshold).astype(jnp.int_)


class Binarizer:
Expand Down Expand Up @@ -626,7 +626,7 @@ def _weighted_percentile(x, q, w):
adjusted_percentile = q / 100 * weight_cdf[-1]

def searchsorted_element(x_inner):
encoding = jnp.where(x_inner >= weight_cdf[0:-1, 0], 1, 0)
encoding = x_inner >= weight_cdf[0:-1, 0]
return jnp.sum(encoding)

percentile_idx = jax.vmap(searchsorted_element)(adjusted_percentile)
Expand Down Expand Up @@ -1112,7 +1112,7 @@ def transform(self, X):

def compute_row(bin, x, c):
def compute_element(x):
encoding = jnp.where(x >= bin[1:-1], 1, 0)
encoding = x >= bin[1:-1]
return jnp.clip(jnp.sum(encoding), 0, c - 2)

return jax.vmap(compute_element)(x)
Expand All @@ -1125,7 +1125,7 @@ def compute_element(x):

def compute_row(bin, x):
def compute_element(x):
encoding = jnp.where(x >= bin[1:-1], 1, 0)
encoding = x >= bin[1:-1]
return jnp.sum(encoding)

return jax.vmap(compute_element)(x)
Expand Down

0 comments on commit e8fe5e2

Please sign in to comment.