From 9d4dab928ad151debd5bd1254bf96d966d1ca9b9 Mon Sep 17 00:00:00 2001 From: Nick Frasser <1693461+nfrasser@users.noreply.github.com> Date: Tue, 24 Dec 2024 15:24:44 -0500 Subject: [PATCH] fix(api): correctly encode query params and headers If given a list of enums, it should encode their values --- cryosparc/api.py | 6 +++--- cryosparc/json_util.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/cryosparc/api.py b/cryosparc/api.py index dba3587f..20eff6a1 100644 --- a/cryosparc/api.py +++ b/cryosparc/api.py @@ -10,7 +10,7 @@ from . import registry from .errors import APIError -from .json_util import api_default, api_object_hook +from .json_util import api_default, api_encode, api_object_hook from .models.auth import Token from .stream import Streamable @@ -74,10 +74,10 @@ def _construct_request(self, _path: str, _schema, *args, **kwargs) -> Tuple[str, _path = _path.replace("{%s}" % param_name, _uriencode(param)) elif param_in == "query" and param_name in kwargs: # query param must be in kwargs - query_params[param_name] = kwargs.pop(param_name) + query_params[param_name] = api_encode(kwargs.pop(param_name)) elif param_in == "header" and (header_name := param_name.replace("-", "_")) in kwargs: # header must be in kwargs - headers[param_name] = kwargs.pop(header_name) + headers[param_name] = api_encode(kwargs.pop(header_name)) elif param_in == "header" and param_name in client_headers: pass # in default headers, no action required elif param_schema["required"]: diff --git a/cryosparc/json_util.py b/cryosparc/json_util.py index 4859e209..4100e647 100644 --- a/cryosparc/json_util.py +++ b/cryosparc/json_util.py @@ -1,5 +1,6 @@ import base64 from datetime import datetime +from enum import Enum from pathlib import PurePath from typing import Any, Mapping @@ -7,6 +8,18 @@ from pydantic import BaseModel +def api_encode(obj: Any): + """ + Recursively encode any object for transmission through the API. + """ + if isinstance(obj, dict): + return {k: api_encode(v) for k, v in obj} + elif isinstance(obj, list): + return [api_encode(v) for v in obj] + else: + return api_default(obj) + + def api_default(obj: Any) -> Any: """ json.dump "default" argument for sending objects over a JSON API. Ensures @@ -28,6 +41,8 @@ def api_default(obj: Any) -> Any: return binary_to_json(obj) elif isinstance(obj, datetime): return obj.isoformat() + elif isinstance(obj, Enum): + return obj.value elif isinstance(obj, PurePath): return str(obj) elif isinstance(obj, BaseModel):