From a67304f6dd0445760e3e3462883c3b7b51b35e8e Mon Sep 17 00:00:00 2001 From: Diane Napolitano Date: Mon, 16 Oct 2023 10:01:35 -0400 Subject: [PATCH] Adding a check to make sure we have enough units --- src/elexsolver/TransitionSolver.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/elexsolver/TransitionSolver.py b/src/elexsolver/TransitionSolver.py index 945da4ea..dd3facab 100644 --- a/src/elexsolver/TransitionSolver.py +++ b/src/elexsolver/TransitionSolver.py @@ -45,10 +45,14 @@ def _check_percentages(self, A: np.ndarray): def _check_and_rescale(self, A: np.ndarray): """ - Rescale columns (units) so that they sum to 1 (100%). + After ensuring that A is (things x units), make sure we have enough units. + If that's the case, rescale columns (units) so that they sum to 1 (100%). """ + if A.shape[1] <= A.shape[0] or (A.shape[1] // 2) <= A.shape[0]: + raise ValueError(f"Not enough units ({A.shape[1]}) relative to the number of things ({A.shape[0]}).") + if not np.all(A.sum(axis=0) == 1): - LOG.warn("Each column (unit) needs to sum to 1. Rescaling...") + LOG.warn("Each unit needs to sum to 1. Rescaling...") if isinstance(A, np.ndarray): for j in range(0, A.shape[1]): A[:, j] /= A[:, j].sum()