diff --git a/src/elexsolver/TransitionSolver.py b/src/elexsolver/TransitionSolver.py index 945da4ea..dd3facab 100644 --- a/src/elexsolver/TransitionSolver.py +++ b/src/elexsolver/TransitionSolver.py @@ -45,10 +45,14 @@ def _check_percentages(self, A: np.ndarray): def _check_and_rescale(self, A: np.ndarray): """ - Rescale columns (units) so that they sum to 1 (100%). + After ensuring that A is (things x units), make sure we have enough units. + If that's the case, rescale columns (units) so that they sum to 1 (100%). """ + if A.shape[1] <= A.shape[0] or (A.shape[1] // 2) <= A.shape[0]: + raise ValueError(f"Not enough units ({A.shape[1]}) relative to the number of things ({A.shape[0]}).") + if not np.all(A.sum(axis=0) == 1): - LOG.warn("Each column (unit) needs to sum to 1. Rescaling...") + LOG.warn("Each unit needs to sum to 1. Rescaling...") if isinstance(A, np.ndarray): for j in range(0, A.shape[1]): A[:, j] /= A[:, j].sum()