Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ELEX-3031: Potential Voterflow Solvers #17

Merged
merged 136 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 125 commits
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
e0580cd
Adding TransitionSolver abstract base class
dmnapolitano Oct 6, 2023
37b756c
Running pre-commit on TransitionSolver and removing unused warnings i…
dmnapolitano Oct 6, 2023
da59677
Modifying TransitionMatrixSolver according to some experiments I've r…
dmnapolitano Oct 6, 2023
3d6aef0
Adding placeholder for prediction intervals method
dmnapolitano Oct 6, 2023
9ebe5c5
Initial check-in of EI-based transition solver
dmnapolitano Oct 6, 2023
9f49d83
Adding some required-matching-length exceptions to EITransitionSolver
dmnapolitano Oct 9, 2023
7f965d8
Rescale if needed, check for things x units
dmnapolitano Oct 9, 2023
9e5a3fe
Adding _check_and_rescale() to superclass so the same check can be pe…
dmnapolitano Oct 10, 2023
76f4b61
Cleaning up code and exception-handling with pre-commit
dmnapolitano Oct 10, 2023
79076a7
Adding MAE computation/reporting to ETTransitionSolver
dmnapolitano Oct 10, 2023
2254a65
Adding the ability to compute a prediction (credible) interval with t…
dmnapolitano Oct 12, 2023
3aa0f5b
Handle zero and other weird division correctly when rescaling the per…
dmnapolitano Oct 12, 2023
a67304f
Adding a check to make sure we have enough units
dmnapolitano Oct 16, 2023
e54ba85
Silencing some warnings
dmnapolitano Oct 17, 2023
aedcdff
Increasing the target_accept on the sampler and making sure the sampl…
dmnapolitano Oct 17, 2023
e9fc725
Trying out some more EI solver optimizations
dmnapolitano Oct 18, 2023
08401eb
Explicitly specifying the cvxpy solver to use since the default is ab…
dmnapolitano Oct 18, 2023
3d35f67
Fixing poorly-worded comment in transition matrix solver
dmnapolitano Oct 18, 2023
1ac1c78
Default sampling chains of 2 on 2 cores
dmnapolitano Oct 18, 2023
ae0b320
Merge branch 'develop' of github.com:washingtonpost/elex-solver into …
dmnapolitano Oct 20, 2023
c87c25b
Making sure pymc is a requirement
dmnapolitano Oct 20, 2023
8d4637d
Experimenting with drawing fewer samples, different Dirichlet alphas
dmnapolitano Oct 23, 2023
8fc068c
Semi-relaxing the check for units summing to 1 using np.allclose()
dmnapolitano Oct 23, 2023
c05acde
Adding numpyro to the requirements
dmnapolitano Oct 23, 2023
39c9749
Test against Python 3.11 because I don't think some of the requiremen…
dmnapolitano Oct 23, 2023
9d595e8
Increasing github test timeout to 10 minutes
dmnapolitano Oct 23, 2023
78d5655
Trying a slightly-older version of numpy to be compatible with the cu…
dmnapolitano Oct 23, 2023
54bd39f
Switching to HalfNormal prior instead of Gamma after preliminary eval…
dmnapolitano Oct 23, 2023
5ce550c
Addressing some pylint complaints
dmnapolitano Oct 24, 2023
c25abec
Starting on TransitionSolver unit tests
dmnapolitano Oct 24, 2023
14d5239
Splitting TransitionSolver._check_and_rescale() into two separate met…
dmnapolitano Oct 24, 2023
6e2dfaa
Finishing TransitionSolver unit tests
dmnapolitano Oct 24, 2023
a5ced36
Make sure X and Y are numpy arrays for cvxpy if they're not already
dmnapolitano Oct 24, 2023
dabea4e
Adding unit tests for the transition matrix solver
dmnapolitano Oct 24, 2023
bc1a769
Adding test for rescaling the matrix if it's a pandas.DataFrame
dmnapolitano Oct 24, 2023
9f2683c
Pass in matrixes of vote counts (integers) so I can ensure the percen…
dmnapolitano Oct 25, 2023
53d4daa
Getting unit tests working again after all those changes
dmnapolitano Oct 25, 2023
1d2e26b
Add check for integer data
dmnapolitano Oct 25, 2023
9ffdc70
Silencing perfectly-ok warning about division-by-zero in rescale meth…
dmnapolitano Nov 2, 2023
c2a8909
Adding check and exception for units that are completely zero
dmnapolitano Nov 20, 2023
6beb558
Need to push the changes I made to the matrix solver so the tests pas…
dmnapolitano Nov 20, 2023
f734896
Adding the check for zero units and adding some consistency to checks…
dmnapolitano Nov 20, 2023
3d19897
Silencing some extraneous/unnecessary pymc and jax logging messages
dmnapolitano Nov 20, 2023
d32844c
Independent function for MAE calculation and MAE as a class member/pr…
dmnapolitano Dec 1, 2023
e6a8033
Handle situation where numpy arrays passed in to MAE function are act…
dmnapolitano Dec 2, 2023
f8e65c6
Updating the MAE unit test since it's no longer a class method
dmnapolitano Dec 2, 2023
d1044d6
Fixing issue where integer division was being performed when rescalin…
dmnapolitano Dec 6, 2023
8a31558
Fixing the matrix solver fit_predict() unit test now that I fixed suc…
dmnapolitano Dec 6, 2023
399daaf
Ensuring that the tests for rescale() test using integers with dtype int
dmnapolitano Dec 6, 2023
6e130ba
Using the Clarabel solver instead of ECOS after all since it seems mo…
dmnapolitano Dec 6, 2023
5d33777
SUPER preliminary version of a bootstrap matrix solver
dmnapolitano Dec 8, 2023
ccea996
Now generating random residuals for each unit/candidate rather than j…
dmnapolitano Dec 8, 2023
f67852b
Hmmmmmm...
dmnapolitano Dec 8, 2023
baf0e41
Shuffling the random per unit/candidate residuals, which does nothing…
dmnapolitano Dec 8, 2023
b1f3a7a
Adding option for weights to matrix solver
dmnapolitano Dec 8, 2023
5fa30c7
Switching the EI solver over to (units x candidates) to match the mat…
dmnapolitano Dec 8, 2023
80512b1
Finish up the conversion of the EI solver from (things x units) to (u…
dmnapolitano Dec 11, 2023
b453261
Removing some redundancy in the EI solver
dmnapolitano Dec 11, 2023
1ed57ae
Fixing mistake in preparing the weights
dmnapolitano Dec 11, 2023
5088ee8
Finished EI solver with weights (I think)
dmnapolitano Dec 11, 2023
5af4835
Switching to 'classic' bootstrap matrix solver which seems to produce…
dmnapolitano Dec 11, 2023
871bf11
Fix typing with weights in bootstrap matrix solver
dmnapolitano Dec 11, 2023
2f9c75b
Improving some of the logging generated by the matrix solvers
dmnapolitano Dec 12, 2023
833cb82
Trying out some error handling
dmnapolitano Dec 12, 2023
211e5b5
Use the weights in the bootstrap to draw a weighted sample
dmnapolitano Dec 12, 2023
518a6f8
Speeding up the bootstrap matrix solver a bit
dmnapolitano Dec 13, 2023
b2232d4
Removing weighting from EI solver for now since it's SUPER slow and p…
dmnapolitano Dec 14, 2023
a77409c
Correcting comment prediction => credible
dmnapolitano Dec 15, 2023
5559e22
Adding method for confidence interval to bootstrap solver
dmnapolitano Dec 15, 2023
11d23bc
Adding unit tests for the weights standardization/checking
dmnapolitano Dec 15, 2023
f661ce8
Adding option to hide the progress bar during bootstrapping
dmnapolitano Dec 15, 2023
9e48134
Removing extraneous parentheses from bootstrap solver
dmnapolitano Dec 15, 2023
8bda934
Don't fail the one unit test requiring pandas if the user doesn't hav…
dmnapolitano Dec 15, 2023
c0efe08
Adding unit tests for the bootstrap solver
dmnapolitano Dec 15, 2023
cec6a3c
Changing constraint to constraints
dmnapolitano Dec 18, 2023
260446f
Adding L2 regularization option to matrix solver
dmnapolitano Dec 19, 2023
e32fb80
Adding lambda argument to bootstrap solver to enable bootstrap ridge
dmnapolitano Dec 19, 2023
20e14b6
Clarifying predicted percentages vs. transitions
dmnapolitano Dec 20, 2023
0905e66
Adding function for WAPE
dmnapolitano Dec 26, 2023
5f7efcb
Converting the model-fit score to WAPE, too
dmnapolitano Dec 26, 2023
a5346de
Handle situation in calculating WAPE when the expected Y is 0
dmnapolitano Dec 27, 2023
7955e82
Letting WAPE remain undefined when Y_expected is zero
dmnapolitano Dec 27, 2023
27a92be
Switching back to MAE because WAPE is undefined when Y expected is zero
dmnapolitano Dec 27, 2023
61d2f01
Adding option to compute MAE with sample weights
dmnapolitano Dec 28, 2023
372cef3
And apparently fixing a mistake I had made in the MAE formula...
dmnapolitano Dec 28, 2023
a93a189
Moving MAE computation out of elex-solver
dmnapolitano Jan 15, 2024
54a47b0
Finishing up unit tests for TransitionSolver abstract base class
dmnapolitano Jan 17, 2024
9ec424e
Adding type hints to TransitionMatrixSolver
dmnapolitano Jan 17, 2024
1fc45b8
Missed a few type hints in TransitionMatrixSolver constructor
dmnapolitano Jan 22, 2024
12861ed
Adding type hints to EITransitionSolver
dmnapolitano Jan 22, 2024
4579678
Adding unit test for strict constraints with matrix solver and global…
dmnapolitano Jan 22, 2024
bfd05a8
Adding unit test for matrix solver with L2 regularization
dmnapolitano Jan 22, 2024
1b583b8
Adding a matrix solver unit test where the matrix needs to be pivoted…
dmnapolitano Jan 22, 2024
bba196f
Removing redundant/error-prone 'taking the mean' to get bootstrapped …
dmnapolitano Jan 22, 2024
98b9e81
Two more unit tests on the bootstrap confidence interval
dmnapolitano Jan 22, 2024
3e57ffd
Starting work on EI solver unit tests
dmnapolitano Jan 23, 2024
8064b6e
Adding two more EI solver unit tests
dmnapolitano Jan 23, 2024
006de0e
Updating some requirement version numbers
dmnapolitano Jan 23, 2024
85c9114
Increasing the test timeout on github
dmnapolitano Jan 23, 2024
fd2a344
Maybe the numpyro NUTS sampler is the problem?
dmnapolitano Jan 23, 2024
707f641
Please let me have fixed these unit tests
dmnapolitano Jan 23, 2024
9997d13
Trying to fix my failing pymc-related unit tests by adding a pytensor…
dmnapolitano Jan 30, 2024
8c69d6e
Hopefully fixing bad whitespacing in .github/workflows/test.yml ?
dmnapolitano Jan 30, 2024
6d6973b
Reverting the commit where I try adding an environment variable to th…
dmnapolitano Jan 30, 2024
0a9fd86
Maybe setting a numpy random seed will help
dmnapolitano Jan 30, 2024
9f7fea3
Reducing the number of samples drawn; want to see if the tests fail t…
dmnapolitano Jan 30, 2024
ca57292
Starting to see same results on macOS/M1 and Ubuntu/x86-64...
dmnapolitano Jan 30, 2024
8e594c4
Adding a test for credible interval, increasing the number of samples…
dmnapolitano Jan 30, 2024
0bf5278
Adding unit tests for other values that could be specified to credibl…
dmnapolitano Jan 30, 2024
f66e236
More matrix solver unit tests, particularly involving pandas
dmnapolitano Jan 30, 2024
c28a6f5
100% code coverage in transition solver base class :tada:
dmnapolitano Jan 30, 2024
bc74763
One last EI solver unit test :tada:
dmnapolitano Jan 30, 2024
f2d0578
Adding some method docstrings and return types to the transition solv…
dmnapolitano Jan 30, 2024
d62680b
Writing one single docstring for fit_predict() in the base class
dmnapolitano Jan 30, 2024
9c3379d
Matrix and bootstrap solvers docstrings, return types, and modifying …
dmnapolitano Jan 30, 2024
1ac03f3
Adding unit test for bootstrap matrix confidence interval transitions
dmnapolitano Jan 30, 2024
5340492
Adding docstrings to EI transition solver and cleaning up a few others
dmnapolitano Jan 31, 2024
7c69f88
Making EI solver's get_transitions() method super private
dmnapolitano Jan 31, 2024
160066c
Adding text to the README:
dmnapolitano Jan 31, 2024
67dd426
Capitalizing the 'm' in the section header
dmnapolitano Jan 31, 2024
fe69ce4
Removing the bootstrap matrix solver from this branch in favor of it …
dmnapolitano Jan 31, 2024
e9ef2ed
Forgot to remove bootstrap mentions from the README
dmnapolitano Jan 31, 2024
92641ac
Moving the EI solver to elex-voterflow-model as the EI model
dmnapolitano Feb 1, 2024
109b718
Asking pylint to ignore missing-module-docstring
dmnapolitano Feb 1, 2024
70bbaf7
When the Clarabel solver fails or throws a warning, chain that to a R…
dmnapolitano Feb 6, 2024
1c18883
Removing get_prediction_interval() method since it hasn't been implem…
dmnapolitano Mar 18, 2024
ffde8aa
Moving rules about matrix dimensions that are specific to voterflow o…
dmnapolitano Mar 18, 2024
71f29c9
Splitting fit_predict() into two methods, renaming 'percentages' to '…
dmnapolitano Mar 19, 2024
b5a2e07
Specifying return type for TransitionSolver.predict()
dmnapolitano Mar 19, 2024
73f0c9e
A version of TransitionSolver that inherets from LinearSolver
dmnapolitano Mar 19, 2024
2705e8b
Forgot to run pre-commit
dmnapolitano Mar 19, 2024
268ad15
Removing math to convert the coefficients to 'transitions' from Trans…
dmnapolitano Mar 20, 2024
f714ef4
Cleaning up some docstring formatting and removing unnecessary property
dmnapolitano Mar 22, 2024
4edb1fd
BROKEN testing the requirement of integer data and the forced-rescali…
dmnapolitano Mar 27, 2024
9beef48
FIXED updated unit tests; integer data is no longer a requirement for…
dmnapolitano Mar 28, 2024
06385f0
Adding more information to why self._check_for_zero_units() is important
dmnapolitano Mar 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ jobs:
test:
name: Run unit tests
runs-on: ubuntu-latest
timeout-minutes: 5
timeout-minutes: 15
strategy:
matrix:
python-version: ['3.10']
python-version: ['3.11']
steps:
- uses: actions/checkout@v2
- name: Setup Python
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ We have our own implementation of ordinary least squares in Python because this
## Quantile Regression
Since we did not find any implementations of quantile regression in Python that fit our needs, we decided to write one ourselves. At the moment this uses two libraries, the version that solves the non-regularized problem uses `numpy`and solves the dual based on [this](https://arxiv.org/pdf/2305.12616.pdf) paper. The version that solves the regularized problem uses [`cvxpy`](https://www.cvxpy.org/#) and sets up the problem as a normal optimization problem. Eventually, we are planning on replacing the regularized version with the dual also.

## Transition matrices
We also have a solver for transition matrices. While this works arbitrarily, we have used this in the past for our primary election night model. We can still use this to create the sankey diagram coefficients.
## Transition Matrices
We also have a matrix regression solver built with `cvxpy`. We've used this for our primary election model and analysis. The transitions it generates form the transitions displayed in our sankey diagrams.

## Development
We welcome contributions to this repo. Please open a Github issue for any issues or comments you have.
Expand Down
6 changes: 5 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@
universal = 1

[pycodestyle]
max-line-length = 160
max-line-length = 160

[pylint]
max-line-length = 160
disable = invalid-name, duplicate-code, missing-function-docstring, too-many-instance-attributes, too-many-arguments, too-many-locals, missing-module-docstring
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from setuptools import find_packages, setup

INSTALL_REQUIRES = ["cvxpy~=1.4", "numpy~=1.26", "scipy~=1.11"]
INSTALL_REQUIRES = ["cvxpy~=1.4", "numpy~=1.26", "scipy~=1.12"]

THIS_FILE_DIR = os.path.dirname(__file__)

Expand All @@ -29,7 +29,7 @@
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
],
description="A package for optimization solvers",
long_description=LONG_DESCRIPTION,
Expand Down
117 changes: 98 additions & 19 deletions src/elexsolver/TransitionMatrixSolver.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,107 @@
import logging
import warnings

import cvxpy as cp
import numpy as np

from elexsolver.logging import initialize_logging
from elexsolver.TransitionSolver import TransitionSolver

initialize_logging()

LOG = logging.getLogger(__name__)


class TransitionMatrixSolver(TransitionSolver):
"""
Matrix regression transition solver using CVXPY.
"""

class TransitionMatrixSolver:
def __init__(self):
self.transition_matrix = None
def __init__(self, strict: bool = True, lam: float | None = None):
"""
Parameters
----------
`strict` : bool, default True
If True, solution will be constrainted so that all coefficients are >= 0,
<= 1, and the sum of each row equals 1.
`lam` : float, optional
`lam` != 0 will enable L2 regularization (Ridge).
"""
super().__init__()
self._strict = strict
self._lambda = lam

@staticmethod
def __get_constraint(X, strict):
def __get_constraints(coef: np.ndarray, strict: bool) -> list:
if strict:
return [cp.sum(X, axis=1) == 1]
return [cp.sum(X, axis=1) <= 1.1, cp.sum(X, axis=1) >= 0.9]

def __solve(self, A, B, strict):
transition_matrix = cp.Variable((A.shape[1], B.shape[1]))
loss_function = cp.norm(A @ transition_matrix - B, "fro")
objective = cp.Minimize(loss_function)
constraint = TransitionMatrixSolver.__get_constraint(transition_matrix, strict)
problem = cp.Problem(objective, constraint)
problem.solve()
return [0 <= coef, coef <= 1, cp.sum(coef, axis=1) == 1]
return [cp.sum(coef, axis=1) <= 1.1, cp.sum(coef, axis=1) >= 0.9]

def __standard_objective(self, A: np.ndarray, B: np.ndarray, beta: np.ndarray) -> cp.Minimize:
loss_function = cp.norm(A @ beta - B, "fro")
return cp.Minimize(loss_function)

def __ridge_objective(self, A: np.ndarray, B: np.ndarray, beta: np.ndarray) -> cp.Minimize:
# Based on https://www.cvxpy.org/examples/machine_learning/ridge_regression.html
lam = cp.Parameter(nonneg=True, value=self._lambda)
loss_function = cp.pnorm(A @ beta - B, p=2) ** 2
regularizer = cp.pnorm(beta, p=2) ** 2
return cp.Minimize(loss_function + lam * regularizer)

def __solve(self, A: np.ndarray, B: np.ndarray, weights: np.ndarray) -> np.ndarray:
transition_matrix = cp.Variable((A.shape[1], B.shape[1]), pos=True)
Aw = np.dot(weights, A)
Bw = np.dot(weights, B)

if self._lambda is None or self._lambda == 0:
objective = self.__standard_objective(Aw, Bw, transition_matrix)
else:
objective = self.__ridge_objective(Aw, Bw, transition_matrix)

constraints = TransitionMatrixSolver.__get_constraints(transition_matrix, self._strict)
problem = cp.Problem(objective, constraints)

with warnings.catch_warnings():
warnings.simplefilter("error")
try:
problem.solve(solver=cp.CLARABEL)
except (UserWarning, cp.error.SolverError) as e:
raise RuntimeError(e) from e

return transition_matrix.value

def fit(self, A, B, strict=False):
transition_matrix = self.__solve(A, B, strict)
self.transition_matrix = transition_matrix
def fit_predict(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None = None) -> np.ndarray:
self._check_data_type(X)
self._check_data_type(Y)
self._check_any_element_nan_or_inf(X)
self._check_any_element_nan_or_inf(Y)

# matrices should be (units x things), where the number of units is > the number of things
if X.shape[1] > X.shape[0]:
dmnapolitano marked this conversation as resolved.
Show resolved Hide resolved
X = X.T
if Y.shape[1] > Y.shape[0]:
Y = Y.T

if X.shape[0] != Y.shape[0]:
raise ValueError(f"Number of units in X ({X.shape[0]}) != number of units in Y ({Y.shape[0]}).")

self._check_dimensions(X)
self._check_dimensions(Y)
self._check_for_zero_units(X)
self._check_for_zero_units(Y)

if not isinstance(X, np.ndarray):
X = X.to_numpy()
if not isinstance(Y, np.ndarray):
Y = Y.to_numpy()

X_expected_totals = X.sum(axis=0) / X.sum(axis=0).sum()

X = self._rescale(X)
Y = self._rescale(Y)

weights = self._check_and_prepare_weights(X, Y, weights)

def predict(self, A):
return A @ self.transition_matrix
percentages = self.__solve(X, Y, weights)
self._transitions = np.diag(X_expected_totals) @ percentages
return percentages
dmnapolitano marked this conversation as resolved.
Show resolved Hide resolved
111 changes: 111 additions & 0 deletions src/elexsolver/TransitionSolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import logging
import warnings
from abc import ABC

import numpy as np

from elexsolver.logging import initialize_logging

initialize_logging()

LOG = logging.getLogger(__name__)


class TransitionSolver(ABC):
dmnapolitano marked this conversation as resolved.
Show resolved Hide resolved
"""
Abstract class for transition solvers.
"""

def __init__(self):
self._transitions = None

def fit_predict(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None = None):
"""
After this method finishes, transitions will be available in the `transitions` class member.

Parameters
----------
`X` : np.ndarray matrix or pandas.DataFrame of int
Must have the same number of rows as `Y` but can have any number of columns greater than the number of rows.
`Y` : np.ndarray matrix or pandas.DataFrame of int
Must have the same number of rows as `X` but can have any number of columns greater than the number of rows.
`weights` : list, np.ndarray, or pandas.Series of int, optional
Must have the same length (number of rows) as both `X` and `Y`.

Returns
-------
np.ndarray matrix of float of shape (number of columns in `X`) x (number of columns in `Y`).
Each float represents the percent of how much of row x is part of column y.
"""
raise NotImplementedError

def get_prediction_interval(self, pi: float):
dmnapolitano marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError

@property
def transitions(self) -> np.ndarray:
return self._transitions

def _check_any_element_nan_or_inf(self, A: np.ndarray):
"""
Check whether any element in a matrix or vector is NaN or infinity
"""
if np.any(np.isnan(A)) or np.any(np.isinf(A)):
raise ValueError("Matrix contains NaN or Infinity.")

def _check_data_type(self, A: np.ndarray):
if not np.all(A.astype("int64") == A):
raise ValueError("Matrix must contain integers.")

def _check_dimensions(self, A: np.ndarray):
"""
Ensure that in our (units x things) matrix, the number of units is
at least twice as large as the number of things.
"""
if A.shape[0] <= A.shape[1] or (A.shape[0] // 2) <= A.shape[1]:
raise ValueError(f"Not enough units ({A.shape[0]}) relative to the number of things ({A.shape[1]}).")

def _check_for_zero_units(self, A: np.ndarray):
"""
If we have at least one unit whose columns are all zero, we can't continue.
"""
if np.any(np.sum(A, axis=1) == 0):
raise ValueError("Matrix cannot contain any rows (units) where all columns (things) are zero.")

def _rescale(self, A: np.ndarray) -> np.ndarray:
"""
Rescale rows (units) to ensure they sum to 1 (100%).
"""
A = A.copy().astype(float)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in divide")
A = (A.T / A.sum(axis=1)).T

return np.nan_to_num(A, nan=0, posinf=0, neginf=0)

def _check_and_prepare_weights(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None) -> np.ndarray:
"""
If `weights` is not None, and `weights` has the same number of rows in both matrices `X` and `Y`,
we'll rescale the weights by taking the square root after dividing them by their sum,
then return a diagonal matrix containing these now-normalized weights.
If `weights` is None, return a diagonal matrix of ones.

Parameters
----------
`X` : np.ndarray matrix of int (same number of rows as `Y`)
`Y` : np.ndarray matrix of int (same number of rows as `X`)
`weights` : np.ndarray of int of the shape (number of rows in `X` and `Y`, 1), optional
"""

if weights is not None:
if len(weights) != X.shape[0] and len(weights) != Y.shape[0]:
raise ValueError("weights must be the same length as the number of rows in X and Y.")
if isinstance(weights, list):
weights = np.array(weights).copy()
elif not isinstance(weights, np.ndarray):
# pandas.Series
weights = weights.values.copy()
return np.diag(np.sqrt(weights.flatten() / weights.sum()))

return np.diag(np.ones((Y.shape[0],)))
Loading
Loading