diff --git a/src/mflux/config/runtime_config.py b/src/mflux/config/runtime_config.py index 25723b72..a1e87b0e 100644 --- a/src/mflux/config/runtime_config.py +++ b/src/mflux/config/runtime_config.py @@ -27,6 +27,10 @@ def width(self) -> int: def guidance(self) -> float: return self.config.guidance + @guidance.setter + def guidance(self, value: float): + self.config.guidance = value + @property def num_inference_steps(self) -> int: return self.config.num_inference_steps diff --git a/src/mflux/flux/flux.py b/src/mflux/flux/flux.py index 08e0de52..fd6746b9 100644 --- a/src/mflux/flux/flux.py +++ b/src/mflux/flux/flux.py @@ -69,7 +69,11 @@ def __init__( def generate_image( self, seed: int, - prompt: str, + src_prompt: str, + tar_prompt: str, + src_guidance: float, + tar_guidance: float, + image_path: str, config: Config = Config(), stepwise_output_dir: Path = None, ) -> GeneratedImage: @@ -80,52 +84,81 @@ def generate_image( flux=self, config=config, seed=seed, - prompt=prompt, + prompt=f"src_prompt: {src_prompt} | tar_prompt: {tar_prompt}", time_steps=time_steps, output_dir=stepwise_output_dir, ) # 1. Create the initial latents - latents = LatentCreator.create_for_txt2img_or_img2img(seed, config, self.vae) - - # 2. Embed the prompt - t5_tokens = self.t5_tokenizer.tokenize(prompt) - clip_tokens = self.clip_tokenizer.tokenize(prompt) - prompt_embeds = self.t5_text_encoder(t5_tokens) - pooled_prompt_embeds = self.clip_text_encoder(clip_tokens) - + image_latents = LatentCreator.encode_image( + init_image_path=Path(image_path), + width=config.width, + height=config.height, + vae=self.vae + ) # fmt:off + + # 2a. Embed the source prompt + t5_tokens_src = self.t5_tokenizer.tokenize(src_prompt) + clip_tokens_src = self.clip_tokenizer.tokenize(src_prompt) + prompt_embeds_src = self.t5_text_encoder(t5_tokens_src) + pooled_prompt_embeds_src = self.clip_text_encoder(clip_tokens_src) + # 2b. Embed the target prompt + t5_tokens_tar = self.t5_tokenizer.tokenize(tar_prompt) + clip_tokens_tar = self.clip_tokenizer.tokenize(tar_prompt) + prompt_embeds_tar = self.t5_text_encoder(t5_tokens_tar) + pooled_prompt_embeds_tar = self.clip_text_encoder(clip_tokens_tar) + + Z_FE = mx.array(image_latents) for gen_step, t in enumerate(time_steps, 1): try: + if config.num_inference_steps - t > 24: + continue + + random_noise = mx.random.normal(shape=[1, (config.height // 16) * (config.width // 16), 64]) + Z_src = (1 - config.sigmas[t]) * image_latents + config.sigmas[t] * random_noise + Z_tar = Z_FE + Z_src - image_latents + # 3.t Predict the noise - noise = self.transformer.predict( + config.guidance = src_guidance + noise_src = self.transformer.predict( t=t, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - hidden_states=latents, + prompt_embeds=prompt_embeds_src, + pooled_prompt_embeds=pooled_prompt_embeds_src, + hidden_states=Z_src, config=config, ) + config.guidance = tar_guidance + noise_tar = self.transformer.predict( + t=t, + prompt_embeds=prompt_embeds_tar, + pooled_prompt_embeds=pooled_prompt_embeds_tar, + hidden_states=Z_tar, + config=config, + ) + + noise_delta = noise_tar - noise_src # 4.t Take one denoise step dt = config.sigmas[t + 1] - config.sigmas[t] - latents += noise * dt + Z_FE += noise_delta * dt # Handle stepwise output if enabled - stepwise_handler.process_step(gen_step, latents) + stepwise_handler.process_step(gen_step, Z_FE) # Evaluate to enable progress tracking - mx.eval(latents) + mx.eval(Z_FE) except KeyboardInterrupt: # noqa: PERF203 stepwise_handler.handle_interruption() raise StopImageGenerationException(f"Stopping image generation at step {t + 1}/{len(time_steps)}") # 5. Decode the latent array and return the image - latents = ArrayUtil.unpack_latents(latents=latents, height=config.height, width=config.width) + latents = ArrayUtil.unpack_latents(latents=Z_FE, height=config.height, width=config.width) decoded = self.vae.decode(latents) return ImageUtil.to_image( decoded_latents=decoded, seed=seed, - prompt=prompt, + prompt=f"src_prompt: {src_prompt} | tar_prompt: {tar_prompt}", quantization=self.bits, generation_time=time_steps.format_dict["elapsed"], lora_paths=self.lora_paths, diff --git a/src/mflux/generate.py b/src/mflux/generate.py index 236b6b03..e6dcf935 100644 --- a/src/mflux/generate.py +++ b/src/mflux/generate.py @@ -1,47 +1,44 @@ -import time from pathlib import Path from mflux import Config, Flux1, ModelConfig, StopImageGenerationException -from mflux.ui.cli.parsers import CommandLineParser +image_path = "/Users/filipstrand/Desktop/gas_station.png" +source_prompt = "A gas station with a white and red sign that reads 'CAFE' There are several cars parked in front of the gas station, including a white car and a van." +target_prompt = "A gas station with a white and red sign that reads 'CVPR' There are several cars parked in front of the gas station, including a white car and a van." +height = 512 +width = 512 +steps = 28 +seed = 2 +source_guidance = 1.5 +target_guidance = 5.5 -def main(): - # fmt: off - parser = CommandLineParser(description="Generate an image based on a prompt.") - parser.add_model_arguments(require_model_arg=False) - parser.add_lora_arguments() - parser.add_image_generator_arguments(supports_metadata_config=True) - parser.add_image_to_image_arguments(required=False) - parser.add_output_arguments() - args = parser.parse_args() +def main(): # Load the model flux = Flux1( - model_config=ModelConfig.from_alias(args.model), - quantize=args.quantize, - local_path=args.path, - lora_paths=args.lora_paths, - lora_scales=args.lora_scales, + model_config=ModelConfig.FLUX1_DEV, + quantize=4, ) try: - # Generate an image image = flux.generate_image( - seed=int(time.time()) if args.seed is None else args.seed, - prompt=args.prompt, - stepwise_output_dir=Path(args.stepwise_image_output_dir) if args.stepwise_image_output_dir else None, + seed=seed, + src_prompt=source_prompt, + tar_prompt=target_prompt, + src_guidance=source_guidance, + tar_guidance=target_guidance, + image_path=image_path, + stepwise_output_dir=Path("/Users/filipstrand/Desktop/edit"), config=Config( - num_inference_steps=args.steps, - height=args.height, - width=args.width, - guidance=args.guidance, - init_image_path=args.init_image_path, - init_image_strength=args.init_image_strength, + num_inference_steps=steps, + height=height, + width=width, + guidance=0.0, ), ) # Save the image - image.save(path=args.output, export_json_metadata=args.metadata) + image.save(path="edited.png") except StopImageGenerationException as stop_exc: print(stop_exc) diff --git a/src/mflux/latent_creator/latent_creator.py b/src/mflux/latent_creator/latent_creator.py index 84e816f0..c3dd4ef9 100644 --- a/src/mflux/latent_creator/latent_creator.py +++ b/src/mflux/latent_creator/latent_creator.py @@ -1,3 +1,5 @@ +from pathlib import Path + import mlx.core as mx from mlx import nn @@ -35,14 +37,12 @@ def create_for_txt2img_or_img2img( return pure_noise else: # Image2Image - user_image = ImageUtil.load_image(runtime_conf.config.init_image_path).convert("RGB") - scaled_user_image = ImageUtil.scale_to_dimensions( - image=user_image, - target_width=runtime_conf.width, - target_height=runtime_conf.height, + latents = LatentCreator.encode_image( + init_image_path=runtime_conf.config.init_image_path, + height=runtime_conf.height, + width=runtime_conf.width, + vae=vae, ) - encoded = vae.encode(ImageUtil.to_array(scaled_user_image)) - latents = ArrayUtil.pack_latents(latents=encoded, height=runtime_conf.height, width=runtime_conf.width) sigma = runtime_conf.sigmas[runtime_conf.init_time_step] return LatentCreator.add_noise_by_interpolation( clean=latents, @@ -50,6 +50,23 @@ def create_for_txt2img_or_img2img( sigma=sigma ) # fmt: off + @staticmethod + def encode_image( + init_image_path: Path, + width: int, + height: int, + vae: nn.Module, + ): + user_image = ImageUtil.load_image(init_image_path).convert("RGB") + scaled_user_image = ImageUtil.scale_to_dimensions( + image=user_image, + target_width=width, + target_height=height, + ) + encoded = vae.encode(ImageUtil.to_array(scaled_user_image)) + latents = ArrayUtil.pack_latents(latents=encoded, height=height, width=width) + return latents + @staticmethod def add_noise_by_interpolation(clean: mx.array, noise: mx.array, sigma: float) -> mx.array: return (1 - sigma) * clean + sigma * noise