Skip to content

Commit

Permalink
bring back old impl of train diff and add new methods (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
ljleb authored Jul 7, 2024
1 parent bab964b commit 6a517f0
Showing 1 changed file with 39 additions and 2 deletions.
41 changes: 39 additions & 2 deletions sd_mecha/merge_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,48 @@ def train_difference(
*,
alpha: Hyper = 1.0,
**kwargs,
) -> Tensor | SameMergeSpace:
mask = 1.8 * torch.nan_to_num(torch.abs(b - a) / (torch.abs(b - a) + torch.abs(b - c)), nan=0)
return a + (b - c) * alpha * mask


@convert_to_recipe
def add_opposite(
a: Tensor | SameMergeSpace,
b: Tensor | SameMergeSpace,
c: Tensor | SameMergeSpace,
*,
alpha: Hyper = 1.0,
**kwargs,
) -> Tensor | SameMergeSpace:
threshold = torch.maximum(torch.abs(a - c), torch.abs(b - c))
dissimilarity = torch.clamp(torch.nan_to_num((c - a) * (b - c) / threshold**2, nan=0), 0)
mask = 1 - torch.nan_to_num((a - c) * (b - c) / threshold**2, nan=0)
return a + (b - c) * alpha * mask

return a + (b - c) * alpha * dissimilarity

@convert_to_recipe
def clamped_add_opposite(
a: Tensor | SameMergeSpace,
b: Tensor | SameMergeSpace,
c: Tensor | SameMergeSpace,
*,
alpha: Hyper = 1.0,
**kwargs,
) -> Tensor | SameMergeSpace:
threshold = torch.maximum(torch.abs(a - c), torch.abs(b - c))
mask = torch.clamp(torch.nan_to_num((c - a) * (b - c) / threshold**2, nan=0), 0) * 2
return a + (b - c) * alpha * mask


@convert_to_recipe
def select_max_delta(
a: Tensor | LiftFlag[MergeSpace.DELTA],
b: Tensor | LiftFlag[MergeSpace.DELTA],
*,
alpha: Hyper = 0.5,
**kwargs,
) -> Tensor | LiftFlag[MergeSpace.DELTA]:
return torch.where((1 - alpha) * a.abs() >= alpha * b.abs(), a, b)


@convert_to_recipe
Expand Down

0 comments on commit 6a517f0

Please sign in to comment.