From 5eb380d1a950949d827d63f094ff79ae9589aa1e Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Fri, 19 Apr 2024 13:13:02 -0400 Subject: [PATCH 01/16] Added tracing methods --- .vscode/settings.json | 11 +++++++ src/kagglehub/__init__.py | 2 +- src/kagglehub/{logging.py => logger.py} | 0 src/kagglehub/tracing.py | 30 +++++++++++++++++ src/kagglehub/tracing_test.py | 43 +++++++++++++++++++++++++ 5 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 .vscode/settings.json rename src/kagglehub/{logging.py => logger.py} (100%) create mode 100644 src/kagglehub/tracing.py create mode 100644 src/kagglehub/tracing_test.py diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..e4c9f29a --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,11 @@ +{ + "python.testing.unittestArgs": [ + "-v", + "-s", + "./src", + "-p", + "*test*.py" + ], + "python.testing.pytestEnabled": false, + "python.testing.unittestEnabled": true +} \ No newline at end of file diff --git a/src/kagglehub/__init__.py b/src/kagglehub/__init__.py index a0bb3888..66c636e8 100644 --- a/src/kagglehub/__init__.py +++ b/src/kagglehub/__init__.py @@ -1,6 +1,6 @@ __version__ = "0.2.3" -import kagglehub.logging # configures the library logger. +import kagglehub.logger # configures the library logger. from kagglehub import colab_cache_resolver, http_resolver, kaggle_cache_resolver, registry from kagglehub.auth import login from kagglehub.models import model_download, model_upload diff --git a/src/kagglehub/logging.py b/src/kagglehub/logger.py similarity index 100% rename from src/kagglehub/logging.py rename to src/kagglehub/logger.py diff --git a/src/kagglehub/tracing.py b/src/kagglehub/tracing.py new file mode 100644 index 00000000..d8dfc2c6 --- /dev/null +++ b/src/kagglehub/tracing.py @@ -0,0 +1,30 @@ +import secrets + +class TraceContext: + """ + Generates and manages identifiers for distributed tracing. + + More information on trace can be found at https://www.w3.org/TR/trace-context/ + + Attributes: + trace: A 16-byte hexadecimal string representing a unique trace ID. + """ + def __init__(self) -> None: + self.trace = secrets.token_bytes(16).hex() + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + pass + + def next(self): + """ + Generates a new span ID within the context of the current trace ID. + + Returns: + A formatted string representing a span ID, in a standard + distributed tracing format (e.g., "00-{trace}-{span}-01") + """ + span = secrets.token_bytes(8).hex() + return f"00-{self.trace}-{span}-01" diff --git a/src/kagglehub/tracing_test.py b/src/kagglehub/tracing_test.py new file mode 100644 index 00000000..09ae9599 --- /dev/null +++ b/src/kagglehub/tracing_test.py @@ -0,0 +1,43 @@ +import unittest + +from kagglehub.tracing import TraceContext + +_CANONICAL_EXAMPLE = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01" + +class TraceContextSuite(unittest.TestCase): + def test_length(self): + with TraceContext() as ctx: + traceparent = ctx.next() + self.assertEqual(len(traceparent), len(_CANONICAL_EXAMPLE)) + + def test_prefix(self): + with TraceContext() as ctx: + traceparent = ctx.next() + self.assertEqual(traceparent[0:2], "00") + + def test_suffix(self): + # always sample + with TraceContext() as ctx: + traceparent = ctx.next() + self.assertEqual(traceparent[-2:], "01") + + def test_pattern(self): + with TraceContext() as ctx: + traceparent = ctx.next() + version,trace,span,flag = traceparent.split("-") + print(version,trace,span,flag) + self.assertRegex(version, "^[0-9]{2}$", "version does not meet pattern") + self.assertRegex(trace, "^[A-Fa-f0-9]{32}$") + self.assertRegex(span, "^[A-Fa-f0-9]{16}$") + self.assertRegex(flag, "^[0-9]{2}$") + + def test_notempty(self): + with TraceContext() as ctx: + traceparent = ctx.next() + _,trace,span,_ = traceparent.split("-") + self.assertNotEqual(trace, f"{0:016x}") + self.assertNotEqual(span, f"{0:08x}") + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From d46f42db413c811d8b1677034aef7763f2da6cbb Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Mon, 22 Apr 2024 11:35:54 -0400 Subject: [PATCH 02/16] Added formatter --- .vscode/launch.json | 15 +++++++ .vscode/settings.json | 6 ++- src/kagglehub/clients.py | 60 ++++++++++++++++++++-------- src/kagglehub/gcs_upload.py | 54 +++++++++++++++++++------- src/kagglehub/models.py | 12 ++++-- src/kagglehub/models_helpers.py | 69 ++++++++++++++++++++++----------- src/kagglehub/tracing.py | 32 +++++++++++---- 7 files changed, 181 insertions(+), 67 deletions(-) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..6b76b4fa --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,15 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index e4c9f29a..acd71588 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,5 +7,9 @@ "*test*.py" ], "python.testing.pytestEnabled": false, - "python.testing.unittestEnabled": true + "python.testing.unittestEnabled": true, + "editor.defaultFormatter": "ms-python.autopep8", + "[python]": { + "editor.formatOnSave": true, + } } \ No newline at end of file diff --git a/src/kagglehub/clients.py b/src/kagglehub/clients.py index 812aa94d..b11410ff 100644 --- a/src/kagglehub/clients.py +++ b/src/kagglehub/clients.py @@ -2,7 +2,7 @@ import json import logging import os -from typing import Optional, Tuple +from typing import Optional, Tuple, Callable from urllib.parse import urljoin import requests @@ -31,6 +31,7 @@ ) from kagglehub.handle import ResourceHandle from kagglehub.integrity import get_md5_checksum_from_response, to_b64_digest, update_hash_from_file +from kagglehub.tracing import TraceContext, default_context_factory CHUNK_SIZE = 1048576 # The `connect` timeout is the number of seconds `requests` will wait for your client to establish a connection. @@ -66,8 +67,10 @@ def get_user_agent() -> str: _cached_user_agent = f"{base_user_agent} kkb/{build_date}" elif is_in_colab_notebook(): colab_tag = os.getenv("COLAB_RELEASE_TAG") - runtime_suffix = "-managed" if os.getenv("TBE_RUNTIME_ADDR") else "-unmanaged" - _cached_user_agent = f"{base_user_agent} colab/{colab_tag}{runtime_suffix}" + runtime_suffix = "-managed" if os.getenv( + "TBE_RUNTIME_ADDR") else "-unmanaged" + _cached_user_agent = f"{ + base_user_agent} colab/{colab_tag}{runtime_suffix}" else: _cached_user_agent = base_user_agent @@ -81,9 +84,10 @@ def get_user_agent() -> str: class KaggleApiV1Client: BASE_PATH = "api/v1" - def __init__(self) -> None: + def __init__(self, ctx_factory: Callable[[], TraceContext]) -> None: self.credentials = get_kaggle_credentials() self.endpoint = get_kaggle_api_endpoint() + self.ctx_factory = ctx_factory if ctx_factory != None else default_context_factory def _check_for_version_update(self, response: requests.Response) -> None: latest_version_str = response.headers.get("X-Kaggle-HubVersion") @@ -93,14 +97,16 @@ def _check_for_version_update(self, response: requests.Response) -> None: if latest_version > current_version: logger.info( "Warning: Looks like you're using an outdated `kagglehub` " - f"version, please consider updating (latest version: {latest_version})" + f"version, please consider updating (latest version: { + latest_version})" ) def get(self, path: str, resource_handle: Optional[ResourceHandle] = None) -> dict: url = self._build_url(path) with requests.get( url, - headers={"User-Agent": get_user_agent()}, + headers={"User-Agent": get_user_agent(), + "traceparent": self.ctx_factory().next()}, auth=self._get_http_basic_auth(), timeout=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT), ) as response: @@ -112,7 +118,10 @@ def post(self, path: str, data: dict) -> dict: url = self._build_url(path) with requests.post( url, - headers={"User-Agent": get_user_agent()}, + headers={ + "User-Agent": get_user_agent(), + "traceparent": self.ctx_factory().next() + }, json=data, auth=self._get_http_basic_auth(), timeout=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT), @@ -126,9 +135,13 @@ def post(self, path: str, data: dict) -> dict: def download_file(self, path: str, out_file: str, resource_handle: Optional[ResourceHandle] = None) -> None: url = self._build_url(path) logger.info(f"Downloading from {url}...") + ctx = self.ctx_factory() with requests.get( url, - headers={"User-Agent": get_user_agent()}, + headers={ + "User-Agent": get_user_agent(), + "traceparent": ctx.next() + }, stream=True, auth=self._get_http_basic_auth(), timeout=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT), @@ -145,10 +158,12 @@ def download_file(self, path: str, out_file: str, resource_handle: Optional[Reso update_hash_from_file(hash_object, out_file) if size_read == total_size: - logger.info(f"Download already complete ({size_read} bytes).") + logger.info( + f"Download already complete ({size_read} bytes).") return - logger.info(f"Resuming download from {size_read} bytes ({total_size - size_read} bytes left)...") + logger.info(f"Resuming download from {size_read} bytes ({ + total_size - size_read} bytes left)...") # Send the request again with the 'Range' header. with requests.get( @@ -156,18 +171,24 @@ def download_file(self, path: str, out_file: str, resource_handle: Optional[Reso stream=True, auth=self._get_http_basic_auth(), timeout=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT), - headers={"Range": f"bytes={size_read}-"}, + headers={ + "Range": f"bytes={size_read}-", + "traceparent": ctx.next() + }, ) as resumed_response: - _download_file(resumed_response, out_file, size_read, total_size, hash_object) + _download_file(resumed_response, out_file, + size_read, total_size, hash_object) else: - _download_file(response, out_file, size_read, total_size, hash_object) + _download_file(response, out_file, size_read, + total_size, hash_object) if hash_object: actual_md5_hash = to_b64_digest(hash_object) if actual_md5_hash != expected_md5_hash: os.remove(out_file) # Delete the corrupted file. raise DataCorruptionError( - _CHECKSUM_MISMATCH_MSG_TEMPLATE.format(expected_md5_hash, actual_md5_hash) + _CHECKSUM_MISMATCH_MSG_TEMPLATE.format( + expected_md5_hash, actual_md5_hash) ) def _get_http_basic_auth(self) -> Optional[HTTPBasicAuth]: @@ -218,7 +239,8 @@ def __init__(self) -> None: if jwt_token is None: msg = ( "A JWT Token is required to call Kaggle, " - f"but none found in environment variable {KAGGLE_JWT_TOKEN_ENV_VAR_NAME}" + f"but none found in environment variable { + KAGGLE_JWT_TOKEN_ENV_VAR_NAME}" ) raise CredentialError(msg) @@ -226,7 +248,8 @@ def __init__(self) -> None: if data_proxy_token is None: msg = ( "A Data Proxy Token is required to call Kaggle, " - f"but none found in environment variable {KAGGLE_DATA_PROXY_TOKEN_ENV_VAR_NAME}" + f"but none found in environment variable { + KAGGLE_DATA_PROXY_TOKEN_ENV_VAR_NAME}" ) raise CredentialError(msg) @@ -240,7 +263,10 @@ def post( self, request_name: str, data: dict, - timeout: Tuple[float, float] = (DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT), + timeout: Tuple[float, float] = ( + DEFAULT_CONNECT_TIMEOUT, + DEFAULT_READ_TIMEOUT + ), ) -> dict: url = f"{self.endpoint}{KaggleJwtClient.BASE_PATH}{request_name}" with requests.post( diff --git a/src/kagglehub/gcs_upload.py b/src/kagglehub/gcs_upload.py index c551d350..8c171763 100644 --- a/src/kagglehub/gcs_upload.py +++ b/src/kagglehub/gcs_upload.py @@ -7,7 +7,7 @@ from multiprocessing import Pool from pathlib import Path from tempfile import TemporaryDirectory -from typing import List, Tuple, Union +from typing import Callable, List, Tuple, Union import requests from requests.exceptions import ConnectionError, Timeout @@ -16,6 +16,7 @@ from kagglehub.clients import KaggleApiV1Client from kagglehub.exceptions import BackendError +from kagglehub.tracing import TraceContext, default_context_factory logger = logging.getLogger(__name__) @@ -26,7 +27,12 @@ def parse_datetime_string(string: str) -> Union[datetime, str]: - time_formats = ["%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S.%fZ"] + time_formats = [ + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%dT%H:%M:%SZ", + "%Y-%m-%dT%H:%M:%S.%f", + "%Y-%m-%dT%H:%M:%S.%fZ" + ] for t in time_formats: try: return datetime.strptime(string[:26], t).replace(microsecond=0) # noqa: DTZ007 @@ -37,7 +43,8 @@ def parse_datetime_string(string: str) -> Union[datetime, str]: class File(object): # noqa: UP004 def __init__(self, init_dict: dict) -> None: - parsed_dict = {k: parse_datetime_string(v) for k, v in init_dict.items()} + parsed_dict = {k: parse_datetime_string( + v) for k, v in init_dict.items()} self.__dict__.update(parsed_dict) @staticmethod @@ -57,7 +64,8 @@ def _check_uploaded_size(session_uri: str, file_size: int, backoff_factor: int = while retry_count < MAX_RETRIES: try: - response = requests.put(session_uri, headers=headers, timeout=REQUEST_TIMEOUT) + response = requests.put( + session_uri, headers=headers, timeout=REQUEST_TIMEOUT) if response.status_code == 308: # Resume Incomplete # noqa: PLR2004 range_header = response.headers.get("Range") if range_header: @@ -67,7 +75,8 @@ def _check_uploaded_size(session_uri: str, file_size: int, backoff_factor: int = else: return file_size except (ConnectionError, Timeout): - logger.info(f"Network issue while checking uploaded size, retrying in {backoff_factor} seconds...") + logger.info(f"Network issue while checking uploaded size, retrying in { + backoff_factor} seconds...") time.sleep(backoff_factor) backoff_factor = min(backoff_factor * 2, 60) retry_count += 1 @@ -75,13 +84,15 @@ def _check_uploaded_size(session_uri: str, file_size: int, backoff_factor: int = return 0 # Return 0 if all retries fail -def _upload_blob(file_path: str, model_type: str) -> str: +def _upload_blob(file_path: str, model_type: str, + ctx_factory: Callable[[], TraceContext] = None) -> str: """Uploads a file to a remote server as a blob and returns an upload token. Parameters ========== file_path: The path to the file to be uploaded. model_type : The type of the model associated with the file. + ctx_factory: The function to initalize a trace context """ file_size = os.path.getsize(file_path) data = { @@ -90,7 +101,9 @@ def _upload_blob(file_path: str, model_type: str) -> str: "contentLength": file_size, "lastModifiedEpochSeconds": int(os.path.getmtime(file_path)), } - api_client = KaggleApiV1Client() + if ctx_factory == None: + ctx_factory = default_context_factory + api_client = KaggleApiV1Client(ctx_factory) response = api_client.post("/blobs/upload", data=data) # Validate response content @@ -102,7 +115,11 @@ def _upload_blob(file_path: str, model_type: str) -> str: raise BackendError(token_exception) session_uri = response["createUrl"] - headers = {"Content-Type": "application/octet-stream", "Content-Range": f"bytes 0-{file_size - 1}/{file_size}"} + headers = { + "Content-Type": "application/octet-stream", + "Content-Range": f"bytes 0-{file_size - 1}/{file_size}", + "traceparent": ctx_factory().next() + } retry_count = 0 uploaded_bytes = 0 @@ -113,7 +130,8 @@ def _upload_blob(file_path: str, model_type: str) -> str: try: f.seek(uploaded_bytes) reader_wrapper = CallbackIOWrapper(pbar.update, f, "read") - headers["Content-Range"] = f"bytes {uploaded_bytes}-{file_size - 1}/{file_size}" + headers["Content-Range"] = f"bytes { + uploaded_bytes}-{file_size - 1}/{file_size}" upload_response = requests.put( session_uri, headers=headers, data=reader_wrapper, timeout=REQUEST_TIMEOUT ) @@ -121,19 +139,23 @@ def _upload_blob(file_path: str, model_type: str) -> str: if upload_response.status_code in [200, 201]: return response["token"] elif upload_response.status_code == 308: # Resume Incomplete # noqa: PLR2004 - uploaded_bytes = _check_uploaded_size(session_uri, file_size) + uploaded_bytes = _check_uploaded_size( + session_uri, file_size) else: upload_failed_exception = ( - f"Upload failed with status code {upload_response.status_code}: {upload_response.text}" + f"Upload failed with status code { + upload_response.status_code}: {upload_response.text}" ) raise BackendError(upload_failed_exception) except (requests.ConnectionError, requests.Timeout) as e: - logger.info(f"Network issue: {e}, retrying in {backoff_factor} seconds...") + logger.info(f"Network issue: {e}, retrying in { + backoff_factor} seconds...") time.sleep(backoff_factor) backoff_factor = min(backoff_factor * 2, 60) retry_count += 1 uploaded_bytes = _check_uploaded_size(session_uri, file_size) pbar.n = uploaded_bytes # Update progress bar to reflect actual uploaded bytes + headers["traceparent"] = ctx_factory().next() return response["token"] @@ -156,7 +178,8 @@ def zip_files(source_path_obj: Path, zip_path: Path) -> List[int]: return sizes -def upload_files(source_path: str, model_type: str) -> List[str]: +def upload_files(source_path: str, model_type: str, + ctx_factory: Callable[[], TraceContext] = None) -> List[str]: source_path_obj = Path(source_path) with TemporaryDirectory() as temp_dir: temp_dir_path = Path(temp_dir) @@ -184,5 +207,6 @@ def upload_files(source_path: str, model_type: str) -> List[str]: shutil.copy(source_path_obj, temp_file_path) pbar.update(temp_file_path.stat().st_size) upload_path = str(temp_file_path) - - return [token for token in [_upload_blob(upload_path, model_type)] if token] + if ctx_factory == None: + ctx_factory = default_context_factory + return [token for token in [_upload_blob(upload_path, model_type, ctx_factory)] if token] diff --git a/src/kagglehub/models.py b/src/kagglehub/models.py index d5e3d5e9..0ae878aa 100644 --- a/src/kagglehub/models.py +++ b/src/kagglehub/models.py @@ -5,6 +5,7 @@ from kagglehub.gcs_upload import upload_files from kagglehub.handle import parse_model_handle from kagglehub.models_helpers import create_model_if_missing, create_model_instance_or_version +from kagglehub.tracing import TraceContext logger = logging.getLogger(__name__) @@ -42,12 +43,15 @@ def model_upload( if h.is_versioned(): is_versioned_exception = "The model handle should not include the version" raise ValueError(is_versioned_exception) - + ctx = TraceContext() + def shared_context_factory(): + return ctx # Create the model if it doesn't already exist - create_model_if_missing(h.owner, h.model) + create_model_if_missing(h.owner, h.model, shared_context_factory) # Upload the model files to GCS - tokens = upload_files(local_model_dir, "model") + tokens = upload_files(local_model_dir, "model", shared_context_factory) # Create a model instance if it doesn't exist, and create a new instance version if an instance exists - create_model_instance_or_version(h, tokens, license_name, version_notes) + create_model_instance_or_version( + h, tokens, license_name, version_notes, shared_context_factory) diff --git a/src/kagglehub/models_helpers.py b/src/kagglehub/models_helpers.py index 4bdaef10..fbdd903f 100644 --- a/src/kagglehub/models_helpers.py +++ b/src/kagglehub/models_helpers.py @@ -1,22 +1,30 @@ import logging from http import HTTPStatus -from typing import List, Optional +from typing import List, Optional, Callable from kagglehub.clients import KaggleApiV1Client from kagglehub.exceptions import KaggleApiHTTPError from kagglehub.handle import ModelHandle +from kagglehub.tracing import TraceContext logger = logging.getLogger(__name__) -def _create_model(owner_slug: str, model_slug: str) -> None: - data = {"ownerSlug": owner_slug, "slug": model_slug, "title": model_slug, "isPrivate": True} - api_client = KaggleApiV1Client() +def _create_model(owner_slug: str, model_slug: str, + ctx_factory: Callable[[], TraceContext] = None) -> None: + data = { + "ownerSlug": owner_slug, + "slug": model_slug, + "title": model_slug, + "isPrivate": True + } + api_client = KaggleApiV1Client(ctx_factory) api_client.post("/models/create/new", data) logger.info(f"Model '{model_slug}' Created.") -def _create_model_instance(model_handle: ModelHandle, files: List[str], license_name: Optional[str] = None) -> None: +def _create_model_instance(model_handle: ModelHandle, files: List[str], license_name: Optional[str] = None, + ctx_factory: Callable[[], TraceContext] = None) -> None: data = { "instanceSlug": model_handle.variation, "framework": model_handle.framework, @@ -25,44 +33,57 @@ def _create_model_instance(model_handle: ModelHandle, files: List[str], license_ if license_name is not None: data["licenseName"] = license_name - api_client = KaggleApiV1Client() - api_client.post(f"/models/{model_handle.owner}/{model_handle.model}/create/instance", data) - logger.info(f"Your model instance has been created.\nFiles are being processed...\nSee at: {model_handle.to_url()}") + api_client = KaggleApiV1Client(ctx_factory) + api_client.post( + f"/models/{model_handle.owner}/{model_handle.model}/create/instance", data) + logger.info(f"Your model instance has been created.\nFiles are being processed...\nSee at: { + model_handle.to_url()}") -def _create_model_instance_version(model_handle: ModelHandle, files: List[str], version_notes: str = "") -> None: - data = {"versionNotes": version_notes, "files": [{"token": file_token} for file_token in files]} - api_client = KaggleApiV1Client() +def _create_model_instance_version(model_handle: ModelHandle, files: List[str], version_notes: str = "", + ctx_factory: Callable[[], TraceContext] = None) -> None: + data = { + "versionNotes": version_notes, + "files": [ + {"token": file_token} for file_token in files] + } + api_client = KaggleApiV1Client(ctx_factory) api_client.post( - f"/models/{model_handle.owner}/{model_handle.model}/{model_handle.framework}/{model_handle.variation}/create/version", + f"/models/{model_handle.owner}/{model_handle.model}/{ + model_handle.framework}/{model_handle.variation}/create/version", data, ) logger.info( - f"Your model instance version has been created.\nFiles are being processed...\nSee at: {model_handle.to_url()}" + f"Your model instance version has been created.\nFiles are being processed...\nSee at: { + model_handle.to_url()}" ) def create_model_instance_or_version( - model_handle: ModelHandle, files: List[str], license_name: Optional[str], version_notes: str = "" + model_handle: ModelHandle, files: List[str], license_name: Optional[str], version_notes: str = "", + ctx_factory: Callable[[], TraceContext] = None ) -> None: try: - api_client = KaggleApiV1Client() + api_client = KaggleApiV1Client(ctx_factory) api_client.get(f"/models/{model_handle}/get", model_handle) # the instance exist, create a new version. - _create_model_instance_version(model_handle, files, version_notes) + _create_model_instance_version( + model_handle, files, version_notes, ctx_factory) except KaggleApiHTTPError as e: if e.response is not None and ( e.response.status_code == HTTPStatus.NOT_FOUND # noqa: PLR1714 or e.response.status_code == HTTPStatus.FORBIDDEN ): - _create_model_instance(model_handle, files, license_name) + _create_model_instance(model_handle, files, + license_name, ctx_factory) else: raise (e) -def create_model_if_missing(owner_slug: str, model_slug: str) -> None: +def create_model_if_missing(owner_slug: str, model_slug: str, + ctx_factory: Callable[[], TraceContext] = None) -> None: try: - api_client = KaggleApiV1Client() + api_client = KaggleApiV1Client(ctx_factory) api_client.get(f"/models/{owner_slug}/{model_slug}/get") except KaggleApiHTTPError as e: if e.response is not None and ( @@ -72,20 +93,22 @@ def create_model_if_missing(owner_slug: str, model_slug: str) -> None: logger.info( f"Model '{model_slug}' does not exist or access is forbidden for user '{owner_slug}'. Creating or handling Model..." # noqa: E501 ) - _create_model(owner_slug, model_slug) + _create_model(owner_slug, model_slug, ctx_factory) else: raise (e) -def delete_model(owner_slug: str, model_slug: str) -> None: +def delete_model(owner_slug: str, model_slug: str, + ctx_factory: Callable[[], TraceContext] = None) -> None: try: - api_client = KaggleApiV1Client() + api_client = KaggleApiV1Client(ctx_factory) api_client.post( f"/models/{owner_slug}/{model_slug}/delete", {}, ) except KaggleApiHTTPError as e: if e.response is not None and e.response.status_code == HTTPStatus.NOT_FOUND: - logger.info(f"Could not delete Model '{model_slug}' for user '{owner_slug}'...") + logger.info(f"Could not delete Model '{ + model_slug}' for user '{owner_slug}'...") else: raise (e) diff --git a/src/kagglehub/tracing.py b/src/kagglehub/tracing.py index d8dfc2c6..b0f5c448 100644 --- a/src/kagglehub/tracing.py +++ b/src/kagglehub/tracing.py @@ -1,4 +1,11 @@ import secrets +from typing import Optional, Type +from types import TracebackType + +# Constants can be found at https://www.w3.org/TR/trace-context/#version-format +_TRACE_LENGTH_BYTES = 16 +_SPAN_LENGTH_BYTES = 8 + class TraceContext: """ @@ -9,22 +16,33 @@ class TraceContext: Attributes: trace: A 16-byte hexadecimal string representing a unique trace ID. """ + def __init__(self) -> None: - self.trace = secrets.token_bytes(16).hex() - - def __enter__(self): + self.trace = secrets.token_bytes(_TRACE_LENGTH_BYTES).hex() + + def __enter__(self) -> 'TraceContext': return self - - def __exit__(self, type, value, traceback): + + def __exit__(self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + traceback: Optional[TracebackType]) -> bool: pass - def next(self): + def next(self) -> str: """ Generates a new span ID within the context of the current trace ID. + An example traceparent: + "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01" + Returns: A formatted string representing a span ID, in a standard distributed tracing format (e.g., "00-{trace}-{span}-01") """ - span = secrets.token_bytes(8).hex() + span = secrets.token_bytes(_SPAN_LENGTH_BYTES).hex() return f"00-{self.trace}-{span}-01" + + +def default_context_factory(): + return TraceContext() From cd825dfdfbe1224a04abc88e3bf57c7cadd8fee8 Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Mon, 22 Apr 2024 12:14:55 -0400 Subject: [PATCH 03/16] Added trace id to log --- src/kagglehub/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/kagglehub/models.py b/src/kagglehub/models.py index 0ae878aa..8cd68351 100644 --- a/src/kagglehub/models.py +++ b/src/kagglehub/models.py @@ -44,8 +44,10 @@ def model_upload( is_versioned_exception = "The model handle should not include the version" raise ValueError(is_versioned_exception) ctx = TraceContext() + def shared_context_factory(): return ctx + logger.debug(f"Using shared trace {ctx.trace}") # Create the model if it doesn't already exist create_model_if_missing(h.owner, h.model, shared_context_factory) From c847690a09f133224d9234358a2e3ccb65d64cfe Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Mon, 22 Apr 2024 13:52:23 -0400 Subject: [PATCH 04/16] Fixed all ruff format checks --- .vscode/settings.json | 2 +- pyproject.toml | 9 ++--- src/kagglehub/clients.py | 42 +++++++--------------- src/kagglehub/gcs_upload.py | 38 ++++++++------------ src/kagglehub/models.py | 6 ++-- src/kagglehub/models_helpers.py | 58 +++++++++++++++---------------- src/kagglehub/tracing.py | 20 ++++++----- src/kagglehub/tracing_test.py | 20 +++++------ tests/test_http_model_download.py | 6 +++- 9 files changed, 89 insertions(+), 112 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index acd71588..f9679744 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,7 +8,7 @@ ], "python.testing.pytestEnabled": false, "python.testing.unittestEnabled": true, - "editor.defaultFormatter": "ms-python.autopep8", + "editor.defaultFormatter": "charliermarsh.ruff", "[python]": { "editor.formatOnSave": true, } diff --git a/pyproject.toml b/pyproject.toml index 5981927e..e4fd7e6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py37" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -114,7 +116,6 @@ select = [ "ANN001", "ANN002", "ANN003", - "ANN102", "ANN201", "ANN202", "ANN401", @@ -138,13 +139,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["kagglehub"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] # Ignore unused imports in __init__.py diff --git a/src/kagglehub/clients.py b/src/kagglehub/clients.py index b11410ff..4b26f0b3 100644 --- a/src/kagglehub/clients.py +++ b/src/kagglehub/clients.py @@ -2,7 +2,7 @@ import json import logging import os -from typing import Optional, Tuple, Callable +from typing import Callable, Optional, Tuple from urllib.parse import urljoin import requests @@ -67,8 +67,7 @@ def get_user_agent() -> str: _cached_user_agent = f"{base_user_agent} kkb/{build_date}" elif is_in_colab_notebook(): colab_tag = os.getenv("COLAB_RELEASE_TAG") - runtime_suffix = "-managed" if os.getenv( - "TBE_RUNTIME_ADDR") else "-unmanaged" + runtime_suffix = "-managed" if os.getenv("TBE_RUNTIME_ADDR") else "-unmanaged" _cached_user_agent = f"{ base_user_agent} colab/{colab_tag}{runtime_suffix}" else: @@ -87,7 +86,7 @@ class KaggleApiV1Client: def __init__(self, ctx_factory: Callable[[], TraceContext]) -> None: self.credentials = get_kaggle_credentials() self.endpoint = get_kaggle_api_endpoint() - self.ctx_factory = ctx_factory if ctx_factory != None else default_context_factory + self.ctx_factory = ctx_factory if ctx_factory is not None else default_context_factory def _check_for_version_update(self, response: requests.Response) -> None: latest_version_str = response.headers.get("X-Kaggle-HubVersion") @@ -105,8 +104,7 @@ def get(self, path: str, resource_handle: Optional[ResourceHandle] = None) -> di url = self._build_url(path) with requests.get( url, - headers={"User-Agent": get_user_agent(), - "traceparent": self.ctx_factory().next()}, + headers={"User-Agent": get_user_agent(), "traceparent": self.ctx_factory().next()}, auth=self._get_http_basic_auth(), timeout=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT), ) as response: @@ -118,10 +116,7 @@ def post(self, path: str, data: dict) -> dict: url = self._build_url(path) with requests.post( url, - headers={ - "User-Agent": get_user_agent(), - "traceparent": self.ctx_factory().next() - }, + headers={"User-Agent": get_user_agent(), "traceparent": self.ctx_factory().next()}, json=data, auth=self._get_http_basic_auth(), timeout=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT), @@ -138,10 +133,7 @@ def download_file(self, path: str, out_file: str, resource_handle: Optional[Reso ctx = self.ctx_factory() with requests.get( url, - headers={ - "User-Agent": get_user_agent(), - "traceparent": ctx.next() - }, + headers={"User-Agent": get_user_agent(), "traceparent": ctx.next()}, stream=True, auth=self._get_http_basic_auth(), timeout=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT), @@ -158,8 +150,7 @@ def download_file(self, path: str, out_file: str, resource_handle: Optional[Reso update_hash_from_file(hash_object, out_file) if size_read == total_size: - logger.info( - f"Download already complete ({size_read} bytes).") + logger.info(f"Download already complete ({size_read} bytes).") return logger.info(f"Resuming download from {size_read} bytes ({ @@ -171,24 +162,18 @@ def download_file(self, path: str, out_file: str, resource_handle: Optional[Reso stream=True, auth=self._get_http_basic_auth(), timeout=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT), - headers={ - "Range": f"bytes={size_read}-", - "traceparent": ctx.next() - }, + headers={"Range": f"bytes={size_read}-", "traceparent": ctx.next()}, ) as resumed_response: - _download_file(resumed_response, out_file, - size_read, total_size, hash_object) + _download_file(resumed_response, out_file, size_read, total_size, hash_object) else: - _download_file(response, out_file, size_read, - total_size, hash_object) + _download_file(response, out_file, size_read, total_size, hash_object) if hash_object: actual_md5_hash = to_b64_digest(hash_object) if actual_md5_hash != expected_md5_hash: os.remove(out_file) # Delete the corrupted file. raise DataCorruptionError( - _CHECKSUM_MISMATCH_MSG_TEMPLATE.format( - expected_md5_hash, actual_md5_hash) + _CHECKSUM_MISMATCH_MSG_TEMPLATE.format(expected_md5_hash, actual_md5_hash) ) def _get_http_basic_auth(self) -> Optional[HTTPBasicAuth]: @@ -263,10 +248,7 @@ def post( self, request_name: str, data: dict, - timeout: Tuple[float, float] = ( - DEFAULT_CONNECT_TIMEOUT, - DEFAULT_READ_TIMEOUT - ), + timeout: Tuple[float, float] = (DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT), ) -> dict: url = f"{self.endpoint}{KaggleJwtClient.BASE_PATH}{request_name}" with requests.post( diff --git a/src/kagglehub/gcs_upload.py b/src/kagglehub/gcs_upload.py index 8c171763..4b3b1484 100644 --- a/src/kagglehub/gcs_upload.py +++ b/src/kagglehub/gcs_upload.py @@ -7,7 +7,7 @@ from multiprocessing import Pool from pathlib import Path from tempfile import TemporaryDirectory -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import requests from requests.exceptions import ConnectionError, Timeout @@ -27,12 +27,7 @@ def parse_datetime_string(string: str) -> Union[datetime, str]: - time_formats = [ - "%Y-%m-%dT%H:%M:%S", - "%Y-%m-%dT%H:%M:%SZ", - "%Y-%m-%dT%H:%M:%S.%f", - "%Y-%m-%dT%H:%M:%S.%fZ" - ] + time_formats = ["%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S.%fZ"] for t in time_formats: try: return datetime.strptime(string[:26], t).replace(microsecond=0) # noqa: DTZ007 @@ -43,8 +38,7 @@ def parse_datetime_string(string: str) -> Union[datetime, str]: class File(object): # noqa: UP004 def __init__(self, init_dict: dict) -> None: - parsed_dict = {k: parse_datetime_string( - v) for k, v in init_dict.items()} + parsed_dict = {k: parse_datetime_string(v) for k, v in init_dict.items()} self.__dict__.update(parsed_dict) @staticmethod @@ -64,8 +58,7 @@ def _check_uploaded_size(session_uri: str, file_size: int, backoff_factor: int = while retry_count < MAX_RETRIES: try: - response = requests.put( - session_uri, headers=headers, timeout=REQUEST_TIMEOUT) + response = requests.put(session_uri, headers=headers, timeout=REQUEST_TIMEOUT) if response.status_code == 308: # Resume Incomplete # noqa: PLR2004 range_header = response.headers.get("Range") if range_header: @@ -84,8 +77,7 @@ def _check_uploaded_size(session_uri: str, file_size: int, backoff_factor: int = return 0 # Return 0 if all retries fail -def _upload_blob(file_path: str, model_type: str, - ctx_factory: Callable[[], TraceContext] = None) -> str: +def _upload_blob(file_path: str, model_type: str, ctx_factory: Optional[Callable[[], TraceContext]] = None) -> str: """Uploads a file to a remote server as a blob and returns an upload token. Parameters @@ -101,7 +93,7 @@ def _upload_blob(file_path: str, model_type: str, "contentLength": file_size, "lastModifiedEpochSeconds": int(os.path.getmtime(file_path)), } - if ctx_factory == None: + if ctx_factory is None: ctx_factory = default_context_factory api_client = KaggleApiV1Client(ctx_factory) response = api_client.post("/blobs/upload", data=data) @@ -118,7 +110,7 @@ def _upload_blob(file_path: str, model_type: str, headers = { "Content-Type": "application/octet-stream", "Content-Range": f"bytes 0-{file_size - 1}/{file_size}", - "traceparent": ctx_factory().next() + "traceparent": ctx_factory().next(), } retry_count = 0 @@ -139,13 +131,10 @@ def _upload_blob(file_path: str, model_type: str, if upload_response.status_code in [200, 201]: return response["token"] elif upload_response.status_code == 308: # Resume Incomplete # noqa: PLR2004 - uploaded_bytes = _check_uploaded_size( - session_uri, file_size) + uploaded_bytes = _check_uploaded_size(session_uri, file_size) else: - upload_failed_exception = ( - f"Upload failed with status code { - upload_response.status_code}: {upload_response.text}" - ) + upload_failed_exception = f"Upload failed with status code { + upload_response.status_code}: {upload_response.text}" raise BackendError(upload_failed_exception) except (requests.ConnectionError, requests.Timeout) as e: logger.info(f"Network issue: {e}, retrying in { @@ -178,8 +167,9 @@ def zip_files(source_path_obj: Path, zip_path: Path) -> List[int]: return sizes -def upload_files(source_path: str, model_type: str, - ctx_factory: Callable[[], TraceContext] = None) -> List[str]: +def upload_files( + source_path: str, model_type: str, ctx_factory: Optional[Callable[[], TraceContext]] = None +) -> List[str]: source_path_obj = Path(source_path) with TemporaryDirectory() as temp_dir: temp_dir_path = Path(temp_dir) @@ -207,6 +197,6 @@ def upload_files(source_path: str, model_type: str, shutil.copy(source_path_obj, temp_file_path) pbar.update(temp_file_path.stat().st_size) upload_path = str(temp_file_path) - if ctx_factory == None: + if ctx_factory is None: ctx_factory = default_context_factory return [token for token in [_upload_blob(upload_path, model_type, ctx_factory)] if token] diff --git a/src/kagglehub/models.py b/src/kagglehub/models.py index 8cd68351..2caade3a 100644 --- a/src/kagglehub/models.py +++ b/src/kagglehub/models.py @@ -45,8 +45,9 @@ def model_upload( raise ValueError(is_versioned_exception) ctx = TraceContext() - def shared_context_factory(): + def shared_context_factory() -> TraceContext: return ctx + logger.debug(f"Using shared trace {ctx.trace}") # Create the model if it doesn't already exist create_model_if_missing(h.owner, h.model, shared_context_factory) @@ -55,5 +56,4 @@ def shared_context_factory(): tokens = upload_files(local_model_dir, "model", shared_context_factory) # Create a model instance if it doesn't exist, and create a new instance version if an instance exists - create_model_instance_or_version( - h, tokens, license_name, version_notes, shared_context_factory) + create_model_instance_or_version(h, tokens, license_name, version_notes, shared_context_factory) diff --git a/src/kagglehub/models_helpers.py b/src/kagglehub/models_helpers.py index fbdd903f..94c9fa3e 100644 --- a/src/kagglehub/models_helpers.py +++ b/src/kagglehub/models_helpers.py @@ -1,6 +1,6 @@ import logging from http import HTTPStatus -from typing import List, Optional, Callable +from typing import Callable, List, Optional from kagglehub.clients import KaggleApiV1Client from kagglehub.exceptions import KaggleApiHTTPError @@ -10,21 +10,19 @@ logger = logging.getLogger(__name__) -def _create_model(owner_slug: str, model_slug: str, - ctx_factory: Callable[[], TraceContext] = None) -> None: - data = { - "ownerSlug": owner_slug, - "slug": model_slug, - "title": model_slug, - "isPrivate": True - } +def _create_model(owner_slug: str, model_slug: str, ctx_factory: Optional[Callable[[], TraceContext]] = None) -> None: + data = {"ownerSlug": owner_slug, "slug": model_slug, "title": model_slug, "isPrivate": True} api_client = KaggleApiV1Client(ctx_factory) api_client.post("/models/create/new", data) logger.info(f"Model '{model_slug}' Created.") -def _create_model_instance(model_handle: ModelHandle, files: List[str], license_name: Optional[str] = None, - ctx_factory: Callable[[], TraceContext] = None) -> None: +def _create_model_instance( + model_handle: ModelHandle, + files: List[str], + license_name: Optional[str] = None, + ctx_factory: Optional[Callable[[], TraceContext]] = None, +) -> None: data = { "instanceSlug": model_handle.variation, "framework": model_handle.framework, @@ -34,19 +32,18 @@ def _create_model_instance(model_handle: ModelHandle, files: List[str], license_ data["licenseName"] = license_name api_client = KaggleApiV1Client(ctx_factory) - api_client.post( - f"/models/{model_handle.owner}/{model_handle.model}/create/instance", data) + api_client.post(f"/models/{model_handle.owner}/{model_handle.model}/create/instance", data) logger.info(f"Your model instance has been created.\nFiles are being processed...\nSee at: { model_handle.to_url()}") -def _create_model_instance_version(model_handle: ModelHandle, files: List[str], version_notes: str = "", - ctx_factory: Callable[[], TraceContext] = None) -> None: - data = { - "versionNotes": version_notes, - "files": [ - {"token": file_token} for file_token in files] - } +def _create_model_instance_version( + model_handle: ModelHandle, + files: List[str], + version_notes: str = "", + ctx_factory: Optional[Callable[[], TraceContext]] = None, +) -> None: + data = {"versionNotes": version_notes, "files": [{"token": file_token} for file_token in files]} api_client = KaggleApiV1Client(ctx_factory) api_client.post( f"/models/{model_handle.owner}/{model_handle.model}/{ @@ -60,28 +57,30 @@ def _create_model_instance_version(model_handle: ModelHandle, files: List[str], def create_model_instance_or_version( - model_handle: ModelHandle, files: List[str], license_name: Optional[str], version_notes: str = "", - ctx_factory: Callable[[], TraceContext] = None + model_handle: ModelHandle, + files: List[str], + license_name: Optional[str], + version_notes: str = "", + ctx_factory: Optional[Callable[[], TraceContext]] = None, ) -> None: try: api_client = KaggleApiV1Client(ctx_factory) api_client.get(f"/models/{model_handle}/get", model_handle) # the instance exist, create a new version. - _create_model_instance_version( - model_handle, files, version_notes, ctx_factory) + _create_model_instance_version(model_handle, files, version_notes, ctx_factory) except KaggleApiHTTPError as e: if e.response is not None and ( e.response.status_code == HTTPStatus.NOT_FOUND # noqa: PLR1714 or e.response.status_code == HTTPStatus.FORBIDDEN ): - _create_model_instance(model_handle, files, - license_name, ctx_factory) + _create_model_instance(model_handle, files, license_name, ctx_factory) else: raise (e) -def create_model_if_missing(owner_slug: str, model_slug: str, - ctx_factory: Callable[[], TraceContext] = None) -> None: +def create_model_if_missing( + owner_slug: str, model_slug: str, ctx_factory: Optional[Callable[[], TraceContext]] = None +) -> None: try: api_client = KaggleApiV1Client(ctx_factory) api_client.get(f"/models/{owner_slug}/{model_slug}/get") @@ -98,8 +97,7 @@ def create_model_if_missing(owner_slug: str, model_slug: str, raise (e) -def delete_model(owner_slug: str, model_slug: str, - ctx_factory: Callable[[], TraceContext] = None) -> None: +def delete_model(owner_slug: str, model_slug: str, ctx_factory: Optional[Callable[[], TraceContext]] = None) -> None: try: api_client = KaggleApiV1Client(ctx_factory) api_client.post( diff --git a/src/kagglehub/tracing.py b/src/kagglehub/tracing.py index b0f5c448..5b709c7a 100644 --- a/src/kagglehub/tracing.py +++ b/src/kagglehub/tracing.py @@ -1,6 +1,6 @@ import secrets -from typing import Optional, Type from types import TracebackType +from typing import Optional, Type # Constants can be found at https://www.w3.org/TR/trace-context/#version-format _TRACE_LENGTH_BYTES = 16 @@ -20,29 +20,31 @@ class TraceContext: def __init__(self) -> None: self.trace = secrets.token_bytes(_TRACE_LENGTH_BYTES).hex() - def __enter__(self) -> 'TraceContext': + def __enter__(self) -> "TraceContext": return self - def __exit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - traceback: Optional[TracebackType]) -> bool: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> bool: pass def next(self) -> str: """ Generates a new span ID within the context of the current trace ID. - An example traceparent: + An example traceparent: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01" Returns: - A formatted string representing a span ID, in a standard + A formatted string representing a span ID, in a standard distributed tracing format (e.g., "00-{trace}-{span}-01") """ span = secrets.token_bytes(_SPAN_LENGTH_BYTES).hex() return f"00-{self.trace}-{span}-01" -def default_context_factory(): +def default_context_factory() -> TraceContext: return TraceContext() diff --git a/src/kagglehub/tracing_test.py b/src/kagglehub/tracing_test.py index 09ae9599..22ebf133 100644 --- a/src/kagglehub/tracing_test.py +++ b/src/kagglehub/tracing_test.py @@ -4,40 +4,40 @@ _CANONICAL_EXAMPLE = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01" + class TraceContextSuite(unittest.TestCase): - def test_length(self): + def test_length(self) -> None: with TraceContext() as ctx: traceparent = ctx.next() self.assertEqual(len(traceparent), len(_CANONICAL_EXAMPLE)) - def test_prefix(self): + def test_prefix(self) -> None: with TraceContext() as ctx: traceparent = ctx.next() self.assertEqual(traceparent[0:2], "00") - def test_suffix(self): + def test_suffix(self) -> None: # always sample with TraceContext() as ctx: traceparent = ctx.next() self.assertEqual(traceparent[-2:], "01") - def test_pattern(self): + def test_pattern(self) -> None: with TraceContext() as ctx: traceparent = ctx.next() - version,trace,span,flag = traceparent.split("-") - print(version,trace,span,flag) + version, trace, span, flag = traceparent.split("-") self.assertRegex(version, "^[0-9]{2}$", "version does not meet pattern") self.assertRegex(trace, "^[A-Fa-f0-9]{32}$") self.assertRegex(span, "^[A-Fa-f0-9]{16}$") self.assertRegex(flag, "^[0-9]{2}$") - def test_notempty(self): + def test_notempty(self) -> None: with TraceContext() as ctx: traceparent = ctx.next() - _,trace,span,_ = traceparent.split("-") + _, trace, span, _ = traceparent.split("-") self.assertNotEqual(trace, f"{0:016x}") self.assertNotEqual(span, f"{0:08x}") -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_http_model_download.py b/tests/test_http_model_download.py index ec738813..597ee94e 100644 --- a/tests/test_http_model_download.py +++ b/tests/test_http_model_download.py @@ -86,7 +86,11 @@ def do_GET(self) -> None: # noqa: N802 # Test cases for the ModelHttpResolver. class TestHttpModelDownload(BaseTestCase): def _download_model_and_assert_downloaded( - self, d: str, model_handle: str, expected_subdir_or_subpath: str, **kwargs # noqa: ANN003 + self, + d: str, + model_handle: str, + expected_subdir_or_subpath: str, + **kwargs, # noqa: ANN003 ) -> None: # Download the full model and ensure all files are there. model_path = kagglehub.model_download(model_handle, **kwargs) From f98e49ba9ed95f53d29d34760ce436cd08b2e8e1 Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Mon, 22 Apr 2024 13:53:45 -0400 Subject: [PATCH 05/16] Added ruff cache to ignore --- .gitignore | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 6769e21d..2e47ccac 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,8 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +#.idea/ + +# ruff Linter / Formatter +.ruff_cache + From 41411f4095c1bd9c212d959eca4a203a0b5b99a7 Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Mon, 22 Apr 2024 14:50:34 -0400 Subject: [PATCH 06/16] Removed logging configuration. Logger config should be specified my client. Updated build failures. --- src/kagglehub/clients.py | 3 +-- src/kagglehub/gcs_upload.py | 3 +-- src/kagglehub/logger.py | 2 +- src/kagglehub/models.py | 2 +- src/kagglehub/models_helpers.py | 9 +++------ 5 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/kagglehub/clients.py b/src/kagglehub/clients.py index 4b26f0b3..f843f926 100644 --- a/src/kagglehub/clients.py +++ b/src/kagglehub/clients.py @@ -68,8 +68,7 @@ def get_user_agent() -> str: elif is_in_colab_notebook(): colab_tag = os.getenv("COLAB_RELEASE_TAG") runtime_suffix = "-managed" if os.getenv("TBE_RUNTIME_ADDR") else "-unmanaged" - _cached_user_agent = f"{ - base_user_agent} colab/{colab_tag}{runtime_suffix}" + _cached_user_agent = f"{base_user_agent} colab/{colab_tag}{runtime_suffix}" else: _cached_user_agent = base_user_agent diff --git a/src/kagglehub/gcs_upload.py b/src/kagglehub/gcs_upload.py index 4b3b1484..86a4e907 100644 --- a/src/kagglehub/gcs_upload.py +++ b/src/kagglehub/gcs_upload.py @@ -68,8 +68,7 @@ def _check_uploaded_size(session_uri: str, file_size: int, backoff_factor: int = else: return file_size except (ConnectionError, Timeout): - logger.info(f"Network issue while checking uploaded size, retrying in { - backoff_factor} seconds...") + logger.info(f"Network issue while checking uploaded size, retrying in {backoff_factor} seconds...") time.sleep(backoff_factor) backoff_factor = min(backoff_factor * 2, 60) retry_count += 1 diff --git a/src/kagglehub/logger.py b/src/kagglehub/logger.py index 4fe419f5..74a770a8 100644 --- a/src/kagglehub/logger.py +++ b/src/kagglehub/logger.py @@ -13,4 +13,4 @@ def _configure_logger() -> None: library_logger.setLevel(get_log_verbosity()) -_configure_logger() +# _configure_logger() diff --git a/src/kagglehub/models.py b/src/kagglehub/models.py index 2caade3a..7891878d 100644 --- a/src/kagglehub/models.py +++ b/src/kagglehub/models.py @@ -48,7 +48,7 @@ def model_upload( def shared_context_factory() -> TraceContext: return ctx - logger.debug(f"Using shared trace {ctx.trace}") + logger.debug(f"Using shared trace: {ctx.trace}") # Create the model if it doesn't already exist create_model_if_missing(h.owner, h.model, shared_context_factory) diff --git a/src/kagglehub/models_helpers.py b/src/kagglehub/models_helpers.py index 94c9fa3e..6702138f 100644 --- a/src/kagglehub/models_helpers.py +++ b/src/kagglehub/models_helpers.py @@ -33,8 +33,7 @@ def _create_model_instance( api_client = KaggleApiV1Client(ctx_factory) api_client.post(f"/models/{model_handle.owner}/{model_handle.model}/create/instance", data) - logger.info(f"Your model instance has been created.\nFiles are being processed...\nSee at: { - model_handle.to_url()}") + logger.info(f"Your model instance has been created.\nFiles are being processed...\nSee at: {model_handle.to_url()}") def _create_model_instance_version( @@ -51,8 +50,7 @@ def _create_model_instance_version( data, ) logger.info( - f"Your model instance version has been created.\nFiles are being processed...\nSee at: { - model_handle.to_url()}" + f"Your model instance version has been created.\nFiles are being processed...\nSee at: {model_handle.to_url()}" ) @@ -106,7 +104,6 @@ def delete_model(owner_slug: str, model_slug: str, ctx_factory: Optional[Callabl ) except KaggleApiHTTPError as e: if e.response is not None and e.response.status_code == HTTPStatus.NOT_FOUND: - logger.info(f"Could not delete Model '{ - model_slug}' for user '{owner_slug}'...") + logger.info(f"Could not delete Model '{model_slug}' for user '{owner_slug}'...") else: raise (e) From 952781b80cf1506a65bfdc582db32f249177e57c Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Mon, 22 Apr 2024 14:58:38 -0400 Subject: [PATCH 07/16] Fixed remaining black linter issues --- .gitignore | 1 - src/kagglehub/clients.py | 12 ++++-------- src/kagglehub/gcs_upload.py | 11 +++++------ src/kagglehub/models_helpers.py | 3 +-- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 2e47ccac..44930b16 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,3 @@ cython_debug/ # ruff Linter / Formatter .ruff_cache - diff --git a/src/kagglehub/clients.py b/src/kagglehub/clients.py index f843f926..1a7bdced 100644 --- a/src/kagglehub/clients.py +++ b/src/kagglehub/clients.py @@ -95,8 +95,7 @@ def _check_for_version_update(self, response: requests.Response) -> None: if latest_version > current_version: logger.info( "Warning: Looks like you're using an outdated `kagglehub` " - f"version, please consider updating (latest version: { - latest_version})" + f"version, please consider updating (latest version: {latest_version})" ) def get(self, path: str, resource_handle: Optional[ResourceHandle] = None) -> dict: @@ -152,8 +151,7 @@ def download_file(self, path: str, out_file: str, resource_handle: Optional[Reso logger.info(f"Download already complete ({size_read} bytes).") return - logger.info(f"Resuming download from {size_read} bytes ({ - total_size - size_read} bytes left)...") + logger.info(f"Resuming download from {size_read} bytes ({total_size - size_read} bytes left)...") # Send the request again with the 'Range' header. with requests.get( @@ -223,8 +221,7 @@ def __init__(self) -> None: if jwt_token is None: msg = ( "A JWT Token is required to call Kaggle, " - f"but none found in environment variable { - KAGGLE_JWT_TOKEN_ENV_VAR_NAME}" + f"but none found in environment variable {KAGGLE_JWT_TOKEN_ENV_VAR_NAME}" ) raise CredentialError(msg) @@ -232,8 +229,7 @@ def __init__(self) -> None: if data_proxy_token is None: msg = ( "A Data Proxy Token is required to call Kaggle, " - f"but none found in environment variable { - KAGGLE_DATA_PROXY_TOKEN_ENV_VAR_NAME}" + f"but none found in environment variable {KAGGLE_DATA_PROXY_TOKEN_ENV_VAR_NAME}" ) raise CredentialError(msg) diff --git a/src/kagglehub/gcs_upload.py b/src/kagglehub/gcs_upload.py index 86a4e907..851df77d 100644 --- a/src/kagglehub/gcs_upload.py +++ b/src/kagglehub/gcs_upload.py @@ -121,8 +121,7 @@ def _upload_blob(file_path: str, model_type: str, ctx_factory: Optional[Callable try: f.seek(uploaded_bytes) reader_wrapper = CallbackIOWrapper(pbar.update, f, "read") - headers["Content-Range"] = f"bytes { - uploaded_bytes}-{file_size - 1}/{file_size}" + headers["Content-Range"] = f"bytes {uploaded_bytes}-{file_size - 1}/{file_size}" upload_response = requests.put( session_uri, headers=headers, data=reader_wrapper, timeout=REQUEST_TIMEOUT ) @@ -132,12 +131,12 @@ def _upload_blob(file_path: str, model_type: str, ctx_factory: Optional[Callable elif upload_response.status_code == 308: # Resume Incomplete # noqa: PLR2004 uploaded_bytes = _check_uploaded_size(session_uri, file_size) else: - upload_failed_exception = f"Upload failed with status code { - upload_response.status_code}: {upload_response.text}" + upload_failed_exception = ( + f"Upload failed with status code {upload_response.status_code}: {upload_response.text}" + ) raise BackendError(upload_failed_exception) except (requests.ConnectionError, requests.Timeout) as e: - logger.info(f"Network issue: {e}, retrying in { - backoff_factor} seconds...") + logger.info(f"Network issue: {e}, retrying in {backoff_factor} seconds...") time.sleep(backoff_factor) backoff_factor = min(backoff_factor * 2, 60) retry_count += 1 diff --git a/src/kagglehub/models_helpers.py b/src/kagglehub/models_helpers.py index 6702138f..6762fe3b 100644 --- a/src/kagglehub/models_helpers.py +++ b/src/kagglehub/models_helpers.py @@ -45,8 +45,7 @@ def _create_model_instance_version( data = {"versionNotes": version_notes, "files": [{"token": file_token} for file_token in files]} api_client = KaggleApiV1Client(ctx_factory) api_client.post( - f"/models/{model_handle.owner}/{model_handle.model}/{ - model_handle.framework}/{model_handle.variation}/create/version", + f"/models/{model_handle.owner}/{model_handle.model}/{model_handle.framework}/{model_handle.variation}/create/version", data, ) logger.info( From 415c00a8b198c38b1a459179e0d6930301beb061 Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Mon, 22 Apr 2024 15:50:36 -0400 Subject: [PATCH 08/16] Fixed build failures. --- src/kagglehub/clients.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/kagglehub/clients.py b/src/kagglehub/clients.py index 1a7bdced..8cd22753 100644 --- a/src/kagglehub/clients.py +++ b/src/kagglehub/clients.py @@ -82,10 +82,10 @@ def get_user_agent() -> str: class KaggleApiV1Client: BASE_PATH = "api/v1" - def __init__(self, ctx_factory: Callable[[], TraceContext]) -> None: + def __init__(self, ctx_factory: Optional[Callable[[], TraceContext]] = None) -> None: self.credentials = get_kaggle_credentials() self.endpoint = get_kaggle_api_endpoint() - self.ctx_factory = ctx_factory if ctx_factory is not None else default_context_factory + self.ctx_factory = default_context_factory if ctx_factory is None else ctx_factory def _check_for_version_update(self, response: requests.Response) -> None: latest_version_str = response.headers.get("X-Kaggle-HubVersion") From f36cb26cddbace0b975e0a0b88ade0140c7e933b Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Mon, 22 Apr 2024 16:32:35 -0400 Subject: [PATCH 09/16] Fixing CI build issues --- src/kagglehub/tracing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kagglehub/tracing.py b/src/kagglehub/tracing.py index 5b709c7a..9c3404be 100644 --- a/src/kagglehub/tracing.py +++ b/src/kagglehub/tracing.py @@ -29,7 +29,7 @@ def __exit__( exc_val: Optional[BaseException], traceback: Optional[TracebackType], ) -> bool: - pass + return def next(self) -> str: """ From ca4722f4c0bb7ea66d802566d0debdd8ed4f8183 Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Mon, 22 Apr 2024 16:35:27 -0400 Subject: [PATCH 10/16] Moved to testing dir --- .vscode/settings.json | 2 +- src/kagglehub/tracing_test.py => tests/test_tracing.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/kagglehub/tracing_test.py => tests/test_tracing.py (100%) diff --git a/.vscode/settings.json b/.vscode/settings.json index f9679744..5e3a3ada 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,7 +2,7 @@ "python.testing.unittestArgs": [ "-v", "-s", - "./src", + "./tests", "-p", "*test*.py" ], diff --git a/src/kagglehub/tracing_test.py b/tests/test_tracing.py similarity index 100% rename from src/kagglehub/tracing_test.py rename to tests/test_tracing.py From 9d50417ef74d50d7150b391141f1ce6f3954cb89 Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Mon, 22 Apr 2024 16:39:34 -0400 Subject: [PATCH 11/16] Added return type --- src/kagglehub/tracing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kagglehub/tracing.py b/src/kagglehub/tracing.py index 9c3404be..96d7ac19 100644 --- a/src/kagglehub/tracing.py +++ b/src/kagglehub/tracing.py @@ -29,7 +29,7 @@ def __exit__( exc_val: Optional[BaseException], traceback: Optional[TracebackType], ) -> bool: - return + return False def next(self) -> str: """ From 2c02beef70f5a88b220eb93668754acfc07b2d45 Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Mon, 22 Apr 2024 16:44:54 -0400 Subject: [PATCH 12/16] Fixed context manager --- src/kagglehub/tracing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/kagglehub/tracing.py b/src/kagglehub/tracing.py index 96d7ac19..f71d9eb7 100644 --- a/src/kagglehub/tracing.py +++ b/src/kagglehub/tracing.py @@ -28,8 +28,8 @@ def __exit__( exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], traceback: Optional[TracebackType], - ) -> bool: - return False + ) -> None: + return def next(self) -> str: """ From 7d2a26b9c0a10103cb2cb5c69bf2d883f7524f21 Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Fri, 26 Apr 2024 20:50:53 +0000 Subject: [PATCH 13/16] Changed logging name to not conflict with standard library Added documentation around supporting Added instructions for hatch setup and integration in vscode Other minor fixes related to linting --- .vscode/extensions.json | 28 ++++++++++++++++++++++++++++ README.md | 30 ++++++++++++++++++++++++++++++ src/kagglehub/__init__.py | 2 +- src/kagglehub/gcs_upload.py | 4 +++- src/kagglehub/logger.py | 30 ++++++++++++++++++++++++++++++ src/kagglehub/logging.py | 16 ---------------- src/kagglehub/models.py | 3 ++- 7 files changed, 94 insertions(+), 19 deletions(-) create mode 100644 .vscode/extensions.json create mode 100644 src/kagglehub/logger.py delete mode 100644 src/kagglehub/logging.py diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 00000000..6b841e79 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,28 @@ +{ + // See https://go.microsoft.com/fwlink/?LinkId=827846 to learn about workspace recommendations. + // Extension identifier format: ${publisher}.${name}. Example: vscode.csharp + + // List of extensions which should be recommended for users of this workspace. + "recommendations": [ + "ms-azuretools.vscode-docker", + "tamasfe.even-better-toml", + "eamodio.gitlens", + "ms-python.isort", + "ms-python.mypy-type-checker", + "ms-python.debugpy", + "ms-python.python", + "ms-python.vscode-pylance", + "charliermarsh.ruff", + "ms-python.autopep8", + "ms-vscode-remote.remote-containers", + "ms-vscode.remote-explorer", + "ms-vscode.remote-server", + "ms-vscode-remote.remote-ssh-edit", + "ms-vscode-remote.remote-ssh", + "ms-toolsai.jupyter-keymap" + ], + // List of extensions recommended by VS Code that should not be recommended for users of this workspace. + "unwantedRecommendations": [ + + ] +} \ No newline at end of file diff --git a/README.md b/README.md index b5398cb4..6982417b 100644 --- a/README.md +++ b/README.md @@ -165,3 +165,33 @@ The following shows how to run `hatch run lint:all` but this also works for any # Use specific Python version (Must be a valid tag from: https://hub.docker.com/_/python) ./docker-hatch -v 3.9 run lint:all ``` + +## Vscode setup + +### Prerequisites +Install the recommended extensions. + +### Instructions + +Configure hatch to create virtual env in project folder. +``` +hatch config set dirs.env.virtual .env +``` + +After, create all the python environments needed by running `hatch -e all run tests`. + +Finally, configure vscode to use one of the selected environments: +`cmd + shift + p` -> `python: Select Interpreter` -> Pick one of the folders in `./.env` + +## Support + +The kagglehub library has configured automatic logging which is stored in a log folder. The log destination is resolved via the [os.path.expanduser](https://docs.python.org/3/library/os.path.html#os.path.expanduser) + +The table below contains possible locations: +| os | log path | +|---------|------------------------------------------------| +| osx | /user/$USERNAME/.kaggle/logs/kagglehub.log | +| linux | ~/.kaggle/logs/kagglehub.log | +| windows | C:\Users\\%USERNAME%\\.kaggle\logs\kagglehub.log | + +Please include the log to help troubleshoot issues. diff --git a/src/kagglehub/__init__.py b/src/kagglehub/__init__.py index 65cc7441..bed98f0d 100644 --- a/src/kagglehub/__init__.py +++ b/src/kagglehub/__init__.py @@ -1,6 +1,6 @@ __version__ = "0.2.4" -import kagglehub.logging # configures the library logger. +import kagglehub.logger # configures the library logger. from kagglehub import colab_cache_resolver, http_resolver, kaggle_cache_resolver, registry from kagglehub.auth import login from kagglehub.models import model_download, model_upload diff --git a/src/kagglehub/gcs_upload.py b/src/kagglehub/gcs_upload.py index d3129168..36fa5141 100644 --- a/src/kagglehub/gcs_upload.py +++ b/src/kagglehub/gcs_upload.py @@ -155,7 +155,9 @@ def _upload_blob(file_path: str, model_type: str) -> str: def upload_files_and_directories( - folder: str, model_type: str, quiet: bool = False # noqa: FBT002, FBT001 + folder: str, + model_type: str, + quiet: bool = False, # noqa: FBT002, FBT001 ) -> UploadDirectoryInfo: # Count the total number of files file_count = 0 diff --git a/src/kagglehub/logger.py b/src/kagglehub/logger.py new file mode 100644 index 00000000..ec2af958 --- /dev/null +++ b/src/kagglehub/logger.py @@ -0,0 +1,30 @@ +import logging +from logging.handlers import RotatingFileHandler +from pathlib import Path + +from kagglehub.config import get_log_verbosity + + +def _configure_logger() -> None: + library_name = __name__.split(".")[0] # i.e. "kagglehub" + library_logger = logging.getLogger(library_name) + while library_logger.handlers: + library_logger.handlers.pop() + + log_dir = Path.home() / ".kaggle" / "logs" + log_dir.mkdir(exist_ok=True, parents=True) + file_handler = RotatingFileHandler(str(log_dir / "kagglehub.log"), maxBytes=1024 * 1024 * 5, backupCount=5) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(threadName)s - %(funcName)s - %(message)s" + ) + file_handler.setFormatter(formatter) + file_handler.setLevel(logging.DEBUG) + library_logger.addHandler(file_handler) + library_logger.addHandler(logging.StreamHandler()) + # Disable propagation of the library log outputs. + # This prevents the same message again from being printed again if a root logger is defined. + library_logger.propagate = False + library_logger.setLevel(get_log_verbosity()) + + +_configure_logger() diff --git a/src/kagglehub/logging.py b/src/kagglehub/logging.py deleted file mode 100644 index 4fe419f5..00000000 --- a/src/kagglehub/logging.py +++ /dev/null @@ -1,16 +0,0 @@ -import logging - -from kagglehub.config import get_log_verbosity - - -def _configure_logger() -> None: - library_name = __name__.split(".")[0] # i.e. "kagglehub" - library_logger = logging.getLogger(library_name) - library_logger.addHandler(logging.StreamHandler()) - # Disable propagation of the library log outputs. - # This prevents the same message again from being printed again if a root logger is defined. - library_logger.propagate = False - library_logger.setLevel(get_log_verbosity()) - - -_configure_logger() diff --git a/src/kagglehub/models.py b/src/kagglehub/models.py index bfa5dd98..c70c95da 100644 --- a/src/kagglehub/models.py +++ b/src/kagglehub/models.py @@ -22,6 +22,7 @@ def model_download(handle: str, path: Optional[str] = None, *, force_download: O A string representing the path to the requested model files. """ h = parse_model_handle(handle) + logger.info(f"Downloading Model: {handle}") return registry.model_resolver(h, path, force_download=force_download) @@ -38,7 +39,7 @@ def model_upload( """ # parse slug h = parse_model_handle(handle) - + logger.info(f"Uploading Model {handle}") if h.is_versioned(): is_versioned_exception = "The model handle should not include the version" raise ValueError(is_versioned_exception) From fdeb341872d89db40c2e9c31f127576b1e424768 Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Mon, 29 Apr 2024 15:08:17 +0000 Subject: [PATCH 14/16] Fixed formatting on readme files --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 6982417b..aef77e8a 100644 --- a/README.md +++ b/README.md @@ -188,10 +188,10 @@ Finally, configure vscode to use one of the selected environments: The kagglehub library has configured automatic logging which is stored in a log folder. The log destination is resolved via the [os.path.expanduser](https://docs.python.org/3/library/os.path.html#os.path.expanduser) The table below contains possible locations: -| os | log path | -|---------|------------------------------------------------| -| osx | /user/$USERNAME/.kaggle/logs/kagglehub.log | -| linux | ~/.kaggle/logs/kagglehub.log | +| os | log path | +|---------|--------------------------------------------------| +| osx | /user/$USERNAME/.kaggle/logs/kagglehub.log | +| linux | ~/.kaggle/logs/kagglehub.log | | windows | C:\Users\\%USERNAME%\\.kaggle\logs\kagglehub.log | Please include the log to help troubleshoot issues. From 59a4144256646f127afb8df200c7d6ce2c2af07e Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Mon, 29 Apr 2024 15:46:35 +0000 Subject: [PATCH 15/16] Merged refactoring --- pyproject.toml | 1 + src/kagglehub/gcs_upload.py | 22 +++++++++++++++------- src/kagglehub/logger.py | 2 +- src/kagglehub/models.py | 2 +- src/kagglehub/models_helpers.py | 3 +-- 5 files changed, 19 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 67d6a33f..cae83107 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,6 +119,7 @@ select = [ "ANN001", "ANN002", "ANN003", + "ANN102", "ANN201", "ANN202", "ANN401", diff --git a/src/kagglehub/gcs_upload.py b/src/kagglehub/gcs_upload.py index 1e422857..737a39e9 100644 --- a/src/kagglehub/gcs_upload.py +++ b/src/kagglehub/gcs_upload.py @@ -4,7 +4,7 @@ import zipfile from datetime import datetime from tempfile import TemporaryDirectory -from typing import Dict, Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import requests from requests.exceptions import ConnectionError, Timeout @@ -164,8 +164,10 @@ def _upload_blob(file_path: str, model_type: str, ctx_factory: Optional[Callable def upload_files_and_directories( - folder: str, model_type: str, quiet: bool = False, # noqa: FBT002, FBT001 - ctx_factory: Callable[[],TraceContext] = None + folder: str, + model_type: str, + quiet: bool = False, # noqa: FBT002, FBT001 + ctx_factory: Optional[Callable[[], TraceContext]] = None, ) -> UploadDirectoryInfo: # Count the total number of files file_count = 0 @@ -232,7 +234,7 @@ def _upload_file_or_folder( file_or_folder_name: str, model_type: str, quiet: bool = False, # noqa: FBT002, FBT001 - ctx_factory : + ctx_factory: Optional[Callable[[], TraceContext]] = None, ) -> Optional[str]: """ Uploads a file or each file inside a folder individually from a specified path to a remote service. @@ -247,11 +249,17 @@ def _upload_file_or_folder( """ full_path = os.path.join(parent_path, file_or_folder_name) if os.path.isfile(full_path): - return _upload_file(file_or_folder_name, full_path, quiet, model_type) + return _upload_file(file_or_folder_name, full_path, quiet, model_type, ctx_factory) return None -def _upload_file(file_name: str, full_path: str, quiet: bool, model_type: str) -> Optional[str]: # noqa: FBT001 +def _upload_file( + file_name: str, + full_path: str, + quiet: bool, # noqa: FBT001 + model_type: str, + ctx_factory: Optional[Callable[[], TraceContext]] = None, +) -> Optional[str]: """Helper function to upload a single file Parameters ========== @@ -266,7 +274,7 @@ def _upload_file(file_name: str, full_path: str, quiet: bool, model_type: str) - logger.info("Starting upload for file " + file_name) content_length = os.path.getsize(full_path) - token = _upload_blob(full_path, model_type) + token = _upload_blob(full_path, model_type, ctx_factory) if not quiet: logger.info("Upload successful: " + file_name + " (" + File.get_size(content_length) + ")") return token diff --git a/src/kagglehub/logger.py b/src/kagglehub/logger.py index 74a770a8..4fe419f5 100644 --- a/src/kagglehub/logger.py +++ b/src/kagglehub/logger.py @@ -13,4 +13,4 @@ def _configure_logger() -> None: library_logger.setLevel(get_log_verbosity()) -# _configure_logger() +_configure_logger() diff --git a/src/kagglehub/models.py b/src/kagglehub/models.py index 01996b0d..f1d1970b 100644 --- a/src/kagglehub/models.py +++ b/src/kagglehub/models.py @@ -53,7 +53,7 @@ def shared_context_factory() -> TraceContext: create_model_if_missing(h.owner, h.model, shared_context_factory) # Upload the model files to GCS - tokens = upload_files_and_directories(local_model_dir, "model", shared_context_factory) + tokens = upload_files_and_directories(local_model_dir, "model", quiet=False, ctx_factory=shared_context_factory) # Create a model instance if it doesn't exist, and create a new instance version if an instance exists create_model_instance_or_version(h, tokens, license_name, version_notes, shared_context_factory) diff --git a/src/kagglehub/models_helpers.py b/src/kagglehub/models_helpers.py index c6d29697..ae75b0bf 100644 --- a/src/kagglehub/models_helpers.py +++ b/src/kagglehub/models_helpers.py @@ -22,7 +22,7 @@ def _create_model_instance( model_handle: ModelHandle, files_and_directories: UploadDirectoryInfo, license_name: Optional[str] = None, - ctx_factory: Optional[Callable[[], TraceContext]] = None + ctx_factory: Optional[Callable[[], TraceContext]] = None, ) -> None: serialized_data = files_and_directories.serialize() data = { @@ -40,7 +40,6 @@ def _create_model_instance( def _create_model_instance_version( - model_handle: ModelHandle, files_and_directories: UploadDirectoryInfo, version_notes: str = "", From e1acc7ce0d1d3a3965f12a271795c53b2d7c8392 Mon Sep 17 00:00:00 2001 From: Nesh Devanathan Date: Mon, 29 Apr 2024 16:07:19 +0000 Subject: [PATCH 16/16] Added server side test for traceparent header --- tests/server_stubs/model_upload_stub.py | 15 +++++++++++++ tests/test_tracing.py | 29 +++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/tests/server_stubs/model_upload_stub.py b/tests/server_stubs/model_upload_stub.py index d151b61f..c071b2f5 100644 --- a/tests/server_stubs/model_upload_stub.py +++ b/tests/server_stubs/model_upload_stub.py @@ -18,12 +18,19 @@ class SharedData: files: List[str] = field(default_factory=list) simulate_308: bool = False blob_request_count: int = 0 + traceparent_header_count: int = 0 shared_data: SharedData = SharedData() lock = threading.Lock() +def _increment_traceparent() -> None: + lock.acquire() + shared_data.traceparent_header_count += 1 + lock.release() + + def _increment_blob_request() -> None: lock.acquire() shared_data.blob_request_count += 1 @@ -41,6 +48,7 @@ def reset() -> None: shared_data.files = [] shared_data.blob_request_count = 0 shared_data.simulate_308 = False + shared_data.traceparent_header_count = 0 lock.release() @@ -50,6 +58,13 @@ def simulate_308(*, state: bool) -> None: lock.release() +@app.before_request +def before_req() -> None: + traceparent = request.headers.get("traceparent") + if traceparent: + _increment_traceparent() + + @app.route("/", methods=["HEAD"]) def head() -> ResponseReturnValue: return "", 200 diff --git a/tests/test_tracing.py b/tests/test_tracing.py index 22ebf133..038a518d 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -1,6 +1,13 @@ import unittest +from pathlib import Path +from tempfile import TemporaryDirectory +from kagglehub.models import model_upload from kagglehub.tracing import TraceContext +from tests.fixtures import BaseTestCase + +from .server_stubs import model_upload_stub as stub +from .server_stubs import serv _CANONICAL_EXAMPLE = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01" @@ -39,5 +46,27 @@ def test_notempty(self) -> None: self.assertNotEqual(span, f"{0:08x}") +class TestModelUpload(BaseTestCase): + def setUp(self) -> None: + stub.reset() + + @classmethod + def setUpClass(cls): # noqa: ANN102 + serv.start_server(stub.app) + + @classmethod + def tearDownClass(cls): # noqa: ANN102 + serv.stop_server() + + def test_model_upload_instance_with_valid_handle(self) -> None: + with TemporaryDirectory() as temp_dir: + test_filepath = Path(temp_dir) / "temp_test_file" + test_filepath.touch() # Create a temporary file in the temporary directory + model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, "Apache 2.0", "model_type") + self.assertEqual(len(stub.shared_data.files), 1) + self.assertIn("temp_test_file", stub.shared_data.files) + self.assertGreaterEqual(stub.shared_data.traceparent_header_count, 2) + + if __name__ == "__main__": unittest.main()