Skip to content

Commit

Permalink
[ONNX] Don't download ONNX model by default (#4338)
Browse files Browse the repository at this point in the history
* [Download] Don't download ONNX weights by default

* [Download] Don't download ONNX weights by default

* [Download] Don't download ONNX weights by default

* fix more

* finish

* finish

* finish
  • Loading branch information
patrickvonplaten authored and sayakpaul committed Jul 28, 2023
1 parent c63d7cd commit 9cde56a
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 1 deletion.
26 changes: 25 additions & 1 deletion src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ class DiffusionPipeline(ConfigMixin):
_optional_components = []
_exclude_from_cpu_offload = []
_load_connected_pipes = False
_is_onnx = False

def register_modules(self, **kwargs):
# import it here to avoid circular import
Expand Down Expand Up @@ -839,6 +840,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_onnx (`bool`, *optional*, defaults to `None`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
Expand Down Expand Up @@ -1268,6 +1274,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
variant (`str`, *optional*):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_onnx (`bool`, *optional*, defaults to `False`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`.
Returns:
`os.PathLike`:
Expand All @@ -1293,6 +1308,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
custom_revision = kwargs.pop("custom_revision", None)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)

if use_safetensors and not is_safetensors_available():
Expand Down Expand Up @@ -1364,7 +1380,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
pretrained_model_name, use_auth_token, variant, revision, model_filenames
)

model_folder_names = {os.path.split(f)[0] for f in model_filenames}
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}

# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
Expand Down Expand Up @@ -1411,6 +1427,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
):
ignore_patterns = ["*.bin", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if (
Expand All @@ -1423,6 +1443,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor

_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor

_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor

_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True

vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def preprocess(image):


class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
_is_onnx = True

def __init__(
self,
vae: OnnxRuntimeModel,
Expand Down
43 changes: 43 additions & 0 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,49 @@ def test_download_bin_index(self):
assert len([f for f in files if ".bin" in f]) == 8
assert not any(".safetensors" in f for f in files)

def test_download_no_openvino_by_default(self):
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-open-vino",
cache_dir=tmpdirname,
)

all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]

# make sure that by default no openvino weights are downloaded
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert not any("openvino_" in f for f in files)

def test_download_no_onnx_by_default(self):
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
cache_dir=tmpdirname,
)

all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]

# make sure that by default no onnx weights are downloaded
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files)

with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
cache_dir=tmpdirname,
use_onnx=True,
)

all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]

# if `use_onnx` is specified make sure weights are downloaded
assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert any((f.endswith(".onnx")) for f in files)
assert any((f.endswith(".pb")) for f in files)

def test_download_no_safety_checker(self):
prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained(
Expand Down

0 comments on commit 9cde56a

Please sign in to comment.