Skip to content

Commit

Permalink
Rescale if needed, check for things x units
Browse files Browse the repository at this point in the history
  • Loading branch information
dmnapolitano committed Oct 9, 2023
1 parent 9f49d83 commit 7f965d8
Showing 1 changed file with 25 additions and 6 deletions.
31 changes: 25 additions & 6 deletions src/elexsolver/EITransitionSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,38 @@ def mean_absolute_error(self, X, Y):
# return mae
return 0 # TODO

def _check_and_rescale(self, A):
if not np.all(A.sum(axis=0) == 1):
LOG.warn("Each column (unit) needs to sum to 1. Rescaling...")
if isinstance(A, np.ndarray):
for j in range(0, A.shape[1]):
A[:, j] /= A[:, j].sum()
else:
# pandas.DataFrame()
for col in A.columns:
A[col] /= A[col].sum()
return A

def fit_predict(self, X, Y):
self._check_any_element_nan_or_inf(X)
self._check_any_element_nan_or_inf(Y)
self._check_percentages(X)
self._check_percentages(Y)

# TODO: check if these matrices are (long x short), then transpose
# currently assuming this is the case since the other solver expects (long x short)
X = np.transpose(X)
Y = np.transpose(Y)
# matrices should be (things x units), where the number of units is > the number of things
if X.shape[0] > X.shape[1]:
X = X.T
if Y.shape[0] > Y.shape[1]:
Y = Y.T

if X.shape[1] != Y.shape[1]:
raise ValueError(f"Number of units in X ({X.shape[1]}) != number of units in Y ({Y.shape[1]}).")
if Y.shape[1] != len(self._n):
raise ValueError(f"Number of units in Y ({Y.shape[1]}) != number of units in n ({len(self._n)}).")

X = self._check_and_rescale(X)
Y = self._check_and_rescale(Y)

num_units = len(self._n) # should be the same as the number of units in Y
num_rows = X.shape[0] # number of things in X that are being transitioned "from"
num_cols = Y.shape[0] # number of things in Y that are being transitioned "to"
Expand All @@ -77,8 +93,11 @@ def fit_predict(self, X, Y):
observed=Y_obs,
shape=(num_units, num_cols),
)
# TODO: allow other samplers; this one is very good but slow
model_trace = pm.sample(chains=self._chains)
try:
# TODO: allow other samplers; this one is very good but slow
model_trace = pm.sample(chains=self._chains)
except:
print(model.debug())

b_values = np.transpose(
model_trace["posterior"]["beta"].stack(all_draws=["chain", "draw"]).values, axes=(3, 0, 1, 2))
Expand Down

0 comments on commit 7f965d8

Please sign in to comment.