Skip to content

Commit

Permalink
[Training] QoL improvements in the Flux Control training scripts (#10461
Browse files Browse the repository at this point in the history
)

* qol improvements to the Flux script.

* propagate the dataloader changes.
  • Loading branch information
sayakpaul authored Jan 7, 2025
1 parent 661bde0 commit b94cfd7
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 20 deletions.
6 changes: 3 additions & 3 deletions examples/flux-control/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ prompt = "A couple, 4k photo, highly detailed"

gen_images = pipe(
prompt=prompt,
condition_image=image,
control_image=image,
num_inference_steps=50,
joint_attention_kwargs={"scale": 0.9},
guidance_scale=25.,
Expand Down Expand Up @@ -190,7 +190,7 @@ prompt = "A couple, 4k photo, highly detailed"

gen_images = pipe(
prompt=prompt,
condition_image=image,
control_image=image,
num_inference_steps=50,
guidance_scale=25.,
).images[0]
Expand All @@ -200,5 +200,5 @@ gen_images.save("output.png")
## Things to note

* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used.
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
60 changes: 51 additions & 9 deletions examples/flux-control/train_control_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f

for _ in range(args.num_validation_images):
with autocast_ctx:
# need to fix in pipeline_flux_controlnet
image = pipeline(
prompt=validation_prompt,
control_image=validation_image,
Expand Down Expand Up @@ -159,7 +158,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
for image in images:
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
Expand Down Expand Up @@ -188,7 +187,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
img_str += f"![images_{i})](./images_{i}.png)\n"

model_description = f"""
# control-lora-{repo_id}
# flux-control-{repo_id}
These are Control weights trained on {base_model} with new type of conditioning.
{img_str}
Expand Down Expand Up @@ -434,14 +433,15 @@ def parse_args(input_args=None):
"--conditioning_image_column",
type=str,
default="conditioning_image",
help="The column of the dataset containing the controlnet conditioning image.",
help="The column of the dataset containing the control conditioning image.",
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
parser.add_argument(
"--max_train_samples",
type=int,
Expand All @@ -468,7 +468,7 @@ def parse_args(input_args=None):
default=None,
nargs="+",
help=(
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
Expand Down Expand Up @@ -505,7 +505,11 @@ def parse_args(input_args=None):
default=None,
help="Path to the jsonl file containing the training data.",
)

parser.add_argument(
"--only_target_transformer_blocks",
action="store_true",
help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).",
)
parser.add_argument(
"--guidance_scale",
type=float,
Expand Down Expand Up @@ -581,7 +585,7 @@ def parse_args(input_args=None):

if args.resolution % 8 != 0:
raise ValueError(
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
)

return args
Expand Down Expand Up @@ -665,7 +669,12 @@ def preprocess_train(examples):
conditioning_images = [image_transforms(image) for image in conditioning_images]
examples["pixel_values"] = images
examples["conditioning_pixel_values"] = conditioning_images
examples["captions"] = list(examples[args.caption_column])

is_caption_list = isinstance(examples[args.caption_column][0], list)
if is_caption_list:
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
else:
examples["captions"] = list(examples[args.caption_column])

return examples

Expand Down Expand Up @@ -765,7 +774,8 @@ def main(args):
subfolder="scheduler",
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
flux_transformer.requires_grad_(True)
if not args.only_target_transformer_blocks:
flux_transformer.requires_grad_(True)
vae.requires_grad_(False)

# cast down and move to the CPU
Expand Down Expand Up @@ -797,6 +807,12 @@ def main(args):
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)

if args.only_target_transformer_blocks:
flux_transformer.x_embedder.requires_grad_(True)
for name, module in flux_transformer.named_modules():
if "transformer_blocks" in name:
module.requires_grad_(True)

def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
Expand Down Expand Up @@ -974,6 +990,32 @@ def load_model_hook(models, input_dir):
else:
initial_global_step = 0

if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
logger.info("Logging some dataset samples.")
formatted_images = []
formatted_control_images = []
all_prompts = []
for i, batch in enumerate(train_dataloader):
images = (batch["pixel_values"] + 1) / 2
control_images = (batch["conditioning_pixel_values"] + 1) / 2
prompts = batch["captions"]

if len(formatted_images) > 10:
break

for img, control_img, prompt in zip(images, control_images, prompts):
formatted_images.append(img)
formatted_control_images.append(control_img)
all_prompts.append(prompt)

logged_artifacts = []
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
logged_artifacts.append(wandb.Image(img, caption=prompt))

wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
wandb_tracker[0].log({"dataset_samples": logged_artifacts})

progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
Expand Down
47 changes: 39 additions & 8 deletions examples/flux-control/train_control_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f

for _ in range(args.num_validation_images):
with autocast_ctx:
# need to fix in pipeline_flux_controlnet
image = pipeline(
prompt=validation_prompt,
control_image=validation_image,
Expand Down Expand Up @@ -169,7 +168,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
for image in images:
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
Expand Down Expand Up @@ -198,7 +197,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
img_str += f"![images_{i})](./images_{i}.png)\n"

model_description = f"""
# controlnet-lora-{repo_id}
# control-lora-{repo_id}
These are Control LoRA weights trained on {base_model} with new type of conditioning.
{img_str}
Expand Down Expand Up @@ -256,7 +255,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--output_dir",
type=str,
default="controlnet-lora",
default="control-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
Expand Down Expand Up @@ -466,14 +465,15 @@ def parse_args(input_args=None):
"--conditioning_image_column",
type=str,
default="conditioning_image",
help="The column of the dataset containing the controlnet conditioning image.",
help="The column of the dataset containing the control conditioning image.",
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
parser.add_argument(
"--max_train_samples",
type=int,
Expand All @@ -500,7 +500,7 @@ def parse_args(input_args=None):
default=None,
nargs="+",
help=(
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
Expand Down Expand Up @@ -613,7 +613,7 @@ def parse_args(input_args=None):

if args.resolution % 8 != 0:
raise ValueError(
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
)

return args
Expand Down Expand Up @@ -697,7 +697,12 @@ def preprocess_train(examples):
conditioning_images = [image_transforms(image) for image in conditioning_images]
examples["pixel_values"] = images
examples["conditioning_pixel_values"] = conditioning_images
examples["captions"] = list(examples[args.caption_column])

is_caption_list = isinstance(examples[args.caption_column][0], list)
if is_caption_list:
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
else:
examples["captions"] = list(examples[args.caption_column])

return examples

Expand Down Expand Up @@ -1132,6 +1137,32 @@ def load_model_hook(models, input_dir):
else:
initial_global_step = 0

if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
logger.info("Logging some dataset samples.")
formatted_images = []
formatted_control_images = []
all_prompts = []
for i, batch in enumerate(train_dataloader):
images = (batch["pixel_values"] + 1) / 2
control_images = (batch["conditioning_pixel_values"] + 1) / 2
prompts = batch["captions"]

if len(formatted_images) > 10:
break

for img, control_img, prompt in zip(images, control_images, prompts):
formatted_images.append(img)
formatted_control_images.append(control_img)
all_prompts.append(prompt)

logged_artifacts = []
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
logged_artifacts.append(wandb.Image(img, caption=prompt))

wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
wandb_tracker[0].log({"dataset_samples": logged_artifacts})

progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
Expand Down

0 comments on commit b94cfd7

Please sign in to comment.