From d7d22f2bcc9fe140f554c748a73ec902d6414a44 Mon Sep 17 00:00:00 2001 From: "Kuan Heng (Jordan) Lin" Date: Sat, 28 Sep 2024 14:42:06 -0700 Subject: [PATCH] Better seeding + slight Gradio demo example schedule changes --- app_ctrlx.py | 10 ++++------ ctrl_x/utils/utils.py | 15 ++++++++++++++- run_ctrlx.py | 5 +---- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/app_ctrlx.py b/app_ctrlx.py index a0ba477..4f3689c 100644 --- a/app_ctrlx.py +++ b/app_ctrlx.py @@ -157,9 +157,7 @@ def inference( structure_schedule, appearance_schedule, use_advanced_config, control_config, ): - global pipe, refiner - - torch.manual_seed(seed) + seed_everything(seed) pipe.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = pipe.scheduler.timesteps @@ -310,7 +308,7 @@ def inference( "a 3D mesh of a cat", "", "a photo of a tiger standing on snow", - 0.6, 0.6, + 0.5, 0.6, ], [ "assets/images/dog__sketch.jpg", @@ -358,7 +356,7 @@ def inference( "a segmentation map of a bear and an avocado", "", "a realistic photo of a bear and an avocado in a forest", - 0.6, 0.6, + 0.5, 0.6, ], [ "assets/images/cat__point_cloud.jpg", @@ -382,7 +380,7 @@ def inference( "a 3D model of a person holding a sword and shield", "", "a photo of a medieval soldier standing on a barren field, raining", - 0.6, 0.6, + 0.5, 0.6, ], [ "assets/images/person__mesh.jpg", diff --git a/ctrl_x/utils/utils.py b/ctrl_x/utils/utils.py index 0597411..8f7f27d 100644 --- a/ctrl_x/utils/utils.py +++ b/ctrl_x/utils/utils.py @@ -1,7 +1,20 @@ +import random +from os import environ + +import numpy as np import torch -JPEG_QUALITY = 95 +JPEG_QUALITY = 100 + + +def seed_everything(seed): + random.seed(seed) + environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False def exists(x): diff --git a/run_ctrlx.py b/run_ctrlx.py index 6b8963a..0d7aa20 100644 --- a/run_ctrlx.py +++ b/run_ctrlx.py @@ -13,9 +13,6 @@ from ctrl_x.utils.sdxl import * -JPEG_QUALITY = 100 - - @torch.no_grad() def inference( pipe, refiner, device, @@ -27,7 +24,7 @@ def inference( width, height, structure_schedule, appearance_schedule, ): - torch.manual_seed(seed) + seed_everything(seed) pipe.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = pipe.scheduler.timesteps