Skip to content

Commit

Permalink
[Tests] Speed up example tests (#6319)
Browse files Browse the repository at this point in the history
* remove validation args from textual onverson tests

* reduce number of train steps in textual inversion tests

* fix: directories.

* debig

* fix: directories.

* remove validation tests from textual onversion

* try reducing the time of test_text_to_image_checkpointing_use_ema

* fix: directories

* speed up test_text_to_image_checkpointing

* speed up test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints

* fix

* speed up test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints

* set checkpoints_total_limit to 2.

* test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints speed up

* speed up test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints

* debug

* fix: directories.

* speed up test_instruct_pix2pix_checkpointing_checkpoints_total_limit

* speed up: test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints

* speed up test_controlnet_sdxl

* speed up dreambooth tests

* speed up test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints

* speed up test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints

* speed up test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit

* speed up # checkpoint-2 should have been deleted

* speed up examples/text_to_image/test_text_to_image.py::TextToImage::test_text_to_image_checkpointing_checkpoints_total_limit

* additional speed ups

* style
  • Loading branch information
sayakpaul authored Dec 25, 2023
1 parent 89459a5 commit f4b0b26
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 164 deletions.
17 changes: 7 additions & 10 deletions examples/controlnet/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_check
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
--max_train_steps=9
--max_train_steps=6
--checkpointing_steps=2
""".split()

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
)

resume_run_args = f"""
Expand All @@ -85,18 +85,15 @@ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_check
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-6
--checkpoints_total_limit=2
""".split()

run_command(self._launch_args + resume_run_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})


class ControlNetSDXL(ExamplesTestsAccelerate):
Expand All @@ -111,7 +108,7 @@ def test_controlnet_sdxl(self):
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
""".split()

Expand Down
20 changes: 7 additions & 13 deletions examples/custom_diffusion/test_custom_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})

def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
Expand All @@ -93,7 +90,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
--train_batch_size=1
--modifier_token=<new1>
--dataloader_num_workers=0
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
--no_safe_serialization
""".split()
Expand All @@ -102,7 +99,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
{"checkpoint-2", "checkpoint-4"},
)

resume_run_args = f"""
Expand All @@ -115,16 +112,13 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
--train_batch_size=1
--modifier_token=<new1>
--dataloader_num_workers=0
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--no_safe_serialization
""".split()

run_command(self._launch_args + resume_run_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
27 changes: 12 additions & 15 deletions examples/dreambooth/test_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_dreambooth_checkpointing(self):

with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# max_train_steps == 4, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4

initial_run_args = f"""
Expand All @@ -100,7 +100,7 @@ def test_dreambooth_checkpointing(self):
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 5
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
Expand All @@ -114,7 +114,7 @@ def test_dreambooth_checkpointing(self):

# check can run the original fully trained output pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(instance_prompt, num_inference_steps=2)
pipe(instance_prompt, num_inference_steps=1)

# check checkpoint directories exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
Expand All @@ -123,7 +123,7 @@ def test_dreambooth_checkpointing(self):
# check can run an intermediate checkpoint
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
pipe(instance_prompt, num_inference_steps=2)
pipe(instance_prompt, num_inference_steps=1)

# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
Expand All @@ -138,7 +138,7 @@ def test_dreambooth_checkpointing(self):
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--max_train_steps 6
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
Expand All @@ -153,7 +153,7 @@ def test_dreambooth_checkpointing(self):

# check can run new fully trained pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(instance_prompt, num_inference_steps=2)
pipe(instance_prompt, num_inference_steps=1)

# check old checkpoints do not exist
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
Expand Down Expand Up @@ -196,15 +196,15 @@ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_check
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
""".split()

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
{"checkpoint-2", "checkpoint-4"},
)

resume_run_args = f"""
Expand All @@ -216,15 +216,12 @@ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_check
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()

run_command(self._launch_args + resume_run_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
29 changes: 10 additions & 19 deletions examples/dreambooth/test_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,13 @@ def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
""".split()

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})

resume_run_args = f"""
examples/dreambooth/train_dreambooth_lora.py
Expand All @@ -155,18 +152,15 @@ def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()

run_command(self._launch_args + resume_run_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

def test_dreambooth_lora_if_model(self):
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down Expand Up @@ -328,7 +322,7 @@ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--max_train_steps 6
--checkpointing_steps=2
--checkpoints_total_limit=2
--learning_rate 5.0e-04
Expand All @@ -342,14 +336,11 @@ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):

pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe("a prompt", num_inference_steps=2)
pipe("a prompt", num_inference_steps=1)

# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
# checkpoint-2 should have been deleted
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})

def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
Expand Down
14 changes: 7 additions & 7 deletions examples/instruct_pix2pix/test_instruct_pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self):
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=7
--max_train_steps=6
--checkpointing_steps=2
--checkpoints_total_limit=2
--output_dir {tmpdir}
Expand All @@ -63,7 +63,7 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
--output_dir {tmpdir}
--seed=0
Expand All @@ -74,7 +74,7 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
{"checkpoint-2", "checkpoint-4"},
)

resume_run_args = f"""
Expand All @@ -84,18 +84,18 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--output_dir {tmpdir}
--seed=0
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()

run_command(self._launch_args + resume_run_args)

# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
{"checkpoint-6", "checkpoint-8"},
)
Loading

0 comments on commit f4b0b26

Please sign in to comment.