diff --git a/pyfixest/estimation/detect_singletons_jax.py b/pyfixest/estimation/detect_singletons_jax.py new file mode 100644 index 00000000..bcf91898 --- /dev/null +++ b/pyfixest/estimation/detect_singletons_jax.py @@ -0,0 +1,137 @@ +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np + + +@partial(jax.jit, static_argnames=("n_samples", "n_features", "max_fixef")) +def _process_features_jax( + ids, non_singletons, n_non_singletons, n_samples, n_features, max_fixef +): + """JIT-compiled inner loop for processing features with static shapes.""" + + def process_feature(carry, j): + non_singletons, n_non_singletons = carry + col = ids[:, j] + + # Initialize counts array + counts = jnp.zeros(max_fixef + 1, dtype=jnp.int32) + + # Count occurrences and track singletons + def count_loop(i, state): + counts, n_singletons = state + e = col[non_singletons[i]] + c = counts[e] + # Exactly match Numba: n_singletons += (c == 0) - (c == 1) + n_singletons = n_singletons + (c == 0) - (c == 1) + counts = counts.at[e].add(1) + return (counts, n_singletons) + + counts, n_singletons = jax.lax.fori_loop( + 0, n_non_singletons, count_loop, (counts, 0) + ) + + # Early return if no singletons found + def no_singletons(_): + return (non_singletons, n_non_singletons) + + # Update non_singletons if singletons found + def update_singletons(_): + def update_loop(i, state): + new_non_singletons, cnt = state + e = col[non_singletons[i]] + keep = counts[e] != 1 + # Exactly match Numba's update logic + new_non_singletons = jax.lax.cond( + keep, + lambda x: x[0].at[x[1]].set(non_singletons[i]), + lambda x: x[0], + (new_non_singletons, cnt), + ) + return (new_non_singletons, cnt + keep) + + new_non_singletons = jnp.zeros_like(non_singletons) + new_non_singletons, new_cnt = jax.lax.fori_loop( + 0, n_non_singletons, update_loop, (new_non_singletons, 0) + ) + return (new_non_singletons, new_cnt) + + return jax.lax.cond( + n_singletons == 0, no_singletons, update_singletons, None + ), None + + return jax.lax.scan( + process_feature, (non_singletons, n_non_singletons), jnp.arange(n_features) + )[0] + + +def detect_singletons_jax(ids: np.ndarray) -> np.ndarray: + """ + JAX implementation of singleton detection in fixed effects. + + Parameters + ---------- + ids : numpy.ndarray + A 2D numpy array representing fixed effects, with shape (n_samples, n_features). + Elements should be non-negative integers representing fixed effect identifiers. + + Returns + ------- + numpy.ndarray + A boolean array of shape (n_samples,), indicating which observations have + a singleton fixed effect. + """ + # Get dimensions and max_fixef before JIT + n_samples, n_features = ids.shape + max_fixef = int(np.max(ids)) # Use numpy.max instead of jax.numpy.max + + # Convert input to JAX array + ids = jnp.array(ids, dtype=jnp.int32) + + # Initialize with all indices as non-singletons + init_non_singletons = jnp.arange(n_samples) + init_n_non_singletons = n_samples + + @partial(jax.jit, static_argnames=("n_samples", "n_features", "max_fixef")) + def _singleton_detection_loop( + ids, non_singletons, n_non_singletons, n_samples, n_features, max_fixef + ): + def cond_fun(state): + prev_n, curr_carry = state + return prev_n != curr_carry[1] + + def body_fun(state): + prev_n, curr_carry = state + new_carry = _process_features_jax( + ids, curr_carry[0], curr_carry[1], n_samples, n_features, max_fixef + ) + return (curr_carry[1], new_carry) + + init_state = (n_samples + 1, (non_singletons, n_non_singletons)) + final_state = jax.lax.while_loop(cond_fun, body_fun, init_state) + return final_state[1] + + # Run iterations until convergence + final_non_singletons, final_n = _singleton_detection_loop( + ids, + init_non_singletons, + init_n_non_singletons, + n_samples, + n_features, + max_fixef, + ) + + # Create final boolean mask + is_singleton = jnp.ones(n_samples, dtype=jnp.bool_) + + @jax.jit + def _mark_non_singletons(is_singleton, final_non_singletons, final_n): + def mark_non_singleton(i, acc): + return acc.at[final_non_singletons[i]].set(False) + + return jax.lax.fori_loop(0, final_n, mark_non_singleton, is_singleton) + + is_singleton = _mark_non_singletons(is_singleton, final_non_singletons, final_n) + + return np.array(is_singleton) diff --git a/tests/test_detect_singletons.py b/tests/test_detect_singletons.py index 2a24692e..d6c2af32 100644 --- a/tests/test_detect_singletons.py +++ b/tests/test_detect_singletons.py @@ -1,6 +1,8 @@ import numpy as np +import pytest from pyfixest.estimation.detect_singletons_ import detect_singletons +from pyfixest.estimation.detect_singletons_jax import detect_singletons_jax input1 = np.array([[0, 2, 1], [0, 2, 1], [0, 1, 3], [0, 1, 2], [0, 1, 2]]) solution1 = np.array([False, False, True, False, False]) @@ -12,7 +14,79 @@ solution3 = np.array([False, False, False, False, False]) -def test_correctness(): - assert np.array_equal(detect_singletons(input1), solution1) - assert np.array_equal(detect_singletons(input2), solution2) - assert np.array_equal(detect_singletons(input3), solution3) +@pytest.mark.parametrize( + argnames="input, solution", + argvalues=[(input1, solution1), (input2, solution2), (input3, solution3)], +) +@pytest.mark.parametrize( + argnames="detection_function", + argvalues=[detect_singletons, detect_singletons_jax], + ids=["numba", "jax"], +) +def test_correctness(input, solution, detection_function): + assert np.array_equal(detection_function(input), solution) + + +@pytest.mark.parametrize( + argnames="detection_function", + argvalues=[detect_singletons, detect_singletons_jax], + ids=["numba", "jax"], +) +def test_single_column(detection_function): + """Test with a single fixed effect column.""" + input_data = np.array([[0], [0], [1], [2], [2]]) + expected = np.array([False, False, True, False, False]) + result = detection_function(input_data) + assert np.array_equal(result, expected) + + +@pytest.mark.parametrize( + argnames="detection_function", + argvalues=[detect_singletons, detect_singletons_jax], + ids=["numba", "jax"], +) +def test_all_singletons(detection_function): + """Test when all observations are singletons.""" + input_data = np.array([[0, 1], [1, 2], [2, 3], [3, 4]]) + expected = np.array([True, True, True, True]) + result = detection_function(input_data) + assert np.array_equal(result, expected) + + +@pytest.mark.parametrize( + argnames="detection_function", + argvalues=[detect_singletons, detect_singletons_jax], + ids=["numba", "jax"], +) +def test_no_singletons(detection_function): + """Test when there are no singletons.""" + input_data = np.array([[0, 0], [0, 0], [1, 1], [1, 1]]) + expected = np.array([False, False, False, False]) + result = detection_function(input_data) + assert np.array_equal(result, expected) + + +@pytest.mark.parametrize( + argnames="detection_function", + argvalues=[detect_singletons, detect_singletons_jax], + ids=["numba", "jax"], +) +def test_large_input(detection_function): + """Test with a larger input to check performance and correctness.""" + rng = np.random.default_rng(42) + N = 10000 + input_data = np.column_stack( + [ + rng.integers(0, N // 10, N), + rng.integers(0, N // 5, N), + rng.integers(0, N // 2, N), + ] + ) + + # For large input, we compare against the Numba implementation as reference + reference = detect_singletons(input_data) + result = detection_function(input_data) + + assert np.array_equal(result, reference) + assert len(result) == N + assert result.dtype == np.bool_