diff --git a/kandinsky2/model/unet.py b/kandinsky2/model/unet.py index 6a5b6cf..2bac4b5 100644 --- a/kandinsky2/model/unet.py +++ b/kandinsky2/model/unet.py @@ -271,7 +271,7 @@ def forward(self, x, encoder_out=None): class QKVAttention(nn.Module): """ - A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping """ def __init__(self, n_heads, use_flash_attention=False):