diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 4b4e3bc3b..b1a295cbc 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -240,7 +240,7 @@ def tf_copy_gfile_to_cache(preset, path): try: import tensorflow as tf - os.make_dirs(os.path.dirname(local_path), exist_ok=True) + os.makedirs(os.path.dirname(local_path), exist_ok=True) tf.io.gfile.copy(url, local_path) except Exception as e: # gfile.copy will leave an empty file after an error. diff --git a/keras_hub/src/utils/preset_utils_test.py b/keras_hub/src/utils/preset_utils_test.py index 787a1ea43..998dcadfa 100644 --- a/keras_hub/src/utils/preset_utils_test.py +++ b/keras_hub/src/utils/preset_utils_test.py @@ -33,6 +33,18 @@ def test_preset_errors(self): with self.assertRaisesRegex(ValueError, "class keras_hub>BortBackbone"): BertBackbone.from_preset(preset_dir) + @pytest.mark.large + def test_tf_file_io(self): + # Load a model from Kaggle to use as a test model. + preset = "bert_tiny_en_uncased" + backbone = BertBackbone.from_preset(preset) + # Save the model on a local directory. + temp_dir = self.get_temp_dir() + local_preset_dir = os.path.join(temp_dir, "bert_preset") + backbone.save_to_preset(local_preset_dir) + # Load with "file://" which tf supports. + backbone = BertBackbone.from_preset("file://" + local_preset_dir) + @pytest.mark.large def test_upload_empty_preset(self): temp_dir = self.get_temp_dir()