Skip to content

Commit

Permalink
WIP...
Browse files Browse the repository at this point in the history
  • Loading branch information
filipstrand committed Jan 12, 2025
1 parent 0eb8f50 commit b42d5fb
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 53 deletions.
4 changes: 4 additions & 0 deletions src/mflux/config/runtime_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def width(self) -> int:
def guidance(self) -> float:
return self.config.guidance

@guidance.setter
def guidance(self, value: float):
self.config.guidance = value

@property
def num_inference_steps(self) -> int:
return self.config.num_inference_steps
Expand Down
71 changes: 52 additions & 19 deletions src/mflux/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ def __init__(
def generate_image(
self,
seed: int,
prompt: str,
src_prompt: str,
tar_prompt: str,
src_guidance: float,
tar_guidance: float,
image_path: str,
config: Config = Config(),
stepwise_output_dir: Path = None,
) -> GeneratedImage:
Expand All @@ -80,52 +84,81 @@ def generate_image(
flux=self,
config=config,
seed=seed,
prompt=prompt,
prompt=f"src_prompt: {src_prompt} | tar_prompt: {tar_prompt}",
time_steps=time_steps,
output_dir=stepwise_output_dir,
)

# 1. Create the initial latents
latents = LatentCreator.create_for_txt2img_or_img2img(seed, config, self.vae)

# 2. Embed the prompt
t5_tokens = self.t5_tokenizer.tokenize(prompt)
clip_tokens = self.clip_tokenizer.tokenize(prompt)
prompt_embeds = self.t5_text_encoder(t5_tokens)
pooled_prompt_embeds = self.clip_text_encoder(clip_tokens)

image_latents = LatentCreator.encode_image(
init_image_path=Path(image_path),
width=config.width,
height=config.height,
vae=self.vae
) # fmt:off

# 2a. Embed the source prompt
t5_tokens_src = self.t5_tokenizer.tokenize(src_prompt)
clip_tokens_src = self.clip_tokenizer.tokenize(src_prompt)
prompt_embeds_src = self.t5_text_encoder(t5_tokens_src)
pooled_prompt_embeds_src = self.clip_text_encoder(clip_tokens_src)
# 2b. Embed the target prompt
t5_tokens_tar = self.t5_tokenizer.tokenize(tar_prompt)
clip_tokens_tar = self.clip_tokenizer.tokenize(tar_prompt)
prompt_embeds_tar = self.t5_text_encoder(t5_tokens_tar)
pooled_prompt_embeds_tar = self.clip_text_encoder(clip_tokens_tar)

Z_FE = mx.array(image_latents)
for gen_step, t in enumerate(time_steps, 1):
try:
if config.num_inference_steps - t > 24:
continue

random_noise = mx.random.normal(shape=[1, (config.height // 16) * (config.width // 16), 64])
Z_src = (1 - config.sigmas[t]) * image_latents + config.sigmas[t] * random_noise
Z_tar = Z_FE + Z_src - image_latents

# 3.t Predict the noise
noise = self.transformer.predict(
config.guidance = src_guidance
noise_src = self.transformer.predict(
t=t,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
hidden_states=latents,
prompt_embeds=prompt_embeds_src,
pooled_prompt_embeds=pooled_prompt_embeds_src,
hidden_states=Z_src,
config=config,
)
config.guidance = tar_guidance
noise_tar = self.transformer.predict(
t=t,
prompt_embeds=prompt_embeds_tar,
pooled_prompt_embeds=pooled_prompt_embeds_tar,
hidden_states=Z_tar,
config=config,
)

noise_delta = noise_tar - noise_src

# 4.t Take one denoise step
dt = config.sigmas[t + 1] - config.sigmas[t]
latents += noise * dt
Z_FE += noise_delta * dt

# Handle stepwise output if enabled
stepwise_handler.process_step(gen_step, latents)
stepwise_handler.process_step(gen_step, Z_FE)

# Evaluate to enable progress tracking
mx.eval(latents)
mx.eval(Z_FE)

except KeyboardInterrupt: # noqa: PERF203
stepwise_handler.handle_interruption()
raise StopImageGenerationException(f"Stopping image generation at step {t + 1}/{len(time_steps)}")

# 5. Decode the latent array and return the image
latents = ArrayUtil.unpack_latents(latents=latents, height=config.height, width=config.width)
latents = ArrayUtil.unpack_latents(latents=Z_FE, height=config.height, width=config.width)
decoded = self.vae.decode(latents)
return ImageUtil.to_image(
decoded_latents=decoded,
seed=seed,
prompt=prompt,
prompt=f"src_prompt: {src_prompt} | tar_prompt: {tar_prompt}",
quantization=self.bits,
generation_time=time_steps.format_dict["elapsed"],
lora_paths=self.lora_paths,
Expand Down
51 changes: 24 additions & 27 deletions src/mflux/generate.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,44 @@
import time
from pathlib import Path

from mflux import Config, Flux1, ModelConfig, StopImageGenerationException
from mflux.ui.cli.parsers import CommandLineParser

image_path = "/Users/filipstrand/Desktop/gas_station.png"
source_prompt = "A gas station with a white and red sign that reads 'CAFE' There are several cars parked in front of the gas station, including a white car and a van."
target_prompt = "A gas station with a white and red sign that reads 'CVPR' There are several cars parked in front of the gas station, including a white car and a van."
height = 512
width = 512
steps = 28
seed = 2
source_guidance = 1.5
target_guidance = 5.5

def main():
# fmt: off
parser = CommandLineParser(description="Generate an image based on a prompt.")
parser.add_model_arguments(require_model_arg=False)
parser.add_lora_arguments()
parser.add_image_generator_arguments(supports_metadata_config=True)
parser.add_image_to_image_arguments(required=False)
parser.add_output_arguments()
args = parser.parse_args()

def main():
# Load the model
flux = Flux1(
model_config=ModelConfig.from_alias(args.model),
quantize=args.quantize,
local_path=args.path,
lora_paths=args.lora_paths,
lora_scales=args.lora_scales,
model_config=ModelConfig.FLUX1_DEV,
quantize=4,
)

try:
# Generate an image
image = flux.generate_image(
seed=int(time.time()) if args.seed is None else args.seed,
prompt=args.prompt,
stepwise_output_dir=Path(args.stepwise_image_output_dir) if args.stepwise_image_output_dir else None,
seed=seed,
src_prompt=source_prompt,
tar_prompt=target_prompt,
src_guidance=source_guidance,
tar_guidance=target_guidance,
image_path=image_path,
stepwise_output_dir=Path("/Users/filipstrand/Desktop/edit"),
config=Config(
num_inference_steps=args.steps,
height=args.height,
width=args.width,
guidance=args.guidance,
init_image_path=args.init_image_path,
init_image_strength=args.init_image_strength,
num_inference_steps=steps,
height=height,
width=width,
guidance=0.0,
),
)

# Save the image
image.save(path=args.output, export_json_metadata=args.metadata)
image.save(path="edited.png")
except StopImageGenerationException as stop_exc:
print(stop_exc)

Expand Down
31 changes: 24 additions & 7 deletions src/mflux/latent_creator/latent_creator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import mlx.core as mx
from mlx import nn

Expand Down Expand Up @@ -35,21 +37,36 @@ def create_for_txt2img_or_img2img(
return pure_noise
else:
# Image2Image
user_image = ImageUtil.load_image(runtime_conf.config.init_image_path).convert("RGB")
scaled_user_image = ImageUtil.scale_to_dimensions(
image=user_image,
target_width=runtime_conf.width,
target_height=runtime_conf.height,
latents = LatentCreator.encode_image(
init_image_path=runtime_conf.config.init_image_path,
height=runtime_conf.height,
width=runtime_conf.width,
vae=vae,
)
encoded = vae.encode(ImageUtil.to_array(scaled_user_image))
latents = ArrayUtil.pack_latents(latents=encoded, height=runtime_conf.height, width=runtime_conf.width)
sigma = runtime_conf.sigmas[runtime_conf.init_time_step]
return LatentCreator.add_noise_by_interpolation(
clean=latents,
noise=pure_noise,
sigma=sigma
) # fmt: off

@staticmethod
def encode_image(
init_image_path: Path,
width: int,
height: int,
vae: nn.Module,
):
user_image = ImageUtil.load_image(init_image_path).convert("RGB")
scaled_user_image = ImageUtil.scale_to_dimensions(
image=user_image,
target_width=width,
target_height=height,
)
encoded = vae.encode(ImageUtil.to_array(scaled_user_image))
latents = ArrayUtil.pack_latents(latents=encoded, height=height, width=width)
return latents

@staticmethod
def add_noise_by_interpolation(clean: mx.array, noise: mx.array, sigma: float) -> mx.array:
return (1 - sigma) * clean + sigma * noise

0 comments on commit b42d5fb

Please sign in to comment.