Skip to content

Commit

Permalink
Ties upgrade (#37)
Browse files Browse the repository at this point in the history
* add TIES_SOUP support

* TIES w/ DARE

* DROP: DARE without ARE

* modelstock and n_average

* hotfix + it is memory intensive

* hotfix: nan

* modelstock fix + geometric_median

* DARE / TIES / Median Stack

* move things around

* return type

* forward n_average

* reformat and deduplicate

---------

Co-authored-by: Darren Laurie <[email protected]>
  • Loading branch information
ljleb and 6DammK9 authored Jul 30, 2024
1 parent 9291b14 commit c0ed6cc
Show file tree
Hide file tree
Showing 7 changed files with 619 additions and 11 deletions.
166 changes: 164 additions & 2 deletions sd_mecha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def serialize_and_save(

weighted_sum = merge_methods.weighted_sum
slerp = merge_methods.slerp
n_average = merge_methods.n_average


def add_difference(
Expand Down Expand Up @@ -132,6 +133,7 @@ def add_perpendicular(
cosine_add_a = merge_methods.add_cosine_a
cosine_add_b = merge_methods.add_cosine_b
ties_sum = merge_methods.ties_sum
ties_sum_extended = merge_methods.ties_sum_extended


# latex notes in reference to original implementation: https://arxiv.org/abs/2306.01708
Expand All @@ -142,11 +144,13 @@ def add_perpendicular(
# - `k`: $$ k $$ ( From $$ \% $$ to $$ 1 $$ )
# - `res`: $$ \lambda * \tau_m $$
# - `return`: $$ \theta_m $$
# Special mode "TIES-SOUP" has been implemented by setting `vote_sgn` > 0.0
# Special mode "TIES-STOCK" has been implemented by setting `apply_stock` > 0.0
def add_difference_ties(
base: RecipeNodeOrPath,
*models: RecipeNodeOrPath,
alpha: float,
k: float = 0.2,
alpha: Hyper,
k: Hyper = 0.2,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
) -> recipe_nodes.RecipeNode:
Expand Down Expand Up @@ -182,6 +186,61 @@ def add_difference_ties(
)


def add_difference_ties_extended(
base: RecipeNodeOrPath,
*models: RecipeNodeOrPath,
alpha: Hyper,
k: Hyper = 0.2,
vote_sgn: Hyper = 0.0,
apply_stock: Hyper = 0.0,
cos_eps: Hyper = 1e-6,
apply_median: Hyper = 0.0,
eps: Hyper = 1e-6,
maxiter: Hyper = 100,
ftol: Hyper =1e-20,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
) -> recipe_nodes.RecipeNode:
# $$ \{\theta_{init}\}_{t=1}^n $$
base = path_to_node(base)
models = tuple(path_to_node(model) for model in models)

# Create task vectors.
# $$ \tau_t $$
models = tuple(
subtract(model, base)
if model.merge_space is MergeSpace.BASE else
model
for model in models
)

# step 1 + step 2 + step 3
res = ties_sum_extended(
*models,
k=k,
vote_sgn=vote_sgn,
apply_stock=apply_stock,
cos_eps=cos_eps,
apply_median=apply_median,
eps=eps,
maxiter=maxiter,
ftol=ftol,
device=device,
dtype=dtype,
)

# Obtain merged checkpoint

# $$ \theta_{init} + \lambda * \tau_m $$
return add_difference(
base, res,
alpha=alpha,
device=device,
dtype=dtype,
)



def copy_region(
a: RecipeNodeOrPath, b: RecipeNodeOrPath, c: Optional[RecipeNodeOrPath] = None, *,
width: Hyper = 1.0,
Expand Down Expand Up @@ -299,6 +358,109 @@ def dropout(
return sd_mecha.add_difference(a, ba_delta, alpha=alpha, device=device, dtype=dtype)


ties_sum_with_dropout = merge_methods.ties_sum_with_dropout


# latex notes in reference to original implementation: https://arxiv.org/abs/2311.03099
# Notice that this is "TIES Merging w/ DARE", which is "Prune > Merge (TIES) > Rescale"
# See https://slgero.medium.com/merge-large-language-models-29897aeb1d1a for details
# - `base`: $$ \theta_{PRE} $$
# - `*models`: $$ \theta_{SFT}^{t_k} $$
# - `deltas`: $$ \delta^t = \theta_{SFT}^{t} - \theta_{PRE} \in \mathbb{R}^d $$
# - `probability`: $$ p $$
# - `res`: $$ \hat{\delta}^t = \tilde{\delta}^t / (1-p) $$
# - `alpha`: $$ \lambda $$
# - `k`: $$ k $$ ( From $$ \% $$ to $$ 1 $$ ) in TIES paper
# - `return`: $$ \theta_M = \theta_{PRE} + \lambda \cdot \Sigma_{k=1}^{K} \tilde{\delta}^{t_k} $$
# Special mode "TIES-SOUP" has been implemented by setting `vote_sgn` > 0.0
def ties_with_dare(
base: RecipeNodeOrPath,
*models: RecipeNodeOrPath,
probability: Hyper = 0.9,
no_rescale: Hyper = 0.0,
alpha: Hyper = 0.5,
seed: Optional[Hyper] = None,
k: Hyper = 0.2,
vote_sgn: Hyper = 0.0,
apply_stock: Hyper = 0.0,
cos_eps: Hyper = 1e-6,
apply_median: Hyper = 0.0,
eps: Hyper = 1e-6,
maxiter: Hyper = 100,
ftol: Hyper =1e-20,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
) -> recipe_nodes.RecipeNode:
# $$ \delta^t = \theta_{SFT}^{t} - \theta_{PRE} \in \mathbb{R}^d $$
base = path_to_node(base)
models = tuple(path_to_node(model) for model in models)
deltas = tuple(
subtract(model, base)
if model.merge_space is MergeSpace.BASE else
model
for model in models
)

# $$ \tilde{\delta}^{t_k} $$
res = ties_sum_with_dropout(
*deltas,
probability=probability,
no_rescale=no_rescale,
k=k,
vote_sgn=vote_sgn,
seed=seed,
apply_stock=apply_stock,
cos_eps=cos_eps,
apply_median=apply_median,
eps=eps,
maxiter=maxiter,
ftol=ftol,
device=device,
dtype=dtype
)

# $$ \theta_M = \theta_{PRE} + \lambda \cdot \Sigma_{k=1}^{K} \tilde{\delta}^{t_k} $$
return sd_mecha.add_difference(base, res, alpha=alpha, device=device, dtype=dtype)


model_stock_for_tensor = merge_methods.model_stock_for_tensor


# Following mergekit's implementation of Model Stock (which official implementation doesn't exist)
# https://github.com/arcee-ai/mergekit/blob/main/mergekit/merge_methods/model_stock.py
def model_stock_n_models(
base: RecipeNodeOrPath,
*models: RecipeNodeOrPath,
cos_eps: Hyper = 1e-6,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
) -> recipe_nodes.RecipeNode:

base = path_to_node(base)
models = tuple(path_to_node(model) for model in models)
deltas = tuple(
subtract(model, base)
if model.merge_space is MergeSpace.BASE else
model
for model in models
)

# This is hacky: Both w_avg and w_h will be calculated there.
# Notice that t and cos_theta is vector instead of single value.
# Conceptually it could compatable with TIES, but algorithm should be rewritten.
res = model_stock_for_tensor(
*deltas,
cos_eps=cos_eps,
device=device,
dtype=dtype
)

return sd_mecha.add_difference(base, res, alpha=1.0, device=device, dtype=dtype)


geometric_median = merge_methods.geometric_median


def model(state_dict: str | pathlib.Path | Mapping[str, Tensor], model_arch: str = "sd1", model_type: str = "base"):
return recipe_nodes.ModelRecipeNode(state_dict, model_arch, model_type)

Expand Down
Loading

0 comments on commit c0ed6cc

Please sign in to comment.