diff --git a/src/elexsolver/EITransitionSolver.py b/src/elexsolver/EITransitionSolver.py index 2991c209..992aa371 100644 --- a/src/elexsolver/EITransitionSolver.py +++ b/src/elexsolver/EITransitionSolver.py @@ -49,6 +49,11 @@ def fit_predict(self, X, Y): X = np.transpose(X) Y = np.transpose(Y) + if X.shape[1] != Y.shape[1]: + raise ValueError(f"Number of units in X ({X.shape[1]}) != number of units in Y ({Y.shape[1]}).") + if Y.shape[1] != len(self._n): + raise ValueError(f"Number of units in Y ({Y.shape[1]}) != number of units in n ({len(self._n)}).") + num_units = len(self._n) # should be the same as the number of units in Y num_rows = X.shape[0] # number of things in X that are being transitioned "from" num_cols = Y.shape[0] # number of things in Y that are being transitioned "to"