Skip to content

Commit

Permalink
Adding some method docstrings and return types to the transition solv…
Browse files Browse the repository at this point in the history
…er base class
  • Loading branch information
dmnapolitano committed Jan 30, 2024
1 parent bc74763 commit f2d0578
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/elexsolver/TransitionSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_prediction_interval(self, pi: float):
raise NotImplementedError

@property
def transitions(self):
def transitions(self) -> np.ndarray:
return self._transitions

def _check_any_element_nan_or_inf(self, A: np.ndarray):
Expand Down Expand Up @@ -55,7 +55,7 @@ def _check_for_zero_units(self, A: np.ndarray):
if np.any(np.sum(A, axis=1) == 0):
raise ValueError("Matrix cannot contain any rows (units) where all columns (things) are zero.")

def _rescale(self, A: np.ndarray):
def _rescale(self, A: np.ndarray) -> np.ndarray:
"""
Rescale rows (units) to ensure they sum to 1 (100%).
"""
Expand All @@ -67,7 +67,20 @@ def _rescale(self, A: np.ndarray):

return np.nan_to_num(A, nan=0, posinf=0, neginf=0)

def _check_and_prepare_weights(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None):
def _check_and_prepare_weights(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None) -> np.ndarray:
"""
If `weights` is not None, and `weights` has the same number of rows in both matrices `X` and `Y`,
we'll rescale the weights by taking the square root after dividing them by their sum,
then return a diagonal matrix containing these now-normalized weights.
If `weights` is None, return a diagonal matrix of ones.
Parameters
----------
`X` : np.ndarray matrix of int (same number of rows as `Y`)
`Y` : np.ndarray matrix of int (same number of rows as `X`)
`weights` : np.ndarray of int of the shape (number of rows in `X` and `Y`, 1), optional
"""

if weights is not None:
if len(weights) != X.shape[0] and len(weights) != Y.shape[0]:
raise ValueError("weights must be the same length as the number of rows in X and Y.")
Expand Down

0 comments on commit f2d0578

Please sign in to comment.