Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(spm): correctly encode enums in query params and headers #111

Merged
merged 1 commit into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cryosparc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from . import registry
from .errors import APIError
from .json_util import api_default, api_object_hook
from .json_util import api_default, api_encode, api_object_hook
from .models.auth import Token
from .stream import Streamable

Expand Down Expand Up @@ -74,10 +74,10 @@ def _construct_request(self, _path: str, _schema, *args, **kwargs) -> Tuple[str,
_path = _path.replace("{%s}" % param_name, _uriencode(param))
elif param_in == "query" and param_name in kwargs:
# query param must be in kwargs
query_params[param_name] = kwargs.pop(param_name)
query_params[param_name] = api_encode(kwargs.pop(param_name))
elif param_in == "header" and (header_name := param_name.replace("-", "_")) in kwargs:
# header must be in kwargs
headers[param_name] = kwargs.pop(header_name)
headers[param_name] = api_encode(kwargs.pop(header_name))
elif param_in == "header" and param_name in client_headers:
pass # in default headers, no action required
elif param_schema["required"]:
Expand Down
15 changes: 15 additions & 0 deletions cryosparc/json_util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
import base64
from datetime import datetime
from enum import Enum
from pathlib import PurePath
from typing import Any, Mapping

import numpy as n
from pydantic import BaseModel


def api_encode(obj: Any):
"""
Recursively encode any object for transmission through the API.
"""
if isinstance(obj, dict):
return {k: api_encode(v) for k, v in obj}
elif isinstance(obj, list):
return [api_encode(v) for v in obj]
else:
return api_default(obj)


def api_default(obj: Any) -> Any:
"""
json.dump "default" argument for sending objects over a JSON API. Ensures
Expand All @@ -28,6 +41,8 @@ def api_default(obj: Any) -> Any:
return binary_to_json(obj)
elif isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, Enum):
return obj.value
elif isinstance(obj, PurePath):
return str(obj)
elif isinstance(obj, BaseModel):
Expand Down
Loading