Skip to content

Commit

Permalink
Better seeding + slight Gradio demo example schedule changes
Browse files Browse the repository at this point in the history
  • Loading branch information
kuanhenglin committed Sep 28, 2024
1 parent 3b13a7f commit d7d22f2
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
10 changes: 4 additions & 6 deletions app_ctrlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ def inference(
structure_schedule, appearance_schedule, use_advanced_config,
control_config,
):
global pipe, refiner

torch.manual_seed(seed)
seed_everything(seed)

pipe.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipe.scheduler.timesteps
Expand Down Expand Up @@ -310,7 +308,7 @@ def inference(
"a 3D mesh of a cat",
"",
"a photo of a tiger standing on snow",
0.6, 0.6,
0.5, 0.6,
],
[
"assets/images/dog__sketch.jpg",
Expand Down Expand Up @@ -358,7 +356,7 @@ def inference(
"a segmentation map of a bear and an avocado",
"",
"a realistic photo of a bear and an avocado in a forest",
0.6, 0.6,
0.5, 0.6,
],
[
"assets/images/cat__point_cloud.jpg",
Expand All @@ -382,7 +380,7 @@ def inference(
"a 3D model of a person holding a sword and shield",
"",
"a photo of a medieval soldier standing on a barren field, raining",
0.6, 0.6,
0.5, 0.6,
],
[
"assets/images/person__mesh.jpg",
Expand Down
15 changes: 14 additions & 1 deletion ctrl_x/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
import random
from os import environ

import numpy as np
import torch


JPEG_QUALITY = 95
JPEG_QUALITY = 100


def seed_everything(seed):
random.seed(seed)
environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def exists(x):
Expand Down
5 changes: 1 addition & 4 deletions run_ctrlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
from ctrl_x.utils.sdxl import *


JPEG_QUALITY = 100


@torch.no_grad()
def inference(
pipe, refiner, device,
Expand All @@ -27,7 +24,7 @@ def inference(
width, height,
structure_schedule, appearance_schedule,
):
torch.manual_seed(seed)
seed_everything(seed)

pipe.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipe.scheduler.timesteps
Expand Down

0 comments on commit d7d22f2

Please sign in to comment.