Skip to content

Commit

Permalink
Adding docstrings to EI transition solver and cleaning up a few others
Browse files Browse the repository at this point in the history
  • Loading branch information
dmnapolitano committed Jan 31, 2024
1 parent 1ac03f3 commit 5340492
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
31 changes: 27 additions & 4 deletions src/elexsolver/EITransitionSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class EITransitionSolver(TransitionSolver):
"""
A (voter) transition solver based on RxC ecological inference.
A transition solver based on RxC ecological inference.
Somewhat adapted from version 1.0.1 of
Knudson et al., (2021). PyEI: A Python package for ecological inference.
Journal of Open Source Software, 6(64), 3397, https://doi.org/10.21105/joss.03397
Expand All @@ -26,6 +26,18 @@ class EITransitionSolver(TransitionSolver):
"""

def __init__(self, sigma: int = 1, sampling_chains: int = 2, random_seed: int | None = None, n_samples: int = 300):
"""
Parameters
----------
`sigma` : int, default 1
Standard deviation of the half-normal distribution that provides alphas to the Dirichlet distribution.
`sampling_chains` : int, default 2
The number of sampling chains to run in parallel, each of which will draw `n_samples`.
`random_seed` : int, optional
For seeding the NUTS sampler.
`n_samples` : int, default 300
The number of samples to draw. Before sampling, the NUTS sampler will be tuned using `n_samples // 2` samples.
"""
super().__init__()
self._sigma = sigma
self._chains = int(sampling_chains)
Expand All @@ -39,7 +51,6 @@ def __init__(self, sigma: int = 1, sampling_chains: int = 2, random_seed: int |

def fit_predict(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None = None):
"""
X and Y are matrixes of integers.
NOTE: weighting is not currently implemented.
"""
self._check_data_type(X)
Expand Down Expand Up @@ -117,6 +128,18 @@ def _get_transitions(self, A: np.ndarray):
return np.array(transitions).T

def get_credible_interval(self, ci: float, transitions: bool = False):
"""
Parameters
----------
`ci` : float
Size of the credible interval [0, 100). If <= 1, will be multiplied by 100.
`transitions` : bool, default False
If True, the returned matrices will represent transitions, not percentages.
Returns
-------
A tuple of two np.ndarray matrices of float. Element 0 has the lower bound and 1 has the upper bound.
"""
if ci <= 1:
ci = ci * 100
if ci < 0 or ci > 100:
Expand All @@ -129,10 +152,10 @@ def get_credible_interval(self, ci: float, transitions: bool = False):
upper: np.zeros((self._sampled.shape[1], self._sampled.shape[2])),
}

for ci in [lower, upper]:
for interval in [lower, upper]:
for i in range(0, self._sampled.shape[1]):
for j in range(0, self._sampled.shape[2]):
A_dict[ci][i][j] = np.percentile(self._sampled[:, i, j], ci)
A_dict[interval][i][j] = np.percentile(self._sampled[:, i, j], interval)

if transitions:
return (self._get_transitions(A_dict[lower]), self._get_transitions(A_dict[upper]))
Expand Down
4 changes: 2 additions & 2 deletions src/elexsolver/TransitionMatrixSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, strict: bool = True, lam: float | None = None):
Parameters
----------
`strict` : bool, default True
If `True`, solution will be constrainted so that all coefficients are >= 0,
If True, solution will be constrainted so that all coefficients are >= 0,
<= 1, and the sum of each row equals 1.
`lam` : float, optional
`lam` != 0 will enable L2 regularization (Ridge).
Expand Down Expand Up @@ -172,7 +172,7 @@ def get_confidence_interval(self, alpha: float, transitions: bool = False) -> (n
`alpha` : float
Value between [0, 1). If greater than 1, will be divided by 100.
`transitions` : bool, default False
If True, the returned matrix will represent transitions, not percentages.
If True, the returned matrices will represent transitions, not percentages.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion src/elexsolver/TransitionSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class TransitionSolver(ABC):
"""
Abstract class for (voter) transition solvers.
Abstract class for transition solvers.
"""

def __init__(self):
Expand Down

0 comments on commit 5340492

Please sign in to comment.