Skip to content

Commit

Permalink
Force safe loading of files in torch format on pytorch 2.4+
Browse files Browse the repository at this point in the history
If this breaks something for you make an issue.
  • Loading branch information
comfyanonymous committed Jan 15, 2025
1 parent 5b657f8 commit 2feb8d0
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,29 @@
from torch.nn.functional import interpolate
from einops import rearrange

ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
class ModelCheckpoint:
pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"

from numpy.core.multiarray import scalar
from numpy import dtype
from numpy.dtypes import Float64DType
from _codecs import encode

torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")


def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
Expand Down

0 comments on commit 2feb8d0

Please sign in to comment.