Skip to content

Commit

Permalink
Use enum list for --fast options (#7024)
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei authored Mar 1, 2025
1 parent cf0b549 commit 4d55f16
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
18 changes: 17 additions & 1 deletion comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,12 @@ class LatentPreviewMethod(enum.Enum):

parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--fast", metavar="number", type=int, const=99, default=0, nargs="?", help="Enable some untested and potentially quality deteriorating optimizations. You can pass a number from 0 to 10 for a bigger speed vs quality tradeoff. Using --fast with no number means maximum speed. 2 or larger enables fp16 accumulation, 5 or larger enables fp8 matrix multiplication.")

class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8Optimization = "fp8_optimization"

parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations.")

parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
Expand Down Expand Up @@ -194,3 +199,14 @@ def is_valid_directory(path: Optional[str]) -> Optional[str]:

if args.force_fp16:
args.fp16_unet = True


# '--fast' is not provided, use an empty set
if args.fast is None:
args.fast = set()
# '--fast' is provided with an empty list, enable all optimizations
elif args.fast == []:
args.fast = set(PerformanceFeature)
# '--fast' is provided with a list of performance features, use that list
else:
args.fast = set(args.fast)
4 changes: 2 additions & 2 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args
from comfy.cli_args import args, PerformanceFeature
import torch
import sys
import platform
Expand Down Expand Up @@ -280,7 +280,7 @@ def is_amd():

PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try:
if is_nvidia() and args.fast >= 2:
if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
logging.info("Enabled fp16 accumulation.")
Expand Down
8 changes: 6 additions & 2 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch
import comfy.model_management
from comfy.cli_args import args
from comfy.cli_args import args, PerformanceFeature
import comfy.float

cast_to = comfy.model_management.cast_to #TODO: remove once no more references
Expand Down Expand Up @@ -360,7 +360,11 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)

if fp8_compute and (fp8_optimizations or args.fast >= 5) and not disable_fast_fp8:
if (
fp8_compute and
(fp8_optimizations or PerformanceFeature.Fp8Optimization in args.fast) and
not disable_fast_fp8
):
return fp8_ops

if compute_dtype is None or weight_dtype == compute_dtype:
Expand Down

0 comments on commit 4d55f16

Please sign in to comment.