From d5f5d4e541f6b20f9d6f33c397bfbd0c72c65acc Mon Sep 17 00:00:00 2001 From: Daniel Bojar Date: Tue, 19 Nov 2024 10:02:19 +0100 Subject: [PATCH] test easy-import, more tests, fixes Fixes: - streamline model training for extra-input - fine-tune metric calculation in train_model - robustify JTK functions to deal with unexpected or imperfect data - support more input flexibility in plot_embeddings - fix random effects in get_meta_analysis --- glycowork/__init__.py | 3 +- ...cs_human_keratinocytes_N_PMID37956981.csv} | 0 glycowork/glycan_data/loader.py | 2 +- glycowork/glycan_data/stats.py | 213 +++-- glycowork/ml/model_training.py | 54 +- glycowork/motif/analysis.py | 25 +- glycowork/motif/regex.py | 2 +- tests/test_core_functions.py | 826 +++++++++++++++++- 8 files changed, 1006 insertions(+), 119 deletions(-) rename glycowork/glycan_data/{glycoproteomics_human_keratinocytes_PMID37956981.csv => glycoproteomics_human_keratinocytes_N_PMID37956981.csv} (100%) diff --git a/glycowork/__init__.py b/glycowork/__init__.py index 9d2ef1f..12e4701 100644 --- a/glycowork/__init__.py +++ b/glycowork/__init__.py @@ -1,4 +1,5 @@ __version__ = "1.4.0" +from .motif.draw import GlycoDraw #from .glycowork import * -__all__ = ['ml', 'motif', 'glycan_data', 'network'] +__all__ = ['ml', 'motif', 'glycan_data', 'network', 'GlycoDraw'] diff --git a/glycowork/glycan_data/glycoproteomics_human_keratinocytes_PMID37956981.csv b/glycowork/glycan_data/glycoproteomics_human_keratinocytes_N_PMID37956981.csv similarity index 100% rename from glycowork/glycan_data/glycoproteomics_human_keratinocytes_PMID37956981.csv rename to glycowork/glycan_data/glycoproteomics_human_keratinocytes_N_PMID37956981.csv diff --git a/glycowork/glycan_data/loader.py b/glycowork/glycan_data/loader.py index dd09134..d2e2695 100644 --- a/glycowork/glycan_data/loader.py +++ b/glycowork/glycan_data/loader.py @@ -6,7 +6,7 @@ from os import path from itertools import chain from importlib import resources -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List with resources.files("glycowork.glycan_data").joinpath("glycan_motifs.csv").open(encoding = 'utf-8-sig') as f: motif_list = pd.read_csv(f) diff --git a/glycowork/glycan_data/stats.py b/glycowork/glycan_data/stats.py index 6e59909..23e8cb2 100644 --- a/glycowork/glycan_data/stats.py +++ b/glycowork/glycan_data/stats.py @@ -58,10 +58,13 @@ def expansion_sum(*args: Union[int, float] # numbers to sum def hlm(z: Union[np.ndarray, List[float]] # array of values ) -> float: # median estimate "Hodges-Lehmann estimator of the median" - z = np.array(z).flatten() + z = np.array(z, dtype = float).flatten() + if len(z) == 0 or np.all(np.isnan(z)): + return 0.0 + z = z[~np.isnan(z)] zz = np.add.outer(z, z) zz = zz[np.tril_indices(len(z))] - return np.median(zz) / 2 + return float(np.median(zz) * 0.5) def update_cf_for_m_n(m: int, # parameter m @@ -262,26 +265,41 @@ def jtkdist(timepoints: Union[int, np.ndarray], # number/array of timepoints wit "Precalculates all possible JT test statistic permutation probabilities for reference using the Harding algorithm" timepoints = timepoints if isinstance(timepoints, int) else timepoints.sum() tim = np.full(timepoints, reps) if reps != timepoints else reps # Support for unbalanced replication (unequal replicates in all groups) - maxnlp = gammaln(np.sum(tim)) - np.sum(np.log(np.arange(1, np.max(tim)+1))) + if np.max(tim) > 0: + range_array = np.arange(1, np.max(tim)+1) + if len(range_array) > 0: + log_sum = np.sum(np.log(range_array)) + maxnlp = gammaln(np.sum(tim)) - log_sum + else: + maxnlp = 0 + else: + maxnlp = 0 limit = math.log(float('inf')) normal = normal or (maxnlp > limit - 1) # Switch to normal approximation if maxnlp is too large nn = sum(tim) # Number of data values (Independent of period and lag) - M = (nn ** 2 - np.sum(np.square(tim))) / 2 # Max possible jtk statistic + M = (nn ** 2 - np.sum(np.square(tim))) * 0.5 if nn > 0 else 0 # Max possible jtk statistic param_dic.update({"GRP_SIZE": tim, "NUM_GRPS": len(tim), "NUM_VALS": nn, - "MAX": M, "DIMS": [int(nn * (nn - 1) / 2), 1]}) + "MAX": M, "DIMS": [int(nn * (nn - 1) * 0.5), 1 if nn > 1 else [0, 0]]}) if normal: - param_dic["VAR"] = (nn ** 2 * (2 * nn + 3) - np.sum(np.fromiter((np.square(tim) * (2 * tim + 3)), dtype = float))) / 72 # Variance of JTK - param_dic["SDV"] = math.sqrt(param_dic["VAR"]) # Standard deviation of JTK - param_dic["EXV"] = M / 2 # Expected value of JTK + if nn > 0: + squared_terms = np.square(tim) * (2 * tim + 3) + var = (nn ** 2 * (2 * nn + 3) - np.sum(squared_terms)) / 72 + else: + var = 0 + param_dic["VAR"] = var # Variance of JTK + param_dic["SDV"] = np.sqrt(max(var, 0.0)) # Standard deviation of JTK + param_dic["EXV"] = M * 0.5 # Expected value of JTK param_dic["EXACT"] = False MM = int(M // 2) # Mode of this possible alternative to JTK distribution cf = [1] * (MM + 1) # Initial lower half cumulative frequency (cf) distribution size = sorted(tim) # Sizes of each group of known replicate values, in ascending order for fastest calculation k = len(tim) # Number of groups of replicates - N = [size[k-1]] - if k > 2: + if k > 1: + N = [size[k-1]] for i in range(k - 1, 1, -1): # Count permutations using the Harding algorithm N.insert(0, (size[i] + N[0])) + else: + N = [size[0]] if k == 1 else [] for m, n in zip(size[:-1], N): update_cf_for_m_n(m, n, MM, cf) cf = np.array(cf) @@ -293,7 +311,7 @@ def jtkdist(timepoints: Union[int, np.ndarray], # number/array of timepoints wit jtkcf = np.concatenate((cf, cf[MM - 1] + cf[MM] - cf[:MM-1][::-1], [cf[MM - 1] + cf[MM]]))[::-1] ajtkcf = list((jtkcf[i - 1] + jtkcf[i]) / 2 for i in range(1, len(jtkcf))) # interpolated cumulative frequency values for all half-intgeger jtk cf = [ajtkcf[(j - 1) // 2] if j % 2 == 0 else jtkcf[j // 2] for j in [i for i in range(1, 2 * int(M) + 2)]] - param_dic["CP"] = [c / jtkcf[0] for c in cf] # all upper-tail p-values + param_dic["CP"] = [c / jtkcf[0] if jtkcf[0] != 0 else 1 for c in cf] # all upper-tail p-values return param_dic @@ -304,57 +322,71 @@ def jtkinit(periods: List[int], # possible periods of rhythmicity in biological ) -> Dict: # updated param_dic with waveform parameters "Defines the parameters of the simulated sine waves for reference later" param_dic["INTERVAL"] = interval - if len(periods) > 1: - param_dic["PERIODS"] = list(periods) - else: - param_dic["PERIODS"] = list(periods) + param_dic["PERIODS"] = list(periods) param_dic["PERFACTOR"] = np.concatenate([np.repeat(i, ti) for i, ti in enumerate(periods, start = 1)]) tim = np.array(param_dic["GRP_SIZE"]) timepoints = int(param_dic["NUM_GRPS"]) timerange = np.arange(timepoints) # Zero-based time indices - param_dic["SIGNCOS"] = np.zeros((periods[0], ((math.floor(timepoints / (periods[0]))*int(periods[0]))* replicates)), dtype = int) + max_period = max(periods) + signcos_length = ((math.floor(timepoints / max_period) * max_period) * replicates) + param_dic["SIGNCOS"] = np.zeros((max_period, signcos_length), dtype = int) + param_dic["CGOOSV"] = [] for i, period in enumerate(periods): - time2angle = np.array([(2*round(math.pi, 4))/period]) # convert time to angle using an ~pi value - theta = timerange*time2angle # zero-based angular values across time indices + time2angle = np.array([(2*round(math.pi, 4)) / period]) # convert time to angle using an ~pi value + theta = timerange * time2angle # zero-based angular values across time indices cos_v = np.cos(theta) # unique cosine values at each time point - cos_r = np.repeat(rankdata(cos_v), np.max(tim)) # replicated ranks of unique cosine values - cgoos = np.sign(np.subtract.outer(cos_r, cos_r)).astype(int) - lower_tri = [] - for col in range(len(cgoos)): - for row in range(col + 1, len(cgoos)): - lower_tri.append(cgoos[row, col]) - cgoos = np.array(lower_tri) - cgoosv = np.array(cgoos).reshape(param_dic["DIMS"]) - param_dic["CGOOSV"] = [] - param_dic["CGOOSV"].append(np.zeros((cgoos.shape[0], period))) - param_dic["CGOOSV"][i][:, 0] = cgoosv[:, 0] + if len(cos_v) > 0: + ranked = rankdata(cos_v) + cos_r = np.repeat(ranked, np.max(tim)) if np.max(tim) > 0 else ranked # replicated ranks of unique cosine values + else: + cos_r = np.array([]) + if len(cos_r) > 0: + cgoos = np.sign(np.subtract.outer(cos_r, cos_r)).astype(int) + lower_tri = [] + for col in range(len(cgoos)): + for row in range(col + 1, len(cgoos)): + lower_tri.append(cgoos[row, col]) + cgoos = np.array(lower_tri) + if len(cgoos) > 0: + cgoosv = np.array(cgoos).reshape(param_dic["DIMS"]) + period_array = np.zeros((cgoos.shape[0], period)) + period_array[:, 0] = cgoosv[:, 0] + param_dic["CGOOSV"].append(period_array) + else: + param_dic["CGOOSV"].append(np.zeros((1, period))) + else: + param_dic["CGOOSV"].append(np.zeros((1, period))) cycles = math.floor(timepoints / period) jrange = np.arange(cycles * period) - cos_s = np.sign(cos_v)[jrange] - cos_s = np.repeat(cos_s, (tim[jrange])) - if replicates == 1: - param_dic["SIGNCOS"][:, i] = cos_s - else: - param_dic["SIGNCOS"][i] = cos_s - for j in range(1, period): # One-based half-integer lag index j - delta_theta = j * time2angle / 2 # Angles of half-integer lags - cos_v = np.cos(theta + delta_theta) # Cycle left - cos_r = np.concatenate([np.repeat(val, num) for val, num in zip(rankdata(cos_v), tim)]) # Phase-shifted replicated ranks - cgoos = np.sign(np.subtract.outer(cos_r, cos_r)).T - mask = np.triu(np.ones(cgoos.shape), k = 1).astype(bool) - mask[np.diag_indices(mask.shape[0])] = False - cgoos = cgoos[mask] - cgoosv = cgoos.reshape(param_dic["DIMS"]) - matrix_i = param_dic["CGOOSV"][i] - matrix_i[:, j] = cgoosv.flatten() - param_dic["CGOOSV[i]"] = matrix_i - cos_v = cos_v.flatten() + if len(cos_v) > 0: cos_s = np.sign(cos_v)[jrange] - cos_s = np.repeat(cos_s, (tim[jrange])) + if len(tim[jrange]) > 0: + cos_s = np.repeat(cos_s, (tim[jrange])) if replicates == 1: - param_dic["SIGNCOS"][:, j] = cos_s + param_dic["SIGNCOS"][:len(cos_s), i] = cos_s else: - param_dic["SIGNCOS"][j] = cos_s + param_dic["SIGNCOS"][i, :len(cos_s)] = cos_s + for j in range(1, period): # One-based half-integer lag index j + delta_theta = j * time2angle / 2 # Angles of half-integer lags + cos_v = np.cos(theta + delta_theta) # Cycle left + if len(cos_v) > 0: + cos_r = np.concatenate([np.repeat(val, num) for val, num in zip(rankdata(cos_v), tim)]) # Phase-shifted replicated ranks + if len(cos_r) > 0: + cgoos = np.sign(np.subtract.outer(cos_r, cos_r)).T + mask = np.triu(np.ones(cgoos.shape), k = 1).astype(bool) + mask[np.diag_indices(mask.shape[0])] = False + cgoos = cgoos[mask] + if len(cgoos) > 0: + cgoosv = cgoos.reshape(param_dic["DIMS"]) + param_dic["CGOOSV"][i][:, j] = cgoosv.flatten() + cos_v = cos_v.flatten() + cos_s = np.sign(cos_v)[jrange] + if len(tim[jrange]) > 0: + cos_s = np.repeat(cos_s, (tim[jrange])) + if replicates == 1: + param_dic["SIGNCOS"][:len(cos_s), j] = cos_s + else: + param_dic["SIGNCOS"][j, :len(cos_s)] = cos_s return param_dic @@ -365,23 +397,44 @@ def jtkstat(z: pd.DataFrame, # expression data for a molecule ordered in groups param_dic["CJTK"] = [] M = param_dic["MAX"] z = np.array(z).flatten() - foosv = np.sign(np.subtract.outer(z, z)).T # Due to differences in the triangle indexing of R / Python we need to transpose and select upper triangle rather than the lower triangle - mask = np.triu(np.ones(foosv.shape), k = 1).astype(bool) # Additionally, we need to remove the middle diagonal from the tri index - mask[np.diag_indices(mask.shape[0])] = False - foosv = foosv[mask].reshape(param_dic["DIMS"]) + valid_mask = ~np.isnan(z) + z_valid = z[valid_mask] + if len(z_valid) > 1: + foosv = np.sign(np.subtract.outer(z_valid, z_valid)).T # Due to differences in the triangle indexing of R / Python we need to transpose and select upper triangle rather than the lower triangle + mask = np.triu(np.ones(foosv.shape), k = 1).astype(bool) # Additionally, we need to remove the middle diagonal from the tri index + mask[np.diag_indices(mask.shape[0])] = False + foosv = foosv[mask] + expected_dims = param_dic["DIMS"][0] * param_dic["DIMS"][1] + if len(foosv) == expected_dims: + foosv = foosv.reshape(param_dic["DIMS"]) + else: + temp_foosv = np.zeros(expected_dims) + temp_foosv[:len(foosv)] = foosv[:expected_dims] + foosv = temp_foosv.reshape(param_dic["DIMS"]) + else: + foosv = np.zeros(param_dic["DIMS"]) for i in range(param_dic["PERIODS"][0]): - cgoosv = param_dic["CGOOSV"][0][i] - S = np.nansum(np.diag(foosv * cgoosv)) - jtk = (abs(S) + M) / 2 # Two-tailed JTK statistic for this lag and distribution - if S == 0: - param_dic["CJTK"].append([1, 0, 0]) - elif param_dic.get("EXACT", False): - jtki = 1 + 2 * int(jtk) # index into the exact upper-tail distribution - p = 2 * param_dic["CP"][jtki-1] - param_dic["CJTK"].append([p, S, S / M]) + if i < len(param_dic["CGOOSV"][0]): + cgoosv = param_dic["CGOOSV"][0][i] + if foosv.shape == cgoosv.shape: + S = np.nansum(np.diag(foosv * cgoosv)) + jtk = (abs(S) + M) / 2 # Two-tailed JTK statistic for this lag and distribution + else: + S = 0 + jtk = M / 2 if M != 0 else 0 + if S == 0: + param_dic["CJTK"].append([1, 0, 0]) + elif param_dic.get("EXACT", False): + jtki = min(1 + 2 * int(jtk), len(param_dic["CP"])) # index into the exact upper-tail distribution + p = 2 * param_dic["CP"][jtki-1] if jtki > 0 else 1 + tau = S / M if M != 0 else 0 + param_dic["CJTK"].append([p, S, tau]) + else: + tau = S / M if M != 0 else 0 + p = 2 * norm.cdf(-(jtk - 0.5), -param_dic["EXV"], param_dic["SDV"]) + param_dic["CJTK"].append([p, S, tau]) # include tau = s/M for this lag and distribution else: - p = 2 * norm.cdf(-(jtk - 0.5), -param_dic["EXV"], param_dic["SDV"]) - param_dic["CJTK"].append([p, S, S / M]) # include tau = s/M for this lag and distribution + param_dic["CJTK"].append([1, 0, 0]) return param_dic @@ -402,23 +455,37 @@ def groupings(padj, param_dic): return dict(d) dpadj = groupings(padj, param_dic) padj = np.array(pd.DataFrame(dpadj.values()).T) - minpadj = [padj[i].min() for i in range(0, np.shape(padj)[1])] # Minimum adjusted p-values for each period + minpadj = [] # Minimum adjusted p-values for each period + for i in range(np.shape(padj)[1]): + col_values = padj[:, i][~np.isnan(padj[:, i])] # Remove NaN values + minpadj.append(np.min(col_values) if len(col_values) > 0 else np.nan) if len(param_dic["PERIODS"]) > 1: - pers_index = np.where(JTK_ADJP == minpadj)[0] # indices of all optimal periods + min_p_idx = np.argmin(minpadj) # index of optimal period + pers = param_dic["PERIODS"][min_p_idx] + else: + pers = param_dic["PERIODS"][0] + valid_padj = ~np.isnan(padj) + if np.any(valid_padj): + lagis = np.where(np.abs(padj - JTK_ADJP) < 1e-10)[0] # list of optimal lag indices for each optimal period + if len(lagis) == 0: + closest_idx = np.nanargmin(np.abs(padj - JTK_ADJP)) + lagis = np.array([closest_idx]) else: - pers_index = 0 - pers = param_dic["PERIODS"][int(pers_index)] # all optimal periods - lagis = np.where(padj == JTK_ADJP)[0] # list of optimal lag indice for each optimal period + lagis = np.array([0]) best_results = {'bestper': 0, 'bestlag': 0, 'besttau': 0, 'maxamp': 0, 'maxamp_ci': 2, 'maxamp_pval': 0} sc = np.transpose(param_dic["SIGNCOS"]) - w = (z[:len(sc)] - hlm(z[:len(sc)])) * math.sqrt(2) + hlm_z = hlm(z[:len(sc)]) + if np.isnan(hlm_z).all(): + hlm_z = np.zeros_like(z[:len(sc)]) # Fallback if all values are NaN + w = (z[:len(sc)] - hlm_z) * math.sqrt(2) for i in range(abs(pers)): for lagi in lagis: S = param_dic["CJTK"][lagi][1] s = np.sign(S) if S != 0 else 1 lag = (pers + (1 - s) * pers / 4 - lagi / 2) % pers tmp = s * w * sc[:, lagi] - amp = hlm(tmp) # Allows missing values + tmp_clean = tmp[np.isfinite(tmp)] + amp = hlm(tmp_clean) if len(tmp_clean) > 0 else 0 # Allows missing values if ampci: jtkwt = pd.DataFrame(wilcoxon(tmp[np.isfinite(tmp)], zero_method = 'wilcox', correction = False, alternatives = 'two-sided', mode = 'exact')) diff --git a/glycowork/ml/model_training.py b/glycowork/ml/model_training.py index f70c113..53e4b6e 100644 --- a/glycowork/ml/model_training.py +++ b/glycowork/ml/model_training.py @@ -59,7 +59,7 @@ def save_checkpoint(self, val_loss: float, model: torch.nn.Module) -> None: def sigmoid(x: float # input value ) -> float: # sigmoid transformed value "Apply sigmoid transformation to input" - return 1 / (1 + math.exp(-x)) + return 1 / (1 + np.exp(-x)) def disable_running_stats(model: torch.nn.Module # model to disable batch norm @@ -127,17 +127,17 @@ def train_model(model: torch.nn.Module, # graph neural network for analyzing gly running_metrics["weights"] = [] for data in dataloaders[phase]: - # Get all relevant node attributes; top LectinOracle-style models, bottom SweetNet-style models - try: - x, y, edge_index, prot, batch = data.labels, data.y, data.edge_index, data.train_idx, data.batch + # Get all relevant node attributes + x, y, edge_index, batch = data.labels, data.y, data.edge_index, data.batch + prot = getattr(data, 'train_idx', None) + if prot is not None: prot = prot.view(max(batch) + 1, -1).to(device) - except: - x, y, edge_index, batch = data.labels, data.y, data.edge_index, data.batch x = x.to(device) if mode == 'multilabel': y = y.view(max(batch) + 1, -1).to(device) else: y = y.to(device) + y = y.view(-1, 1) if mode == 'regression' else y edge_index = edge_index.to(device) batch = batch.to(device) optimizer.zero_grad() @@ -146,24 +146,18 @@ def train_model(model: torch.nn.Module, # graph neural network for analyzing gly # First forward pass if mode + mode2 == 'classificationmulti' or mode + mode2 == 'multilabelmulti': enable_running_stats(model) - try: - pred = model(prot, x, edge_index, batch) - loss = criterion(pred, y.view(-1, 1)) - except: - pred = model(x, edge_index, batch) - loss = criterion(pred, y) + pred = model(prot, x, edge_index, batch) if prot is not None else model(x, edge_index, batch) + loss = criterion(pred, y) if phase == 'train': loss.backward() if mode + mode2 == 'classificationmulti' or mode + mode2 == 'multilabelmulti': - optimizer.first_step(zero_grad=True) + optimizer.first_step(zero_grad = True) # Second forward pass disable_running_stats(model) - try: - criterion(model(prot, x, edge_index, batch), y.view(-1, 1)).backward() - except: - criterion(model(x, edge_index, batch), y).backward() - optimizer.second_step(zero_grad=True) + second_pred = model(prot, x, edge_index, batch) if prot is not None else model(x, edge_index, batch) + criterion(second_pred, y).backward() + optimizer.second_step(zero_grad = True) else: optimizer.step() @@ -175,17 +169,21 @@ def train_model(model: torch.nn.Module, # graph neural network for analyzing gly pred_det = pred.cpu().detach().numpy() if mode == 'classification': if mode2 == 'multi': - pred2 = np.argmax(pred_det, axis=1) + pred_proba = np.exp(pred_det) / np.sum(np.exp(pred_det), axis = 1, keepdims = True) # numpy softmax + pred2 = np.argmax(pred_det, axis = 1) else: - pred2 = [np.round(sigmoid(x)) for x in pred_det] + pred_proba = sigmoid(pred_det) + pred2 = (pred_proba >= 0.5).astype(int) running_metrics["acc"].append(accuracy_score(y_det.astype(int), pred2)) running_metrics["mcc"].append(matthews_corrcoef(y_det, pred2)) - running_metrics["auroc"].append(roc_auc_score(y_det.astype(int), pred2)) + running_metrics["auroc"].append(roc_auc_score(y_det.astype(int), pred_proba if mode2 == 'binary' else pred_proba[:, 1])) elif mode == 'multilabel': - running_metrics["acc"].append(accuracy_score(y_det.astype(int), pred_det)) - running_metrics["mcc"].append(matthews_corrcoef(y_det, pred_det)) - running_metrics["lrap"].append(label_ranking_average_precision_score(y_det.astype(int), pred_det)) - running_metrics["ndcg"].append(ndcg_score(y_det.astype(int), pred_det)) + pred_proba = sigmoid(pred_det) + pred2 = (pred_proba >= 0.5).astype(int) + running_metrics["acc"].append(accuracy_score(y_det.astype(int), pred2)) + running_metrics["mcc"].append(matthews_corrcoef(y_det.flatten(), pred2.flatten())) + running_metrics["lrap"].append(label_ranking_average_precision_score(y_det.astype(int), pred_proba)) + running_metrics["ndcg"].append(ndcg_score(y_det.astype(int), pred_proba)) else: running_metrics["mse"].append(mean_squared_error(y_det, pred_det)) running_metrics["mae"].append(mean_absolute_error(y_det, pred_det)) @@ -195,7 +193,7 @@ def train_model(model: torch.nn.Module, # graph neural network for analyzing gly for key in running_metrics: if key == "weights": continue - metrics[phase][key].append(np.average(running_metrics[key], weights=running_metrics["weights"])) + metrics[phase][key].append(np.average(running_metrics[key], weights = running_metrics["weights"])) if mode == 'classification': print('{} Loss: {:.4f} Accuracy: {:.4f} MCC: {:.4f}'.format(phase, metrics[phase]["loss"][-1], metrics[phase]["acc"][-1], metrics[phase]["mcc"][-1])) @@ -220,9 +218,9 @@ def train_model(model: torch.nn.Module, # graph neural network for analyzing gly # Check Early Stopping & adjust learning rate if needed early_stopping(metrics[phase]["loss"][-1], model) - try: + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(metrics[phase]["loss"][-1]) - except: + else: scheduler.step() if early_stopping.early_stop: diff --git a/glycowork/motif/analysis.py b/glycowork/motif/analysis.py index d862cc9..c70cad7 100644 --- a/glycowork/motif/analysis.py +++ b/glycowork/motif/analysis.py @@ -269,31 +269,33 @@ def plot_embeddings( "Visualizes learned glycan embeddings using t-SNE dimensionality reduction with optional group coloring" idx = [i for i, g in enumerate(glycans) if '{' not in g] glycans = [glycans[i] for i in idx] - label_list = [label_list[i] for i in idx] + if label_list is not None: + label_list = [label_list[i] for i in idx] # Get all glycan embeddings if emb is None: if not os.path.exists('glycan_representations_v1_4.pkl'): download_model("https://drive.google.com/file/d/1--tf0kyea9jFLfffUtICKkyIw36E9hJ3/view?usp=sharing", local_path = 'glycan_representations_v1_4.pkl') emb = pickle.load(open('glycan_representations_v1_4.pkl', 'rb')) # Get the subset of embeddings corresponding to 'glycans' - if isinstance(emb, pd.DataFrame): - emb = {g: emb.iloc[i, :] for i, g in enumerate(glycans)} - embs = np.vstack([emb[g] for g in glycans]) + embs = emb.values if isinstance(emb, pd.DataFrame) else np.vstack([emb[g] for g in glycans]) # Calculate t-SNE of embeddings - embs = TSNE(random_state = 42, + n_samples = embs.shape[0] + perplexity = min(30, n_samples - 1) + embs = TSNE(random_state = 42, perplexity = perplexity, init = 'pca', learning_rate = 'auto').fit_transform(embs) # Plot the t-SNE markers = None if shape_feature is not None: markers = {shape_feature: "X", "Absent": "o"} shape_feature = [shape_feature if shape_feature in g else 'Absent' for g in glycans] - sns.scatterplot(x = embs[:, 0], y = embs[:, 1], hue = label_list, - palette = palette, style = shape_feature, markers = markers, - alpha = alpha, **kwargs) + sns.scatterplot(x = embs[:, 0], y = embs[:, 1], hue = label_list if label_list is not None else None, + palette = palette if label_list is not None else None, style = shape_feature, + markers = markers, alpha = alpha, **kwargs) sns.despine(left = True, bottom = True) plt.xlabel('Dim1') plt.ylabel('Dim2') - plt.legend(bbox_to_anchor = (1.05, 1), loc = 2, borderaxespad = 0.) + if label_list is not None: + plt.legend(bbox_to_anchor = (1.05, 1), loc = 2, borderaxespad = 0.) plt.tight_layout() if filepath: plt.savefig(filepath, format = filepath.split('.')[-1], dpi = 300, @@ -773,7 +775,8 @@ def get_meta_analysis( "Performs fixed/random effects meta-analysis using DerSimonian-Laird method for between-study variance estimation, with optional Forest plot visualization" if model not in ['fixed', 'random']: raise ValueError("Model must be 'fixed' or 'random'") - weights = 1 / np.array(variances) + variances = np.array(variances) + weights = 1 / variances total_weight = np.sum(weights) combined_effect_size = np.dot(weights, effect_sizes) / total_weight if model == 'random': @@ -783,7 +786,7 @@ def get_meta_analysis( c = total_weight - np.sum(weights**2) / total_weight tau_squared = max((q - df) / (c + 1e-8), 0) # Update weights for tau_squared - weights /= (variances + tau_squared) + weights = 1 / (variances + tau_squared) total_weight = np.sum(weights) # Recalculate combined effect size combined_effect_size = np.dot(weights, effect_sizes) / (total_weight + 1e-8) diff --git a/glycowork/motif/regex.py b/glycowork/motif/regex.py index 089e0ad..f04c219 100644 --- a/glycowork/motif/regex.py +++ b/glycowork/motif/regex.py @@ -2,7 +2,7 @@ import copy import networkx as nx from itertools import product, combinations, chain -from typing import Dict, List, Union, Optional, Tuple, Any +from typing import Dict, List, Union, Optional, Tuple from glycowork.glycan_data.loader import replace_every_second, unwrap from glycowork.motif.processing import min_process_glycans, bracket_removal, canonicalize_iupac from glycowork.motif.graph import graph_to_string, subgraph_isomorphism, compare_glycans, glycan_to_nxGraph diff --git a/tests/test_core_functions.py b/tests/test_core_functions.py index 5b2ccee..5782aa9 100644 --- a/tests/test_core_functions.py +++ b/tests/test_core_functions.py @@ -3,7 +3,12 @@ import networkx.algorithms.isomorphism as iso import pandas as pd import numpy as np +import xgboost as xgb +import seaborn as sns import torch +import torch.nn as nn +import matplotlib.pyplot as plt +from unittest.mock import Mock, patch, MagicMock from torch_geometric.data import Data from glycowork.glycan_data.data_entry import check_presence from glycowork.motif.query import get_insight, glytoucan_to_glycan @@ -66,8 +71,9 @@ get_hit_atoms_and_bonds, add_colours_to_map, unique) from glycowork.motif.analysis import (preprocess_data, get_pvals_motifs, select_grouping, get_glycanova, get_differential_expression, - get_biodiversity, get_time_series, get_SparCC, get_roc, - get_representative_substructures, get_lectin_array) + get_biodiversity, get_time_series, get_SparCC, get_roc, get_ma, get_volcano, get_meta_analysis, + get_representative_substructures, get_lectin_array, get_coverage, plot_embeddings, + characterize_monosaccharide, get_heatmap, get_pca, get_jtk, multi_feature_scoring, get_glycoshift_per_site) from glycowork.network.biosynthesis import (safe_compare, safe_index, get_neighbors, create_neighbors, find_diff, find_path, construct_network, prune_network, estimate_weights, extend_glycans, highlight_network, infer_roots, deorphanize_nodes, get_edge_weight_by_abundance, @@ -81,6 +87,10 @@ from glycowork.ml.processing import (augment_glycan, AugmentedGlycanDataset, dataset_to_graphs, dataset_to_dataloader, split_data_to_train) +from glycowork.ml.model_training import (EarlyStopping, sigmoid, disable_running_stats, + enable_running_stats, train_model, SAM, + Poly1CrossEntropyLoss, training_setup, + train_ml_model, analyze_ml_model, get_mismatch) @pytest.mark.parametrize("glycan", [ @@ -2480,6 +2490,484 @@ def simple_glycans(): ] +# Mock plt.show() to avoid displaying plots during tests +@pytest.fixture(autouse=True) +def mock_show(): + with patch('matplotlib.pyplot.show'): + yield + + +# Sample data fixtures +@pytest.fixture +def sample_df(): + data = { + 'glycan': ['Man(a1-3)[Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc', + 'Man(a1-2)Man(a1-3)[Man(a1-6)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc', + 'Man(a1-2)Man(a1-2)Man(a1-3)[Man(a1-2)Man(a1-3)[Man(a1-2)Man(a1-6)]Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc'], + 'sample1': [10, 20, 30], + 'sample2': [15, 25, 35], + 'sample3': [12, 22, 32] + } + return pd.DataFrame(data) + + +@pytest.fixture +def sample_abundance_df(): + data = { + 'sample1': [0.2, 0.3, 0.5], + 'sample2': [0.25, 0.35, 0.4], + 'sample3': [0.3, 0.3, 0.4] + } + index = ['Man3GlcNAc2', 'Man5GlcNAc2', 'Man9GlcNAc2'] + return pd.DataFrame(data, index=index) + + +@pytest.fixture +def sample_diff_expr_results(): + data = { + 'Glycan': ['Man3GlcNAc2', 'Man5GlcNAc2', 'Man9GlcNAc2'], + 'Mean abundance': [0.3, 0.35, 0.4], + 'Log2FC': [1.5, -0.5, 0.8], + 'Effect size': [0.6, -0.3, 0.4], + 'p-val': [0.01, 0.04, 0.002], + 'corr p-val': [0.03, 0.06, 0.006] + } + return pd.DataFrame(data) + + +def test_get_coverage_basic(sample_df): + """Test basic functionality of get_coverage""" + with patch('matplotlib.pyplot.savefig') as mock_savefig: + get_coverage(sample_df) + mock_savefig.assert_not_called() + + +def test_get_coverage_with_filepath(sample_df): + """Test get_coverage with filepath saving""" + with patch('matplotlib.pyplot.savefig') as mock_savefig: + get_coverage(sample_df, filepath='test.png') + mock_savefig.assert_called_once() + + +def test_get_ma_basic(sample_diff_expr_results): + """Test basic functionality of get_ma""" + with patch('matplotlib.pyplot.savefig') as mock_savefig: + get_ma(sample_diff_expr_results) + mock_savefig.assert_not_called() + + +def test_get_ma_with_filepath(sample_diff_expr_results): + """Test get_ma with filepath saving""" + with patch('matplotlib.pyplot.savefig') as mock_savefig: + get_ma(sample_diff_expr_results, filepath='test.png') + mock_savefig.assert_called_once() + + +def test_get_ma_with_custom_thresholds(sample_diff_expr_results): + """Test get_ma with custom thresholds""" + with patch('matplotlib.pyplot.savefig') as mock_savefig: + get_ma(sample_diff_expr_results, log2fc_thresh=2, sig_thresh=0.01) + mock_savefig.assert_not_called() + + +def test_get_volcano_basic(sample_diff_expr_results): + """Test basic functionality of get_volcano""" + with patch('matplotlib.pyplot.savefig') as mock_savefig: + get_volcano(sample_diff_expr_results) + mock_savefig.assert_not_called() + + +def test_get_volcano_with_filepath(sample_diff_expr_results): + """Test get_volcano with filepath saving""" + with patch('matplotlib.pyplot.savefig') as mock_savefig: + get_volcano(sample_diff_expr_results, filepath='test.png') + mock_savefig.assert_called_once() + + +def test_get_volcano_with_custom_thresholds(sample_diff_expr_results): + """Test get_volcano with custom thresholds""" + with patch('matplotlib.pyplot.savefig') as mock_savefig: + get_volcano(sample_diff_expr_results, y_thresh=0.01, x_thresh=1.0) + mock_savefig.assert_not_called() + + +def test_get_volcano_with_effect_size(sample_diff_expr_results): + """Test get_volcano using effect size instead of Log2FC""" + with patch('matplotlib.pyplot.savefig') as mock_savefig: + get_volcano(sample_diff_expr_results, x_metric='Effect size') + mock_savefig.assert_not_called() + + +def test_get_meta_analysis_fixed(): + """Test meta-analysis with fixed effects model""" + effect_sizes = [0.5, 0.3, 0.7] + variances = [0.1, 0.15, 0.08] + combined_effect, p_value = get_meta_analysis(effect_sizes, variances, model='fixed') + assert isinstance(combined_effect, float) + assert isinstance(p_value, float) + assert 0 <= p_value <= 1 + + +def test_get_meta_analysis_random(): + """Test meta-analysis with random effects model""" + effect_sizes = [0.5, 0.3, 0.7] + variances = [0.1, 0.15, 0.08] + combined_effect, p_value = get_meta_analysis(effect_sizes, variances, model='random') + assert isinstance(combined_effect, float) + assert isinstance(p_value, float) + assert 0 <= p_value <= 1 + + +def test_get_meta_analysis_with_study_names(): + """Test meta-analysis with study names""" + effect_sizes = [0.5, 0.3, 0.7] + variances = [0.1, 0.15, 0.08] + study_names = ['Study1', 'Study2', 'Study3'] + with patch('matplotlib.pyplot.subplots') as mock_subplots: + mock_fig = MagicMock() + mock_ax = MagicMock() + mock_subplots.return_value = (mock_fig, mock_ax) + with patch('matplotlib.pyplot.savefig') as mock_savefig: + get_meta_analysis(effect_sizes, variances, study_names=study_names, filepath='test.png') + mock_savefig.assert_called_once() + + +def test_get_meta_analysis_invalid_model(): + """Test meta-analysis with invalid model specification""" + effect_sizes = [0.5, 0.3, 0.7] + variances = [0.1, 0.15, 0.08] + with pytest.raises(ValueError): + get_meta_analysis(effect_sizes, variances, model='invalid') + + +def test_get_meta_analysis_mismatched_lengths(): + """Test meta-analysis with mismatched effect sizes and variances""" + effect_sizes = [0.5, 0.3, 0.7] + variances = [0.1, 0.15] + with pytest.raises(ValueError): + get_meta_analysis(effect_sizes, variances) + + +def test_plot_embeddings_basic(): + """Test basic functionality of plot_embeddings""" + glycans = ['Man3GlcNAc2', 'Man5GlcNAc2', 'Man9GlcNAc2'] + emb = pd.DataFrame(np.random.rand(3, 10)) + with patch('matplotlib.pyplot.savefig') as mock_savefig: + plot_embeddings(glycans, emb) + mock_savefig.assert_not_called() + + +def test_plot_embeddings_with_labels(): + """Test plot_embeddings with group labels""" + glycans = ['Man3GlcNAc2', 'Man5GlcNAc2', 'Man9GlcNAc2'] + emb = pd.DataFrame(np.random.rand(3, 10)) + labels = ['A', 'B', 'A'] + with patch('matplotlib.pyplot.savefig') as mock_savefig: + plot_embeddings(glycans, emb, label_list=labels) + mock_savefig.assert_not_called() + + +def test_plot_embeddings_with_shape_feature(): + """Test plot_embeddings with shape feature""" + glycans = ['Man3GlcNAc2', 'Man5GlcNAc2', 'Man9GlcNAc2'] + emb = pd.DataFrame(np.random.rand(3, 10)) + with patch('matplotlib.pyplot.savefig') as mock_savefig: + plot_embeddings(glycans, emb, shape_feature='Man') + mock_savefig.assert_not_called() + + +def test_plot_embeddings_with_filepath(tmp_path): + """Test plot_embeddings with file saving""" + glycans = ['Man3GlcNAc2', 'Man5GlcNAc2', 'Man9GlcNAc2'] + emb = pd.DataFrame(np.random.rand(3, 10)) + filepath = tmp_path / "test.png" + with patch('matplotlib.pyplot.savefig') as mock_savefig: + plot_embeddings(glycans, emb, filepath=str(filepath)) + mock_savefig.assert_called_once() + + +@pytest.fixture +def sample_time_series_df(): + data = { + 'glycan': ['Man3GlcNAc2', 'Man5GlcNAc2', 'Man9GlcNAc2'], + 'T1_h0_r1': [10, 20, 30], + 'T1_h4_r1': [15, 25, 35], + 'T1_h8_r1': [12, 22, 32], + 'T1_h0_r2': [11, 21, 31], + 'T1_h4_r2': [14, 24, 34], + 'T1_h8_r2': [13, 23, 33] + } + return pd.DataFrame(data) + + +@pytest.fixture +def sample_glycoshift_df(): + data = { + 'protein_site_composition': ['ProtA_123_Man3GlcNAc2', 'ProtA_123_Man5GlcNAc2', 'ProtB_456_Man3GlcNAc2'], + 'sample1': [10, 20, 30], + 'sample2': [15, 25, 35], + 'sample3': [12, 22, 32], + 'sample4': [11, 21, 31] + } + return pd.DataFrame(data) + + +def test_characterize_monosaccharide_basic(): + """Test basic functionality of characterize_monosaccharide""" + with patch('matplotlib.pyplot.savefig') as mock_savefig: + characterize_monosaccharide('Man') + mock_savefig.assert_not_called() + + +def test_characterize_monosaccharide_with_custom_df(sample_df): + """Test characterize_monosaccharide with custom DataFrame""" + with patch('matplotlib.pyplot.savefig') as mock_savefig: + characterize_monosaccharide('Man', df=sample_df, thresh=1) + mock_savefig.assert_not_called() + + +def test_characterize_monosaccharide_with_bond_mode(): + """Test characterize_monosaccharide in bond mode""" + with patch('matplotlib.pyplot.savefig') as mock_savefig: + characterize_monosaccharide('a1-3', mode='bond') + mock_savefig.assert_not_called() + + +def test_characterize_monosaccharide_with_modifications(): + """Test characterize_monosaccharide with modifications enabled""" + with patch('matplotlib.pyplot.savefig') as mock_savefig: + characterize_monosaccharide('Man', modifications=True) + mock_savefig.assert_not_called() + + +@pytest.fixture +def mock_clustermap(): + mock_fig = MagicMock() + mock_fig.figure = plt.figure() + mock_clustermap = MagicMock() + mock_clustermap.fig = mock_fig.figure + return mock_clustermap + + +def test_get_heatmap_basic(sample_df): + """Test basic functionality of get_heatmap""" + with patch.object(sns, 'clustermap', return_value=mock_clustermap): + get_heatmap(sample_df) + + +def test_get_heatmap_with_motifs(sample_df): + """Test get_heatmap with motif analysis""" + with patch.object(sns, 'clustermap', return_value=mock_clustermap): + get_heatmap(sample_df, motifs=True) + + +def test_get_heatmap_with_transform(sample_df): + """Test get_heatmap with data transformation""" + with patch.object(sns, 'clustermap', return_value=mock_clustermap): + get_heatmap(sample_df, transform='CLR') + + +def test_get_heatmap_with_custom_feature_set(sample_df): + """Test get_heatmap with custom feature set""" + with patch.object(sns, 'clustermap', return_value=mock_clustermap): + get_heatmap(sample_df, motifs=True, feature_set=['known']) + + +def test_get_pca_basic(sample_df): + """Test basic functionality of get_pca""" + groups = [1, 1, 1] + with patch('matplotlib.pyplot.savefig') as mock_savefig: + get_pca(sample_df, groups) + mock_savefig.assert_not_called() + + +def test_get_pca_with_metadata(sample_df): + """Test get_pca with metadata DataFrame""" + metadata = pd.DataFrame({'id': ['sample1', 'sample2', 'sample3'], + 'group': ['A', 'A', 'B']}) + with patch('matplotlib.pyplot.savefig') as mock_savefig: + get_pca(sample_df, metadata) + mock_savefig.assert_not_called() + + +def test_get_pca_with_motifs(sample_df): + """Test get_pca with motif analysis""" + groups = [1, 1, 1] + with patch('matplotlib.pyplot.savefig') as mock_savefig: + get_pca(sample_df, groups, motifs=True) + mock_savefig.assert_not_called() + + +def test_get_pca_with_custom_components(sample_df): + """Test get_pca with custom principal components""" + groups = [1, 1, 1] + with patch('matplotlib.pyplot.savefig') as mock_savefig: + get_pca(sample_df, groups, pc_x=2, pc_y=3) + mock_savefig.assert_not_called() + + +@pytest.fixture +def sample_jtk_df(): + """ + Creates a test dataset for JTK analysis with: + - 24h coverage with 4h intervals (0,4,8,12,16,20) + - 2 replicates per timepoint + - 3 glycans with different rhythmic patterns: + - Man3GlcNAc2: Strong 24h rhythm + - Man5GlcNAc2: Weak 12h rhythm + - Man9GlcNAc2: No rhythm (control) + """ + # Generate timepoints every 4 hours for 24 hours + timepoints = [0, 4, 8, 12, 16, 20] + columns = ['glycan'] + [f'T_h{t}_r{r}' for t in timepoints for r in [1, 2]] + # Create rhythmic patterns + # Man3GlcNAc2: 24h cycle (peak at 12h) + h24_pattern = [10, 15, 25, 30, 25, 15] # Base values for 24h rhythm + # Man5GlcNAc2: 12h cycle (peaks at 4h and 16h) + h12_pattern = [15, 25, 15, 20, 25, 15] # Base values for 12h rhythm + # Man9GlcNAc2: No rhythm (random variation around mean) + no_rhythm = [20, 21, 19, 22, 20, 21] # Stable values with small variation + # Add some noise to replicates (±10% variation) + data = { + 'glycan': ['Man(a1-3)[Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc', + 'Man(a1-2)Man(a1-3)[Man(a1-6)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc', + 'Man(a1-2)Man(a1-2)Man(a1-3)[Man(a1-2)Man(a1-3)[Man(a1-2)Man(a1-6)]Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc'] + } + # Add two replicates for each timepoint with small random variation + np.random.seed(42) # For reproducibility + for t_idx, t in enumerate(timepoints): + for r in [1, 2]: + col = f'T_h{t}_r{r}' + base_values = [h24_pattern[t_idx], h12_pattern[t_idx], no_rhythm[t_idx]] + # Add 5% random noise + noise = np.random.normal(0, 0.05, 3) * np.array(base_values) + data[col] = np.array(base_values) + noise + return pd.DataFrame(data) + + +def test_get_jtk_basic(sample_jtk_df): + """Test basic functionality of get_jtk""" + periods = [12, 24] + result = get_jtk(sample_jtk_df, timepoints=6, periods=periods, interval=4) + assert isinstance(result, pd.DataFrame) + assert 'Adjusted_P_value' in result.columns + assert 'Period_Length' in result.columns + + +def test_get_jtk_with_motifs(sample_jtk_df): + """Test get_jtk with motif analysis""" + periods = [12, 24] + result = get_jtk(sample_jtk_df, timepoints=6, periods=periods, + interval=4, motifs=True) + assert isinstance(result, pd.DataFrame) + assert 'Adjusted_P_value' in result.columns + + +def test_get_jtk_with_transform(sample_jtk_df): + """Test get_jtk with data transformation""" + periods = [12, 24] + result = get_jtk(sample_jtk_df, timepoints=6, periods=periods, + interval=4, transform='CLR') + assert isinstance(result, pd.DataFrame) + + +def test_multi_feature_scoring_basic(sample_df): + """Test basic functionality of multi_feature_scoring""" + np.random.seed(42) + n_features = 5 + n_samples = 10 + data = np.random.rand(n_features, n_samples) + # Make the first feature strongly predictive + data[0, :5] = 0.1 # Clear signal for group 1 + data[0, 5:] = 0.9 # Clear signal for group 2 + # Make the second feature moderately predictive + data[1, :5] = 0.3 + data[1, 5:] = 0.7 + df_transformed = pd.DataFrame(data) + group1 = [0, 1, 2, 3, 4] + group2 = [5, 6, 7, 8, 9] + model, roc_auc = multi_feature_scoring(df_transformed, group1, group2) + assert hasattr(model, 'predict') + assert 0 <= roc_auc <= 1 + assert roc_auc > 0.5 + + +def test_multi_feature_scoring_with_filepath(sample_df): + """Test multi_feature_scoring with file saving""" + np.random.seed(42) + n_features = 5 + n_samples = 10 + data = np.random.rand(n_features, n_samples) + # Make the first feature strongly predictive + data[0, :5] = 0.1 + data[0, 5:] = 0.9 + df_transformed = pd.DataFrame(data) + group1 = [0, 1, 2, 3, 4] + group2 = [5, 6, 7, 8, 9] + with patch('matplotlib.pyplot.subplots') as mock_subplots: + mock_fig = MagicMock() + mock_ax = MagicMock() + mock_subplots.return_value = (mock_fig, mock_ax) + with patch('matplotlib.pyplot.savefig') as mock_savefig: + multi_feature_scoring(df_transformed, group1, group2, filepath='test.png') + mock_savefig.assert_called_once() + + +def test_multi_feature_scoring_imbalanced_groups(sample_df): + """Test multi_feature_scoring with imbalanced groups""" + np.random.seed(42) + n_features = 5 + n_samples = 10 + data = np.random.rand(n_features, n_samples) + # Make the first feature strongly predictive + data[0, :4] = 0.1 # Group 1 (4 samples) + data[0, 4:] = 0.9 # Group 2 (6 samples) + df_transformed = pd.DataFrame(data) + group1 = [0, 1, 2, 3] + group2 = [4, 5, 6, 7, 8, 9] + with patch('matplotlib.pyplot.figure'): + model, roc_auc = multi_feature_scoring(df_transformed, group1, group2) + assert hasattr(model, 'predict') + assert 0 <= roc_auc <= 1 + assert roc_auc > 0.5 + + +def test_get_glycoshift_per_site_basic(sample_glycoshift_df): + """Test basic functionality of get_glycoshift_per_site""" + group1 = ['sample1', 'sample2'] + group2 = ['sample3', 'sample4'] + result = get_glycoshift_per_site(sample_glycoshift_df, group1, group2) + assert isinstance(result, pd.DataFrame) + assert 'Condition_corr_pval' in result.columns + + +def test_get_glycoshift_per_site_with_custom_params(sample_glycoshift_df): + """Test get_glycoshift_per_site with custom parameters""" + group1 = ['sample1', 'sample2'] + group2 = ['sample3', 'sample4'] + result = get_glycoshift_per_site(sample_glycoshift_df, group1, group2, + min_samples=0.3, gamma=0.2) + assert isinstance(result, pd.DataFrame) + + +def test_get_glycoshift_per_site_paired(sample_glycoshift_df): + """Test get_glycoshift_per_site with paired samples""" + group1 = ['sample1', 'sample2'] + group2 = ['sample3', 'sample4'] + result = get_glycoshift_per_site(sample_glycoshift_df, group1, group2, + paired=True) + assert isinstance(result, pd.DataFrame) + + +def test_get_glycoshift_per_site_no_imputation(sample_glycoshift_df): + """Test get_glycoshift_per_site without imputation""" + group1 = ['sample1', 'sample2'] + group2 = ['sample3', 'sample4'] + result = get_glycoshift_per_site(sample_glycoshift_df, group1, group2, + impute=False) + assert isinstance(result, pd.DataFrame) + + @pytest.fixture def simple_network(): # Create a simple directed network @@ -3159,7 +3647,7 @@ def test_hierarchy_filter_basic(): 'Domain': ['d1']*5 + ['d2']*5 } df = pd.DataFrame(data) - train_x, val_x, train_y, val_y, id_val, class_list, class_converter = hierarchy_filter( + train_x, val_x, train_y, val_y, _, class_list, class_converter = hierarchy_filter( df, rank='Domain', min_seq=1 ) assert len(train_x) + len(val_x) == len(set(df['glycan'])) # Check for duplicates removal @@ -3445,4 +3933,334 @@ def test_dataset_to_dataloader_batch_size(mock_glycan_dataset, mock_library): batch_size=batch_size ) assert dataloader.batch_size == batch_size - first_batch = next(iter(dataloader)) + _ = next(iter(dataloader)) + + +@pytest.fixture(params=[ + 'regression', + 'classification', + 'multilabel' +]) +def mode(request): + return request.param + + +@pytest.fixture +def expected_metrics(mode): + metrics_map = { + 'regression': ['loss', 'mse', 'mae', 'r2'], + 'classification': ['loss', 'acc', 'mcc', 'auroc'], + 'multilabel': ['loss', 'acc', 'mcc', 'lrap', 'ndcg'] + } + return metrics_map[mode] + + +@pytest.fixture +def mock_model(mode): + """Create a mock model that adapts its output size based on the mode""" + class SimpleModel(nn.Module): + def __init__(self, output_size): + super().__init__() + self.embedding = nn.Embedding(10, 32) + self.fc = nn.Linear(32, output_size) + self.bn = nn.BatchNorm1d(32) + + def forward(self, x, edge_index, batch): + x = self.embedding(x) + x = self.bn(x) + graph_embed = torch.zeros(batch.max().item() + 1, x.size(1), + device=x.device) + graph_embed.index_add_(0, batch, x) + out = self.fc(graph_embed) + return out + output_sizes = { + 'regression': 1, + 'classification': 2, + 'multilabel': 2 + } + return SimpleModel(output_sizes[mode]) + + +@pytest.fixture +def mock_dataloader(mode): + class MockData: + def __init__(self, mode): + # Node features for 6 nodes across 2 graphs + self.labels = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.long) + # Edge connections + self.edge_index = torch.tensor([[0, 1, 2, 3, 4, 5], + [1, 2, 0, 4, 5, 3]], dtype=torch.long) + # Batch assignments: first 3 nodes to first graph, last 3 to second + self.batch = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long) + # Adapt y based on mode + if mode == 'regression': + self.y = torch.tensor([0.5, 1.5], dtype=torch.float) + elif mode == 'classification': + self.y = torch.tensor([0, 1], dtype=torch.long) + else: # multilabel + self.y = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float) + class MockLoader: + def __init__(self, data_list): + self.data_list = data_list + def __iter__(self): + return iter(self.data_list) + def __len__(self): + return len(self.data_list) + mock_data_list = [MockData(mode=mode) for _ in range(3)] + loader = MockLoader(mock_data_list) + return {'train': loader, 'val': loader} + + +@pytest.fixture +def mock_xgb_data(): + X_train = pd.DataFrame({ + 'feature1': [1, 2, 3], + 'feature2': [4, 5, 6] + }) + X_test = pd.DataFrame({ + 'feature1': [7, 8, 9], + 'feature2': [10, 11, 12] + }) + y_train = [0, 1, 0] + y_test = [1, 0, 1] + return X_train, X_test, y_train, y_test + + +def test_early_stopping(): + early_stopping = EarlyStopping(patience=2, verbose=True) + model = Mock() + # Should not stop after first higher loss + early_stopping(0.5, model) + early_stopping(0.6, model) + assert not early_stopping.early_stop + # Should stop after patience exceeded + early_stopping(0.7, model) + assert early_stopping.early_stop + # Should reset counter on improvement + early_stopping = EarlyStopping(patience=2) + early_stopping(0.5, model) + early_stopping(0.4, model) # Improvement + early_stopping(0.45, model) # Worse + assert not early_stopping.early_stop # Counter should have reset + + +def test_sigmoid(): + assert abs(sigmoid(0) - 0.5) < 1e-6 + assert sigmoid(100) > 0.99 + assert sigmoid(-100) < 0.01 + + +def test_batch_norm_stats(mock_model): + # Test disabling + disable_running_stats(mock_model) + assert mock_model.bn.momentum == 0 + assert hasattr(mock_model.bn, 'backup_momentum') + # Test enabling + enable_running_stats(mock_model) + assert mock_model.bn.momentum == mock_model.bn.backup_momentum + + +@patch('torch.cuda.is_available', return_value=False) +def test_poly1_cross_entropy_loss(_): + criterion = Poly1CrossEntropyLoss(num_classes=3, epsilon=1.0) + logits = torch.tensor([[2.0, 1.0, 0.0], [0.0, 2.0, 1.0]]) + labels = torch.tensor([0, 1]) + loss = criterion(logits, labels) + assert isinstance(loss, torch.Tensor) + assert loss.ndim == 0 # Scalar tensor + + +def test_sam_optimizer(mock_model, mode): + if mode == "classification": + # Create a dummy input and target + x = torch.randint(0, 10, (6,)) # 6 nodes with features in range [0, 10) + edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]], dtype=torch.long) # Some edges + batch = torch.tensor([0, 0, 0, 1, 1, 1]) # Two graphs + target = torch.tensor([0, 1]) # Labels for the two graphs + # Create optimizer + sam = SAM( + mock_model.parameters(), + base_optimizer=torch.optim.SGD, + rho=0.5, + adaptive=True, + lr=0.1 + ) + # Forward pass + output = mock_model(x, edge_index, batch) + loss = torch.nn.functional.cross_entropy(output, target) + # Backward pass to create gradients + loss.backward() + # Now test SAM steps + sam.first_step(zero_grad=True) + # Another forward-backward pass + output = mock_model(x, edge_index, batch) + loss = torch.nn.functional.cross_entropy(output, target) + loss.backward() + sam.second_step(zero_grad=True) + # Verify state + assert hasattr(sam, 'base_optimizer') + assert isinstance(sam.base_optimizer, torch.optim.SGD) + + +def test_sam_optimizer_state(): + # Create a simple model for testing + model = torch.nn.Linear(2, 2) + # Create dummy data + x = torch.randn(4, 2) + y = torch.tensor([0, 1, 0, 1]) + # Initialize SAM + sam = SAM( + model.parameters(), + base_optimizer=torch.optim.SGD, + rho=0.5, + adaptive=True, + lr=0.1 + ) + # Initial parameter values + initial_params = {name: param.clone() for name, param in model.named_parameters()} + # Forward pass + output = model(x) + loss = torch.nn.functional.cross_entropy(output, y) + loss.backward() + # First step + sam.first_step(zero_grad=True) + # Check parameters were updated + for name, param in model.named_parameters(): + assert not torch.equal(param, initial_params[name]), f"Parameters {name} were not updated in first step" + # Store parameters after first step + params_after_first = {name: param.clone() for name, param in model.named_parameters()} + # Another forward-backward pass + output = model(x) + loss = torch.nn.functional.cross_entropy(output, y) + loss.backward() + # Second step + sam.second_step(zero_grad=True) + # Check parameters were updated again + for name, param in model.named_parameters(): + assert not torch.equal(param, params_after_first[name]), f"Parameters {name} were not updated in second step" + + +def test_sam_optimizer_zero_grad(): + model = torch.nn.Linear(2, 2) + x = torch.randn(4, 2) + y = torch.tensor([0, 1, 0, 1]) + sam = SAM( + model.parameters(), + base_optimizer=torch.optim.SGD, + rho=0.5, + adaptive=True, + lr=0.1 + ) + # Forward and backward pass + output = model(x) + loss = torch.nn.functional.cross_entropy(output, y) + loss.backward() + # Check gradients exist + assert all(p.grad is not None for p in model.parameters()) + # Test zero_grad + sam.zero_grad() + assert all(p.grad is None or torch.all(p.grad == 0) for p in model.parameters()) + + +@patch('torch.cuda.is_available', return_value=False) +def test_training_setup(_, mock_model): + optimizer, scheduler, criterion = training_setup( + mock_model, + lr=0.001, + mode='multiclass', + num_classes=3 + ) + assert isinstance(optimizer, SAM) + assert isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) + assert isinstance(criterion, Poly1CrossEntropyLoss) + + +def test_train_ml_model(mock_xgb_data): + X_train, X_test, y_train, y_test = mock_xgb_data + # Test classification + model = train_ml_model( + X_train, X_test, y_train, y_test, + mode='classification', + feature_calc=False + ) + assert isinstance(model, xgb.XGBClassifier) + # Test regression + model = train_ml_model( + X_train, X_test, y_train, y_test, + mode='regression', + feature_calc=False + ) + assert isinstance(model, xgb.XGBRegressor) + + +@patch('matplotlib.pyplot.show') +def test_analyze_ml_model(mock_show, mock_xgb_data): + X_train, X_test, y_train, y_test = mock_xgb_data + model = train_ml_model(X_train, X_test, y_train, y_test, mode='classification') + analyze_ml_model(model) + mock_show.assert_called_once() + + +def test_get_mismatch(mock_xgb_data): + X_train, X_test, y_train, y_test = mock_xgb_data + model = train_ml_model(X_train, X_test, y_train, y_test, mode='classification') + mismatches = get_mismatch(model, X_test, y_test, n=2) + assert isinstance(mismatches, list) + assert all(isinstance(m, tuple) and len(m) == 2 for m in mismatches) + assert len(mismatches) <= 2 + + +def verify_mock_data(data): + """Helper function to verify mock data structure""" + assert hasattr(data, 'labels') + assert hasattr(data, 'y') + assert hasattr(data, 'edge_index') + assert hasattr(data, 'batch') + assert isinstance(data.labels, torch.Tensor) + assert isinstance(data.y, torch.Tensor) + assert isinstance(data.edge_index, torch.Tensor) + assert isinstance(data.batch, torch.Tensor) + assert data.batch.max() == 1 # Should have exactly 2 graphs + assert len(data.y) == 2 # Should have 2 labels (one per graph) + assert len(data.labels) == 6 # Should have 6 nodes + assert (data.batch == 0).sum() == 3 # 3 nodes in first graph + assert (data.batch == 1).sum() == 3 # 3 nodes in second graph + + +@patch('torch.cuda.is_available', return_value=False) +def test_train_model_all_modes(mock_cuda, mode, expected_metrics, mock_model, mock_dataloader): + # Configure based on mode + if mode == 'regression': + criterion = nn.MSELoss() + optimizer = torch.optim.Adam(mock_model.parameters(), lr=0.1) + elif mode == 'classification': + criterion = nn.CrossEntropyLoss() + optimizer = SAM(mock_model.parameters(), torch.optim.SGD, lr=0.1) + else: # multilabel + criterion = nn.BCEWithLogitsLoss() + optimizer = SAM(mock_model.parameters(), torch.optim.SGD, lr=0.1) + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer.base_optimizer if isinstance(optimizer, SAM) else optimizer, + step_size=1 + ) + # Run training + _, metrics = train_model( + mock_model, + mock_dataloader, + criterion, + optimizer, + scheduler, + num_epochs=2, + mode=mode, + mode2='multi' if mode != 'regression' else 'binary', + return_metrics=True + ) + # Verify metrics structure + assert isinstance(metrics, dict) + assert 'train' in metrics and 'val' in metrics + for phase in ['train', 'val']: + assert all(key in metrics[phase] for key in expected_metrics) + for metric in metrics[phase].values(): + assert len(metric) == 2 # Two epochs + assert all(isinstance(v, float) for v in metric) + assert all(not np.isnan(v) for v in metric)