diff --git a/src/elexsolver/EITransitionSolver.py b/src/elexsolver/EITransitionSolver.py index 992aa371..7b7fa814 100644 --- a/src/elexsolver/EITransitionSolver.py +++ b/src/elexsolver/EITransitionSolver.py @@ -38,22 +38,38 @@ def mean_absolute_error(self, X, Y): # return mae return 0 # TODO + def _check_and_rescale(self, A): + if not np.all(A.sum(axis=0) == 1): + LOG.warn("Each column (unit) needs to sum to 1. Rescaling...") + if isinstance(A, np.ndarray): + for j in range(0, A.shape[1]): + A[:, j] /= A[:, j].sum() + else: + # pandas.DataFrame() + for col in A.columns: + A[col] /= A[col].sum() + return A + def fit_predict(self, X, Y): self._check_any_element_nan_or_inf(X) self._check_any_element_nan_or_inf(Y) self._check_percentages(X) self._check_percentages(Y) - # TODO: check if these matrices are (long x short), then transpose - # currently assuming this is the case since the other solver expects (long x short) - X = np.transpose(X) - Y = np.transpose(Y) + # matrices should be (things x units), where the number of units is > the number of things + if X.shape[0] > X.shape[1]: + X = X.T + if Y.shape[0] > Y.shape[1]: + Y = Y.T if X.shape[1] != Y.shape[1]: raise ValueError(f"Number of units in X ({X.shape[1]}) != number of units in Y ({Y.shape[1]}).") if Y.shape[1] != len(self._n): raise ValueError(f"Number of units in Y ({Y.shape[1]}) != number of units in n ({len(self._n)}).") + X = self._check_and_rescale(X) + Y = self._check_and_rescale(Y) + num_units = len(self._n) # should be the same as the number of units in Y num_rows = X.shape[0] # number of things in X that are being transitioned "from" num_cols = Y.shape[0] # number of things in Y that are being transitioned "to" @@ -77,8 +93,11 @@ def fit_predict(self, X, Y): observed=Y_obs, shape=(num_units, num_cols), ) - # TODO: allow other samplers; this one is very good but slow - model_trace = pm.sample(chains=self._chains) + try: + # TODO: allow other samplers; this one is very good but slow + model_trace = pm.sample(chains=self._chains) + except: + print(model.debug()) b_values = np.transpose( model_trace["posterior"]["beta"].stack(all_draws=["chain", "draw"]).values, axes=(3, 0, 1, 2))