diff --git a/scripts/dump_state_dict.py b/scripts/dump_state_dict.py index 9186f89e..9883cfaa 100644 --- a/scripts/dump_state_dict.py +++ b/scripts/dump_state_dict.py @@ -43,6 +43,7 @@ import argparse import collections +import sys import textwrap from dataclasses import dataclass from typing import Any, Dict, Generic, Iterable, Mapping, TypeVar @@ -50,6 +51,9 @@ from torch import Tensor try: + sys.path.insert(0, __file__ + "/../../libs/spandrel") + sys.path.insert(0, __file__ + "/../../libs/spandrel_extra_arches") + from spandrel import MAIN_REGISTRY, ModelLoader # noqa: E402 from spandrel_extra_arches import EXTRA_REGISTRY # noqa: E402 except ImportError: