Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some minor code tweaks #4303

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions comfy/cldm/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists
from .control_types import UNION_CONTROLNET_TYPES
from collections import OrderedDict
import comfy.ops
Expand Down Expand Up @@ -234,12 +233,12 @@ def __init__(
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
if disable_self_attentions is not None:
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False

if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
if num_attention_blocks is None or nr < num_attention_blocks[level]:
layers.append(
SpatialTransformer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
Expand Down Expand Up @@ -434,4 +433,3 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
out_middle.append(self.middle_block_out(h, emb, context))

return {"middle": out_middle, "output": out_output}

16 changes: 2 additions & 14 deletions comfy/gligen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,6 @@
import comfy.ops
ops = comfy.ops.manual_cast

def exists(val):
return val is not None


def uniq(arr):
return{el: True for el in arr}.keys()


def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d


# feedforward
class GEGLU(nn.Module):
Expand All @@ -34,7 +21,8 @@ class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
if not dim_out:
dim_out = dim() if isfunction(dim) else dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dim is definitely not a function here

project_in = nn.Sequential(
ops.Linear(dim, inner_dim),
nn.GELU()
Expand Down
47 changes: 12 additions & 35 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,9 @@
def get_attn_precision(attn_precision):
if args.dont_upcast_attention:
return None
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
return FORCE_UPCAST_ATTENTION_DTYPE
return attn_precision

def exists(val):
return val is not None


def uniq(arr):
return{el: True for el in arr}.keys()


def default(val, d):
if exists(val):
return val
return d


def max_neg_value(t):
return -torch.finfo(t.dtype).max


def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
if FORCE_UPCAST_ATTENTION_DTYPE is None:
return attn_precision
return FORCE_UPCAST_ATTENTION_DTYPE


# feedforward
Expand All @@ -68,7 +44,8 @@ class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
if dim_out is None:
dim_out = dim
project_in = nn.Sequential(
operations.Linear(dim, inner_dim, dtype=dtype, device=device),
nn.GELU()
Expand Down Expand Up @@ -121,7 +98,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape

del q, k

if exists(mask):
if mask is not None:
if mask.dtype == torch.bool:
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
max_neg_value = -torch.finfo(sim.dtype).max
Expand Down Expand Up @@ -449,7 +426,8 @@ class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
if context_dim is None:
context_dim = query_dim
self.attn_precision = attn_precision

self.heads = heads
Expand All @@ -463,7 +441,8 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.

def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x)
context = default(context, x)
if context is None:
context = x
k = self.to_k(context)
if value is not None:
v = self.to_v(value)
Expand Down Expand Up @@ -649,7 +628,7 @@ def __init__(self, in_channels, n_heads, d_head,
disable_self_attn=False, use_linear=False,
use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
if context_dim is not None and not isinstance(context_dim, list):
context_dim = [context_dim] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
Expand Down Expand Up @@ -799,7 +778,7 @@ def forward(
_, _, h, w = x.shape
x_in = x
spatial_context = None
if exists(context):
if context is not None:
spatial_context = context

if self.use_spatial_context:
Expand Down Expand Up @@ -861,5 +840,3 @@ def forward(
x = self.proj_out(x)
out = x + x_in
return out


13 changes: 4 additions & 9 deletions comfy/ldm/modules/diffusionmodules/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
import comfy.ops
import comfy.ldm.common_dit

def default(x, y):
if x is not None:
return x
return y

class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
Expand Down Expand Up @@ -713,14 +709,14 @@ def __init__(
self.learn_sigma = learn_sigma
self.in_channels = in_channels
default_out_channels = in_channels * 2 if learn_sigma else in_channels
self.out_channels = default(out_channels, default_out_channels)
self.out_channels = default_out_channels if out_channels is None else out_channels
self.patch_size = patch_size
self.pos_embed_scaling_factor = pos_embed_scaling_factor
self.pos_embed_offset = pos_embed_offset
self.pos_embed_max_size = pos_embed_max_size

# hidden_size = default(hidden_size, 64 * depth)
# num_heads = default(num_heads, hidden_size // 64)
# hidden_size = 64 * depth if hidden_size is None else hidden_size
# num_heads = hidden_size // 64 if num_heads is None else num_heads

# apply magic --> this defines a head_size of 64
self.hidden_size = 64 * depth
Expand Down Expand Up @@ -862,7 +858,7 @@ def forward_core_with_concat(
context = torch.cat(
(
repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
default(context, torch.Tensor([]).type_as(x)),
torch.Tensor([]).type_as(x) is context is None else context,
),
1,
)
Expand Down Expand Up @@ -932,4 +928,3 @@ def forward(
**kwargs,
) -> torch.Tensor:
return super().forward(x, timesteps, context=context, y=y, control=control)

15 changes: 7 additions & 8 deletions comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists
from ..attention import SpatialTransformer, SpatialVideoTransformer
import comfy.ops
ops = comfy.ops.disable_weight_init

Expand Down Expand Up @@ -301,11 +300,11 @@ def __init__(
)

self.time_stack = ResBlock(
default(out_channels, channels),
channels if out_channels is None else out_channels,
emb_channels,
dropout=dropout,
dims=3,
out_channels=default(out_channels, channels),
out_channels=channels if out_channels is None else out_channels,
use_scale_shift_norm=False,
use_conv=False,
up=False,
Expand Down Expand Up @@ -642,12 +641,12 @@ def get_resblock(
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
if disable_self_attentions is not None:
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False

if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
if not num_attention_blocks is None or nr < num_attention_blocks[level]:
layers.append(get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
Expand Down Expand Up @@ -768,12 +767,12 @@ def get_resblock(
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
if disable_self_attentions is not None:
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False

if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
if num_attention_blocks is None or i < num_attention_blocks[level]:
layers.append(
get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
Expand Down
4 changes: 0 additions & 4 deletions comfy/ldm/modules/diffusionmodules/upscaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import partial

from .util import extract_into_tensor, make_beta_schedule
from comfy.ldm.util import default


class AbstractLowScaleModel(nn.Module):
Expand Down Expand Up @@ -80,6 +79,3 @@ def forward(self, x, noise_level=None, seed=None):
assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level, seed=seed)
return z, noise_level



12 changes: 1 addition & 11 deletions comfy/ldm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,6 @@ def isimage(x):
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)


def exists(x):
return x is not None


def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d


def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Expand Down Expand Up @@ -194,4 +184,4 @@ def step(self, closure=None):
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)

return loss
return loss
1 change: 1 addition & 0 deletions folder_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
folder_names_and_paths[folder_name] = ([full_folder_path], set())

def get_folder_paths(folder_name: str) -> list[str]:
global folder_names_and_paths
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

folder_name = map_legacy(folder_name)
return folder_names_and_paths[folder_name][0][:]

Expand Down