Skip to content

Commit

Permalink
Add context capture to formula materialisation (#770)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamvig96 authored Jan 3, 2025
1 parent cf91a1a commit 6b2a2c5
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 13 deletions.
13 changes: 12 additions & 1 deletion pyfixest/estimation/FixestMulti_.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
from collections.abc import Mapping
from importlib import import_module
from typing import Literal, Optional, Union
from typing import Any, Literal, Optional, Union

import pandas as pd

Expand All @@ -10,6 +11,7 @@
from pyfixest.estimation.fepois_ import Fepois
from pyfixest.estimation.FormulaParser import FixestFormulaParser
from pyfixest.utils.dev_utils import DataFrameType, _narwhals_to_pandas
from pyfixest.utils.utils import capture_context


class FixestMulti:
Expand All @@ -29,6 +31,7 @@ def __init__(
split: Optional[str],
fsplit: Optional[str],
separation_check: Optional[list[str]] = None,
context: Union[int, Mapping[str, Any]] = 0,
) -> None:
"""
Initialize a class for multiple fixed effect estimations.
Expand Down Expand Up @@ -61,6 +64,11 @@ def __init__(
separation_check: list[str], optional
Only used in "fepois". Methods to identify and drop separated observations.
Either "fe" or "ir". Executes both by default.
context : int or Mapping[str, Any]
A dictionary containing additional context variables to be used by
formulaic during the creation of the model matrix. This can include
custom factorization functions, transformations, or any other
variables that need to be available in the formula environment.
Returns
-------
Expand All @@ -75,6 +83,7 @@ def __init__(
self._reps = reps if use_compression else None
self._seed = seed if use_compression else None
self._separation_check = separation_check
self._context = capture_context(context)

self._run_split = split is not None or fsplit is not None
self._run_full = not (split and not fsplit)
Expand Down Expand Up @@ -243,6 +252,7 @@ def _estimate_all_models(
_run_split = self._run_split
_run_full = self._run_full
_splitvar = self._splitvar
_context = self._context

FixestFormulaDict = self.FixestFormulaDict
_fixef_keys = list(FixestFormulaDict.keys())
Expand Down Expand Up @@ -282,6 +292,7 @@ def _estimate_all_models(
store_data=_store_data,
copy_data=_copy_data,
lean=_lean,
context=_context,
sample_split_value=sample_split_value,
sample_split_var=_splitvar,
)
Expand Down
64 changes: 61 additions & 3 deletions pyfixest/estimation/estimation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Union
from collections.abc import Mapping
from typing import Any, Optional, Union

import pandas as pd

Expand All @@ -14,6 +15,7 @@
WeightsTypeOptions,
)
from pyfixest.utils.dev_utils import DataFrameType, _narwhals_to_pandas
from pyfixest.utils.utils import capture_context
from pyfixest.utils.utils import ssc as ssc_func


Expand All @@ -36,6 +38,7 @@ def feols(
demeaner_backend: DemeanerBackendOptions = "numba",
use_compression: bool = False,
reps: int = 100,
context: Optional[Union[int, Mapping[str, Any]]] = None,
seed: Optional[int] = None,
split: Optional[str] = None,
fsplit: Optional[str] = None,
Expand Down Expand Up @@ -139,6 +142,12 @@ def feols(
Number of bootstrap repetitions. Only relevant for boostrap inference applied to
compute cluster robust errors when `use_compression = True`.
context : int or Mapping[str, Any]
A dictionary containing additional context variables to be used by
formulaic during the creation of the model matrix. This can include
custom factorization functions, transformations, or any other
variables that need to be available in the formula environment.
seed: Optional[int]
Seed for the random number generator. Only relevant for boostrap inference applied to
compute cluster robust errors when `use_compression = True`.
Expand Down Expand Up @@ -166,6 +175,8 @@ class for multiple models specified via `fml`.
```{python}
import pyfixest as pf
import pandas as pd
import numpy as np
data = pf.get_data()
Expand Down Expand Up @@ -306,6 +317,40 @@ class for multiple models specified via `fml`.
Last, `feols()` supports interaction of variables via the `i()` syntax.
Documentation on this is tba.
You can pass custom transforms via the `context` argument. If you set `context = 0`, all
functions from the level of the call to `feols()` will be available:
```{python}
def _lspline(series: pd.Series, knots: list[float]) -> np.array:
'Generate a linear spline design matrix for the input series based on knots.'
vector = series.values
columns = []
for i, knot in enumerate(knots):
column = np.minimum(vector, knot if i == 0 else knot - knots[i - 1])
columns.append(column)
vector = vector - column
# Add the remainder as the last column
columns.append(vector)
# Combine columns into a design matrix
return np.column_stack(columns)
spline_split = _lspline(data["X2"], [0, 1])
data["X2_0"] = spline_split[:, 0]
data["0_X2_1"] = spline_split[:, 1]
data["1_X2"] = spline_split[:, 2]
explicit_fit = pf.feols("Y ~ X2_0 + 0_X2_1 + 1_X2 | f1 + f2", data=data)
# set context = 0 to make _lspline available for feols' internal call to Formulaic.model_matrix
context_captured_fit = pf.feols("Y ~ _lspline(X2,[0,1]) | f1 + f2", data=data, context = 0)
# or provide it as a dict / mapping
context_captured_fit_map = pf.feols("Y ~ _lspline(X2,[0,1]) | f1 + f2", data=data, context = {"_lspline":_lspline})
pf.etable([explicit_fit, context_captured_fit, context_captured_fit_map])
```
After fitting a model via `feols()`, you can use the `predict()` method to
get the predicted values:
Expand Down Expand Up @@ -396,6 +441,7 @@ class for multiple models specified via `fml`.
instead of the former feols('Y~ i(f1)', data = data, i_ref=1).
"""
)
context = {} if context is None else capture_context(context)

_estimation_input_checks(
fml=fml,
Expand Down Expand Up @@ -429,6 +475,7 @@ class for multiple models specified via `fml`.
seed=seed,
split=split,
fsplit=fsplit,
context=context,
)

estimation = "feols" if not use_compression else "compression"
Expand Down Expand Up @@ -469,6 +516,7 @@ def fepois(
copy_data: bool = True,
store_data: bool = True,
lean: bool = False,
context: Optional[Union[int, Mapping[str, Any]]] = None,
split: Optional[str] = None,
fsplit: Optional[str] = None,
) -> Union[Feols, Fepois, FixestMulti]:
Expand Down Expand Up @@ -559,6 +607,12 @@ def fepois(
to obtain the appropriate standard-errors at estimation time,
since obtaining different SEs won't be possible afterwards.
context : int or Mapping[str, Any]
A dictionary containing additional context variables to be used by
formulaic during the creation of the model matrix. This can include
custom factorization functions, transformations, or any other
variables that need to be available in the formula environment.
split: Optional[str]
A character string, i.e. 'split = var'. If provided, the sample is split according to the
variable and one estimation is performed for each value of that variable. If you also want
Expand Down Expand Up @@ -602,6 +656,7 @@ def fepois(
instead of the former fepois('Y~ i(f1)', data = data, i_ref=1).
"""
)
context = {} if context is None else capture_context(context)

# WLS currently not supported for Poisson regression
weights = None
Expand Down Expand Up @@ -640,6 +695,7 @@ def fepois(
seed=None,
split=split,
fsplit=fsplit,
context=context,
)

fixest._prepare_estimation(
Expand Down Expand Up @@ -770,10 +826,12 @@ def _estimation_input_checks(
raise TypeError("The function argument fsplit needs to be of type str.")

if split is not None and fsplit is not None and split != fsplit:
raise ValueError(f"""
raise ValueError(
f"""
Arguments split and fsplit are both specified, but not identical.
split is specified as {split}, while fsplit is specified as {fsplit}.
""")
"""
)

if isinstance(split, str) and split not in data.columns:
raise KeyError(f"Column '{split}' not found in data.")
Expand Down
5 changes: 4 additions & 1 deletion pyfixest/estimation/feiv_.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from collections.abc import Mapping
from importlib import import_module
from typing import Literal, Optional, Union
from typing import Any, Literal, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -149,6 +150,7 @@ def __init__(
store_data: bool = True,
copy_data: bool = True,
lean: bool = False,
context: Union[int, Mapping[str, Any]] = 0,
sample_split_var: Optional[str] = None,
sample_split_value: Optional[Union[str, int]] = None,
) -> None:
Expand All @@ -168,6 +170,7 @@ def __init__(
store_data,
copy_data,
lean,
context,
sample_split_var,
sample_split_value,
)
Expand Down
13 changes: 11 additions & 2 deletions pyfixest/estimation/feols_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import gc
import re
import warnings
from collections.abc import Mapping
from importlib import import_module
from typing import Literal, Optional, Union
from typing import Any, Literal, Optional, Union

import numba as nb
import numpy as np
Expand Down Expand Up @@ -42,7 +43,7 @@
_narwhals_to_pandas,
_select_order_coefs,
)
from pyfixest.utils.utils import get_ssc, simultaneous_crit_val
from pyfixest.utils.utils import capture_context, get_ssc, simultaneous_crit_val

decomposition_type = Literal["gelbach"]
prediction_type = Literal["response", "link"]
Expand Down Expand Up @@ -78,6 +79,11 @@ class Feols:
solver : str, optional.
The solver to use for the regression. Can be either "np.linalg.solve" or
"np.linalg.lstsq". Defaults to "np.linalg.solve".
context : int or Mapping[str, Any]
A dictionary containing additional context variables to be used by
formulaic during the creation of the model matrix. This can include
custom factorization functions, transformations, or any other
variables that need to be available in the formula environment.
Attributes
----------
Expand Down Expand Up @@ -220,6 +226,7 @@ def __init__(
store_data: bool = True,
copy_data: bool = True,
lean: bool = False,
context: Union[int, Mapping[str, Any]] = 0,
sample_split_var: Optional[str] = None,
sample_split_value: Optional[Union[str, int, float]] = None,
) -> None:
Expand Down Expand Up @@ -253,6 +260,7 @@ def __init__(
self._copy_data = copy_data
self._lean = lean
self._use_mundlak = False
self._context = capture_context(context)

self._support_crv3_inference = True
if self._weights_name is not None:
Expand Down Expand Up @@ -339,6 +347,7 @@ def prepare_model_matrix(self):
drop_singletons=self._drop_singletons,
drop_intercept=self._drop_intercept,
weights=self._weights_name,
context=self._context,
)

self._Y = mm_dict.get("Y")
Expand Down
10 changes: 9 additions & 1 deletion pyfixest/estimation/feols_compressed_.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Literal, Optional, Union
from typing import Any, Literal, Optional, Union

import narwhals as nw
import numpy as np
Expand Down Expand Up @@ -59,6 +60,11 @@ class FeolsCompressed(Feols):
Whether to copy the data.
lean : bool
Whether to keep memory-heavy objects as attributes or not.
context : int or Mapping[str, Any]
A dictionary containing additional context variables to be used by
formulaic during the creation of the model matrix. This can include
custom factorization functions, transformations, or any other
variables that need to be available in the formula environment.
reps : int
The number of bootstrap repetitions. Default is 100. Only used for CRV1 inference, where
a wild cluster bootstrap is used.
Expand Down Expand Up @@ -86,6 +92,7 @@ def __init__(
store_data: bool = True,
copy_data: bool = True,
lean: bool = False,
context: Union[int, Mapping[str, Any]] = 0,
reps=100,
seed: Optional[int] = None,
sample_split_var: Optional[str] = None,
Expand All @@ -107,6 +114,7 @@ def __init__(
store_data,
copy_data,
lean,
context,
sample_split_var,
sample_split_value,
)
Expand Down
10 changes: 9 additions & 1 deletion pyfixest/estimation/fepois_.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from collections.abc import Mapping
from importlib import import_module
from typing import Literal, Optional, Protocol, Union
from typing import Any, Literal, Optional, Protocol, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -56,6 +57,11 @@ class Fepois(Feols):
The backend used for demeaning.
fixef_tol: float, default = 1e-08.
Tolerance level for the convergence of the demeaning algorithm.
context : int or Mapping[str, Any]
A dictionary containing additional context variables to be used by
formulaic during the creation of the model matrix. This can include
custom factorization functions, transformations, or any other
variables that need to be available in the formula environment.
weights_name : Optional[str]
Name of the weights variable.
weights_type : Optional[str]
Expand Down Expand Up @@ -83,6 +89,7 @@ def __init__(
"np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"
] = "np.linalg.solve",
demeaner_backend: Literal["numba", "jax"] = "numba",
context: Union[int, Mapping[str, Any]] = 0,
store_data: bool = True,
copy_data: bool = True,
lean: bool = False,
Expand All @@ -106,6 +113,7 @@ def __init__(
store_data,
copy_data,
lean,
context,
sample_split_var,
sample_split_value,
)
Expand Down
Loading

0 comments on commit 6b2a2c5

Please sign in to comment.