diff --git a/modelconverter/__init__.py b/modelconverter/__init__.py index 0acfaec..ee884c7 100644 --- a/modelconverter/__init__.py +++ b/modelconverter/__init__.py @@ -1,10 +1,12 @@ import pkg_resources from luxonis_ml.utils import PUT_FILE_REGISTRY -from .hub import * +from .hub import convert __version__ = "0.3.1" +__all__ = ["convert"] + def load_put_file_plugins() -> None: """Registers any external put file plugins.""" diff --git a/modelconverter/cli/types.py b/modelconverter/cli/types.py index f8a2026..3ceb300 100644 --- a/modelconverter/cli/types.py +++ b/modelconverter/cli/types.py @@ -224,11 +224,6 @@ class Status(str, Enum): ), ] -TeamIDOption = Annotated[ - Optional[str], - typer.Option(help="The team ID", show_default=False), -] - RepositoryUrlOption = Annotated[ Optional[str], typer.Option(help="The repository URL", show_default=False), @@ -294,9 +289,12 @@ class Status(str, Enum): str, typer.Argument(help="Name of the model", show_default=False) ] -UserIDOption = Annotated[ - Optional[str], - typer.Option(help="The user ID", show_default=False), +IsOwnerOption = Annotated[ + bool, + typer.Option( + help="Whether the user is the owner of the resource", + show_default=False, + ), ] ArchitectureIDOption = Annotated[ @@ -408,13 +406,6 @@ class Status(str, Enum): typer.Option(help="The project ID", show_default=False), ] -FilterPublicEntityByTeamIDOption = Annotated[ - Optional[bool], - typer.Option( - help="Whether to filter public entity by team ID", show_default=False - ), -] - LuxonisOnlyOption = Annotated[ bool, typer.Option(help="Whether Luxonis only models", show_default=False), diff --git a/modelconverter/cli/utils.py b/modelconverter/cli/utils.py index 1fad5d3..7c27c6e 100644 --- a/modelconverter/cli/utils.py +++ b/modelconverter/cli/utils.py @@ -249,7 +249,7 @@ def hub_ls( **kwargs, ) -> List[Dict[str, Any]]: rename = rename or {} - data = Request.get(f"{endpoint}/", params=kwargs).json() + data = Request.get(f"{endpoint}/", params=kwargs) table = Table(row_styles=["yellow", "cyan"], box=ROUNDED) for key in keys: table.add_column(rename.get(key, key), header_style="magenta i") @@ -285,7 +285,7 @@ def slug_to_id( "is_public": is_public, "slug": slug, } - data = Request.get(f"{endpoint}/", params=params).json() + data = Request.get(f"{endpoint}/", params=params) if data: return data[0]["id"] raise ValueError(f"Model with slug '{slug}' not found.") @@ -307,7 +307,7 @@ def request_info( resource_id = get_resource_id(identifier, endpoint) try: - return Request.get(f"{endpoint}/{resource_id}/").json() + return Request.get(f"{endpoint}/{resource_id}/") except HTTPError: typer.echo(f"Resource with ID '{resource_id}' not found.") exit(1) @@ -333,26 +333,21 @@ def get_variant_name( def get_version_number(model_id: str) -> str: - versions = Request.get( - "modelVersions/", params={"model_id": model_id} - ).json() + versions = Request.get("modelVersions/", params={"model_id": model_id}) if not versions: - version = "0.1.0" - else: - max_version = Version(versions[0]["version"]) - for v in versions[1:]: - max_version = max(max_version, Version(v["version"])) - max_version = str(max_version) - version_numbers = max_version.split(".") - version_numbers[-1] = str(int(version_numbers[-1]) + 1) - version = ".".join(version_numbers) - return version + return "0.1.0" + max_version = Version(versions[0]["version"]) + for v in versions[1:]: + max_version = max(max_version, Version(v["version"])) + max_version = str(max_version) + version_numbers = max_version.split(".") + version_numbers[-1] = str(int(version_numbers[-1]) + 1) + return ".".join(version_numbers) def wait_for_export(run_id: str) -> None: def _get_run(run_id: str) -> Dict[str, Any]: - run = Request.dag_get(f"runs/{run_id}").json() - return run + return Request.dag_get(f"runs/{run_id}") def _clean_logs(logs: str) -> str: pattern = r"\[.*?\] \{.*?\} INFO - \[base\] logs:\s*" diff --git a/modelconverter/hub/__main__.py b/modelconverter/hub/__main__.py index 39fa368..2abbca9 100644 --- a/modelconverter/hub/__main__.py +++ b/modelconverter/hub/__main__.py @@ -20,11 +20,11 @@ DescriptionOption, DescriptionShortOption, DomainOption, - FilterPublicEntityByTeamIDOption, HashOption, HubVersionOption, HubVersionOptionRequired, IdentifierArgument, + IsOwnerOption, IsPublicOption, JSONOption, LicenseTypeOption, @@ -58,8 +58,6 @@ TagsOption, TargetPrecisionOption, TasksOption, - TeamIDOption, - UserIDOption, VariantSlugOption, VersionOption, get_configs, @@ -140,14 +138,12 @@ def login( @model.command(name="ls") def model_ls( - team_id: TeamIDOption = None, tasks: TasksOption = None, - user_id: UserIDOption = None, license_type: LicenseTypeOption = None, is_public: IsPublicOption = None, + is_owner: IsOwnerOption = True, slug: SlugOption = None, project_id: ProjectIDOption = None, - filter_public_entity_by_team_id: FilterPublicEntityByTeamIDOption = None, luxonis_only: LuxonisOnlyOption = False, limit: LimitOption = 50, sort: SortOption = "updated", @@ -156,14 +152,12 @@ def model_ls( """Lists models.""" return hub_ls( "models", - team_id=team_id, tasks=[task for task in tasks] if tasks else [], - user_id=user_id, license_type=license_type, is_public=is_public, + is_owner=is_owner, slug=slug, project_id=project_id, - filter_public_entity_by_team_id=filter_public_entity_by_team_id, luxonis_only=luxonis_only, limit=limit, sort=sort, @@ -224,7 +218,7 @@ def model_create( "links": links or [], } try: - res = Request.post("models", json=data).json() + res = Request.post("models", json=data) except requests.HTTPError as e: if ( e.response is not None @@ -248,13 +242,12 @@ def model_delete(identifier: IdentifierArgument): @variant.command(name="ls") def variant_ls( - team_id: TeamIDOption = None, - user_id: UserIDOption = None, model_id: ModelIDOption = None, slug: SlugOption = None, variant_slug: VariantSlugOption = None, version: HubVersionOption = None, is_public: IsPublicOption = None, + is_owner: IsOwnerOption = True, limit: LimitOption = 50, sort: SortOption = "updated", order: OrderOption = Order.DESC, @@ -262,10 +255,9 @@ def variant_ls( """Lists model versions.""" return hub_ls( "modelVersions", - team_id=team_id, - user_id=user_id, model_id=model_id, is_public=is_public, + is_owner=is_owner, slug=slug, variant_slug=variant_slug, version=version, @@ -324,7 +316,7 @@ def variant_create( "tags": tags or [], } try: - res = Request.post("modelVersions", json=data).json() + res = Request.post("modelVersions", json=data) except requests.HTTPError as e: if str(e).startswith("{'detail': 'Unique constraint error."): raise ValueError( @@ -348,8 +340,6 @@ def variant_delete(identifier: IdentifierArgument): @instance.command(name="ls") def instance_ls( platforms: PlatformsOption = None, - team_id: TeamIDOption = None, - user_id: UserIDOption = None, model_id: ModelIDOption = None, variant_id: ModelVersionIDOption = None, model_type: ModelTypeOption = None, @@ -359,6 +349,7 @@ def instance_ls( hash: HashOption = None, status: StatusOption = None, is_public: IsPublicOption = None, + is_owner: IsOwnerOption = True, compression_level: CompressionLevelOption = None, optimization_level: OptimizationLevelOption = None, slug: SlugOption = None, @@ -382,9 +373,8 @@ def instance_ls( status=status, compression_level=compression_level, optimization_level=optimization_level, - team_id=team_id, - user_id=user_id, is_public=is_public, + is_owner=is_owner, slug=slug, limit=limit, sort=sort, @@ -434,7 +424,7 @@ def instance_download( dest = Path(output_dir) if output_dir else None model_instance_id = get_resource_id(identifier, "modelInstances") downloaded_path = None - urls = Request.get(f"modelInstances/{model_instance_id}/download").json() + urls = Request.get(f"modelInstances/{model_instance_id}/download") if not urls: raise ValueError("No files to download") @@ -445,9 +435,9 @@ def instance_download( filename = unquote(Path(urlparse(url).path).name) if dest is None: dest = Path( - Request.get(f"modelInstances/{model_instance_id}") - .json() - .get("slug", model_instance_id) + Request.get(f"modelInstances/{model_instance_id}").get( + "slug", model_instance_id + ) ) dest.mkdir(parents=True, exist_ok=True) @@ -487,7 +477,7 @@ def instance_create( "quantization_data": quantization_data, "is_deployable": is_deployable, } - res = Request.post("modelInstances", json=data).json() + res = Request.post("modelInstances", json=data) print(f"Model instance '{res['name']}' created with ID '{res['id']}'") if not silent: instance_info(res["id"]) @@ -506,16 +496,14 @@ def instance_delete(identifier: IdentifierArgument): def config(identifier: IdentifierArgument): """Prints the configuration of a model instance.""" model_instance_id = get_resource_id(identifier, "modelInstances") - res = Request.get(f"modelInstances/{model_instance_id}/config") - print(res.json()) + print(Request.get(f"modelInstances/{model_instance_id}/config")) @instance.command() def files(identifier: IdentifierArgument): """Prints the configuration of a model instance.""" model_instance_id = get_resource_id(identifier, "modelInstances") - res = Request.get(f"modelInstances/{model_instance_id}/files") - print(res.json()) + print(Request.get(f"modelInstances/{model_instance_id}/files")) @instance.command() @@ -548,7 +536,7 @@ def _export( res = Request.post( f"modelInstances/{model_instance_id}/export/{target.lower()}", json=json, - ).json() + ) print( f"Model instance '{name}' created for {target} export with ID '{res['id']}'" ) diff --git a/modelconverter/hub/hub_requests.py b/modelconverter/hub/hub_requests.py index 1027bb3..9143243 100644 --- a/modelconverter/hub/hub_requests.py +++ b/modelconverter/hub/hub_requests.py @@ -1,5 +1,5 @@ from json import JSONDecodeError -from typing import Dict, Final, Optional +from typing import Any, Dict, Final, Optional import requests from requests import HTTPError, Response @@ -8,26 +8,36 @@ class Request: - URL: Final[str] = f"{environ.HUBAI_URL.rstrip('/')}/api/v1" + URL: Final[str] = f"{environ.HUBAI_URL.rstrip('/')}/models/api/v1" DAG_URL: Final[str] = URL.replace("models", "dags") HEADERS: Final[Dict[str, str]] = { "accept": "application/json", "Authorization": f"Bearer {environ.HUBAI_API_KEY}", } + @staticmethod + def _process_response(response: Response) -> Any: + return Request._get_json(Request._check_response(response)) + @staticmethod def _check_response(response: Response) -> Response: if response.status_code >= 400: - try: - json = response.json() - raise HTTPError(json, response=response) - except JSONDecodeError as e: - raise HTTPError(response.text) from e + raise HTTPError(Request._get_json(response), response=response) return response @staticmethod - def get(endpoint: str = "", **kwargs) -> requests.Response: - return Request._check_response( + def _get_json(response: Response) -> Any: + try: + return response.json() + except JSONDecodeError as e: + raise HTTPError( + f"Unexpected response from the server:\n{response.text}", + response=response, + ) from e + + @staticmethod + def get(endpoint: str = "", **kwargs) -> Any: + return Request._process_response( requests.get( Request._get_url(endpoint), headers=Request.HEADERS, @@ -36,8 +46,8 @@ def get(endpoint: str = "", **kwargs) -> requests.Response: ) @staticmethod - def dag_get(endpoint: str = "", **kwargs) -> requests.Response: - return Request._check_response( + def dag_get(endpoint: str = "", **kwargs) -> Any: + return Request._process_response( requests.get( Request._get_url(endpoint, Request.DAG_URL), headers=Request.HEADERS, @@ -46,19 +56,19 @@ def dag_get(endpoint: str = "", **kwargs) -> requests.Response: ) @staticmethod - def post(endpoint: str = "", **kwargs) -> requests.Response: + def post(endpoint: str = "", **kwargs) -> Any: headers = Request.HEADERS if "headers" in kwargs: headers = {**Request.HEADERS, **kwargs.pop("headers")} - return Request._check_response( + return Request._process_response( requests.post( Request._get_url(endpoint), headers=headers, **kwargs ) ) @staticmethod - def delete(endpoint: str = "", **kwargs) -> requests.Response: - return Request._check_response( + def delete(endpoint: str = "", **kwargs) -> Any: + return Request._process_response( requests.delete( Request._get_url(endpoint), headers=Request.HEADERS, **kwargs )