Skip to content

Commit

Permalink
test easy-import, more tests, fixes
Browse files Browse the repository at this point in the history
Fixes:
- streamline model training for extra-input
- fine-tune metric calculation in train_model
- robustify JTK functions to deal with unexpected or imperfect data
- support more input flexibility in plot_embeddings
- fix random effects in get_meta_analysis
  • Loading branch information
Bribak committed Nov 19, 2024
1 parent 94646ad commit d5f5d4e
Show file tree
Hide file tree
Showing 8 changed files with 1,006 additions and 119 deletions.
3 changes: 2 additions & 1 deletion glycowork/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__version__ = "1.4.0"
from .motif.draw import GlycoDraw
#from .glycowork import *

__all__ = ['ml', 'motif', 'glycan_data', 'network']
__all__ = ['ml', 'motif', 'glycan_data', 'network', 'GlycoDraw']
2 changes: 1 addition & 1 deletion glycowork/glycan_data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from os import path
from itertools import chain
from importlib import resources
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

with resources.files("glycowork.glycan_data").joinpath("glycan_motifs.csv").open(encoding = 'utf-8-sig') as f:
motif_list = pd.read_csv(f)
Expand Down
213 changes: 140 additions & 73 deletions glycowork/glycan_data/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,13 @@ def expansion_sum(*args: Union[int, float] # numbers to sum
def hlm(z: Union[np.ndarray, List[float]] # array of values
) -> float: # median estimate
"Hodges-Lehmann estimator of the median"
z = np.array(z).flatten()
z = np.array(z, dtype = float).flatten()
if len(z) == 0 or np.all(np.isnan(z)):
return 0.0
z = z[~np.isnan(z)]
zz = np.add.outer(z, z)
zz = zz[np.tril_indices(len(z))]
return np.median(zz) / 2
return float(np.median(zz) * 0.5)


def update_cf_for_m_n(m: int, # parameter m
Expand Down Expand Up @@ -262,26 +265,41 @@ def jtkdist(timepoints: Union[int, np.ndarray], # number/array of timepoints wit
"Precalculates all possible JT test statistic permutation probabilities for reference using the Harding algorithm"
timepoints = timepoints if isinstance(timepoints, int) else timepoints.sum()
tim = np.full(timepoints, reps) if reps != timepoints else reps # Support for unbalanced replication (unequal replicates in all groups)
maxnlp = gammaln(np.sum(tim)) - np.sum(np.log(np.arange(1, np.max(tim)+1)))
if np.max(tim) > 0:
range_array = np.arange(1, np.max(tim)+1)
if len(range_array) > 0:
log_sum = np.sum(np.log(range_array))
maxnlp = gammaln(np.sum(tim)) - log_sum
else:
maxnlp = 0
else:
maxnlp = 0
limit = math.log(float('inf'))
normal = normal or (maxnlp > limit - 1) # Switch to normal approximation if maxnlp is too large
nn = sum(tim) # Number of data values (Independent of period and lag)
M = (nn ** 2 - np.sum(np.square(tim))) / 2 # Max possible jtk statistic
M = (nn ** 2 - np.sum(np.square(tim))) * 0.5 if nn > 0 else 0 # Max possible jtk statistic
param_dic.update({"GRP_SIZE": tim, "NUM_GRPS": len(tim), "NUM_VALS": nn,
"MAX": M, "DIMS": [int(nn * (nn - 1) / 2), 1]})
"MAX": M, "DIMS": [int(nn * (nn - 1) * 0.5), 1 if nn > 1 else [0, 0]]})
if normal:
param_dic["VAR"] = (nn ** 2 * (2 * nn + 3) - np.sum(np.fromiter((np.square(tim) * (2 * tim + 3)), dtype = float))) / 72 # Variance of JTK
param_dic["SDV"] = math.sqrt(param_dic["VAR"]) # Standard deviation of JTK
param_dic["EXV"] = M / 2 # Expected value of JTK
if nn > 0:
squared_terms = np.square(tim) * (2 * tim + 3)
var = (nn ** 2 * (2 * nn + 3) - np.sum(squared_terms)) / 72
else:
var = 0
param_dic["VAR"] = var # Variance of JTK
param_dic["SDV"] = np.sqrt(max(var, 0.0)) # Standard deviation of JTK
param_dic["EXV"] = M * 0.5 # Expected value of JTK
param_dic["EXACT"] = False
MM = int(M // 2) # Mode of this possible alternative to JTK distribution
cf = [1] * (MM + 1) # Initial lower half cumulative frequency (cf) distribution
size = sorted(tim) # Sizes of each group of known replicate values, in ascending order for fastest calculation
k = len(tim) # Number of groups of replicates
N = [size[k-1]]
if k > 2:
if k > 1:
N = [size[k-1]]
for i in range(k - 1, 1, -1): # Count permutations using the Harding algorithm
N.insert(0, (size[i] + N[0]))
else:
N = [size[0]] if k == 1 else []
for m, n in zip(size[:-1], N):
update_cf_for_m_n(m, n, MM, cf)
cf = np.array(cf)
Expand All @@ -293,7 +311,7 @@ def jtkdist(timepoints: Union[int, np.ndarray], # number/array of timepoints wit
jtkcf = np.concatenate((cf, cf[MM - 1] + cf[MM] - cf[:MM-1][::-1], [cf[MM - 1] + cf[MM]]))[::-1]
ajtkcf = list((jtkcf[i - 1] + jtkcf[i]) / 2 for i in range(1, len(jtkcf))) # interpolated cumulative frequency values for all half-intgeger jtk
cf = [ajtkcf[(j - 1) // 2] if j % 2 == 0 else jtkcf[j // 2] for j in [i for i in range(1, 2 * int(M) + 2)]]
param_dic["CP"] = [c / jtkcf[0] for c in cf] # all upper-tail p-values
param_dic["CP"] = [c / jtkcf[0] if jtkcf[0] != 0 else 1 for c in cf] # all upper-tail p-values
return param_dic


Expand All @@ -304,57 +322,71 @@ def jtkinit(periods: List[int], # possible periods of rhythmicity in biological
) -> Dict: # updated param_dic with waveform parameters
"Defines the parameters of the simulated sine waves for reference later"
param_dic["INTERVAL"] = interval
if len(periods) > 1:
param_dic["PERIODS"] = list(periods)
else:
param_dic["PERIODS"] = list(periods)
param_dic["PERIODS"] = list(periods)
param_dic["PERFACTOR"] = np.concatenate([np.repeat(i, ti) for i, ti in enumerate(periods, start = 1)])
tim = np.array(param_dic["GRP_SIZE"])
timepoints = int(param_dic["NUM_GRPS"])
timerange = np.arange(timepoints) # Zero-based time indices
param_dic["SIGNCOS"] = np.zeros((periods[0], ((math.floor(timepoints / (periods[0]))*int(periods[0]))* replicates)), dtype = int)
max_period = max(periods)
signcos_length = ((math.floor(timepoints / max_period) * max_period) * replicates)
param_dic["SIGNCOS"] = np.zeros((max_period, signcos_length), dtype = int)
param_dic["CGOOSV"] = []
for i, period in enumerate(periods):
time2angle = np.array([(2*round(math.pi, 4))/period]) # convert time to angle using an ~pi value
theta = timerange*time2angle # zero-based angular values across time indices
time2angle = np.array([(2*round(math.pi, 4)) / period]) # convert time to angle using an ~pi value
theta = timerange * time2angle # zero-based angular values across time indices
cos_v = np.cos(theta) # unique cosine values at each time point
cos_r = np.repeat(rankdata(cos_v), np.max(tim)) # replicated ranks of unique cosine values
cgoos = np.sign(np.subtract.outer(cos_r, cos_r)).astype(int)
lower_tri = []
for col in range(len(cgoos)):
for row in range(col + 1, len(cgoos)):
lower_tri.append(cgoos[row, col])
cgoos = np.array(lower_tri)
cgoosv = np.array(cgoos).reshape(param_dic["DIMS"])
param_dic["CGOOSV"] = []
param_dic["CGOOSV"].append(np.zeros((cgoos.shape[0], period)))
param_dic["CGOOSV"][i][:, 0] = cgoosv[:, 0]
if len(cos_v) > 0:
ranked = rankdata(cos_v)
cos_r = np.repeat(ranked, np.max(tim)) if np.max(tim) > 0 else ranked # replicated ranks of unique cosine values
else:
cos_r = np.array([])
if len(cos_r) > 0:
cgoos = np.sign(np.subtract.outer(cos_r, cos_r)).astype(int)
lower_tri = []
for col in range(len(cgoos)):
for row in range(col + 1, len(cgoos)):
lower_tri.append(cgoos[row, col])
cgoos = np.array(lower_tri)
if len(cgoos) > 0:
cgoosv = np.array(cgoos).reshape(param_dic["DIMS"])
period_array = np.zeros((cgoos.shape[0], period))
period_array[:, 0] = cgoosv[:, 0]
param_dic["CGOOSV"].append(period_array)
else:
param_dic["CGOOSV"].append(np.zeros((1, period)))
else:
param_dic["CGOOSV"].append(np.zeros((1, period)))
cycles = math.floor(timepoints / period)
jrange = np.arange(cycles * period)
cos_s = np.sign(cos_v)[jrange]
cos_s = np.repeat(cos_s, (tim[jrange]))
if replicates == 1:
param_dic["SIGNCOS"][:, i] = cos_s
else:
param_dic["SIGNCOS"][i] = cos_s
for j in range(1, period): # One-based half-integer lag index j
delta_theta = j * time2angle / 2 # Angles of half-integer lags
cos_v = np.cos(theta + delta_theta) # Cycle left
cos_r = np.concatenate([np.repeat(val, num) for val, num in zip(rankdata(cos_v), tim)]) # Phase-shifted replicated ranks
cgoos = np.sign(np.subtract.outer(cos_r, cos_r)).T
mask = np.triu(np.ones(cgoos.shape), k = 1).astype(bool)
mask[np.diag_indices(mask.shape[0])] = False
cgoos = cgoos[mask]
cgoosv = cgoos.reshape(param_dic["DIMS"])
matrix_i = param_dic["CGOOSV"][i]
matrix_i[:, j] = cgoosv.flatten()
param_dic["CGOOSV[i]"] = matrix_i
cos_v = cos_v.flatten()
if len(cos_v) > 0:
cos_s = np.sign(cos_v)[jrange]
cos_s = np.repeat(cos_s, (tim[jrange]))
if len(tim[jrange]) > 0:
cos_s = np.repeat(cos_s, (tim[jrange]))
if replicates == 1:
param_dic["SIGNCOS"][:, j] = cos_s
param_dic["SIGNCOS"][:len(cos_s), i] = cos_s
else:
param_dic["SIGNCOS"][j] = cos_s
param_dic["SIGNCOS"][i, :len(cos_s)] = cos_s
for j in range(1, period): # One-based half-integer lag index j
delta_theta = j * time2angle / 2 # Angles of half-integer lags
cos_v = np.cos(theta + delta_theta) # Cycle left
if len(cos_v) > 0:
cos_r = np.concatenate([np.repeat(val, num) for val, num in zip(rankdata(cos_v), tim)]) # Phase-shifted replicated ranks
if len(cos_r) > 0:
cgoos = np.sign(np.subtract.outer(cos_r, cos_r)).T
mask = np.triu(np.ones(cgoos.shape), k = 1).astype(bool)
mask[np.diag_indices(mask.shape[0])] = False
cgoos = cgoos[mask]
if len(cgoos) > 0:
cgoosv = cgoos.reshape(param_dic["DIMS"])
param_dic["CGOOSV"][i][:, j] = cgoosv.flatten()
cos_v = cos_v.flatten()
cos_s = np.sign(cos_v)[jrange]
if len(tim[jrange]) > 0:
cos_s = np.repeat(cos_s, (tim[jrange]))
if replicates == 1:
param_dic["SIGNCOS"][:len(cos_s), j] = cos_s
else:
param_dic["SIGNCOS"][j, :len(cos_s)] = cos_s
return param_dic


Expand All @@ -365,23 +397,44 @@ def jtkstat(z: pd.DataFrame, # expression data for a molecule ordered in groups
param_dic["CJTK"] = []
M = param_dic["MAX"]
z = np.array(z).flatten()
foosv = np.sign(np.subtract.outer(z, z)).T # Due to differences in the triangle indexing of R / Python we need to transpose and select upper triangle rather than the lower triangle
mask = np.triu(np.ones(foosv.shape), k = 1).astype(bool) # Additionally, we need to remove the middle diagonal from the tri index
mask[np.diag_indices(mask.shape[0])] = False
foosv = foosv[mask].reshape(param_dic["DIMS"])
valid_mask = ~np.isnan(z)
z_valid = z[valid_mask]
if len(z_valid) > 1:
foosv = np.sign(np.subtract.outer(z_valid, z_valid)).T # Due to differences in the triangle indexing of R / Python we need to transpose and select upper triangle rather than the lower triangle
mask = np.triu(np.ones(foosv.shape), k = 1).astype(bool) # Additionally, we need to remove the middle diagonal from the tri index
mask[np.diag_indices(mask.shape[0])] = False
foosv = foosv[mask]
expected_dims = param_dic["DIMS"][0] * param_dic["DIMS"][1]
if len(foosv) == expected_dims:
foosv = foosv.reshape(param_dic["DIMS"])
else:
temp_foosv = np.zeros(expected_dims)
temp_foosv[:len(foosv)] = foosv[:expected_dims]
foosv = temp_foosv.reshape(param_dic["DIMS"])
else:
foosv = np.zeros(param_dic["DIMS"])
for i in range(param_dic["PERIODS"][0]):
cgoosv = param_dic["CGOOSV"][0][i]
S = np.nansum(np.diag(foosv * cgoosv))
jtk = (abs(S) + M) / 2 # Two-tailed JTK statistic for this lag and distribution
if S == 0:
param_dic["CJTK"].append([1, 0, 0])
elif param_dic.get("EXACT", False):
jtki = 1 + 2 * int(jtk) # index into the exact upper-tail distribution
p = 2 * param_dic["CP"][jtki-1]
param_dic["CJTK"].append([p, S, S / M])
if i < len(param_dic["CGOOSV"][0]):
cgoosv = param_dic["CGOOSV"][0][i]
if foosv.shape == cgoosv.shape:
S = np.nansum(np.diag(foosv * cgoosv))
jtk = (abs(S) + M) / 2 # Two-tailed JTK statistic for this lag and distribution
else:
S = 0
jtk = M / 2 if M != 0 else 0
if S == 0:
param_dic["CJTK"].append([1, 0, 0])
elif param_dic.get("EXACT", False):
jtki = min(1 + 2 * int(jtk), len(param_dic["CP"])) # index into the exact upper-tail distribution
p = 2 * param_dic["CP"][jtki-1] if jtki > 0 else 1
tau = S / M if M != 0 else 0
param_dic["CJTK"].append([p, S, tau])
else:
tau = S / M if M != 0 else 0
p = 2 * norm.cdf(-(jtk - 0.5), -param_dic["EXV"], param_dic["SDV"])
param_dic["CJTK"].append([p, S, tau]) # include tau = s/M for this lag and distribution
else:
p = 2 * norm.cdf(-(jtk - 0.5), -param_dic["EXV"], param_dic["SDV"])
param_dic["CJTK"].append([p, S, S / M]) # include tau = s/M for this lag and distribution
param_dic["CJTK"].append([1, 0, 0])
return param_dic


Expand All @@ -402,23 +455,37 @@ def groupings(padj, param_dic):
return dict(d)
dpadj = groupings(padj, param_dic)
padj = np.array(pd.DataFrame(dpadj.values()).T)
minpadj = [padj[i].min() for i in range(0, np.shape(padj)[1])] # Minimum adjusted p-values for each period
minpadj = [] # Minimum adjusted p-values for each period
for i in range(np.shape(padj)[1]):
col_values = padj[:, i][~np.isnan(padj[:, i])] # Remove NaN values
minpadj.append(np.min(col_values) if len(col_values) > 0 else np.nan)
if len(param_dic["PERIODS"]) > 1:
pers_index = np.where(JTK_ADJP == minpadj)[0] # indices of all optimal periods
min_p_idx = np.argmin(minpadj) # index of optimal period
pers = param_dic["PERIODS"][min_p_idx]
else:
pers = param_dic["PERIODS"][0]
valid_padj = ~np.isnan(padj)
if np.any(valid_padj):
lagis = np.where(np.abs(padj - JTK_ADJP) < 1e-10)[0] # list of optimal lag indices for each optimal period
if len(lagis) == 0:
closest_idx = np.nanargmin(np.abs(padj - JTK_ADJP))
lagis = np.array([closest_idx])
else:
pers_index = 0
pers = param_dic["PERIODS"][int(pers_index)] # all optimal periods
lagis = np.where(padj == JTK_ADJP)[0] # list of optimal lag indice for each optimal period
lagis = np.array([0])
best_results = {'bestper': 0, 'bestlag': 0, 'besttau': 0, 'maxamp': 0, 'maxamp_ci': 2, 'maxamp_pval': 0}
sc = np.transpose(param_dic["SIGNCOS"])
w = (z[:len(sc)] - hlm(z[:len(sc)])) * math.sqrt(2)
hlm_z = hlm(z[:len(sc)])
if np.isnan(hlm_z).all():
hlm_z = np.zeros_like(z[:len(sc)]) # Fallback if all values are NaN
w = (z[:len(sc)] - hlm_z) * math.sqrt(2)
for i in range(abs(pers)):
for lagi in lagis:
S = param_dic["CJTK"][lagi][1]
s = np.sign(S) if S != 0 else 1
lag = (pers + (1 - s) * pers / 4 - lagi / 2) % pers
tmp = s * w * sc[:, lagi]
amp = hlm(tmp) # Allows missing values
tmp_clean = tmp[np.isfinite(tmp)]
amp = hlm(tmp_clean) if len(tmp_clean) > 0 else 0 # Allows missing values
if ampci:
jtkwt = pd.DataFrame(wilcoxon(tmp[np.isfinite(tmp)], zero_method = 'wilcox', correction = False,
alternatives = 'two-sided', mode = 'exact'))
Expand Down
Loading

0 comments on commit d5f5d4e

Please sign in to comment.