forked from keras-team/keras-hub
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
A tool to update all json files for a preset (by running them through Keras' serialize and deserialize routines). A tool to update all preset version in the library to the latest version on kaggle.
- Loading branch information
1 parent
9b024bd
commit 39e5e14
Showing
3 changed files
with
174 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
"""Update all json files for all models on Kaggle. | ||
Run tools/admin/update-all-versions.py before running this tool to make sure | ||
all our kaggle links point to the latest version! | ||
This script downloads all models from KaggleHub, loads and re-serializes all | ||
json files, and reuploads them. This can be useful when changing our metadata or | ||
updating our saved configs. | ||
This script relies on private imports from preset_utils and may need updates | ||
when it is re-run. | ||
Usage: | ||
``` | ||
# Preview changes. | ||
python tools/admin/update_all_json.py | ||
# Upload changes. | ||
python tools/admin/update_all_json.py --upload | ||
``` | ||
""" | ||
|
||
import difflib | ||
import os | ||
import pathlib | ||
import shutil | ||
|
||
import kagglehub | ||
from absl import app | ||
from absl import flags | ||
|
||
import keras_hub | ||
from keras_hub.src.utils import preset_utils | ||
|
||
FLAGS = flags.FLAGS | ||
flags.DEFINE_boolean("upload", False, "Upload updated models.") | ||
|
||
|
||
BOLD = "\033[1m" | ||
GREEN = "\033[92m" | ||
RED = "\033[91m" | ||
RESET = "\033[0m" | ||
|
||
|
||
def diff(in_path, out_path): | ||
with open(in_path) as in_file, open(out_path) as out_file: | ||
in_lines = in_file.readlines() | ||
out_lines = out_file.readlines() | ||
# Ignore updates to upload_date. | ||
if "metadata.json" in in_path.name: | ||
in_lines = [line for line in in_lines if "date" not in line] | ||
out_lines = [line for line in out_lines if "date" not in line] | ||
diff = difflib.unified_diff( | ||
in_lines, | ||
out_lines, | ||
) | ||
diff = list(diff) | ||
if not diff: | ||
return False | ||
for line in diff: | ||
if line.startswith("+"): | ||
print(" " + GREEN + line + RESET, end="") | ||
elif line.startswith("-"): | ||
print(" " + RED + line + RESET, end="") | ||
else: | ||
print(" " + line, end="") | ||
print() | ||
return True | ||
|
||
|
||
def main(argv): | ||
presets = keras_hub.models.Backbone.presets | ||
output_parent = pathlib.Path("updates") | ||
output_parent.mkdir(parents=True, exist_ok=True) | ||
|
||
for preset in sorted(presets.keys()): | ||
handle = presets[preset]["kaggle_handle"].removeprefix("kaggle://") | ||
handle_no_version = os.path.dirname(handle) | ||
builtin_name = os.path.basename(handle_no_version) | ||
|
||
# Download the full model with KaggleHub. | ||
input_dir = kagglehub.model_download(handle) | ||
input_dir = pathlib.Path(input_dir) | ||
output_dir = output_parent / builtin_name | ||
if os.path.exists(output_dir): | ||
shutil.rmtree(output_dir) | ||
shutil.copytree(input_dir, output_dir) | ||
|
||
# Manually create saver/loader objects. | ||
config = preset_utils.load_json(preset, preset_utils.CONFIG_FILE) | ||
loader = preset_utils.KerasPresetLoader(preset, config) | ||
saver = preset_utils.KerasPresetSaver(output_dir) | ||
|
||
# Update all json files. | ||
print(BOLD + handle + RESET) | ||
updated = False | ||
for file in input_dir.glob("*.json"): | ||
if file.name == preset_utils.METADATA_FILE: | ||
# metadata.json is handled concurrently with config.json. | ||
continue | ||
print(" " + BOLD + file.name + RESET) | ||
config = preset_utils.load_json(preset, file.name) | ||
layer = loader._load_serialized_object(config) | ||
saver._save_serialized_object(layer, file.name) | ||
if file.name == preset_utils.CONFIG_FILE: | ||
# Handle metadata.json with config.json. | ||
print(" ", preset_utils.METADATA_FILE) | ||
saver._save_metadata(layer) | ||
name = preset_utils.METADATA_FILE | ||
if diff(input_dir / name, output_dir / name): | ||
updated = True | ||
if diff(input_dir / file.name, output_dir / file.name): | ||
updated = True | ||
del layer | ||
|
||
if not updated: | ||
continue | ||
|
||
# Reupload the model if any json files were updated. | ||
if FLAGS.upload: | ||
print(BOLD + "Uploading " + handle_no_version + RESET) | ||
kagglehub.model_upload( | ||
handle_no_version, | ||
output_dir, | ||
version_notes="updated json files", | ||
) | ||
else: | ||
print(BOLD + "Preview. Not uploading " + handle_no_version + RESET) | ||
print(BOLD + "Wait a few hours (for kaggle to process uploads)." + RESET) | ||
print(BOLD + "Then run tasks/admin/update_all_versions.py" + RESET) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
"""Update all preset files to use the latest version on kaggle. | ||
Run from the base of the repo. | ||
Usage: | ||
``` | ||
python tools/admin/update_all_versions.py | ||
``` | ||
""" | ||
|
||
import os | ||
import pathlib | ||
|
||
import kagglehub | ||
|
||
import keras_hub | ||
|
||
|
||
def update(): | ||
presets = keras_hub.models.Backbone.presets | ||
for preset in sorted(presets.keys()): | ||
uri = presets[preset]["kaggle_handle"] | ||
kaggle_handle = uri.removeprefix("kaggle://") | ||
old_version = os.path.basename(kaggle_handle) | ||
kaggle_handle = os.path.dirname(kaggle_handle) | ||
hub_dir = kagglehub.model_download(kaggle_handle, path="metadata.json") | ||
new_version = os.path.basename(os.path.dirname(hub_dir)) | ||
if old_version != new_version: | ||
print(f"Updating {preset} from {old_version} to {new_version}") | ||
for path in pathlib.Path(".").glob("keras_hub/**/*_presets.py"): | ||
with open(path, "r") as file: | ||
contents = file.read() | ||
new_uri = os.path.dirname(uri) + f"/{new_version}" | ||
contents = contents.replace(f'"{uri}"', f'"{new_uri}"') | ||
with open(path, "w") as file: | ||
file.write(contents) | ||
|
||
|
||
if __name__ == "__main__": | ||
update() |