Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to upload Ludwig models to Predibase. #3687

Merged
merged 16 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ludwig/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self):
init_config Initialize a user config from a dataset and targets
render_config Renders the fully populated config with all defaults set
check_install Runs a quick training run on synthetic data to verify installation status
upload Push trained model artifacts to a registry (e.g., HuggingFace Hub)
upload Push trained model artifacts to a registry (e.g., HuggingFace Hub, Predibase)
""",
)
parser.add_argument("command", help="Subcommand to run")
Expand Down
13 changes: 12 additions & 1 deletion ludwig/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from typing import Optional

from ludwig.utils.print_utils import get_logging_level_registry
from ludwig.utils.upload_utils import HuggingFaceHub
from ludwig.utils.upload_utils import HuggingFaceHub, Predibase

logger = logging.getLogger(__name__)


def get_upload_registry():
return {
"hf_hub": HuggingFaceHub,
"predibase": Predibase,
}


Expand All @@ -23,6 +24,8 @@ def upload_cli(
private: bool = False,
commit_message: str = "Upload trained [Ludwig](https://ludwig.ai/latest/) model weights",
commit_description: Optional[str] = None,
dataset_file: Optional[str] = None,
dataset_name: Optional[str] = None,
**kwargs,
) -> None:
"""Create an empty repo on the HuggingFace Hub and upload trained model artifacts to that repo.
Expand All @@ -49,6 +52,12 @@ def upload_cli(
`f"Upload {path_in_repo} with huggingface_hub"`
commit_description (`str` *optional*):
The description of the generated commit
dataset_file (`str`, *optional*):
The path to the dataset file. Required if `service` is set to
`"predibase"`.
dataset_name (`str`, *optional*):
The name of the dataset. Used by the `service`
`"predibase"`.
"""
model_service = get_upload_registry().get(service, "hf_hub")
hub = model_service()
Expand All @@ -60,6 +69,8 @@ def upload_cli(
private=private,
commit_message=commit_message,
commit_description=commit_description,
dataset_file=dataset_file,
dataset_name=dataset_name,
)


Expand Down
92 changes: 92 additions & 0 deletions ludwig/utils/upload_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
import os
import tempfile
import zipfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional

from huggingface_hub import HfApi, login
Expand Down Expand Up @@ -37,6 +40,8 @@ def upload(
private: Optional[bool] = False,
commit_message: Optional[str] = None,
commit_description: Optional[str] = None,
dataset_file: Optional[str] = None,
dataset_name: Optional[str] = None,
) -> bool:
"""Abstract method to upload trained model artifacts to the target repository.

Expand All @@ -61,6 +66,8 @@ def _validate_upload_parameters(
private: Optional[bool] = False,
commit_message: Optional[str] = None,
commit_description: Optional[str] = None,
dataset_file: Optional[str] = None,
dataset_name: Optional[str] = None,
):
"""Validate parameters before uploading trained model artifacts.

Expand Down Expand Up @@ -205,3 +212,88 @@ def upload(
return True

return False


class Predibase(BaseModelUpload):
def __init__(self):
self.pc = None
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved

def login(self):
"""Login to Predibase using the token stored in the PREDIBASE_API_TOKEN environment variable and returns a
PredibaseClient object that can be used to interact with Predibase."""
from predibase import PredibaseClient

try:
pc = PredibaseClient()

# TODO: Check if subscription has expired

self.pc = pc
except Exception as e:
raise Exception(f"Failed to login to Predibase: {e}")
return False

return True

def upload(
self,
repo_id: str,
model_path: str,
commit_description: Optional[str] = None,
dataset_file: Optional[str] = None,
dataset_name: Optional[str] = None,
) -> bool:
"""Create an empty repo in Predibase and upload trained model artifacts to that repo.

Args:
model_path (`str`):
The path of the saved model. This is the top level directory where
the models weights as well as other associated training artifacts
are saved.
repo_name (`str`):
A repo name.
repo_description (`str` *optional*):
The description of the repo.
"""
# Create empty model repo using repo_name, but it is okay if it already exists.
try:
repo = self.pc.create_repo(
name=repo_id,
description=commit_description,
exist_ok=True,
)
except Exception as e:
raise Exception(f"Failed to create repo in Predibase: {e}")
return True

# Upload the dataset to Predibase
try:
dataset = self.pc.upload_dataset(
martindavis marked this conversation as resolved.
Show resolved Hide resolved
file_path=dataset_file,
name=dataset_name,
)
except Exception as e:
raise Exception(f"Failed to upload dataset to Predibase: {e}")
return True

with tempfile.TemporaryDirectory() as tmpdir:
martindavis marked this conversation as resolved.
Show resolved Hide resolved
# Create a zip file of the model weights folder
model_data_zip_path = os.path.join(tmpdir, "model_data.zip")
fp_zip = Path(model_data_zip_path)
path_to_archive = Path(model_path)
with zipfile.ZipFile(fp_zip, "w", zipfile.ZIP_DEFLATED) as zipf:
for fp in path_to_archive.glob("**/*"):
zipf.write(fp, arcname=fp.relative_to(path_to_archive))

# Upload the zip file to Predibase
try:
self.pc.upload_model(
repo_id=repo.id,
zip_fp=model_data_zip_path,
dataset=dataset,
)
except Exception as e:
raise Exception(f"Failed to upload model to Predibase: {e}")
martindavis marked this conversation as resolved.
Show resolved Hide resolved
return True

return False
martindavis marked this conversation as resolved.
Show resolved Hide resolved
Loading