Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Training] QoL improvements in the Flux Control training scripts #10461

Merged
merged 4 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/flux-control/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ prompt = "A couple, 4k photo, highly detailed"

gen_images = pipe(
prompt=prompt,
condition_image=image,
control_image=image,
num_inference_steps=50,
joint_attention_kwargs={"scale": 0.9},
guidance_scale=25.,
Expand Down Expand Up @@ -190,7 +190,7 @@ prompt = "A couple, 4k photo, highly detailed"

gen_images = pipe(
prompt=prompt,
condition_image=image,
control_image=image,
num_inference_steps=50,
guidance_scale=25.,
).images[0]
Expand All @@ -200,5 +200,5 @@ gen_images.save("output.png")
## Things to note

* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used.
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
60 changes: 51 additions & 9 deletions examples/flux-control/train_control_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f

for _ in range(args.num_validation_images):
with autocast_ctx:
# need to fix in pipeline_flux_controlnet
image = pipeline(
prompt=validation_prompt,
control_image=validation_image,
Expand Down Expand Up @@ -159,7 +158,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
for image in images:
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
Expand Down Expand Up @@ -188,7 +187,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
img_str += f"![images_{i})](./images_{i}.png)\n"

model_description = f"""
# control-lora-{repo_id}
# flux-control-{repo_id}

These are Control weights trained on {base_model} with new type of conditioning.
{img_str}
Expand Down Expand Up @@ -434,14 +433,15 @@ def parse_args(input_args=None):
"--conditioning_image_column",
type=str,
default="conditioning_image",
help="The column of the dataset containing the controlnet conditioning image.",
help="The column of the dataset containing the control conditioning image.",
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
parser.add_argument(
"--max_train_samples",
type=int,
Expand All @@ -468,7 +468,7 @@ def parse_args(input_args=None):
default=None,
nargs="+",
help=(
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
Expand Down Expand Up @@ -505,7 +505,11 @@ def parse_args(input_args=None):
default=None,
help="Path to the jsonl file containing the training data.",
)

parser.add_argument(
"--only_target_transformer_blocks",
action="store_true",
help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).",
)
parser.add_argument(
"--guidance_scale",
type=float,
Expand Down Expand Up @@ -581,7 +585,7 @@ def parse_args(input_args=None):

if args.resolution % 8 != 0:
raise ValueError(
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
)

return args
Expand Down Expand Up @@ -665,7 +669,12 @@ def preprocess_train(examples):
conditioning_images = [image_transforms(image) for image in conditioning_images]
examples["pixel_values"] = images
examples["conditioning_pixel_values"] = conditioning_images
examples["captions"] = list(examples[args.caption_column])

is_caption_list = isinstance(examples[args.caption_column][0], list)
if is_caption_list:
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
else:
examples["captions"] = list(examples[args.caption_column])

return examples

Expand Down Expand Up @@ -765,7 +774,8 @@ def main(args):
subfolder="scheduler",
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
flux_transformer.requires_grad_(True)
if not args.only_target_transformer_blocks:
flux_transformer.requires_grad_(True)
vae.requires_grad_(False)

# cast down and move to the CPU
Expand Down Expand Up @@ -797,6 +807,12 @@ def main(args):
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)

if args.only_target_transformer_blocks:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't apply to the LoRA script.

flux_transformer.x_embedder.requires_grad_(True)
for name, module in flux_transformer.named_modules():
if "transformer_blocks" in name:
module.requires_grad_(True)

def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
Expand Down Expand Up @@ -974,6 +990,32 @@ def load_model_hook(models, input_dir):
else:
initial_global_step = 0

if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
logger.info("Logging some dataset samples.")
formatted_images = []
formatted_control_images = []
all_prompts = []
for i, batch in enumerate(train_dataloader):
images = (batch["pixel_values"] + 1) / 2
control_images = (batch["conditioning_pixel_values"] + 1) / 2
prompts = batch["captions"]

if len(formatted_images) > 10:
break

for img, control_img, prompt in zip(images, control_images, prompts):
formatted_images.append(img)
formatted_control_images.append(control_img)
all_prompts.append(prompt)

logged_artifacts = []
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
logged_artifacts.append(wandb.Image(img, caption=prompt))

wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
wandb_tracker[0].log({"dataset_samples": logged_artifacts})

progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
Expand Down
40 changes: 33 additions & 7 deletions examples/flux-control/train_control_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f

for _ in range(args.num_validation_images):
with autocast_ctx:
# need to fix in pipeline_flux_controlnet
image = pipeline(
prompt=validation_prompt,
control_image=validation_image,
Expand Down Expand Up @@ -169,7 +168,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
for image in images:
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
Expand Down Expand Up @@ -198,7 +197,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
img_str += f"![images_{i})](./images_{i}.png)\n"

model_description = f"""
# controlnet-lora-{repo_id}
# control-lora-{repo_id}

These are Control LoRA weights trained on {base_model} with new type of conditioning.
{img_str}
Expand Down Expand Up @@ -256,7 +255,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--output_dir",
type=str,
default="controlnet-lora",
default="control-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
Expand Down Expand Up @@ -466,14 +465,15 @@ def parse_args(input_args=None):
"--conditioning_image_column",
type=str,
default="conditioning_image",
help="The column of the dataset containing the controlnet conditioning image.",
help="The column of the dataset containing the control conditioning image.",
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
parser.add_argument(
"--max_train_samples",
type=int,
Expand All @@ -500,7 +500,7 @@ def parse_args(input_args=None):
default=None,
nargs="+",
help=(
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
Expand Down Expand Up @@ -613,7 +613,7 @@ def parse_args(input_args=None):

if args.resolution % 8 != 0:
raise ValueError(
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
)

return args
Expand Down Expand Up @@ -1132,6 +1132,32 @@ def load_model_hook(models, input_dir):
else:
initial_global_step = 0

if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
logger.info("Logging some dataset samples.")
formatted_images = []
formatted_control_images = []
all_prompts = []
for i, batch in enumerate(train_dataloader):
images = (batch["pixel_values"] + 1) / 2
control_images = (batch["conditioning_pixel_values"] + 1) / 2
prompts = batch["captions"]

if len(formatted_images) > 10:
break

for img, control_img, prompt in zip(images, control_images, prompts):
formatted_images.append(img)
formatted_control_images.append(control_img)
all_prompts.append(prompt)

logged_artifacts = []
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
logged_artifacts.append(wandb.Image(img, caption=prompt))

wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
wandb_tracker[0].log({"dataset_samples": logged_artifacts})

progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
Expand Down
Loading