Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating residuals #27

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open

Updating residuals #27

wants to merge 6 commits into from

Conversation

lennybronner
Copy link
Collaborator

@lennybronner lennybronner commented Sep 24, 2024

Description

This PR changes how we compute the residuals for the linear solvers. For LinearSolver we add the ability to compute the k-fold residual as an estimate for the leave one out residual. The K residual, is computed by generating K folds of the data, re-estimating the model on all but fold and computing the residual on the outstanding fold. All the K residuals are then concatenated together.

If K is set to None, then we just compute the training residual. For the OLS model we still allow for the exact leave one out residual, if K is equal to the number of units.

Jira Ticket

This is necessary to do this ticket: https://arcpublishing.atlassian.net/browse/ELEX-4549

Test Steps

Unit tests have been added

@lennybronner lennybronner requested a review from a team as a code owner September 24, 2024 23:45
y: np.ndarray,
weights: np.ndarray | None = None,
lambda_: float = 0.0,
cache: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help me understand how the cache is used? It seems like there are cases where we want to (not) keep track of certain things computed during the k-fold residuals process, but I'm not sure I understand why or why not. Plus, this looks like it's only used in the subclasses, so I'd suggest we either (a) remove it from here, or (b) add a method in this super-class that's like def cache_things() (bad method name) where the logic for this is used and can be shared by all subclasses 🤔

"""
Fits model
"""
raise NotImplementedError

def predict(self, x: np.ndarray) -> np.ndarray:
def predict(self, x: np.ndarray, coefficients: np.ndarray | None = None) -> np.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would you feel about leaving this (public method) like:

def predict(self, x: np.ndarray) -> np.ndarray:
    return _predict(x, self.coefficients)

and adding:

def _predict(self, x: np.ndarray, coefficients: np.ndarray) -> np.ndarray:
    return x @ coefficients

and then in residuals() you call _predict(x_test, coefficients_k) ? The way it's written now invites users to pass in any arbitrary coefficients, which might not be a good idea 🤔

center: bool = True,
**kwargs
) -> np.ndarray:
if K == x.shape[0]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that any other subclasses would benefit from this logic? 🤔

@@ -3,6 +3,7 @@
import os
import sys

import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know you said you're struggling to add unit tests for all these changes and I'm curious what you think is missing. These look good to me although I'll keep thinking about it 🤔 🎉

@@ -23,7 +23,7 @@ def test_basic_median_1():
preds = quantreg.predict(x)
# you'd think it would be 8 instead of 7.5, but run quantreg in R to confirm
# has to do with missing intercept
np.testing.assert_array_equal(preds, [[7.5, 7.5, 7.5, 15]])
np.testing.assert_array_equal(preds, [[7.5], [7.5], [7.5], [15]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this return type have to change? 🤔

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants