Skip to content

Commit

Permalink
Add disable_mmap arg in method load_torch_file
Browse files Browse the repository at this point in the history
(cherry picked from commit 2ffcc72)
  • Loading branch information
FE-xiaoJiang committed Mar 6, 2025
1 parent 0124be4 commit 43eaa03
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult"

parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult")
parser.add_argument("--disable-mmap", action="store_true", help="When load .safetensors or .sft model sometimes.")

parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
Expand Down
25 changes: 19 additions & 6 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import itertools
from torch.nn.functional import interpolate
from einops import rearrange
from comfy.cli_args import args

ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
Expand All @@ -46,18 +47,30 @@ class ModelCheckpoint:
else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")

def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False, disable_mmap=None):
if device is None:
device = torch.device("cpu")
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
sd[k] = f.get_tensor(k)
if disable_mmap is None:
disable_mmap_decision = args.disable_mmap
else:
disable_mmap_decision = True

if disable_mmap_decision:
pl_sd = safetensors.torch.load(open(ckpt, 'rb').read())
sd = {k: v.to(device) for k, v in pl_sd.items()}
if return_metadata:
metadata = f.metadata()
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
metadata = f.metadata()
else:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
sd[k] = f.get_tensor(k)
if return_metadata:
metadata = f.metadata()
except Exception as e:
if len(e.args) > 0:
message = e.args[0]
Expand Down

0 comments on commit 43eaa03

Please sign in to comment.