From 396c3e7d62b682114f164e807bccdd2def63f603 Mon Sep 17 00:00:00 2001 From: unmonoqueteclea Date: Tue, 13 Feb 2024 10:51:23 +0100 Subject: [PATCH] test: fix test that tried to use tensorflowjs --- sensenet/importers.py | 6 +++++- sensenet/models/wrappers.py | 2 +- tests/test_export.py | 20 +++++++++++--------- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/sensenet/importers.py b/sensenet/importers.py index 3d7f013..385d70f 100755 --- a/sensenet/importers.py +++ b/sensenet/importers.py @@ -10,6 +10,7 @@ from sensenet import __tree_ext_prefix__ +logger = logging.getLogger(__name__) logging.getLogger("tensorflow").setLevel(logging.ERROR) warnings.filterwarnings("ignore", message=".*binary incompatibility.*") @@ -33,7 +34,10 @@ # but it is not mandatory for Sensenet to work try: import tensorflowjs - except: # noqa: E722 + except Exception as e: # noqa: E722 + logger.info( + f"tensorflowjs not found, you can't export models to JS: {e}" + ) tensorflowjs = None bigml_tf_module = None diff --git a/sensenet/models/wrappers.py b/sensenet/models/wrappers.py index 3091af8..c92980f 100644 --- a/sensenet/models/wrappers.py +++ b/sensenet/models/wrappers.py @@ -65,7 +65,7 @@ def write_tfjs_files(self, model_path, save_path): with warnings.catch_warnings(): warnings.filterwarnings("ignore", message=".*alias for the.*") with suppress_stdout(): - if tfjs: + if tfjs and tfjs.converters: tfjs.converters.convert_tf_saved_model( model_path, save_path, skip_op_check=True ) diff --git a/tests/test_export.py b/tests/test_export.py index 34ebf2b..de036c9 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -3,19 +3,22 @@ np = sensenet.importers.import_numpy() tf = sensenet.importers.import_tensorflow() -import os -import shutil import gzip import json - -from PIL import Image +import os +import pathlib +import shutil from sensenet.constants import WARP -from sensenet.models.wrappers import Deepnet, ObjectDetector -from sensenet.models.wrappers import convert, tflite_predict +from sensenet.models.wrappers import ( + Deepnet, + ObjectDetector, + convert, + tflite_predict, +) -from .utils import TEST_DATA_DIR, TEST_IMAGE_DATA from .test_pretrained import create_image_model +from .utils import TEST_DATA_DIR, TEST_IMAGE_DATA MOBILENET_PATH = os.path.join(TEST_DATA_DIR, "mobilenetv2.json.gz") TEST_SAVE_MODEL = os.path.join(TEST_DATA_DIR, "test_model_save") @@ -118,8 +121,7 @@ def test_all_conversions(): for aformat in ["tflite", "tfjs", "smbundle", "h5"]: outpath = TEST_SAVE_MODEL + "." + aformat convert(jmodel, None, outpath, aformat) - - if aformat == "tfjs": + if aformat == "tfjs" and pathlib.Path(outpath).exists(): shutil.rmtree(outpath) else: os.remove(outpath)