Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tests] Speed up example tests #6319

Merged
merged 28 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
cba7991
remove validation args from textual onverson tests
sayakpaul Dec 25, 2023
ede49cf
reduce number of train steps in textual inversion tests
sayakpaul Dec 25, 2023
85f3160
fix: directories.
sayakpaul Dec 25, 2023
b7dfcb9
debig
sayakpaul Dec 25, 2023
29f6ece
fix: directories.
sayakpaul Dec 25, 2023
ec52282
remove validation tests from textual onversion
sayakpaul Dec 25, 2023
ff7e4ae
try reducing the time of test_text_to_image_checkpointing_use_ema
sayakpaul Dec 25, 2023
1e63a6e
fix: directories
sayakpaul Dec 25, 2023
e07f9c9
speed up test_text_to_image_checkpointing
sayakpaul Dec 25, 2023
bff0d05
speed up test_text_to_image_checkpointing_checkpoints_total_limit_rem…
sayakpaul Dec 25, 2023
b8032c7
fix
sayakpaul Dec 25, 2023
e84c10a
speed up test_instruct_pix2pix_checkpointing_checkpoints_total_limit_…
sayakpaul Dec 25, 2023
443bcd5
set checkpoints_total_limit to 2.
sayakpaul Dec 25, 2023
9d3ea31
test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes…
sayakpaul Dec 25, 2023
4296ce1
speed up test_unconditional_checkpointing_checkpoints_total_limit_rem…
sayakpaul Dec 25, 2023
2c7e979
debug
sayakpaul Dec 25, 2023
6ec0d84
fix: directories.
sayakpaul Dec 25, 2023
fca5113
speed up test_instruct_pix2pix_checkpointing_checkpoints_total_limit
sayakpaul Dec 25, 2023
d71605b
speed up: test_controlnet_checkpointing_checkpoints_total_limit_remov…
sayakpaul Dec 25, 2023
e137f8a
speed up test_controlnet_sdxl
sayakpaul Dec 25, 2023
7afc283
speed up dreambooth tests
sayakpaul Dec 25, 2023
4f2fcde
speed up test_dreambooth_lora_checkpointing_checkpoints_total_limit_r…
sayakpaul Dec 25, 2023
651cfce
speed up test_custom_diffusion_checkpointing_checkpoints_total_limit_…
sayakpaul Dec 25, 2023
9fa43da
speed up test_text_to_image_lora_sdxl_text_encoder_checkpointing_chec…
sayakpaul Dec 25, 2023
c1dcdc2
speed up # checkpoint-2 should have been deleted
sayakpaul Dec 25, 2023
e999faf
speed up examples/text_to_image/test_text_to_image.py::TextToImage::t…
sayakpaul Dec 25, 2023
3d66d1e
additional speed ups
sayakpaul Dec 25, 2023
d5dd942
style
sayakpaul Dec 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions examples/controlnet/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_check
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
--max_train_steps=9
--max_train_steps=6
--checkpointing_steps=2
""".split()

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
)

resume_run_args = f"""
Expand All @@ -85,18 +85,15 @@ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_check
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-6
--checkpoints_total_limit=2
""".split()

run_command(self._launch_args + resume_run_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})


class ControlNetSDXL(ExamplesTestsAccelerate):
Expand All @@ -111,7 +108,7 @@ def test_controlnet_sdxl(self):
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
""".split()

Expand Down
20 changes: 7 additions & 13 deletions examples/custom_diffusion/test_custom_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})

def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
Expand All @@ -93,7 +90,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
--train_batch_size=1
--modifier_token=<new1>
--dataloader_num_workers=0
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
--no_safe_serialization
""".split()
Expand All @@ -102,7 +99,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
{"checkpoint-2", "checkpoint-4"},
)

resume_run_args = f"""
Expand All @@ -115,16 +112,13 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
--train_batch_size=1
--modifier_token=<new1>
--dataloader_num_workers=0
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--no_safe_serialization
""".split()

run_command(self._launch_args + resume_run_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
27 changes: 12 additions & 15 deletions examples/dreambooth/test_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_dreambooth_checkpointing(self):

with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# max_train_steps == 4, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4

initial_run_args = f"""
Expand All @@ -100,7 +100,7 @@ def test_dreambooth_checkpointing(self):
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 5
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
Expand All @@ -114,7 +114,7 @@ def test_dreambooth_checkpointing(self):

# check can run the original fully trained output pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(instance_prompt, num_inference_steps=2)
pipe(instance_prompt, num_inference_steps=1)

# check checkpoint directories exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
Expand All @@ -123,7 +123,7 @@ def test_dreambooth_checkpointing(self):
# check can run an intermediate checkpoint
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
pipe(instance_prompt, num_inference_steps=2)
pipe(instance_prompt, num_inference_steps=1)

# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
Expand All @@ -138,7 +138,7 @@ def test_dreambooth_checkpointing(self):
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--max_train_steps 6
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
Expand All @@ -153,7 +153,7 @@ def test_dreambooth_checkpointing(self):

# check can run new fully trained pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(instance_prompt, num_inference_steps=2)
pipe(instance_prompt, num_inference_steps=1)

# check old checkpoints do not exist
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
Expand Down Expand Up @@ -196,15 +196,15 @@ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_check
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
""".split()

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
{"checkpoint-2", "checkpoint-4"},
)

resume_run_args = f"""
Expand All @@ -216,15 +216,12 @@ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_check
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()

run_command(self._launch_args + resume_run_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
29 changes: 10 additions & 19 deletions examples/dreambooth/test_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,13 @@ def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
""".split()

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})

resume_run_args = f"""
examples/dreambooth/train_dreambooth_lora.py
Expand All @@ -155,18 +152,15 @@ def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()

run_command(self._launch_args + resume_run_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

def test_dreambooth_lora_if_model(self):
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down Expand Up @@ -328,7 +322,7 @@ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--max_train_steps 6
--checkpointing_steps=2
--checkpoints_total_limit=2
--learning_rate 5.0e-04
Expand All @@ -342,14 +336,11 @@ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):

pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe("a prompt", num_inference_steps=2)
pipe("a prompt", num_inference_steps=1)

# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
# checkpoint-2 should have been deleted
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})

def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
Expand Down
14 changes: 7 additions & 7 deletions examples/instruct_pix2pix/test_instruct_pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self):
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=7
--max_train_steps=6
--checkpointing_steps=2
--checkpoints_total_limit=2
--output_dir {tmpdir}
Expand All @@ -63,7 +63,7 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=9
--max_train_steps=4
--checkpointing_steps=2
--output_dir {tmpdir}
--seed=0
Expand All @@ -74,7 +74,7 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
{"checkpoint-2", "checkpoint-4"},
)

resume_run_args = f"""
Expand All @@ -84,18 +84,18 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=11
--max_train_steps=8
--checkpointing_steps=2
--output_dir {tmpdir}
--seed=0
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()

run_command(self._launch_args + resume_run_args)

# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
{"checkpoint-6", "checkpoint-8"},
)
Loading
Loading