Skip to content

Commit

Permalink
Punch cache_dir through model factory / builder / pretrain helpers. I…
Browse files Browse the repository at this point in the history
…mprove some annotations in related code.
  • Loading branch information
rwightman committed Dec 5, 2024
1 parent 553ded5 commit 71849b9
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 36 deletions.
36 changes: 23 additions & 13 deletions timm/models/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import logging
import os
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from torch import nn as nn
from torch.hub import load_state_dict_from_url
Expand Down Expand Up @@ -90,6 +91,7 @@ def load_custom_pretrained(
model: nn.Module,
pretrained_cfg: Optional[Dict] = None,
load_fn: Optional[Callable] = None,
cache_dir: Optional[Union[str, Path]] = None,
):
r"""Loads a custom (read non .pth) weight file
Expand All @@ -102,9 +104,9 @@ def load_custom_pretrained(
Args:
model: The instantiated model to load weights into
pretrained_cfg (dict): Default pretrained model cfg
pretrained_cfg: Default pretrained model cfg
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
'laod_pretrained' on the model will be called if it exists
'load_pretrained' on the model will be called if it exists
"""
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
Expand All @@ -122,6 +124,7 @@ def load_custom_pretrained(
pretrained_loc,
check_hash=_CHECK_HASH,
progress=_DOWNLOAD_PROGRESS,
cache_dir=cache_dir,
)

if load_fn is not None:
Expand All @@ -139,17 +142,18 @@ def load_pretrained(
in_chans: int = 3,
filter_fn: Optional[Callable] = None,
strict: bool = True,
cache_dir: Optional[Union[str, Path]] = None,
):
""" Load pretrained checkpoint
Args:
model (nn.Module) : PyTorch model module
pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
num_classes (int): num_classes for target model
in_chans (int): in_chans for target model
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
strict (bool): strict load of checkpoint
model: PyTorch module
pretrained_cfg: configuration for pretrained weights / target dataset
num_classes: number of classes for target model
in_chans: number of input chans for target model
filter_fn: state_dict filter fn for load (takes state_dict, model as args)
strict: strict load of checkpoint
cache_dir: override path to cache dir for this load
"""
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
Expand All @@ -173,6 +177,7 @@ def load_pretrained(
pretrained_loc,
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
cache_dir=cache_dir,
)
model.load_pretrained(pretrained_loc)
return
Expand All @@ -184,25 +189,27 @@ def load_pretrained(
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
weights_only=True,
model_dir=cache_dir,
)
except TypeError:
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
model_dir=cache_dir,
)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
custom_load = pretrained_cfg.get('custom_load', False)
if isinstance(custom_load, str) and custom_load == 'hf':
load_custom_from_hf(*pretrained_loc, model)
load_custom_from_hf(*pretrained_loc, model, cache_dir=cache_dir)
return
else:
state_dict = load_state_dict_from_hf(*pretrained_loc)
state_dict = load_state_dict_from_hf(*pretrained_loc, cache_dir=cache_dir)
else:
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True)
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True, cache_dir=cache_dir)
else:
model_name = pretrained_cfg.get('architecture', 'this model')
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")
Expand Down Expand Up @@ -362,6 +369,7 @@ def build_model_with_cfg(
feature_cfg: Optional[Dict] = None,
pretrained_strict: bool = True,
pretrained_filter_fn: Optional[Callable] = None,
cache_dir: Optional[Union[str, Path]] = None,
kwargs_filter: Optional[Tuple[str]] = None,
**kwargs,
):
Expand All @@ -382,6 +390,7 @@ def build_model_with_cfg(
feature_cfg: feature extraction adapter config
pretrained_strict: load pretrained weights strictly
pretrained_filter_fn: filter callable for pretrained weights
cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations
kwargs_filter: kwargs to filter before passing to model
**kwargs: model args passed through to model __init__
"""
Expand Down Expand Up @@ -431,6 +440,7 @@ def build_model_with_cfg(
in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn,
strict=pretrained_strict,
cache_dir=cache_dir,
)

# Wrap the model in a feature extraction module if enabled
Expand Down
14 changes: 10 additions & 4 deletions timm/models/_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union
from urllib.parse import urlsplit

Expand Down Expand Up @@ -40,7 +41,8 @@ def create_model(
pretrained: bool = False,
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
checkpoint_path: str = '',
checkpoint_path: Optional[Union[str, Path]] = None,
cache_dir: Optional[Union[str, Path]] = None,
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
Expand All @@ -50,17 +52,17 @@ def create_model(
Lookup model's entrypoint function and pass relevant args to create a new model.
<Tip>
Tip:
**kwargs will be passed through entrypoint fn to ``timm.models.build_model_with_cfg()``
and then the model class __init__(). kwargs values set to None are pruned before passing.
</Tip>
Args:
model_name: Name of model to instantiate.
pretrained: If set to `True`, load pretrained ImageNet-1k weights.
pretrained_cfg: Pass in an external pretrained_cfg for model.
pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations
scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
Expand Down Expand Up @@ -99,7 +101,10 @@ def create_model(
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name, model_args = load_model_config_from_hf(model_name)
pretrained_cfg, model_name, model_args = load_model_config_from_hf(
model_name,
cache_dir=cache_dir,
)
if model_args:
for k, v in model_args.items():
kwargs.setdefault(k, v)
Expand All @@ -118,6 +123,7 @@ def create_model(
pretrained=pretrained,
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay,
cache_dir=cache_dir,
**kwargs,
)

Expand Down
1 change: 0 additions & 1 deletion timm/models/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""
import logging
import os
from collections import OrderedDict
from typing import Any, Callable, Dict, Optional, Union

import torch
Expand Down
84 changes: 66 additions & 18 deletions timm/models/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Iterable, Optional, Union
from typing import Iterable, List, Optional, Tuple, Union

import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
Expand Down Expand Up @@ -53,7 +53,7 @@
HF_OPEN_CLIP_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version


def get_cache_dir(child_dir=''):
def get_cache_dir(child_dir: str = ''):
"""
Returns the location of the directory where models are cached (and creates it if necessary).
"""
Expand All @@ -68,13 +68,22 @@ def get_cache_dir(child_dir=''):
return model_dir


def download_cached_file(url, check_hash=True, progress=False):
def download_cached_file(
url: Union[str, List[str], Tuple[str, str]],
check_hash: bool = True,
progress: bool = False,
cache_dir: Optional[Union[str, Path]] = None,
):
if isinstance(url, (list, tuple)):
url, filename = url
else:
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(get_cache_dir(), filename)
if cache_dir:
os.makedirs(cache_dir, exist_ok=True)
else:
cache_dir = get_cache_dir()
cached_file = os.path.join(cache_dir, filename)
if not os.path.exists(cached_file):
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
Expand All @@ -85,13 +94,19 @@ def download_cached_file(url, check_hash=True, progress=False):
return cached_file


def check_cached_file(url, check_hash=True):
def check_cached_file(
url: Union[str, List[str], Tuple[str, str]],
check_hash: bool = True,
cache_dir: Optional[Union[str, Path]] = None,
):
if isinstance(url, (list, tuple)):
url, filename = url
else:
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(get_cache_dir(), filename)
if not cache_dir:
cache_dir = get_cache_dir()
cached_file = os.path.join(cache_dir, filename)
if os.path.exists(cached_file):
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
Expand All @@ -105,7 +120,7 @@ def check_cached_file(url, check_hash=True):
return False


def has_hf_hub(necessary=False):
def has_hf_hub(necessary: bool = False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError(
Expand All @@ -122,20 +137,32 @@ def hf_split(hf_id: str):
return hf_model_id, hf_revision


def load_cfg_from_json(json_file: Union[str, os.PathLike]):
def load_cfg_from_json(json_file: Union[str, Path]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)


def download_from_hf(model_id: str, filename: str):
def download_from_hf(
model_id: str,
filename: str,
cache_dir: Optional[Union[str, Path]] = None,
):
hf_model_id, hf_revision = hf_split(model_id)
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
return hf_hub_download(
hf_model_id,
filename,
revision=hf_revision,
cache_dir=cache_dir,
)


def load_model_config_from_hf(model_id: str):
def load_model_config_from_hf(
model_id: str,
cache_dir: Optional[Union[str, Path]] = None,
):
assert has_hf_hub(True)
cached_file = download_from_hf(model_id, 'config.json')
cached_file = download_from_hf(model_id, 'config.json', cache_dir=cache_dir)

hf_config = load_cfg_from_json(cached_file)
if 'pretrained_cfg' not in hf_config:
Expand Down Expand Up @@ -172,6 +199,7 @@ def load_state_dict_from_hf(
model_id: str,
filename: str = HF_WEIGHTS_NAME,
weights_only: bool = False,
cache_dir: Optional[Union[str, Path]] = None,
):
assert has_hf_hub(True)
hf_model_id, hf_revision = hf_split(model_id)
Expand All @@ -180,7 +208,12 @@ def load_state_dict_from_hf(
if _has_safetensors:
for safe_filename in _get_safe_alternatives(filename):
try:
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
cached_safe_file = hf_hub_download(
repo_id=hf_model_id,
filename=safe_filename,
revision=hf_revision,
cache_dir=cache_dir,
)
_logger.info(
f"[{model_id}] Safe alternative available for '{filename}' "
f"(as '{safe_filename}'). Loading weights using safetensors.")
Expand All @@ -189,7 +222,12 @@ def load_state_dict_from_hf(
pass

# Otherwise, load using pytorch.load
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
cached_file = hf_hub_download(
hf_model_id,
filename=filename,
revision=hf_revision,
cache_dir=cache_dir,
)
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
try:
state_dict = torch.load(cached_file, map_location='cpu', weights_only=weights_only)
Expand All @@ -198,15 +236,25 @@ def load_state_dict_from_hf(
return state_dict


def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):
def load_custom_from_hf(
model_id: str,
filename: str,
model: torch.nn.Module,
cache_dir: Optional[Union[str, Path]] = None,
):
assert has_hf_hub(True)
hf_model_id, hf_revision = hf_split(model_id)
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
cached_file = hf_hub_download(
hf_model_id,
filename=filename,
revision=hf_revision,
cache_dir=cache_dir,
)
return model.load_pretrained(cached_file)


def save_config_for_hf(
model,
model: torch.nn.Module,
config_path: str,
model_config: Optional[dict] = None,
model_args: Optional[dict] = None
Expand Down Expand Up @@ -255,7 +303,7 @@ def save_config_for_hf(


def save_for_hf(
model,
model: torch.nn.Module,
save_directory: str,
model_config: Optional[dict] = None,
model_args: Optional[dict] = None,
Expand Down

0 comments on commit 71849b9

Please sign in to comment.