Skip to content

Commit

Permalink
Add Detect singletons algorithm in JAX (#774)
Browse files Browse the repository at this point in the history
* init

* modularize
  • Loading branch information
juanitorduz authored Jan 2, 2025
1 parent 93c9580 commit cf91a1a
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 4 deletions.
137 changes: 137 additions & 0 deletions pyfixest/estimation/detect_singletons_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np


@partial(jax.jit, static_argnames=("n_samples", "n_features", "max_fixef"))
def _process_features_jax(
ids, non_singletons, n_non_singletons, n_samples, n_features, max_fixef
):
"""JIT-compiled inner loop for processing features with static shapes."""

def process_feature(carry, j):
non_singletons, n_non_singletons = carry
col = ids[:, j]

# Initialize counts array
counts = jnp.zeros(max_fixef + 1, dtype=jnp.int32)

# Count occurrences and track singletons
def count_loop(i, state):
counts, n_singletons = state
e = col[non_singletons[i]]
c = counts[e]
# Exactly match Numba: n_singletons += (c == 0) - (c == 1)
n_singletons = n_singletons + (c == 0) - (c == 1)
counts = counts.at[e].add(1)
return (counts, n_singletons)

counts, n_singletons = jax.lax.fori_loop(
0, n_non_singletons, count_loop, (counts, 0)
)

# Early return if no singletons found
def no_singletons(_):
return (non_singletons, n_non_singletons)

# Update non_singletons if singletons found
def update_singletons(_):
def update_loop(i, state):
new_non_singletons, cnt = state
e = col[non_singletons[i]]
keep = counts[e] != 1
# Exactly match Numba's update logic
new_non_singletons = jax.lax.cond(
keep,
lambda x: x[0].at[x[1]].set(non_singletons[i]),
lambda x: x[0],
(new_non_singletons, cnt),
)
return (new_non_singletons, cnt + keep)

new_non_singletons = jnp.zeros_like(non_singletons)
new_non_singletons, new_cnt = jax.lax.fori_loop(
0, n_non_singletons, update_loop, (new_non_singletons, 0)
)
return (new_non_singletons, new_cnt)

return jax.lax.cond(
n_singletons == 0, no_singletons, update_singletons, None
), None

return jax.lax.scan(
process_feature, (non_singletons, n_non_singletons), jnp.arange(n_features)
)[0]


def detect_singletons_jax(ids: np.ndarray) -> np.ndarray:
"""
JAX implementation of singleton detection in fixed effects.
Parameters
----------
ids : numpy.ndarray
A 2D numpy array representing fixed effects, with shape (n_samples, n_features).
Elements should be non-negative integers representing fixed effect identifiers.
Returns
-------
numpy.ndarray
A boolean array of shape (n_samples,), indicating which observations have
a singleton fixed effect.
"""
# Get dimensions and max_fixef before JIT
n_samples, n_features = ids.shape
max_fixef = int(np.max(ids)) # Use numpy.max instead of jax.numpy.max

# Convert input to JAX array
ids = jnp.array(ids, dtype=jnp.int32)

# Initialize with all indices as non-singletons
init_non_singletons = jnp.arange(n_samples)
init_n_non_singletons = n_samples

@partial(jax.jit, static_argnames=("n_samples", "n_features", "max_fixef"))
def _singleton_detection_loop(
ids, non_singletons, n_non_singletons, n_samples, n_features, max_fixef
):
def cond_fun(state):
prev_n, curr_carry = state
return prev_n != curr_carry[1]

def body_fun(state):
prev_n, curr_carry = state
new_carry = _process_features_jax(
ids, curr_carry[0], curr_carry[1], n_samples, n_features, max_fixef
)
return (curr_carry[1], new_carry)

init_state = (n_samples + 1, (non_singletons, n_non_singletons))
final_state = jax.lax.while_loop(cond_fun, body_fun, init_state)
return final_state[1]

# Run iterations until convergence
final_non_singletons, final_n = _singleton_detection_loop(
ids,
init_non_singletons,
init_n_non_singletons,
n_samples,
n_features,
max_fixef,
)

# Create final boolean mask
is_singleton = jnp.ones(n_samples, dtype=jnp.bool_)

@jax.jit
def _mark_non_singletons(is_singleton, final_non_singletons, final_n):
def mark_non_singleton(i, acc):
return acc.at[final_non_singletons[i]].set(False)

return jax.lax.fori_loop(0, final_n, mark_non_singleton, is_singleton)

is_singleton = _mark_non_singletons(is_singleton, final_non_singletons, final_n)

return np.array(is_singleton)
82 changes: 78 additions & 4 deletions tests/test_detect_singletons.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import pytest

from pyfixest.estimation.detect_singletons_ import detect_singletons
from pyfixest.estimation.detect_singletons_jax import detect_singletons_jax

input1 = np.array([[0, 2, 1], [0, 2, 1], [0, 1, 3], [0, 1, 2], [0, 1, 2]])
solution1 = np.array([False, False, True, False, False])
Expand All @@ -12,7 +14,79 @@
solution3 = np.array([False, False, False, False, False])


def test_correctness():
assert np.array_equal(detect_singletons(input1), solution1)
assert np.array_equal(detect_singletons(input2), solution2)
assert np.array_equal(detect_singletons(input3), solution3)
@pytest.mark.parametrize(
argnames="input, solution",
argvalues=[(input1, solution1), (input2, solution2), (input3, solution3)],
)
@pytest.mark.parametrize(
argnames="detection_function",
argvalues=[detect_singletons, detect_singletons_jax],
ids=["numba", "jax"],
)
def test_correctness(input, solution, detection_function):
assert np.array_equal(detection_function(input), solution)


@pytest.mark.parametrize(
argnames="detection_function",
argvalues=[detect_singletons, detect_singletons_jax],
ids=["numba", "jax"],
)
def test_single_column(detection_function):
"""Test with a single fixed effect column."""
input_data = np.array([[0], [0], [1], [2], [2]])
expected = np.array([False, False, True, False, False])
result = detection_function(input_data)
assert np.array_equal(result, expected)


@pytest.mark.parametrize(
argnames="detection_function",
argvalues=[detect_singletons, detect_singletons_jax],
ids=["numba", "jax"],
)
def test_all_singletons(detection_function):
"""Test when all observations are singletons."""
input_data = np.array([[0, 1], [1, 2], [2, 3], [3, 4]])
expected = np.array([True, True, True, True])
result = detection_function(input_data)
assert np.array_equal(result, expected)


@pytest.mark.parametrize(
argnames="detection_function",
argvalues=[detect_singletons, detect_singletons_jax],
ids=["numba", "jax"],
)
def test_no_singletons(detection_function):
"""Test when there are no singletons."""
input_data = np.array([[0, 0], [0, 0], [1, 1], [1, 1]])
expected = np.array([False, False, False, False])
result = detection_function(input_data)
assert np.array_equal(result, expected)


@pytest.mark.parametrize(
argnames="detection_function",
argvalues=[detect_singletons, detect_singletons_jax],
ids=["numba", "jax"],
)
def test_large_input(detection_function):
"""Test with a larger input to check performance and correctness."""
rng = np.random.default_rng(42)
N = 10000
input_data = np.column_stack(
[
rng.integers(0, N // 10, N),
rng.integers(0, N // 5, N),
rng.integers(0, N // 2, N),
]
)

# For large input, we compare against the Numba implementation as reference
reference = detect_singletons(input_data)
result = detection_function(input_data)

assert np.array_equal(result, reference)
assert len(result) == N
assert result.dtype == np.bool_

0 comments on commit cf91a1a

Please sign in to comment.