Skip to content

Commit

Permalink
fix flux test shape
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jan 10, 2025
1 parent 4051f76 commit 3bd5336
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions tests/onnxruntime/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,21 +221,19 @@ def test_shape(self, model_arch: str):
elif output_type == "pt":
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
else:
expected_height = height // pipeline.vae_scale_factor
expected_width = width // pipeline.vae_scale_factor

if model_arch == "flux":
expected_height = height // 2**pipeline.vae_scale_factor
expected_width = width // 2**pipeline.vae_scale_factor
channels = pipeline.transformer.config.in_channels
if is_diffusers_version(">=", "0.32.0"):
expected_shape = (batch_size, expected_height * expected_width // 4, channels)
else:
expected_shape = (batch_size, expected_height * expected_width, channels)

elif model_arch == "stable-diffusion-3":
out_channels = pipeline.transformer.config.out_channels
expected_shape = (batch_size, out_channels, expected_height, expected_width)
expected_shape = (batch_size, expected_height * expected_width, channels)
else:
out_channels = pipeline.unet.config.out_channels
expected_height = height // pipeline.vae_scale_factor
expected_width = width // pipeline.vae_scale_factor
out_channels = (
pipeline.unet.config.out_channels
if getattr(pipeline, "unet", None) is not None
else pipeline.transformer.config.out_channels
)
expected_shape = (batch_size, out_channels, expected_height, expected_width)

self.assertEqual(outputs.shape, expected_shape)
Expand Down

0 comments on commit 3bd5336

Please sign in to comment.