From 58d195aade9b9a09952f165858f71ae900cf3695 Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Thu, 7 Dec 2023 03:04:12 +0900 Subject: [PATCH] Update orig_attention.py continous -> continuous --- magicanimate/models/orig_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/magicanimate/models/orig_attention.py b/magicanimate/models/orig_attention.py index 9c3eba09..fd4ab237 100644 --- a/magicanimate/models/orig_attention.py +++ b/magicanimate/models/orig_attention.py @@ -114,7 +114,7 @@ def __init__( self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim - # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration self.is_input_continuous = in_channels is not None self.is_input_vectorized = num_vector_embeds is not None @@ -185,7 +185,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. - When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to @@ -985,4 +985,4 @@ def forward( if not return_dict: return (output_states,) - return Transformer2DModelOutput(sample=output_states) \ No newline at end of file + return Transformer2DModelOutput(sample=output_states)