From c0ed6cc970f9284073acebc2bf743919d8d89e04 Mon Sep 17 00:00:00 2001 From: ljleb Date: Tue, 30 Jul 2024 02:34:22 -0400 Subject: [PATCH] Ties upgrade (#37) * add TIES_SOUP support * TIES w/ DARE * DROP: DARE without ARE * modelstock and n_average * hotfix + it is memory intensive * hotfix: nan * modelstock fix + geometric_median * DARE / TIES / Median Stack * move things around * return type * forward n_average * reformat and deduplicate --------- Co-authored-by: Darren Laurie <6DammK9@gmail.com> --- sd_mecha/__init__.py | 166 ++++++++++++++++++++++++++- sd_mecha/merge_methods.py | 207 ++++++++++++++++++++++++++++++++-- test/unit_test_dropout.py | 61 ++++++++++ test/unit_test_geom_median.py | 82 ++++++++++++++ test/unit_test_modelstock.py | 75 ++++++++++++ test/unit_test_n_average.py | 34 ++++++ test/unit_test_ties.py | 5 + 7 files changed, 619 insertions(+), 11 deletions(-) create mode 100644 test/unit_test_dropout.py create mode 100644 test/unit_test_geom_median.py create mode 100644 test/unit_test_modelstock.py create mode 100644 test/unit_test_n_average.py diff --git a/sd_mecha/__init__.py b/sd_mecha/__init__.py index caae446..991ed01 100644 --- a/sd_mecha/__init__.py +++ b/sd_mecha/__init__.py @@ -42,6 +42,7 @@ def serialize_and_save( weighted_sum = merge_methods.weighted_sum slerp = merge_methods.slerp +n_average = merge_methods.n_average def add_difference( @@ -132,6 +133,7 @@ def add_perpendicular( cosine_add_a = merge_methods.add_cosine_a cosine_add_b = merge_methods.add_cosine_b ties_sum = merge_methods.ties_sum +ties_sum_extended = merge_methods.ties_sum_extended # latex notes in reference to original implementation: https://arxiv.org/abs/2306.01708 @@ -142,11 +144,13 @@ def add_perpendicular( # - `k`: $$ k $$ ( From $$ \% $$ to $$ 1 $$ ) # - `res`: $$ \lambda * \tau_m $$ # - `return`: $$ \theta_m $$ +# Special mode "TIES-SOUP" has been implemented by setting `vote_sgn` > 0.0 +# Special mode "TIES-STOCK" has been implemented by setting `apply_stock` > 0.0 def add_difference_ties( base: RecipeNodeOrPath, *models: RecipeNodeOrPath, - alpha: float, - k: float = 0.2, + alpha: Hyper, + k: Hyper = 0.2, device: Optional[str] = None, dtype: Optional[torch.dtype] = None, ) -> recipe_nodes.RecipeNode: @@ -182,6 +186,61 @@ def add_difference_ties( ) +def add_difference_ties_extended( + base: RecipeNodeOrPath, + *models: RecipeNodeOrPath, + alpha: Hyper, + k: Hyper = 0.2, + vote_sgn: Hyper = 0.0, + apply_stock: Hyper = 0.0, + cos_eps: Hyper = 1e-6, + apply_median: Hyper = 0.0, + eps: Hyper = 1e-6, + maxiter: Hyper = 100, + ftol: Hyper =1e-20, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> recipe_nodes.RecipeNode: + # $$ \{\theta_{init}\}_{t=1}^n $$ + base = path_to_node(base) + models = tuple(path_to_node(model) for model in models) + + # Create task vectors. + # $$ \tau_t $$ + models = tuple( + subtract(model, base) + if model.merge_space is MergeSpace.BASE else + model + for model in models + ) + + # step 1 + step 2 + step 3 + res = ties_sum_extended( + *models, + k=k, + vote_sgn=vote_sgn, + apply_stock=apply_stock, + cos_eps=cos_eps, + apply_median=apply_median, + eps=eps, + maxiter=maxiter, + ftol=ftol, + device=device, + dtype=dtype, + ) + + # Obtain merged checkpoint + + # $$ \theta_{init} + \lambda * \tau_m $$ + return add_difference( + base, res, + alpha=alpha, + device=device, + dtype=dtype, + ) + + + def copy_region( a: RecipeNodeOrPath, b: RecipeNodeOrPath, c: Optional[RecipeNodeOrPath] = None, *, width: Hyper = 1.0, @@ -299,6 +358,109 @@ def dropout( return sd_mecha.add_difference(a, ba_delta, alpha=alpha, device=device, dtype=dtype) +ties_sum_with_dropout = merge_methods.ties_sum_with_dropout + + +# latex notes in reference to original implementation: https://arxiv.org/abs/2311.03099 +# Notice that this is "TIES Merging w/ DARE", which is "Prune > Merge (TIES) > Rescale" +# See https://slgero.medium.com/merge-large-language-models-29897aeb1d1a for details +# - `base`: $$ \theta_{PRE} $$ +# - `*models`: $$ \theta_{SFT}^{t_k} $$ +# - `deltas`: $$ \delta^t = \theta_{SFT}^{t} - \theta_{PRE} \in \mathbb{R}^d $$ +# - `probability`: $$ p $$ +# - `res`: $$ \hat{\delta}^t = \tilde{\delta}^t / (1-p) $$ +# - `alpha`: $$ \lambda $$ +# - `k`: $$ k $$ ( From $$ \% $$ to $$ 1 $$ ) in TIES paper +# - `return`: $$ \theta_M = \theta_{PRE} + \lambda \cdot \Sigma_{k=1}^{K} \tilde{\delta}^{t_k} $$ +# Special mode "TIES-SOUP" has been implemented by setting `vote_sgn` > 0.0 +def ties_with_dare( + base: RecipeNodeOrPath, + *models: RecipeNodeOrPath, + probability: Hyper = 0.9, + no_rescale: Hyper = 0.0, + alpha: Hyper = 0.5, + seed: Optional[Hyper] = None, + k: Hyper = 0.2, + vote_sgn: Hyper = 0.0, + apply_stock: Hyper = 0.0, + cos_eps: Hyper = 1e-6, + apply_median: Hyper = 0.0, + eps: Hyper = 1e-6, + maxiter: Hyper = 100, + ftol: Hyper =1e-20, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> recipe_nodes.RecipeNode: + # $$ \delta^t = \theta_{SFT}^{t} - \theta_{PRE} \in \mathbb{R}^d $$ + base = path_to_node(base) + models = tuple(path_to_node(model) for model in models) + deltas = tuple( + subtract(model, base) + if model.merge_space is MergeSpace.BASE else + model + for model in models + ) + + # $$ \tilde{\delta}^{t_k} $$ + res = ties_sum_with_dropout( + *deltas, + probability=probability, + no_rescale=no_rescale, + k=k, + vote_sgn=vote_sgn, + seed=seed, + apply_stock=apply_stock, + cos_eps=cos_eps, + apply_median=apply_median, + eps=eps, + maxiter=maxiter, + ftol=ftol, + device=device, + dtype=dtype + ) + + # $$ \theta_M = \theta_{PRE} + \lambda \cdot \Sigma_{k=1}^{K} \tilde{\delta}^{t_k} $$ + return sd_mecha.add_difference(base, res, alpha=alpha, device=device, dtype=dtype) + + +model_stock_for_tensor = merge_methods.model_stock_for_tensor + + +# Following mergekit's implementation of Model Stock (which official implementation doesn't exist) +# https://github.com/arcee-ai/mergekit/blob/main/mergekit/merge_methods/model_stock.py +def model_stock_n_models( + base: RecipeNodeOrPath, + *models: RecipeNodeOrPath, + cos_eps: Hyper = 1e-6, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> recipe_nodes.RecipeNode: + + base = path_to_node(base) + models = tuple(path_to_node(model) for model in models) + deltas = tuple( + subtract(model, base) + if model.merge_space is MergeSpace.BASE else + model + for model in models + ) + + # This is hacky: Both w_avg and w_h will be calculated there. + # Notice that t and cos_theta is vector instead of single value. + # Conceptually it could compatable with TIES, but algorithm should be rewritten. + res = model_stock_for_tensor( + *deltas, + cos_eps=cos_eps, + device=device, + dtype=dtype + ) + + return sd_mecha.add_difference(base, res, alpha=1.0, device=device, dtype=dtype) + + +geometric_median = merge_methods.geometric_median + + def model(state_dict: str | pathlib.Path | Mapping[str, Tensor], model_arch: str = "sd1", model_type: str = "base"): return recipe_nodes.ModelRecipeNode(state_dict, model_arch, model_type) diff --git a/sd_mecha/merge_methods.py b/sd_mecha/merge_methods.py index fb51514..07a9e9b 100644 --- a/sd_mecha/merge_methods.py +++ b/sd_mecha/merge_methods.py @@ -25,6 +25,14 @@ def weighted_sum( ) -> Tensor | SameMergeSpace: return (1 - alpha) * a + alpha * b +# Isotropic merge / Uniform Soup / Uniform Merge... you name it. +# Instead of running average, this may run faster. +@convert_to_recipe +def n_average( + *models: Tensor | SameMergeSpace, + **kwargs, +) -> Tensor | SameMergeSpace: + return torch.mean(torch.stack(models), dim=0) @convert_to_recipe def slerp( @@ -130,21 +138,66 @@ def add_cosine_generic(a: Tensor, b: Tensor, alpha: float, similarity: Tensor) - return weighted_sum.__wrapped__(a, b, alpha=k) +# Special mode "TIES-STOCK" has been implemented by setting `apply_stock` > 0.0 +# Special mode "TIES-GMEDIAN" has been implemented by setting `apply_median` > 0.0 +@convert_to_recipe +def ties_sum_extended( # aka add_difference_ties + *models: Tensor | LiftFlag[MergeSpace.DELTA], + k: Hyper = 0.2, + vote_sgn: Hyper = 0.0, + apply_stock: Hyper = 0.0, + cos_eps: Hyper = 1e-6, + apply_median: Hyper = 0.0, + eps: Hyper = 1e-6, + maxiter: Hyper = 100, + ftol: Hyper =1e-20, + **kwargs, +) -> Tensor | LiftFlag[MergeSpace.DELTA]: + filtered_delta, param_counts = ties_sum_deltas(*models, k=k, vote_sgn=vote_sgn) + + if apply_median <= 0.0: + # Model Stock + t = 1.0 if apply_stock <= 0.0 else get_model_stock_t(torch.unbind(filtered_delta), cos_eps=cos_eps) + + filtered_delta = filtered_delta.sum(dim=0) + + # $$ \tau_m $$ + return torch.nan_to_num(filtered_delta * t / param_counts) + else: + # $$ \tau_m $$, but in geometric median instead of arithmetic mean. Considered to replace model stock. + filtered_delta = geometric_median_list_of_array(torch.unbind(filtered_delta), eps=eps, maxiter=maxiter, ftol=ftol) + + return torch.nan_to_num(filtered_delta) + + # latex notes in reference to original implementation: https://arxiv.org/abs/2306.01708 # - `delta`: $$ \hat{\tau}_t $$ # - `signs`: $$ \gamma_t $$ -# - `final_sign`: $$ \gamma_m^p = sgn(\sum_{t=1}^n \hat{\tau}_t^p) $$ +# - `final_sign`: $$ \gamma_m^p = sgn(\Sigma_{t=1}^n \hat{\tau}_t^p) $$ # - `delta_filters`: $$ \{ \gamma_t^p = \gamma_m^p \} $$ # - `param_counts`: $$ |A^p| $$ -# - `filtered_delta`: $$ \sum_{t\in{A^p}} \hat{\tau}_t^p $$ +# - `filtered_delta`: $$ \Sigma_{t\in{A^p}} \hat{\tau}_t^p $$ # - `return`: $$ \lambda * \tau_m $$ +# Special mode "TIES-SOUP" has been implemented by setting `vote_sgn` > 0.0 +# - `final_sign`: $$ \gamma_m^p = sgn(\Sigma_{t=1}^n \gamma_t^p) $$ @convert_to_recipe def ties_sum( # aka add_difference_ties *models: Tensor | LiftFlag[MergeSpace.DELTA], k: Hyper = 0.2, + vote_sgn: Hyper = 0.0, **kwargs, ) -> Tensor | LiftFlag[MergeSpace.DELTA]: + filtered_delta, param_counts = ties_sum_deltas(*models, k=k, vote_sgn=vote_sgn) + # $$ \tau_m $$ + return torch.nan_to_num(filtered_delta.sum(dim=0) / param_counts) + + +def ties_sum_deltas( + *models: Tensor, + k: float = 0.2, + vote_sgn: float = 0.0, +): # Step 1: Trim redundant parameters # $$ \hat{\tau}_t $$ O(N) in space @@ -157,11 +210,12 @@ def ties_sum( # aka add_difference_ties # Step 2: Elect Final Signs. - # $$ \gamma_t $$ + # $$ \gamma_t $$ signs = torch.sign(deltas) - # $$ \gamma_m^p = sgn(\sum_{t=1}^n \hat{\tau}_t^p) $$ - final_sign = torch.sign(torch.sum(deltas, dim=0)) + # $$ \gamma_m^p = sgn(\Sigma_{t=1}^n \hat{\tau}_t^p) $$ for normal TIES + # $$ \gamma_m^p = sgn(\Sigma_{t=1}^n \gamma_t^p) $$ if "TIES-SOUP" is activated + final_sign = torch.sign(torch.sum(deltas if vote_sgn <= 0.0 else signs, dim=0)) # Step 3: Disjoint merge. @@ -171,11 +225,11 @@ def ties_sum( # aka add_difference_ties # $$ |A^p| $$ param_counts = torch.sum(delta_filters, dim=0) - # $$ \sum_{t\in{A^P}} \hat{\tau}_t^p $$ - filtered_delta = (deltas * delta_filters).sum(dim=0) + # $$ \Sigma_{t\in{A^P}} \hat{\tau}_t^p $$ + # (note that the sum is not performed here directly) + filtered_delta = deltas * delta_filters - # $$ \tau_m $$ - return torch.nan_to_num(filtered_delta / param_counts) + return filtered_delta, param_counts def filter_top_k(a: Tensor, k: float): @@ -614,6 +668,49 @@ def dropout( # aka n-supermario return final_delta / masks.sum(0).clamp(1) / (1 - probability) +# Part of TIES w/ DARE +# Hyperparameters defauled to values proposed to paper. +# Special mode "DROP" has been implemented by setting `no_rescale` > 0.0 +# - `return`: $$ \hat{\delta}^t = \tilde{\delta}^t $$ +@convert_to_recipe +def ties_sum_with_dropout( + *deltas: Tensor | LiftFlag[MergeSpace.DELTA], + probability: Hyper = 0.9, + no_rescale: Hyper = 0.0, + k: Hyper = 0.2, + vote_sgn: Hyper = 0.0, + apply_stock: Hyper = 0.0, + cos_eps: Hyper = 1e-6, + apply_median: Hyper = 0.0, + eps: Hyper = 1e-6, + maxiter: Hyper = 100, + ftol: Hyper =1e-20, + seed: Hyper = None, + **kwargs, +) -> Tensor | LiftFlag[MergeSpace.DELTA]: + # Set seed + torch.manual_seed(seed) + + # Under "Dropout", delta will be 0 by definition. Multiply it (Hadamard product) will return 0 also. + # $$ \tilde{\delta}^t = (1 - m^t) \odot \delta^t $$ + deltas = [delta * torch.bernoulli(torch.full(delta.shape, 1 - probability)) for delta in deltas] + + # $$ \tilde{\delta}^t = \tau_m = \hat{\tau}_t $$ O(N) in space + deltas = ties_sum_extended.__wrapped__(*deltas, k=k, vote_sgn=vote_sgn, apply_stock=apply_stock, cos_eps=cos_eps, apply_median=apply_median, eps=eps, maxiter=maxiter, ftol=ftol) + + if probability == 1.0: + # Corner case + return deltas * 0.0 + elif no_rescale <= 0.0: + # Rescale + # $$ \hat{\delta}^t = \tilde{\delta}^t / (1-p) $$ + return deltas / (1.0 - probability) + else: + # No rescale + # $$ \hat{\delta}^t = \tilde{\delta}^t $$ + return deltas + + def overlapping_sets_pmf(n, p, overlap, overlap_emphasis): if np.isclose(overlap, round(overlap)): if round(overlap) % 2 == 0: @@ -653,3 +750,95 @@ def binomial_coefficient_np(n, k): for i in range(1, k+1): result = result * (n - i + 1) // i return result + + +# Following mergekit's implementation of Model Stock (which official implementation doesn't exist) +# https://github.com/arcee-ai/mergekit/blob/main/mergekit/merge_methods/model_stock.py +# I will break the functions to be retrivible for other algos like TIES. +@convert_to_recipe +def model_stock_for_tensor( + *deltas: Tensor | LiftFlag[MergeSpace.DELTA], + cos_eps: Hyper = 1e-6, + **kwargs, +) -> Tensor | LiftFlag[MergeSpace.DELTA]: + + # This is obvious. + w_avg = n_average.__wrapped__(*deltas) + + # t can get inf so handle with care + t = get_model_stock_t(deltas, cos_eps) + + # return w_h. Notice that w_0 is 0 here. + return torch.nan_to_num(t * w_avg) + + +# The guess from mergekit: Average of cos(theta). Expected value is 0, somehow match with paper. +# However this may be very unstable, and the range is still -1 to 1. +def get_model_stock_t(deltas, cos_eps): + n = len(deltas) + + # Generator function. Default eps from torch API doc. + cos = torch.nn.CosineSimilarity(dim=-1, eps=cos_eps) + + # One-liner is all you need. I may make it in running average if it really memory hungry. + cos_thetas = [cos(deltas[i], deltas[i + 1]) for i, _ in enumerate(deltas) if (i + 1) < n] + + # Still a vector. + cos_theta = torch.stack(cos_thetas).mean(dim=0) + + # Convert to column vector for multiplication. + t = (n * cos_theta / (1 + (n - 1) * cos_theta)).unsqueeze(-1) + + return t + + +# This becomes a wrapper since I want TIES use GM also. +@convert_to_recipe +def geometric_median( + *models: Tensor | SameMergeSpace, + eps: Hyper = 1e-6, + maxiter: Hyper = 100, + ftol: Hyper = 1e-20, + **kwargs, +) -> Tensor | SameMergeSpace: + return geometric_median_list_of_array(models, eps, maxiter, ftol) + + +# Original sourcecode: https://github.com/krishnap25/geom_median/blob/main/src/geom_median/torch/weiszfeld_list_of_array.py +# Changed to "List comprehension" and rely on torch API only. It is now fully parallel. +def geometric_median_list_of_array(models, eps, maxiter, ftol): + # I think it is impossible to pass this from user space so I hardcode this instead. + # Meanwhile I rename "points" as "models" + # no_grad part is rare case: Merge algorithm under GPU is never heard. + weights = torch.ones(len(models), device=models[0].device) + + # initialize median estimate at mean + median = weighted_average(models, weights) + new_weights = weights + objective_value = geometric_median_objective(median, models, weights) + + # Weiszfeld iterations + for _ in range(maxiter): + prev_obj_value = objective_value + denom = torch.stack([l2distance(p, median) for p in models]) + new_weights = weights / torch.clamp(denom, min=eps) + median = weighted_average(models, new_weights) + + objective_value = geometric_median_objective(median, models, weights) + if abs(prev_obj_value - objective_value) <= ftol * objective_value: + break + + return weighted_average(models, new_weights) + + +def weighted_average(points, weights): + # weighted_average_component is not even required. + return torch.sum(torch.stack([p * weights[i] for i, p in enumerate(points)]), dim=0) / weights.sum() + + +def geometric_median_objective(median, points, weights): + return torch.mean(torch.stack([l2distance(point, median) for point in points]) * weights) + + +def l2distance(p1, p2): + return torch.dist(p1, p2, p=2) diff --git a/test/unit_test_dropout.py b/test/unit_test_dropout.py new file mode 100644 index 0000000..23acc13 --- /dev/null +++ b/test/unit_test_dropout.py @@ -0,0 +1,61 @@ +import torch +import sd_mecha + +_k = 1.0 +_use_delta = 0.0 +_use_signs = 1.0 + +_probability = 0.25 +_use_rescale = 0.0 +_no_rescale = 1.0 +_seed = 114514 + +_alpha = 0.0 #Not used + +# Sudoku of 4x4 +_models = [ + torch.tensor([ + [-1., 2., 3., 4.], + [4., -3., 2., 1.], + [3., 4., 1., -2.], + [2., 1., -4., 3.], + ]), + torch.tensor([ + [3., 4., 1., -2.], + [2., 1., -4., 3.], + [-1., 2., 3., 4.], + [4., -3., 2., 1.], + ]) +] + +_expected1 = torch.tensor([ + [-1.3333, 4.0000, 2.6667, 0.0000], + [ 4.0000, -4.0000, 2.6667, 1.3333], + [-1.3333, 4.0000, 1.3333, 0.0000], + [ 4.0000, 0.0000, -5.3333, 4.0000] +]) + +_expected2 = torch.tensor([ + [-1., 3., 2., 0.], + [ 3., 0., 2., 1.], + [-1., 3., 1., 0.], + [ 3., 0., -4., 3.] +]) + +#Visual inspect if dropout really happens + +_dare1 = sd_mecha.ties_sum_with_dropout.__wrapped__(*_models, probability=_probability, k=_k, seed=_seed) +#print(_dare1) + +assert torch.allclose(_dare1, _expected1, atol = 0.0001) + +_dare2 = sd_mecha.ties_sum_with_dropout.__wrapped__(*_models, probability=_probability, no_rescale=_no_rescale, k=_k, vote_sgn=_use_signs, seed=_seed) + +#print(_dare2) +assert torch.allclose(_dare2, _expected2, atol = 0.0001) + +#_ties1 = sd_mecha.ties_sum.__wrapped__(*_models, k=_k) +#print(_ties1) + +#_ties2 = sd_mecha.ties_sum.__wrapped__(*_models, k=_k, vote_sgn=_use_signs) +#print(_ties2) diff --git a/test/unit_test_geom_median.py b/test/unit_test_geom_median.py new file mode 100644 index 0000000..1b29daa --- /dev/null +++ b/test/unit_test_geom_median.py @@ -0,0 +1,82 @@ +import torch +import sd_mecha +import time + + +_k = 1.0 +_use_delta = 0.0 +_use_signs = 1.0 + +_probability = 0.25 +_use_rescale = 0.0 +_no_rescale = 1.0 +_seed = 114514 + +_alpha = 0.0 #Not used +_cos_eps = 1e-6 +_apply_stock = 1.0 +_no_stock = 0.0 + +_apply_median = 1.0 +_no_median = 0.0 + +_eps = 1e-6 +_maxiter = 100 #1 iter = 10 sec, avg 5-10 iter +_ftol = 1e-20 + +_models = [ + torch.tensor([ + [3., 4., 1., -2.], + [2., 1., -4., 3.], + [-1., 2., 3., 4.], + [4., -3., 2., 1.], + ]), + torch.tensor([ + [-1., 3., 4., 2.], + [4., 2., -3., 1.], + [3., -1., 2., 4.], + [2., 4., 1., -3.], + ]), + torch.tensor([ + [-1., 3., 2., 0.], + [ 3., 0., 2., 1.], + [-1., 3., 1., 0.], + [ 3., 0., -4., 3.] + ]) +] + +_models2 = [] +for i in range(100): + _models2.append(torch.rand(1280, 1280)) + +# Not used +_weights = torch.ones(len(_models), device=_models[0].device) + +_expected = torch.tensor([ + [ 0.4791, 3.3698, 2.2343, -0.1354], + [ 2.9323, 0.9739, -1.7289, 1.7395], + [ 0.2082, 1.4220, 2.0416, 2.6873], + [ 3.0677, 0.0989, -0.2711, 0.4481] +]) + +_expected2 = torch.tensor([ + [ 0.0000, 3.1750, 2.2089, 0.0000], + [ 3.0170, 0.5588, -0.6999, 1.1580], + [ 0.5758, 2.2492, 1.1580, 0.0000], + [ 2.9830, 0.0000, 0.3500, 0.1750] +]) + +median = sd_mecha.geometric_median.__wrapped__(*_models, eps=_eps, maxiter=_maxiter, ftol=_ftol) +#print(median2) +assert torch.allclose(median, _expected, atol = 0.0001) + +_with_dare = sd_mecha.ties_sum_with_dropout.__wrapped__(*_models, probability=_probability, no_rescale=_no_rescale, k=_k, vote_sgn=_use_signs, seed=_seed, apply_stock = _no_stock, apply_median = _apply_median, cos_eps = _cos_eps, eps=_eps, maxiter=_maxiter, ftol=_ftol) +#print(_with_dare) +assert torch.allclose(_with_dare, _expected2, atol = 0.0001) + +ts = time.time() +median2 = sd_mecha.geometric_median.__wrapped__(*_models2, eps=_eps, maxiter=_maxiter, ftol=_ftol) +#print(median2) #Around 0.5 but will flutter +te = time.time() +#print(te - ts) #WS = 0.9, Notebook = 1.76 +assert (te - ts) < 10.0 \ No newline at end of file diff --git a/test/unit_test_modelstock.py b/test/unit_test_modelstock.py new file mode 100644 index 0000000..12a17c2 --- /dev/null +++ b/test/unit_test_modelstock.py @@ -0,0 +1,75 @@ +import torch +import sd_mecha + +_k = 1.0 +_use_delta = 0.0 +_use_signs = 1.0 + +_probability = 0.25 +_use_rescale = 0.0 +_no_rescale = 1.0 +_seed = 114514 + +_alpha = 0.0 #Not used +_cos_eps = 1e-6 +_apply_stock = 1.0 +_no_stock = 0.0 + +# This time more 4x4 sudoku. + +_models = [ + torch.tensor([ + [3., 4., 1., -2.], + [2., 1., -4., 3.], + [-1., 2., 3., 4.], + [4., -3., 2., 1.], + ]), + torch.tensor([ + [-1., 3., 4., 2.], + [4., 2., -3., 1.], + [3., -1., 2., 4.], + [2., 4., 1., -3.], + ]), + torch.tensor([ + [-1., 2., 3., 4.], + [4., -3., 2., 1.], + [3., 4., 1., -2.], + [2., 1., -4., 3.], + ]) +] + +_expected1 = torch.tensor([ + [ 0.2727, 2.4545, 2.1818, 1.0909], + [ 2.5000, 0.0000, -1.2500, 1.2500], + [ 0.8696, 0.8696, 1.0435, 1.0435], + [-2.0000, -0.5000, 0.2500, -0.2500] +]) + +# Notice that it can be brutal. +_expected2 = torch.tensor([ + [ 0.0000, 2.6592, 2.3638, 3.5456], + [ 2.8031, 1.2614, -3.3638, 1.6819], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 2.6077, 0.0000, 1.9557, 0.9779] +]) + +_expected3 = torch.tensor([ + [ 0.0000, 3.0000, 2.6667, 4.0000], + [ 3.3333, 1.5000, -4.0000, 2.0000], + [ 3.0000, 3.0000, 2.0000, 0.0000], + [ 2.6667, 0.0000, 2.0000, 1.0000] +]) + +#Visual inspect if dropout really happens + +_stock_only = sd_mecha.model_stock_for_tensor.__wrapped__(*_models, cos_eps=_cos_eps) +#print(_stock_only) +assert torch.allclose(_stock_only, _expected1, atol=0.0001) + +_with_dare = sd_mecha.ties_sum_with_dropout.__wrapped__(*_models, probability=_probability, no_rescale=_no_rescale, k=_k, vote_sgn=_use_signs, seed=_seed, apply_stock = _apply_stock, cos_eps = _cos_eps) +#print(_with_dare) +assert torch.allclose(_with_dare, _expected2, atol=0.0001) + +_dare_only = sd_mecha.ties_sum_with_dropout.__wrapped__(*_models, probability=_probability, no_rescale=_no_rescale, k=_k, vote_sgn=_use_signs, seed=_seed, apply_stock = _no_stock, cos_eps = _cos_eps) +#print(_dare_only) +assert torch.allclose(_dare_only, _expected3, atol=0.0001) diff --git a/test/unit_test_n_average.py b/test/unit_test_n_average.py new file mode 100644 index 0000000..3343b52 --- /dev/null +++ b/test/unit_test_n_average.py @@ -0,0 +1,34 @@ +import torch +import sd_mecha + +_models = [ + torch.tensor([ + [3., 4., 1., -2.], + [2., 1., -4., 3.], + [-1., 2., 3., 4.], + [4., -3., 2., 1.], + ]), + torch.tensor([ + [-1., 3., 4., 2.], + [4., 2., -3., 1.], + [3., -1., 2., 4.], + [2., 4., 1., -3.], + ]), + torch.tensor([ + [-1., 3., 2., 0.], + [ 3., 0., 2., 1.], + [-1., 3., 1., 0.], + [ 3., 0., -4., 3.] + ]) +] + +_expected = torch.tensor([ + [ 0.3333, 3.3333, 2.3333, 0.0000], + [ 3.0000, 1.0000, -1.6667, 1.6667], + [ 0.3333, 1.3333, 2.0000, 2.6667], + [ 3.0000, 0.3333, -0.3333, 0.3333] +]) + +avg = sd_mecha.n_average.__wrapped__(*_models) + +assert torch.allclose(avg, _expected, atol=0.0001) diff --git a/test/unit_test_ties.py b/test/unit_test_ties.py index 700bb72..63ca6d6 100644 --- a/test/unit_test_ties.py +++ b/test/unit_test_ties.py @@ -24,6 +24,8 @@ _alpha = 0.33 _k = 0.5 +_use_delta = 0.0 +_use_signs = 1.0 # Sudoku of 4x4, "top k" should be 2. _models = [ torch.tensor([ @@ -48,3 +50,6 @@ _actual = sd_mecha.ties_sum.__wrapped__(*_models, k=_k) assert torch.allclose(_actual, _expected) + +_actual2 = sd_mecha.ties_sum.__wrapped__(*_models, k=_k, vote_sgn=_use_signs) +assert not torch.allclose(_actual, _actual2)