Skip to content

Commit

Permalink
New tools for model preset admin
Browse files Browse the repository at this point in the history
A tool to update all json files for a preset (by running them through
Keras' serialize and deserialize routines).

A tool to update all preset version in the library to the latest version
on kaggle.
  • Loading branch information
mattdangerw committed Dec 20, 2024
1 parent 9b024bd commit 39e5e14
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tools/admin/mirror_weights_on_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
kagglehub = None

HF_BASE_URI = "hf://keras"
JSON_FILE_PATH = "tools/hf_uploaded_presets.json"
JSON_FILE_PATH = "tools/admin/hf_uploaded_presets.json"
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")


Expand Down
133 changes: 133 additions & 0 deletions tools/admin/update_all_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Update all json files for all models on Kaggle.
Run tools/admin/update-all-versions.py before running this tool to make sure
all our kaggle links point to the latest version!
This script downloads all models from KaggleHub, loads and re-serializes all
json files, and reuploads them. This can be useful when changing our metadata or
updating our saved configs.
This script relies on private imports from preset_utils and may need updates
when it is re-run.
Usage:
```
# Preview changes.
python tools/admin/update_all_json.py
# Upload changes.
python tools/admin/update_all_json.py --upload
```
"""

import difflib
import os
import pathlib
import shutil

import kagglehub
from absl import app
from absl import flags

import keras_hub
from keras_hub.src.utils import preset_utils

FLAGS = flags.FLAGS
flags.DEFINE_boolean("upload", False, "Upload updated models.")


BOLD = "\033[1m"
GREEN = "\033[92m"
RED = "\033[91m"
RESET = "\033[0m"


def diff(in_path, out_path):
with open(in_path) as in_file, open(out_path) as out_file:
in_lines = in_file.readlines()
out_lines = out_file.readlines()
# Ignore updates to upload_date.
if "metadata.json" in in_path.name:
in_lines = [line for line in in_lines if "date" not in line]
out_lines = [line for line in out_lines if "date" not in line]
diff = difflib.unified_diff(
in_lines,
out_lines,
)
diff = list(diff)
if not diff:
return False
for line in diff:
if line.startswith("+"):
print(" " + GREEN + line + RESET, end="")
elif line.startswith("-"):
print(" " + RED + line + RESET, end="")
else:
print(" " + line, end="")
print()
return True


def main(argv):
presets = keras_hub.models.Backbone.presets
output_parent = pathlib.Path("updates")
output_parent.mkdir(parents=True, exist_ok=True)

for preset in sorted(presets.keys()):
handle = presets[preset]["kaggle_handle"].removeprefix("kaggle://")
handle_no_version = os.path.dirname(handle)
builtin_name = os.path.basename(handle_no_version)

# Download the full model with KaggleHub.
input_dir = kagglehub.model_download(handle)
input_dir = pathlib.Path(input_dir)
output_dir = output_parent / builtin_name
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
shutil.copytree(input_dir, output_dir)

# Manually create saver/loader objects.
config = preset_utils.load_json(preset, preset_utils.CONFIG_FILE)
loader = preset_utils.KerasPresetLoader(preset, config)
saver = preset_utils.KerasPresetSaver(output_dir)

# Update all json files.
print(BOLD + handle + RESET)
updated = False
for file in input_dir.glob("*.json"):
if file.name == preset_utils.METADATA_FILE:
# metadata.json is handled concurrently with config.json.
continue
print(" " + BOLD + file.name + RESET)
config = preset_utils.load_json(preset, file.name)
layer = loader._load_serialized_object(config)
saver._save_serialized_object(layer, file.name)
if file.name == preset_utils.CONFIG_FILE:
# Handle metadata.json with config.json.
print(" ", preset_utils.METADATA_FILE)
saver._save_metadata(layer)
name = preset_utils.METADATA_FILE
if diff(input_dir / name, output_dir / name):
updated = True
if diff(input_dir / file.name, output_dir / file.name):
updated = True
del layer

if not updated:
continue

# Reupload the model if any json files were updated.
if FLAGS.upload:
print(BOLD + "Uploading " + handle_no_version + RESET)
kagglehub.model_upload(
handle_no_version,
output_dir,
version_notes="updated json files",
)
else:
print(BOLD + "Preview. Not uploading " + handle_no_version + RESET)
print(BOLD + "Wait a few hours (for kaggle to process uploads)." + RESET)
print(BOLD + "Then run tasks/admin/update_all_versions.py" + RESET)


if __name__ == "__main__":
app.run(main)
40 changes: 40 additions & 0 deletions tools/admin/update_all_versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Update all preset files to use the latest version on kaggle.
Run from the base of the repo.
Usage:
```
python tools/admin/update_all_versions.py
```
"""

import os
import pathlib

import kagglehub

import keras_hub


def update():
presets = keras_hub.models.Backbone.presets
for preset in sorted(presets.keys()):
uri = presets[preset]["kaggle_handle"]
kaggle_handle = uri.removeprefix("kaggle://")
old_version = os.path.basename(kaggle_handle)
kaggle_handle = os.path.dirname(kaggle_handle)
hub_dir = kagglehub.model_download(kaggle_handle, path="metadata.json")
new_version = os.path.basename(os.path.dirname(hub_dir))
if old_version != new_version:
print(f"Updating {preset} from {old_version} to {new_version}")
for path in pathlib.Path(".").glob("keras_hub/**/*_presets.py"):
with open(path, "r") as file:
contents = file.read()
new_uri = os.path.dirname(uri) + f"/{new_version}"
contents = contents.replace(f'"{uri}"', f'"{new_uri}"')
with open(path, "w") as file:
file.write(contents)


if __name__ == "__main__":
update()

0 comments on commit 39e5e14

Please sign in to comment.