Skip to content

Commit

Permalink
Fixed stat_smooth for method=glm & family param
Browse files Browse the repository at this point in the history
fixes #769
  • Loading branch information
has2k1 committed Apr 11, 2024
1 parent a8a2e97 commit b0da71e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
9 changes: 9 additions & 0 deletions doc/changelog.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,17 @@
title: Changelog
---

## v0.13.5
(not-yet-released)

### Bug Fixes

- Fix bug in [](:class:`~plotnine.stat_smooth`) where you could not set the
family when using a `glm`. ({{< issue 769 >}})

## v0.13.4
(2024-04-03)

[![](https://zenodo.org/badge/DOI/10.5281/zenodo.10912461.svg)](https://doi.org/10.5281/zenodo.10912461)

### Bug Fixes
Expand Down
37 changes: 35 additions & 2 deletions plotnine/stats/smoothers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..exceptions import PlotnineError, PlotnineWarning

if TYPE_CHECKING:
import statsmodels.api as sm
from patsy.eval import EvalEnvironment

from plotnine.mapping import Environment
Expand Down Expand Up @@ -274,6 +275,10 @@ def glm(data, xseq, **params):
init_kwargs, fit_kwargs = separate_method_kwargs(
params["method_args"], sm.GLM, sm.GLM.fit
)

if isinstance(family := init_kwargs.get("family"), str):
init_kwargs["family"] = _glm_family(family)

model = sm.GLM(data["y"], X, **init_kwargs)
results = model.fit(**fit_kwargs)

Expand All @@ -282,7 +287,7 @@ def glm(data, xseq, **params):

if params["se"]:
prediction = results.get_prediction(Xseq)
ci = prediction.conf_int(1 - params["level"])
ci = prediction.conf_int(alpha=1 - params["level"])
data["ymin"] = ci[:, 0]
data["ymax"] = ci[:, 1]

Expand All @@ -300,6 +305,10 @@ def glm_formula(data, xseq, **params):
init_kwargs, fit_kwargs = separate_method_kwargs(
params["method_args"], sm.GLM, sm.GLM.fit
)

if isinstance(family := init_kwargs.get("family"), str):
init_kwargs["family"] = _glm_family(family)

model = smf.glm(params["formula"], data, eval_env=eval_env, **init_kwargs)
results = model.fit(**fit_kwargs)
data = pd.DataFrame({"x": xseq})
Expand All @@ -308,7 +317,7 @@ def glm_formula(data, xseq, **params):
if params["se"]:
xdata = pd.DataFrame({"x": xseq})
prediction = results.get_prediction(xdata)
ci = prediction.conf_int(1 - params["level"])
ci = prediction.conf_int(alpha=1 - params["level"])
data["ymin"] = ci[:, 0]
data["ymax"] = ci[:, 1]
return data
Expand Down Expand Up @@ -594,3 +603,27 @@ def _to_patsy_env(environment: Environment) -> EvalEnvironment:

eval_env = EvalEnvironment(environment.namespaces)
return eval_env


def _glm_family(family: str) -> sm.families.Family:
"""
Get glm-family instance
Ref: https://www.statsmodels.org/stable/glm.html#families
"""
import statsmodels.api as sm

lookup: dict[str, type[sm.families.Family]] = {
"binomial": sm.families.Binomial,
"gamma": sm.families.Gamma,
"gaussian": sm.families.Gaussian,
"inverseGaussian": sm.families.InverseGaussian,
"negativeBinomial": sm.families.NegativeBinomial,
"poisson": sm.families.Poisson,
"tweedie": sm.families.Tweedie,
}
try:
return lookup[family.lower()](link=None) # pyright: ignore
except KeyError as err:
msg = f"GLM family should be one of {tuple(lookup)}"
raise ValueError(msg) from err
4 changes: 3 additions & 1 deletion tests/test_geom_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def test_rlm(self):
p.draw_test()

def test_glm(self):
p = self.p + geom_smooth(aes(y="y_noisy"), method="glm")
p = self.p + geom_smooth(
aes(y="y_noisy"), method="glm", method_args={"family": "gaussian"}
)
p.draw_test()

def test_gls(self):
Expand Down

0 comments on commit b0da71e

Please sign in to comment.