Skip to content

Commit

Permalink
add SDXL turbo
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Dec 7, 2023
1 parent f76ba5b commit 4e7fb4d
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 11 deletions.
28 changes: 28 additions & 0 deletions benchmarks/base_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"stabilityai/stable-diffusion-2-1": (768, 768),
"stabilityai/stable-diffusion-xl-base-1.0": (1024, 1024),
"stabilityai/stable-diffusion-xl-refiner-1.0": (1024, 1024),
"stabilityai/sdxl-turbo": (512, 512),
}


Expand Down Expand Up @@ -119,6 +120,19 @@ def benchmark(self, args):
flush()


class TurboTextToImageBenchmark(TextToImageBenchmark):
def __init__(self, args):
super().__init__(args)

def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
guidance_scale=0.0,
)


class ImageToImageBenchmark(TextToImageBenchmark):
pipeline_class = AutoPipelineForImage2Image
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/1665_Girl_with_a_Pearl_Earring.jpg"
Expand All @@ -137,6 +151,20 @@ def run_inference(self, pipe, args):
)


class TurboImageToImageBenchmark(ImageToImageBenchmark):
def __init__(self, args):
super().__init__(args)

def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
image=self.image,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
guidance_scale=0.0,
)


class InpaintingBenchmark(ImageToImageBenchmark):
pipeline_class = AutoPipelineForInpainting
mask_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/overture-creations-5sI6fQgYIuo_mask.png"
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/benchmark_sd_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


sys.path.append(".")
from base_classes import ImageToImageBenchmark # noqa: E402
from base_classes import ImageToImageBenchmark, TurboImageToImageBenchmark # noqa: E402


if __name__ == "__main__":
Expand All @@ -16,6 +16,7 @@
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-xl-refiner-1.0",
"stabilityai/sdxl-turbo",
],
)
parser.add_argument("--batch_size", type=int, default=1)
Expand All @@ -24,5 +25,5 @@
parser.add_argument("--run_compile", action="store_true")
args = parser.parse_args()

benchmark_pipe = ImageToImageBenchmark(args)
benchmark_pipe = ImageToImageBenchmark(args) if "turbo" not in args.ckpt else TurboImageToImageBenchmark(args)
benchmark_pipe.benchmark(args)
5 changes: 3 additions & 2 deletions benchmarks/benchmark_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


sys.path.append(".")
from base_classes import TextToImageBenchmark # noqa: E402
from base_classes import TextToImageBenchmark, TurboTextToImageBenchmark # noqa: E402


if __name__ == "__main__":
Expand All @@ -18,6 +18,7 @@
"stabilityai/stable-diffusion-xl-base-1.0",
"kandinsky-community/kandinsky-2-2-decoder",
"warp-ai/wuerstchen",
"stabilityai/sdxl-turbo",
],
)
parser.add_argument("--batch_size", type=int, default=1)
Expand All @@ -26,5 +27,5 @@
parser.add_argument("--run_compile", action="store_true")
args = parser.parse_args()

benchmark_pipe = TextToImageBenchmark(args)
benchmark_pipe = TextToImageBenchmark(args) if "turbo" not in args.ckpt else TurboTextToImageBenchmark(args)
benchmark_pipe.benchmark(args)
18 changes: 11 additions & 7 deletions benchmarks/run_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,17 @@ def main():
command += " --run_compile"
run_command(command.split())

elif file in ["benchmark_sd_img.py", "benchmark_sd_inpainting.py"]:
sdxl_ckpt = (
"stabilityai/stable-diffusion-xl-refiner-1.0"
if "inpainting" not in file
else "stabilityai/stable-diffusion-xl-base-1.0"
)
command = f"python {file} --ckpt {sdxl_ckpt}"
elif file == "benchmark_sd_img.py":
for ckpt in ["stabilityai/stable-diffusion-xl-refiner-1.0", "stabilityai/sdxl-turbo"]:
command = f"python {file} --ckpt {ckpt}"
run_command(command.split())

command += " --run_compile"
run_command(command.split())

elif file == "benchmark_sd_inpainting.py":
sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
command = f"python {file} --ckpt {ckpt}"
run_command(command.split())

command += " --run_compile"
Expand Down

0 comments on commit 4e7fb4d

Please sign in to comment.