From 245d8b3d20687c5ac0e0e8f94a8dba313405497c Mon Sep 17 00:00:00 2001 From: Matthias Schmidtblaicher <42544829+MatthiasSchmidtblaicherQC@users.noreply.github.com> Date: Wed, 6 Nov 2024 09:15:56 +0100 Subject: [PATCH] Add tutorial on how to estimate a Cox Proportional Hazards Model in glum (#876) * add tutorial * first version * add anchors and adjust format * title case * categoricals * tiny changes * even more cosmetics * wording * add lifelines optional dependency * tiny wordings * add lifelines to pixi lock * even more tiny wordings * add output for all cells * clearer notation and small wordings * some more words on data part * update pixi lock * add reference to penalized splines blog post * Apply suggestions from code review Co-authored-by: Martin Stancsics * suggestion & add conclusion * add clarifying sentence --------- Co-authored-by: Martin Stancsics --- docs/tutorials/cox_model/cox_model.ipynb | 896 ++++++++++++++++++ .../penalized_splines/penalized_splines.ipynb | 20 + docs/tutorials/tutorials.rst | 2 + pixi.lock | 65 ++ pixi.toml | 1 + 5 files changed, 984 insertions(+) create mode 100644 docs/tutorials/cox_model/cox_model.ipynb create mode 100644 docs/tutorials/penalized_splines/penalized_splines.ipynb diff --git a/docs/tutorials/cox_model/cox_model.ipynb b/docs/tutorials/cox_model/cox_model.ipynb new file mode 100644 index 00000000..cb9c1cc2 --- /dev/null +++ b/docs/tutorials/cox_model/cox_model.ipynb @@ -0,0 +1,896 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# The Cox Proportional Hazards Model in Glum\n", + "\n", + "**Intro**\n", + "\n", + "This tutorial shows how the Cox proportional hazards model (from now on: Cox model) which cannot be represented as an Exponential Dispersion Model (EDM), can still be estimated by a simple data transformation followed by a standard Poisson regression in `glum` (from now on: Poisson approach). The Poisson approach requires estimating the coefficients of a high-dimensional categorical of time-fixed-effects, which can be done efficiently in glum, leveraging the capabilities of `tabmat`'s `CategoricalMatrix`. The exposition of the Poisson approach here is based on [1], but the equivalence has been described before [2].\n", + "\n", + "## Table of Contents\n", + "* [1. Equivalence Between the Cox Likelihood and a Profile Poisson Likelihood](#1.-Equivalence-Between-the-Cox-Likelihood-and-a-Profile-Poisson-Likelihood)\n", + "* [2. Estimating a Cox Model in Glum](#2.-Estimating-a-Cox-Model-in-Glum)\n", + "* [3. Speed Considerations](#3.-Speed-Considerations)\n", + "\n", + "## 1. Equivalence Between the Cox Likelihood and a Profile Poisson Likelihood\n", + "\n", + "In the Cox model, the rate of event occurrence, $\\lambda(t,x_i)$, factorizes nicely into a linear predictor $\\eta_i=\\sum_k \\beta_k x_{ik}$ that depends on individual $i$'s characteristics but not on time $t$, and a baseline hazard $\\lambda_0$ that depends only on time: $\\lambda(t,x_i)=\\lambda_0(t)\\exp(\\eta_i)$. This is known as the proportional hazards assumption). The partial log-likelihood of $\\eta_i$ is\n", + "$$\n", + "\\sum_{\\text{event times}}\\log\\left(\\frac{y_{i,t}\\exp(\\eta_{i})}{\\sum_{i \\in \\mathcal{R}_t} \\exp(\\eta_i)} \\right),\n", + "$$\n", + "where $\\mathcal{R}_t$ is the set of individuals observed at event time $t$ and $y_{i,t}$ is one if the individual has an event at $t$ and zero otherwise.[1](#fn1) This partial log-likelihood cannot be represented as the log-likelihood of an EDM.[2](#fn2) Now consider an alternative Poisson regression with $y_{i,t}$ as an outcome. Apart from a constant, the log likelihood is\n", + "$$\n", + "\\sum_{\\text{event times}}\\sum_{i \\in \\mathcal{R}_t} y_{i,t} \\log(\\lambda(t,x_i)) - \\lambda(t,x_i).\n", + "$$\n", + "Using the proportional hazards assumption and letting $\\alpha_t = \\log(\\lambda_0(t))$, this becomes\n", + "$$\n", + "\\sum_{\\text{event times}}\\sum_{i \\in \\mathcal{R}_t} y_{i,t} \\left(\\alpha_t + \\eta_i\\right) - \\exp(\\alpha_t + \\eta_i).\n", + "$$\n", + "Solving the first order condition with respect to $\\alpha_t$ yields $\\exp(\\hat{\\alpha}_t) = \\left(\\sum_{i \\in \\mathcal{R}_t} \\exp(\\eta_i)\\right)^{-1}$. This can be plugged back into the log likelihood to yield, after some simplifications,\n", + "$$\n", + "\\sum_{\\text{event times}}\\log\\left(\\frac{y_{i,t}\\exp(\\eta_{i})}{\\sum_{i \\in \\mathcal{R}_t} \\exp(\\eta_i)} \\right) - 1,\n", + "$$\n", + "which is the same as the partial likelihood in the Cox model, apart from the -1 which drops out when taking derivatives. In short, the Cox partial log likelihood is equivalent to a Poisson log likelihood with the estimate for time period effects fed back in (\"profiled out\"). This means that, to estimate the parameters of the Cox model, one can simply run a Poisson regression with time fixed effects $\\alpha_t$. The data structures for the two objectives are different: the Cox partial log-likelihood operates on data with one row per observed individual, while the Poisson log-likelihood uses one row per individual and time period.\n", + "\n", + "## 2. Estimating a Cox Model in Glum\n", + "\n", + "We now show that a Poisson approach in `glum` yields the same parameter estimates as a Cox model. For the latter, we use the [lifelines](https://github.com/CamDavidsonPilon/lifelines) library. We also take the dataset from lifelines, which is from an RCT on recidivism for 432 convicts released from Maryland state prisons with first arrest after release as event. We first load imports and the dataset. The dataset has one row per convict, with two outcome columns: the `week` until which the observation lasts and `arrest`, which indicates whether an arrest event happened or not (censoring)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
weekarrestfinageracewexpmarparoprio
020102710013
117101810018
2251019010113
352012311111
452001901013
\n", + "
" + ], + "text/plain": [ + " week arrest fin age race wexp mar paro prio\n", + "0 20 1 0 27 1 0 0 1 3\n", + "1 17 1 0 18 1 0 0 1 8\n", + "2 25 1 0 19 0 1 0 1 13\n", + "3 52 0 1 23 1 1 1 1 1\n", + "4 52 0 0 19 0 1 0 1 3" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from lifelines.datasets import load_rossi\n", + "from lifelines import CoxPHFitter\n", + "import numpy as np\n", + "import pandas as pd\n", + "import glum\n", + "\n", + "df = load_rossi()\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we estimate the Cox model with an indicator for financial aid (the treatment), a B-spline in the age at time of release, indicators for race, prior experience, marital status, and parole, and a B-spline in the number of prior convictions:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
modellifelines.CoxPHFitter
duration col'week'
event col'arrest'
baseline estimationbreslow
number of observations432
number of events observed114
partial log-likelihood-656.25
time fit was run2024-11-01 19:22:33 UTC
\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
coefexp(coef)se(coef)coef lower 95%coef upper 95%exp(coef) lower 95%exp(coef) upper 95%cmp tozp-log2(p)
fin-0.350.700.19-0.730.030.481.030.00-1.820.073.85
bs(age, df=4)[1]-0.490.610.64-1.740.760.182.130.00-0.770.441.18
bs(age, df=4)[2]-1.810.160.84-3.46-0.160.030.850.00-2.150.034.99
bs(age, df=4)[3]-0.910.401.41-3.671.860.036.400.00-0.640.520.94
bs(age, df=4)[4]-1.760.171.10-3.920.410.021.500.00-1.590.113.17
race0.361.430.31-0.250.960.782.620.001.150.252.00
wexp-0.090.910.22-0.520.330.601.390.00-0.430.670.59
mar-0.330.720.39-1.090.420.341.530.00-0.870.391.37
paro-0.140.870.20-0.530.250.591.280.00-0.700.481.05
bs(prio, df=3)[1]1.363.910.96-0.533.250.5925.870.001.420.162.67
bs(prio, df=3)[2]-0.240.781.05-2.301.810.106.110.00-0.230.820.29
bs(prio, df=3)[3]2.7415.470.811.164.323.1974.970.003.40<0.00510.54

\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Concordance0.65
Partial AIC1336.50
log-likelihood ratio test38.26 on 12 df
-log2(p) of ll-ratio test12.81
\n", + "
" + ], + "text/latex": [ + "\\begin{tabular}{lrrrrrrrrrrr}\n", + " & coef & exp(coef) & se(coef) & coef lower 95% & coef upper 95% & exp(coef) lower 95% & exp(coef) upper 95% & cmp to & z & p & -log2(p) \\\\\n", + "covariate & & & & & & & & & & & \\\\\n", + "fin & -0.35 & 0.70 & 0.19 & -0.73 & 0.03 & 0.48 & 1.03 & 0.00 & -1.82 & 0.07 & 3.85 \\\\\n", + "bs(age, df=4)[1] & -0.49 & 0.61 & 0.64 & -1.74 & 0.76 & 0.18 & 2.13 & 0.00 & -0.77 & 0.44 & 1.18 \\\\\n", + "bs(age, df=4)[2] & -1.81 & 0.16 & 0.84 & -3.46 & -0.16 & 0.03 & 0.85 & 0.00 & -2.15 & 0.03 & 4.99 \\\\\n", + "bs(age, df=4)[3] & -0.91 & 0.40 & 1.41 & -3.67 & 1.86 & 0.03 & 6.40 & 0.00 & -0.64 & 0.52 & 0.94 \\\\\n", + "bs(age, df=4)[4] & -1.76 & 0.17 & 1.10 & -3.92 & 0.41 & 0.02 & 1.50 & 0.00 & -1.59 & 0.11 & 3.17 \\\\\n", + "race & 0.36 & 1.43 & 0.31 & -0.25 & 0.96 & 0.78 & 2.62 & 0.00 & 1.15 & 0.25 & 2.00 \\\\\n", + "wexp & -0.09 & 0.91 & 0.22 & -0.52 & 0.33 & 0.60 & 1.39 & 0.00 & -0.43 & 0.67 & 0.59 \\\\\n", + "mar & -0.33 & 0.72 & 0.39 & -1.09 & 0.42 & 0.34 & 1.53 & 0.00 & -0.87 & 0.39 & 1.37 \\\\\n", + "paro & -0.14 & 0.87 & 0.20 & -0.53 & 0.25 & 0.59 & 1.28 & 0.00 & -0.70 & 0.48 & 1.05 \\\\\n", + "bs(prio, df=3)[1] & 1.36 & 3.91 & 0.96 & -0.53 & 3.25 & 0.59 & 25.87 & 0.00 & 1.42 & 0.16 & 2.67 \\\\\n", + "bs(prio, df=3)[2] & -0.24 & 0.78 & 1.05 & -2.30 & 1.81 & 0.10 & 6.11 & 0.00 & -0.23 & 0.82 & 0.29 \\\\\n", + "bs(prio, df=3)[3] & 2.74 & 15.47 & 0.81 & 1.16 & 4.32 & 3.19 & 74.97 & 0.00 & 3.40 & 0.00 & 10.54 \\\\\n", + "\\end{tabular}\n" + ], + "text/plain": [ + "\n", + " duration col = 'week'\n", + " event col = 'arrest'\n", + " baseline estimation = breslow\n", + " number of observations = 432\n", + "number of events observed = 114\n", + " partial log-likelihood = -656.25\n", + " time fit was run = 2024-11-01 19:22:33 UTC\n", + "\n", + "---\n", + " coef exp(coef) se(coef) coef lower 95% coef upper 95% exp(coef) lower 95% exp(coef) upper 95%\n", + "covariate \n", + "fin -0.35 0.70 0.19 -0.73 0.03 0.48 1.03\n", + "bs(age, df=4)[1] -0.49 0.61 0.64 -1.74 0.76 0.18 2.13\n", + "bs(age, df=4)[2] -1.81 0.16 0.84 -3.46 -0.16 0.03 0.85\n", + "bs(age, df=4)[3] -0.91 0.40 1.41 -3.67 1.86 0.03 6.40\n", + "bs(age, df=4)[4] -1.76 0.17 1.10 -3.92 0.41 0.02 1.50\n", + "race 0.36 1.43 0.31 -0.25 0.96 0.78 2.62\n", + "wexp -0.09 0.91 0.22 -0.52 0.33 0.60 1.39\n", + "mar -0.33 0.72 0.39 -1.09 0.42 0.34 1.53\n", + "paro -0.14 0.87 0.20 -0.53 0.25 0.59 1.28\n", + "bs(prio, df=3)[1] 1.36 3.91 0.96 -0.53 3.25 0.59 25.87\n", + "bs(prio, df=3)[2] -0.24 0.78 1.05 -2.30 1.81 0.10 6.11\n", + "bs(prio, df=3)[3] 2.74 15.47 0.81 1.16 4.32 3.19 74.97\n", + "\n", + " cmp to z p -log2(p)\n", + "covariate \n", + "fin 0.00 -1.82 0.07 3.85\n", + "bs(age, df=4)[1] 0.00 -0.77 0.44 1.18\n", + "bs(age, df=4)[2] 0.00 -2.15 0.03 4.99\n", + "bs(age, df=4)[3] 0.00 -0.64 0.52 0.94\n", + "bs(age, df=4)[4] 0.00 -1.59 0.11 3.17\n", + "race 0.00 1.15 0.25 2.00\n", + "wexp 0.00 -0.43 0.67 0.59\n", + "mar 0.00 -0.87 0.39 1.37\n", + "paro 0.00 -0.70 0.48 1.05\n", + "bs(prio, df=3)[1] 0.00 1.42 0.16 2.67\n", + "bs(prio, df=3)[2] 0.00 -0.23 0.82 0.29\n", + "bs(prio, df=3)[3] 0.00 3.40 <0.005 10.54\n", + "---\n", + "Concordance = 0.65\n", + "Partial AIC = 1336.50\n", + "log-likelihood ratio test = 38.26 on 12 df\n", + "-log2(p) of ll-ratio test = 12.81" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "cph = CoxPHFitter()\n", + "cph.fit(\n", + " df,\n", + " duration_col=\"week\",\n", + " event_col=\"arrest\",\n", + " formula=\"fin + bs(age, df=4) + race + wexp + mar + paro + bs(prio, df=3)\",\n", + ")\n", + "cph.print_summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The results imply that financial aid leads to an about 30% reduction in the hazard of being arrested again, but the effect is just barely significant. We can now go about replicating these results in `glum`. For that, we first need to define a function that creates a dataset with one row per convict and period:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def survival_split(df, time, outcome):\n", + " \"\"\"Split survival data into one row per observation and period. Inspired by `SurvSplit` in R.\"\"\"\n", + "\n", + " # table with unique event or censoring times\n", + " df_times = df[[time]].drop_duplicates().sort_values(time)\n", + "\n", + " # create table with one row per time and left row\n", + " df[\"temp\"] = 1\n", + " # optional: add id\n", + " df_times[\"temp\"] = 1\n", + " df_ = pd.merge(df, df_times, on=\"temp\", how=\"left\", suffixes=[\"_end\", \"\"]).drop(\n", + " columns=\"temp\"\n", + " )\n", + " df = df.drop(columns=\"temp\")\n", + "\n", + " # remove rows after censoring or end time\n", + " df_ = df_.loc[df_[time + \"_end\"] >= df_[time]].reset_index(drop=True)\n", + "\n", + " # add outcome\n", + " df_[outcome] = np.where(\n", + " df_[time + \"_end\"] == df_[time],\n", + " df_[outcome],\n", + " False if pd.api.types.is_bool_dtype(df_[time]) else 0,\n", + " )\n", + "\n", + " return df_.drop(columns=time + \"_end\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "All that remains is to estimate a Poisson regression on this transformed dataset. Note that the formula here contains the week of the year as a categorical in addition to the regressors of the Cox model:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "df_split = survival_split(df, \"week\", \"arrest\")\n", + "model_glum = glum.GeneralizedLinearRegressor(\n", + " family=\"poisson\",\n", + " formula=\"arrest ~ fin + bs(age, df=4) + race + wexp + mar + paro + bs(prio, df=3) + C(week)\",\n", + " fit_intercept=False,\n", + ").fit(df_split)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can check that the Poisson regression yields estimates and standard errors that, for all practical purposes, are the same as those of the Cox regression:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
covariatecoef_coxphse_coxphcoef_poissonse_poisson
0fin-0.3500010.192611-0.3496680.192792
1bs(age, df=4)[1]-0.4905600.636493-0.4880860.637050
2bs(age, df=4)[2]-1.8078380.840686-1.7972580.841228
3bs(age, df=4)[3]-0.9060261.409066-0.9053131.410311
4bs(age, df=4)[4]-1.7562151.102792-1.7472881.103601
5race0.3569410.3101880.3565860.310484
6wexp-0.0931440.215796-0.0951440.215862
7mar-0.3343480.385933-0.3339040.386238
8paro-0.1389120.198157-0.1386370.198338
9bs(prio, df=3)[1]1.3640300.9638991.3552610.964604
10bs(prio, df=3)[2]-0.2448461.048481-0.2340071.048920
11bs(prio, df=3)[3]2.7389140.8052022.7143820.805035
\n", + "
" + ], + "text/plain": [ + " covariate coef_coxph se_coxph coef_poisson se_poisson\n", + "0 fin -0.350001 0.192611 -0.349668 0.192792\n", + "1 bs(age, df=4)[1] -0.490560 0.636493 -0.488086 0.637050\n", + "2 bs(age, df=4)[2] -1.807838 0.840686 -1.797258 0.841228\n", + "3 bs(age, df=4)[3] -0.906026 1.409066 -0.905313 1.410311\n", + "4 bs(age, df=4)[4] -1.756215 1.102792 -1.747288 1.103601\n", + "5 race 0.356941 0.310188 0.356586 0.310484\n", + "6 wexp -0.093144 0.215796 -0.095144 0.215862\n", + "7 mar -0.334348 0.385933 -0.333904 0.386238\n", + "8 paro -0.138912 0.198157 -0.138637 0.198338\n", + "9 bs(prio, df=3)[1] 1.364030 0.963899 1.355261 0.964604\n", + "10 bs(prio, df=3)[2] -0.244846 1.048481 -0.234007 1.048920\n", + "11 bs(prio, df=3)[3] 2.738914 0.805202 2.714382 0.805035" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cph.summary.reset_index()[[\"covariate\", \"coef\", \"se(coef)\"]].merge(\n", + " model_glum.coef_table(X=df_split, robust=False)[[\"coef\", \"se\"]],\n", + " left_on=\"covariate\",\n", + " right_index=True,\n", + ").rename(\n", + " columns={\n", + " \"coef_x\": \"coef_coxph\",\n", + " \"coef_y\": \"coef_poisson\",\n", + " \"se(coef)\": \"se_coxph\",\n", + " \"se\": \"se_poisson\",\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Speed Considerations\n", + "\n", + "The Poisson model estimates each time fixed effect parameter and, therefore, many more parameters than the Cox model. In our example, there are 61 coefficients in the Poisson model as opposed to 12 in the Cox model:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(12, 61)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(cph.summary), len(model_glum.coef_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One might, therefore, wonder if the Poisson approach is competitive in terms of estimation speed. For the dataset here, the Poisson approach, including the data transformation by `survival_split`, turns out to be faster than the Cox model. This speedup is aided by tabmat's optimizations for the high-dimensional `week` categorical." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "32.8 ms ± 3.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%timeit cox_model = CoxPHFitter().fit(df, duration_col='week', event_col='arrest', formula=\"fin + bs(age, df=4) + race + wexp + mar + paro + bs(prio,df=3)\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18 ms ± 73.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%timeit model_glum = glum.GeneralizedLinearRegressor(family=\"poisson\",formula=\"arrest ~ fin + bs(age, df=4) + race + wexp + mar + paro + bs(prio,df=3) + C(week)\").fit(survival_split(df, 'week', 'arrest'))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "This tutorial has shown how the Cox proportional hazards model can be reformulated as a Poisson regression on a transformed dataset, which can be run without specialized survival estimators. The Poisson regression involves estimating a high-dimensional categorical with one category per event time, for which glum's handling of categoricals is handy. This reformulation also allows for exploration beyond the Cox model, including the option to replace the piecewise constant baseline hazard with a smooth one, and to swap the Poisson objective for another, such as a binomial.\n", + "\n", + "\n", + "## Footnotes\n", + " 1:\n", + "The Cox model assumes that at most one individual has an event at any time (\"no ties\"). Different tie breaking methods available for datasets with duplicate event times. The most popular ones are an exact method, which is expensive to compute, Efron's method which is reasonably fast and accurate, and the Breslow approximation, which is the fastest. In the case of many ties, an inherently discrete survival model is probably the best option.\n", + "\n", + " 2:\n", + "The log likelihood of an EDM with parameters $\\theta_i$ is\n", + "$$\n", + "\\sum_i\\frac{y_i\\theta_i - b(\\theta_i)}{\\sigma_i} + \\text{constant}, \n", + "$$\n", + "with cumulant function $b$ and dispersion parameter $\\sigma_i$. The linear predictor is given by $\\eta_i=g((b'(\\theta_i))$ for link function $g$. The Cox partial log-likelihood cannot be broken down into such a sum over individuals $i$.\n", + "\n", + "## References\n", + "\n", + "[1] Carstensen, B. 2023. Who needs the Cox model anyway. December. Available at: http://bendixcarstensen.com/WntCma.pdf.\n", + "\n", + "[2] Whitehead, J., 1980. Fitting Cox’s regression model to survival data using GLIM. _Journal of the Royal Statistical Society Series C: Applied Statistics_, 29(3), pp.268-275." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "glum", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/tutorials/penalized_splines/penalized_splines.ipynb b/docs/tutorials/penalized_splines/penalized_splines.ipynb new file mode 100644 index 00000000..d2d19b65 --- /dev/null +++ b/docs/tutorials/penalized_splines/penalized_splines.ipynb @@ -0,0 +1,20 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fitting Penalized Splines in glum\n", + "\n", + "See the [post on Matt Mills' website](http://statmills.com/2023-11-20-Penalized_Splines_Using_glum/)." + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/tutorials/tutorials.rst b/docs/tutorials/tutorials.rst index 86166d17..36b9891b 100644 --- a/docs/tutorials/tutorials.rst +++ b/docs/tutorials/tutorials.rst @@ -8,3 +8,5 @@ Tutorials High Dimensional Fixed Effects with Rossman Sales Regularization with King County Housing Sales Formula interface + Cox Proportional Hazards Model + Fitting Penalized Splines diff --git a/pixi.lock b/pixi.lock index 04ebd477..a6e72273 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1014,6 +1014,8 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/asttokens-2.4.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/async-lru-2.0.4-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/attrs-24.2.0-pyh71513ae_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/autograd-1.7.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/autograd-gamma-0.5.0-pyh9f0ad1d_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/linux-64/aws-c-auth-0.8.0-h56a2c13_4.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/aws-c-cal-0.8.0-hd3f4568_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/aws-c-common-0.9.31-hb9d3cd8_0.conda @@ -1203,6 +1205,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.13.4-hb346dea_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/lifelines-0.30.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/llvmlite-0.43.0-py312h374181b_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/linux-64/lz4-4.3.3-py312hb3f7f12_1.conda @@ -1369,6 +1372,8 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/asttokens-2.4.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/async-lru-2.0.4-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/attrs-24.2.0-pyh71513ae_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/autograd-1.7.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/autograd-gamma-0.5.0-pyh9f0ad1d_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/aws-c-auth-0.8.0-ha41d1bc_4.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/aws-c-cal-0.8.0-hfd083d3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/aws-c-common-0.9.31-h7ab814d_0.conda @@ -1551,6 +1556,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libxcb-1.17.0-hdb1d25a_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libxml2-2.13.4-h8424949_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libzlib-1.3.1-h8359307_2.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/lifelines-0.30.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/llvm-openmp-19.1.3-hb52a8e5_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/llvm-tools-17.0.6-h5090b49_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/llvmlite-0.43.0-py312ha9ca408_1.conda @@ -1720,6 +1726,8 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/asttokens-2.4.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/async-lru-2.0.4-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/attrs-24.2.0-pyh71513ae_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/autograd-1.7.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/autograd-gamma-0.5.0-pyh9f0ad1d_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/win-64/aws-c-auth-0.8.0-h75ad88d_4.conda - conda: https://conda.anaconda.org/conda-forge/win-64/aws-c-cal-0.8.0-h0da4a7a_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/aws-c-common-0.9.31-h2466b09_0.conda @@ -1872,6 +1880,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/libxcb-1.17.0-h0e4246c_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/libxml2-2.13.4-h442d1da_2.conda - conda: https://conda.anaconda.org/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/lifelines-0.30.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/llvm-meta-5.0.0-0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/win-64/llvmlite-0.43.0-py312h1f7db74_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 @@ -5478,6 +5487,40 @@ packages: license_family: MIT size: 56048 timestamp: 1722977241383 +- kind: conda + name: autograd + version: 1.7.0 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/autograd-1.7.0-pyhd8ed1ab_0.conda + sha256: 2640424d10419ebc53cf962b606de81deec33611f9d1f08eee1e23bbd7fc4a96 + md5: 3d0c2341f515348d7fdc4f6d5f448446 + depends: + - numpy >=1.25 + - python >=3.8 + - scipy >=1.11 + license: MIT + license_family: MIT + size: 47431 + timestamp: 1724658599538 +- kind: conda + name: autograd-gamma + version: 0.5.0 + build: pyh9f0ad1d_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/autograd-gamma-0.5.0-pyh9f0ad1d_0.tar.bz2 + sha256: ee6384ca35889fbc2a877ae7140eb90ca66980310640027eb39185ffa84f92bd + md5: 1d2f3cd0881ead2f033ec5a9d567c6f0 + depends: + - autograd >=1.2.0 + - python + - scipy >=1.2.0 + license: MIT + license_family: MIT + size: 7767 + timestamp: 1602812490828 - kind: conda name: aws-c-auth version: 0.8.0 @@ -14128,6 +14171,28 @@ packages: license_family: Other size: 60963 timestamp: 1727963148474 +- kind: conda + name: lifelines + version: 0.30.0 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/lifelines-0.30.0-pyhd8ed1ab_0.conda + sha256: 226cf5033ddf1a5a977dd7fa754d06d518f58f4dbdcadbfbc4e9e4768642d91c + md5: b8d86da95fef6c11135562d2b37ce3e3 + depends: + - autograd >=1.5 + - autograd-gamma >=0.3 + - formulaic >=0.2.2 + - matplotlib-base >=3.0 + - numpy >=1.14.0 + - pandas >=2.1.0 + - python >=3.9 + - scipy >=1.7 + license: MIT + license_family: MIT + size: 266707 + timestamp: 1730245900323 - kind: conda name: line_profiler version: 4.1.3 diff --git a/pixi.toml b/pixi.toml index 7ba83a14..22c33cac 100644 --- a/pixi.toml +++ b/pixi.toml @@ -109,6 +109,7 @@ altair = "*" # used in docs/tutorials/rossman dask-ml = ">=2022.5.27" # used in tutorials rossman and insurance jupyterlab = "*" libpysal = "*" # used in docs/tutorials/regularization_housing_data +lifelines = "*" # used in docs/tutorials/cox_model openml = "*" # used for downloading datasets in the tutorials. shapely = "*" # used in docs/tutorials/regularization_housing_data