From 9cde56a7291c004003b278e5dee3861b3931f886 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 28 Jul 2023 14:02:48 +0200 Subject: [PATCH] [ONNX] Don't download ONNX model by default (#4338) * [Download] Don't download ONNX weights by default * [Download] Don't download ONNX weights by default * [Download] Don't download ONNX weights by default * fix more * finish * finish * finish --- src/diffusers/pipelines/pipeline_utils.py | 26 ++++++++++- .../pipeline_onnx_stable_diffusion.py | 1 + .../pipeline_onnx_stable_diffusion_img2img.py | 1 + .../pipeline_onnx_stable_diffusion_inpaint.py | 1 + ...ne_onnx_stable_diffusion_inpaint_legacy.py | 1 + .../pipeline_onnx_stable_diffusion_upscale.py | 2 + tests/pipelines/test_pipelines.py | 43 +++++++++++++++++++ 7 files changed, 74 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 21ba7a320ff0..133bf3a7a2f8 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -494,6 +494,7 @@ class DiffusionPipeline(ConfigMixin): _optional_components = [] _exclude_from_cpu_offload = [] _load_connected_pipes = False + _is_onnx = False def register_modules(self, **kwargs): # import it here to avoid circular import @@ -839,6 +840,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. + use_onnx (`bool`, *optional*, defaults to `None`): + If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights + will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is + `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending + with `.onnx` and `.pb`. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline class). The overwritten components are passed directly to the pipelines `__init__` method. See example @@ -1268,6 +1274,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + use_onnx (`bool`, *optional*, defaults to `False`): + If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights + will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is + `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending + with `.onnx` and `.pb`. Returns: `os.PathLike`: @@ -1293,6 +1308,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: custom_revision = kwargs.pop("custom_revision", None) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) + use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) if use_safetensors and not is_safetensors_available(): @@ -1364,7 +1380,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: pretrained_model_name, use_auth_token, variant, revision, model_filenames ) - model_folder_names = {os.path.split(f)[0] for f in model_filenames} + model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} # all filenames compatible with variant will be added allow_patterns = list(model_filenames) @@ -1411,6 +1427,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: ): ignore_patterns = ["*.bin", "*.msgpack"] + use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx + if not use_onnx: + ignore_patterns += ["*.onnx", "*.pb"] + safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} if ( @@ -1423,6 +1443,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: else: ignore_patterns = ["*.safetensors", "*.msgpack"] + use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx + if not use_onnx: + ignore_patterns += ["*.onnx", "*.pb"] + bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index eb02f6cb321c..6c8ff7fe78df 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -41,6 +41,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True def __init__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index 293ed7d981b8..508085094b16 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -98,6 +98,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True def __init__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index 0bb39c4b1c61..4856babce807 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -90,6 +90,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True def __init__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py index 8ef7a781451c..a4b54b9724fb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py @@ -67,6 +67,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py index 56681391aeeb..93e86def7a05 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -46,6 +46,8 @@ def preprocess(image): class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): + _is_onnx = True + def __init__( self, vae: OnnxRuntimeModel, diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 6ad8241698b0..5ce2316c9b19 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -310,6 +310,49 @@ def test_download_bin_index(self): assert len([f for f in files if ".bin" in f]) == 8 assert not any(".safetensors" in f for f in files) + def test_download_no_openvino_by_default(self): + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-open-vino", + cache_dir=tmpdirname, + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] + files = [item for sublist in all_root_files for item in sublist] + + # make sure that by default no openvino weights are downloaded + assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files) + assert not any("openvino_" in f for f in files) + + def test_download_no_onnx_by_default(self): + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = DiffusionPipeline.download( + "hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline", + cache_dir=tmpdirname, + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] + files = [item for sublist in all_root_files for item in sublist] + + # make sure that by default no onnx weights are downloaded + assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files) + assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files) + + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = DiffusionPipeline.download( + "hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline", + cache_dir=tmpdirname, + use_onnx=True, + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] + files = [item for sublist in all_root_files for item in sublist] + + # if `use_onnx` is specified make sure weights are downloaded + assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files) + assert any((f.endswith(".onnx")) for f in files) + assert any((f.endswith(".pb")) for f in files) + def test_download_no_safety_checker(self): prompt = "hello" pipe = StableDiffusionPipeline.from_pretrained(