diff --git a/src/elexsolver/TransitionSolver.py b/src/elexsolver/TransitionSolver.py index 6346ba56..88112fd9 100644 --- a/src/elexsolver/TransitionSolver.py +++ b/src/elexsolver/TransitionSolver.py @@ -26,7 +26,7 @@ def get_prediction_interval(self, pi: float): raise NotImplementedError @property - def transitions(self): + def transitions(self) -> np.ndarray: return self._transitions def _check_any_element_nan_or_inf(self, A: np.ndarray): @@ -55,7 +55,7 @@ def _check_for_zero_units(self, A: np.ndarray): if np.any(np.sum(A, axis=1) == 0): raise ValueError("Matrix cannot contain any rows (units) where all columns (things) are zero.") - def _rescale(self, A: np.ndarray): + def _rescale(self, A: np.ndarray) -> np.ndarray: """ Rescale rows (units) to ensure they sum to 1 (100%). """ @@ -67,7 +67,20 @@ def _rescale(self, A: np.ndarray): return np.nan_to_num(A, nan=0, posinf=0, neginf=0) - def _check_and_prepare_weights(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None): + def _check_and_prepare_weights(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None) -> np.ndarray: + """ + If `weights` is not None, and `weights` has the same number of rows in both matrices `X` and `Y`, + we'll rescale the weights by taking the square root after dividing them by their sum, + then return a diagonal matrix containing these now-normalized weights. + If `weights` is None, return a diagonal matrix of ones. + + Parameters + ---------- + `X` : np.ndarray matrix of int (same number of rows as `Y`) + `Y` : np.ndarray matrix of int (same number of rows as `X`) + `weights` : np.ndarray of int of the shape (number of rows in `X` and `Y`, 1), optional + """ + if weights is not None: if len(weights) != X.shape[0] and len(weights) != Y.shape[0]: raise ValueError("weights must be the same length as the number of rows in X and Y.")