diff --git a/.gitignore b/.gitignore index eaaeec7..9dcb600 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ htmlcov validated.yaml merged.yaml polytope_server.egg-info -**/build \ No newline at end of file +**/build +.venv \ No newline at end of file diff --git a/polytope_server/broker/broker.py b/polytope_server/broker/broker.py index 7af98bd..168c89c 100644 --- a/polytope_server/broker/broker.py +++ b/polytope_server/broker/broker.py @@ -40,13 +40,10 @@ def __init__(self, config): self.collections = collection.create_collections(config.get("collections")) - self.user_limit = self.broker_config.get("user_limit", None) - def run(self): logging.info("Starting broker...") logging.info("Maximum Queue Size: {}".format(self.max_queue_size)) - logging.info("User Request Limit: {}".format(self.user_limit)) while not time.sleep(self.scheduling_interval): self.check_requests() @@ -87,45 +84,41 @@ def check_requests(self): def check_limits(self, active_requests, request): - logging.debug("Checking limits for request {}".format(request.id)) - - # User limits - if self.user_limit is not None: - user_active_requests = sum(qr.user == request.user for qr in active_requests) - if user_active_requests >= self.user_limit: - logging.debug("User has {} of {} active requests".format(user_active_requests, self.user_limit)) - return False - - # Collection limits - collection_total_limit = self.collections[request.collection].limits.get("total", None) - if collection_total_limit is not None: - collection_active_requests = sum(qr.collection == request.collection for qr in active_requests) - if collection_active_requests >= collection_total_limit: - logging.debug( - "Collection has {} of {} total active requests".format( - collection_active_requests, collection_total_limit - ) - ) - return False - - # Collection-user limits - collection_user_limit = self.collections[request.collection].limits.get("per-user", None) - if collection_user_limit is not None: - collection_user_active_requests = sum( - (qr.collection == request.collection and qr.user == request.user) for qr in active_requests - ) - if collection_user_active_requests >= collection_user_limit: - logging.debug( - "User has {} of {} active requests in collection {}".format( - collection_user_active_requests, - collection_user_limit, - request.collection, - ) - ) + logging.debug(f"Checking limits for request {request.id}") + + # Get collection limits and calculate active requests + collection = self.collections[request.collection] + collection_limits = collection.limits + collection_total_limit = collection_limits.get("total") + collection_active_requests = sum(qr.collection == request.collection for qr in active_requests) + logging.debug(f"Collection {request.collection} has {collection_active_requests} active requests") + + # Check collection total limit + if collection_total_limit is not None and collection_active_requests >= collection_total_limit: + logging.debug(f"Collection has {collection_active_requests} of {collection_total_limit} total active requests") + return False + + # Determine the effective limit based on role or per-user setting + role_limits = collection_limits.get("per-role", {}).get(request.user.realm, {}) + limit = max((role_limits.get(role, 0) for role in request.user.roles), default=0) + if limit == 0: # Use collection per-user limit if no role-specific limit + limit = collection_limits.get("per-user", 0) + + # Check if user exceeds the effective limit + if limit > 0: + user_active_requests = sum(qr.collection == request.collection and qr.user == request.user for qr in active_requests) + if user_active_requests >= limit: + logging.debug(f"User {request.user} has {user_active_requests} of {limit} active requests in collection {request.collection}") return False + else: + logging.debug(f"User {request.user} has {user_active_requests} of {limit} active requests in collection {request.collection}") + return True + # Allow if no limits are exceeded + logging.debug(f"No limit for user {request.user} in collection {request.collection}") return True + def enqueue(self, request): logging.info("Queuing request", extra={"request_id": request.id}) diff --git a/polytope_server/common/auth.py b/polytope_server/common/auth.py index 6754692..8bea551 100644 --- a/polytope_server/common/auth.py +++ b/polytope_server/common/auth.py @@ -117,7 +117,7 @@ def authenticate(self, auth_header) -> User: www_authenticate=self.auth_info, ) - user.roles = ["default"] + user.roles.append("default") # Visit all authorizers to append additional roles and attributes for authorizer in self.authorizers: diff --git a/polytope_server/common/authentication/authentication.py b/polytope_server/common/authentication/authentication.py index 0f05871..5227dd4 100644 --- a/polytope_server/common/authentication/authentication.py +++ b/polytope_server/common/authentication/authentication.py @@ -77,7 +77,8 @@ def name(self) -> str: "plain": "PlainAuthentication", "keycloak": "KeycloakAuthentication", "federation": "FederationAuthentication", - "jwt": "JWTAuthentication", + "jwt" : "JWTAuthentication", + "openid_offline_access" : "OpenIDOfflineAuthentication", } diff --git a/polytope_server/common/authentication/jwt_authentication.py b/polytope_server/common/authentication/jwt_authentication.py index 9034b14..6c8502f 100644 --- a/polytope_server/common/authentication/jwt_authentication.py +++ b/polytope_server/common/authentication/jwt_authentication.py @@ -35,6 +35,7 @@ def __init__(self, name, realm, config): self.config = config self.certs_url = config["cert_url"] + self.client_id = config["client_id"] super().__init__(name, realm, config) @@ -57,8 +58,14 @@ def authenticate(self, credentials: str) -> User: token=credentials, algorithms=jwt.get_unverified_header(credentials).get("alg"), key=certs ) + logging.info("Decoded JWT: {}".format(decoded_token)) + + user = User(decoded_token["sub"], self.realm()) + roles = decoded_token.get("resource_access", {}).get(self.client_id, {}).get("roles", []) + user.roles.extend(roles) + logging.info("Found user {} from decoded JWT".format(user)) except Exception as e: logging.info("Failed to authenticate user from JWT") diff --git a/polytope_server/common/authentication/mongoapikey_authentication.py b/polytope_server/common/authentication/mongoapikey_authentication.py index 5ebbdde..f2fcdfc 100644 --- a/polytope_server/common/authentication/mongoapikey_authentication.py +++ b/polytope_server/common/authentication/mongoapikey_authentication.py @@ -20,8 +20,7 @@ from datetime import datetime -import pymongo - +from .. import mongo_client_factory from ..auth import User from ..exceptions import ForbiddenRequest from ..metric_collector import MongoStorageMetricCollector @@ -39,19 +38,18 @@ class ApiKeyMongoAuthentication(authentication.Authentication): """ def __init__(self, name, realm, config): - self.config = config - host = config.get("host", "localhost") - port = config.get("port", "27017") + uri = config.get("uri", "mongodb://localhost:27017") collection = config.get("collection", "keys") + username = config.get("username") + password = config.get("password") - endpoint = "{}:{}".format(host, port) - self.mongo_client = pymongo.MongoClient(endpoint, journal=True, connect=False) + self.mongo_client = mongo_client_factory.create_client(uri, username, password) self.database = self.mongo_client.keys self.keys = self.database[collection] assert realm == "polytope" - self.storage_metric_collector = MongoStorageMetricCollector(endpoint, self.mongo_client, "keys", collection) + self.storage_metric_collector = MongoStorageMetricCollector(uri, self.mongo_client, "keys", collection) super().__init__(name, realm, config) @@ -62,7 +60,6 @@ def authentication_info(self): return "Authenticate with Polytope API Key from ../auth/keys" def authenticate(self, credentials: str) -> User: - # credentials should be of the form '' res = self.keys.find_one({"key.key": credentials}) if res is None: diff --git a/polytope_server/common/authentication/mongodb_authentication.py b/polytope_server/common/authentication/mongodb_authentication.py index ff1f579..a1d0be3 100644 --- a/polytope_server/common/authentication/mongodb_authentication.py +++ b/polytope_server/common/authentication/mongodb_authentication.py @@ -22,8 +22,7 @@ import binascii import hashlib -import pymongo - +from .. import mongo_client_factory from ..auth import User from ..exceptions import ForbiddenRequest from ..metric_collector import MongoStorageMetricCollector @@ -32,19 +31,18 @@ class MongoAuthentication(authentication.Authentication): def __init__(self, name, realm, config): - self.config = config - host = config.get("host", "localhost") - port = config.get("port", "27017") + uri = config.get("uri", "mongodb://localhost:27017") collection = config.get("collection", "users") + username = config.get("username") + password = config.get("password") - endpoint = "{}:{}".format(host, port) - self.mongo_client = pymongo.MongoClient(endpoint, journal=True, connect=False) + self.mongo_client = mongo_client_factory.create_client(uri, username, password) self.database = self.mongo_client.authentication self.users = self.database[collection] self.storage_metric_collector = MongoStorageMetricCollector( - endpoint, self.mongo_client, "authentication", collection + uri, self.mongo_client, "authentication", collection ) super().__init__(name, realm, config) @@ -59,7 +57,6 @@ def authentication_info(self): return "Authenticate with username and password" def authenticate(self, credentials: str) -> User: - # credentials should be of the form 'base64(:)' try: decoded = base64.b64decode(credentials).decode("utf-8") diff --git a/polytope_server/common/authentication/openid_offline_access_authentication.py b/polytope_server/common/authentication/openid_offline_access_authentication.py new file mode 100644 index 0000000..5550100 --- /dev/null +++ b/polytope_server/common/authentication/openid_offline_access_authentication.py @@ -0,0 +1,116 @@ +# +# Copyright 2022 European Centre for Medium-Range Weather Forecasts (ECMWF) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation nor +# does it submit to any jurisdiction. +# + +import logging +import os +import requests +from jose import jwt + +from ..auth import User +from ..caching import cache +from . import authentication +from ..exceptions import ForbiddenRequest + + +class OpenIDOfflineAuthentication(authentication.Authentication): + def __init__(self, name, realm, config): + self.config = config + + self.certs_url = config["cert_url"] + self.public_client_id = config["public_client_id"] + self.private_client_id = config["private_client_id"] + self.private_client_secret = config["private_client_secret"] + self.iam_url = config["iam_url"] + self.iam_realm = config["iam_realm"] + + + super().__init__(name, realm, config) + + def authentication_type(self): + return "Bearer" + + def authentication_info(self): + return "Authenticate with OpenID offline_access token" + + @cache(lifetime=120) + def get_certs(self): + return requests.get(self.certs_url).json() + + @cache(lifetime=120) + def check_offline_access_token(self, token: str) -> bool: + """ + We check if the token is recognised by the IAM service, and we cache this result. + We cannot simply try to get the access token because we would spam the IAM server with invalid tokens, and the + failure at that point would not be cached. + """ + keycloak_token_introspection = self.iam_url + "/realms/" + self.iam_realm + "/protocol/openid-connect/token/introspect" + introspection_data = { + "token": token + } + b_auth = requests.auth.HTTPBasicAuth(self.private_client_id, self.private_client_secret) + resp = requests.post(url=keycloak_token_introspection, data=introspection_data, auth=b_auth).json() + if resp["active"] and resp["token_type"] == "Offline": + return True + else: + return False + + @cache(lifetime=120) + def authenticate(self, credentials: str) -> User: + + try: + + # Check if this is a valid offline_access token + if not self.check_offline_access_token(credentials): + raise ForbiddenRequest("Not a valid offline_access token") + + # Generate an access token from the offline_access token (like a refresh token) + refresh_data = { + "client_id": self.public_client_id, + "grant_type": "refresh_token", + "refresh_token": credentials + } + keycloak_token_endpoint = self.iam_url + "/realms/" + self.iam_realm + "/protocol/openid-connect/token" + resp = requests.post(url=keycloak_token_endpoint, data=refresh_data) + token = resp.json()['access_token'] + + certs = self.get_certs() + decoded_token = jwt.decode(token=token, + algorithms=jwt.get_unverified_header(token).get('alg'), + key=certs + ) + + logging.info("Decoded JWT: {}".format(decoded_token)) + + user = User(decoded_token["sub"], self.realm()) + + roles = decoded_token.get("resource_access", {}).get(self.public_client_id, {}).get("roles", []) + user.roles.extend(roles) + + logging.info("Found user {} from openid offline_access token".format(user)) + + except Exception as e: + logging.info("Failed to authenticate user from openid offline_access token") + logging.info(e) + raise ForbiddenRequest("Could not authenticate user from openid offline_access token") + return user + + + def collect_metric_info(self): + return {} diff --git a/polytope_server/common/authorization/mongodb_authorization.py b/polytope_server/common/authorization/mongodb_authorization.py index 59fe716..48ac44c 100644 --- a/polytope_server/common/authorization/mongodb_authorization.py +++ b/polytope_server/common/authorization/mongodb_authorization.py @@ -18,8 +18,7 @@ # does it submit to any jurisdiction. # -import pymongo - +from .. import mongo_client_factory from ..auth import User from ..metric_collector import MongoStorageMetricCollector from . import authorization @@ -29,23 +28,22 @@ class MongoDBAuthorization(authorization.Authorization): def __init__(self, name, realm, config): self.config = config assert self.config["type"] == "mongodb" - self.host = config.get("host", "localhost") - self.port = config.get("port", "27017") + self.uri = config.get("uri", "mongodb://localhost:27017") self.collection = config.get("collection", "users") + username = config.get("username") + password = config.get("password") - endpoint = "{}:{}".format(self.host, self.port) - self.mongo_client = pymongo.MongoClient(endpoint, journal=True, connect=False) + self.mongo_client = mongo_client_factory.create_client(self.uri, username, password) self.database = self.mongo_client.authentication self.users = self.database[self.collection] self.storage_metric_collector = MongoStorageMetricCollector( - endpoint, self.mongo_client, "authentication", self.collection + self.uri, self.mongo_client, "authentication", self.collection ) super().__init__(name, realm, config) def get_roles(self, user: User) -> list: - if user.realm != self.realm(): raise ValueError( "Trying to authorize a user in the wrong realm, expected {}, got {}".format(self.realm(), user.realm) diff --git a/polytope_server/common/caching/caching.py b/polytope_server/common/caching/caching.py index 652ba07..7cbd983 100644 --- a/polytope_server/common/caching/caching.py +++ b/polytope_server/common/caching/caching.py @@ -29,9 +29,9 @@ from typing import Dict, Union import pymemcache -import pymongo import redis +from .. import mongo_client_factory from ..metric import MetricType from ..metric_collector import ( DictStorageMetricCollector, @@ -195,17 +195,20 @@ def collect_metric_info(self): class MongoDBCaching(Caching): def __init__(self, cache_config): super().__init__(cache_config) - host = cache_config.get("host", "localhost") - port = cache_config.get("port", 27017) - endpoint = "{}:{}".format(host, port) + uri = cache_config.get("uri", "mongodb://localhost:27017") + + username = cache_config.get("username") + password = cache_config.get("password") + collection = cache_config.get("collection", "cache") - self.client = pymongo.MongoClient(host + ":" + str(port), journal=False, connect=False) + self.client = mongo_client_factory.create_client(uri, username, password,) + self.database = self.client.cache self.collection = self.database[collection] self.collection.create_index("expire_at", expireAfterSeconds=0) self.collection.update_one({"_id": "hits"}, {"$setOnInsert": {"n": 0}}, upsert=True) self.collection.update_one({"_id": "misses"}, {"$setOnInsert": {"n": 0}}, upsert=True) - self.storage_metric_collector = MongoStorageMetricCollector(endpoint, self.client, "cache", collection) + self.storage_metric_collector = MongoStorageMetricCollector(uri, self.client, "cache", collection) self.cache_metric_collector = MongoCacheMetricCollector(self.client, "cache", collection) def get_type(self): @@ -220,7 +223,6 @@ def get(self, key): return obj["data"] def set(self, key, object, lifetime): - if lifetime == 0 or lifetime is None: expiry = datetime.datetime.max else: @@ -324,7 +326,6 @@ def __call__(self, f): @functools.wraps(f) def wrapper(*args, **kwargs): - cache.cancelled = False if self.cache is None: diff --git a/polytope_server/common/config/schema.yaml b/polytope_server/common/config/schema.yaml index 889c79b..9d7b162 100644 --- a/polytope_server/common/config/schema.yaml +++ b/polytope_server/common/config/schema.yaml @@ -68,7 +68,7 @@ mapping: desc: point to a hosted mongodb type: map mapping: - endpoint: + uri: desc: host and port example: localhost:27017 type: str @@ -116,7 +116,7 @@ mapping: desc: point to a hosted mongodb type: map mapping: - endpoint: + uri: desc: host and port example: localhost:27017 type: str diff --git a/polytope_server/common/datasource/datasource.py b/polytope_server/common/datasource/datasource.py index 3151097..cf4c75e 100644 --- a/polytope_server/common/datasource/datasource.py +++ b/polytope_server/common/datasource/datasource.py @@ -113,6 +113,7 @@ def dispatch(self, request, input_data) -> bool: "echo": "EchoDataSource", "dummy": "DummyDataSource", "raise": "RaiseDataSource", + "ionbeam": "IonBeamDataSource" } diff --git a/polytope_server/common/datasource/fdb.py b/polytope_server/common/datasource/fdb.py index 609f923..b9cae2d 100644 --- a/polytope_server/common/datasource/fdb.py +++ b/polytope_server/common/datasource/fdb.py @@ -26,7 +26,6 @@ from datetime import datetime, timedelta from pathlib import Path -import pyfdb import yaml from dateutil.relativedelta import relativedelta @@ -47,7 +46,10 @@ def __init__(self, config): self.check_schema() os.environ["FDB5_CONFIG"] = json.dumps(self.fdb_config) + os.environ["FDB_CONFIG"] = json.dumps(self.fdb_config) os.environ["FDB5_HOME"] = self.config.get("fdb_home", "/opt/fdb") + os.environ["FDB_HOME"] = self.config.get("fdb_home", "/opt/fdb") + import pyfdb self.fdb = pyfdb.FDB() if "spaces" in self.fdb_config: @@ -143,14 +145,26 @@ def match(self, request): r = yaml.safe_load(request.user_request) or {} for k, v in self.match_rules.items(): + + # An empty match rule means that the key must not be present + if v is None or len(v) == 0: + if k in r: + raise Exception("Request containing key '{}' is not allowed".format(k)) + else: + continue # no more checks to do + # Check that all required keys exist - if k not in r: - raise Exception("Request does not contain expected key {}".format(k)) + if k not in r and not (v is None or len(v) == 0): + raise Exception("Request does not contain expected key '{}'".format(k)) + + # Process date rules if k == "date": self.date_check(r["date"], v) continue + # ... and check the value of other keys + v = [v] if isinstance(v, str) else v if r[k] not in v: raise Exception("got {} : {}, but expected one of {}".format(k, r[k], v)) diff --git a/polytope_server/common/datasource/ionbeam.py b/polytope_server/common/datasource/ionbeam.py new file mode 100644 index 0000000..4fec9bc --- /dev/null +++ b/polytope_server/common/datasource/ionbeam.py @@ -0,0 +1,134 @@ +# +# Copyright 2022 European Centre for Medium-Range Weather Forecasts (ECMWF) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation nor +# does it submit to any jurisdiction. +# +import yaml +import logging +import requests +from urllib.parse import urljoin + +from . import datasource +from requests import Request +from dataclasses import dataclass + +@dataclass +class IonBeamAPI: + endpoint : str + + def __post_init__(self): + assert not self.endpoint.endswith("/") + self.session = requests.Session() + + def get(self, path : str, **kwargs) -> requests.Response: + return self.session.get(f"{self.endpoint}/{path}", stream=True, **kwargs) + + def get_bytes(self, path : str, **kwargs) -> requests.Response: + kwargs["headers"] = kwargs.get("headers", {}) | { + 'Accept': 'application/octet-stream' + } + return self.get(path, **kwargs) + + def get_json(self, path, **kwargs): + return self.get(path, **kwargs).json() + + def list(self, request : dict[str, str] = {}): + return self.get_json("list", params = request) + + def head(self, request : dict[str, str] = {}): + return self.get_json("head", params = request) + + def retrieve(self, request : dict[str, str]) -> requests.Response: + return self.get_bytes("retrieve", params = request) + + def archive(self, request, file) -> requests.Response: + files = {'file': file} + return self.session.post(f"{self.endpoint}/archive", files=files, params = request) + + +class IonBeamDataSource(datasource.DataSource): + """ + Retrieve data from the IonBeam REST backend that lives here: https://github.com/ecmwf/IonBeam-Deployment/tree/main/docker/rest_api + + + """ + read_chunk_size = 2 * 1024 * 1024 + + def __init__(self, config): + """Instantiate a datasource for the IonBeam REST API""" + self.type = config["type"] + assert self.type == "ionbeam" + + self.match_rules = config.get("match", {}) + endpoint = config.get("api_endpoint", "http://iotdev-001:18201/api/v1/") + self.api = IonBeamAPI(endpoint) + + def mime_type(self) -> str: + """Returns the mimetype of the result""" + return "application/octet-stream" + + def get_type(self): + return self.type + + def archive(self, request : Request): + """Archive data, returns nothing but updates datasource state""" + r = yaml.safe_load(request.user_request) + keys = r["keys"] + + with open(r["path"], 'rb') as f: + return self.api.archive(keys, f) + + def list(self, request : Request) -> list: + request_keys = yaml.safe_load(request.user_request) + return self.api.list(request_keys) + + def retrieve(self, request : Request) -> bool: + """Retrieve data, returns nothing but updates datasource state""" + + request_keys = yaml.safe_load(request.user_request) + self.response = self.api.retrieve(request_keys) + return True + + def result(self, request : Request): + """Returns a generator for the resultant data""" + return self.response.iter_content(chunk_size = self.read_chunk_size, decode_unicode=False) + + def destroy(self, request) -> None: + """A hook to do essential freeing of resources, called upon success or failure""" + + # requests response objects with stream=True can remain open indefinitely if not read to completion + # or closed explicitly + if self.response: + self.response.close() + + def match(self, request: Request) -> None: + """Checks if the request matches the datasource, raises on failure""" + + r = yaml.safe_load(request.user_request) or {} + + for k, v in self.match_rules.items(): + # Check that all required keys exist + if k not in r: + raise Exception("Request does not contain expected key {}".format(k)) + # Process date rules + if k == "date": + # self.date_check(r["date"], v) + continue + # ... and check the value of other keys + v = [v] if isinstance(v, str) else v + if r[k] not in v: + raise Exception("got {} : {}, but expected one of {}".format(k, r[k], v)) diff --git a/polytope_server/common/identity/mongodb_identity.py b/polytope_server/common/identity/mongodb_identity.py index d6c68be..86a67ec 100644 --- a/polytope_server/common/identity/mongodb_identity.py +++ b/polytope_server/common/identity/mongodb_identity.py @@ -18,8 +18,7 @@ # does it submit to any jurisdiction. # -import pymongo - +from .. import mongo_client_factory from ..authentication.mongodb_authentication import MongoAuthentication from ..exceptions import Conflict, NotFound from ..metric_collector import MetricCollector, MongoStorageMetricCollector @@ -29,12 +28,17 @@ class MongoDBIdentity(identity.Identity): def __init__(self, config): self.config = config - self.host = config.get("host", "localhost") - self.port = config.get("port", "27017") + self.uri = config.get("uri", "mongodb://localhost:27017") + self.collection = config.get("collection", "users") + username = config.get("username") + password = config.get("password") - endpoint = "{}:{}".format(self.host, self.port) - self.mongo_client = pymongo.MongoClient(endpoint, journal=True, connect=False) + self.mongo_client = mongo_client_factory.create_client( + self.uri, + username, + password, + ) self.database = self.mongo_client.authentication self.users = self.database[self.collection] self.realm = config.get("realm") @@ -47,12 +51,11 @@ def __init__(self, config): pass self.storage_metric_collector = MongoStorageMetricCollector( - endpoint, self.mongo_client, "authentication", self.collection + self.uri, self.mongo_client, "authentication", self.collection ) self.identity_metric_collector = MetricCollector() def add_user(self, username: str, password: str, roles: list) -> bool: - if self.users.find_one({"username": username}) is not None: raise Conflict("Username already registered") @@ -70,7 +73,6 @@ def add_user(self, username: str, password: str, roles: list) -> bool: return True def remove_user(self, username: str) -> bool: - result = self.users.delete_one({"username": username}) if result.deleted_count > 0: return True diff --git a/polytope_server/common/keygenerator/mongodb_keygenerator.py b/polytope_server/common/keygenerator/mongodb_keygenerator.py index 471f49b..489054e 100644 --- a/polytope_server/common/keygenerator/mongodb_keygenerator.py +++ b/polytope_server/common/keygenerator/mongodb_keygenerator.py @@ -22,8 +22,7 @@ import uuid from datetime import datetime, timedelta -import pymongo - +from .. import mongo_client_factory from ..auth import User from ..exceptions import ForbiddenRequest from ..metric_collector import MongoStorageMetricCollector @@ -34,19 +33,19 @@ class MongoKeyGenerator(keygenerator.KeyGenerator): def __init__(self, config): self.config = config assert self.config["type"] == "mongodb" - host = config.get("host", "localhost") - port = config.get("port", "27017") + uri = config.get("uri", "mongodb://localhost:27017") collection = config.get("collection", "keys") - endpoint = "{}:{}".format(host, port) - self.mongo_client = pymongo.MongoClient(endpoint, journal=True, connect=False) + username = config.get("username") + password = config.get("password") + + self.mongo_client = mongo_client_factory.create_client(uri, username, password) self.database = self.mongo_client.keys self.keys = self.database[collection] self.realms = config.get("allowed_realms") - self.storage_metric_collector = MongoStorageMetricCollector(endpoint, self.mongo_client, "keys", collection) + self.storage_metric_collector = MongoStorageMetricCollector(uri, self.mongo_client, "keys", collection) def create_key(self, user: User) -> ApiKey: - if user.realm not in self.realms: raise ForbiddenRequest("Not allowed to create an API Key for users in realm {}".format(user.realm)) diff --git a/polytope_server/common/metric_collector/queue_metric_collector.py b/polytope_server/common/metric_collector/queue_metric_collector.py index 828a774..22827fa 100644 --- a/polytope_server/common/metric_collector/queue_metric_collector.py +++ b/polytope_server/common/metric_collector/queue_metric_collector.py @@ -46,3 +46,21 @@ def total_queued(self): channel = connection.channel() q = channel.queue_declare(queue=self.queue_name, durable=True, passive=True) return q.method.message_count + + +class SQSQueueMetricCollector(QueueMetricCollector): + def __init__(self, host, client): + self.host = host + self.client = client + + def total_queued(self): + response = self.client.get_queue_attributes( + QueueUrl=self.host, + AttributeNames=[ + "ApproximateNumberOfMessages", + "ApproximateNumberOfMessagesDelayed", + "ApproximateNumberOfMessagesNotVisible", + ], + ) + values = response.get("Attributes", {}).values() + return sum(map(int, values)) diff --git a/polytope_server/common/metric_collector/storage_metric_collector.py b/polytope_server/common/metric_collector/storage_metric_collector.py index 3c088d6..7c57ecb 100644 --- a/polytope_server/common/metric_collector/storage_metric_collector.py +++ b/polytope_server/common/metric_collector/storage_metric_collector.py @@ -112,7 +112,7 @@ def storage_space_used(self): return space_used def total_entries(self): - return self.store.count() + return self.store.count_documents({}) def db_name(self): return self.database diff --git a/polytope_server/common/metric_store/mongodb_metric_store.py b/polytope_server/common/metric_store/mongodb_metric_store.py index 9dbcdaf..b529426 100644 --- a/polytope_server/common/metric_store/mongodb_metric_store.py +++ b/polytope_server/common/metric_store/mongodb_metric_store.py @@ -22,6 +22,7 @@ import pymongo +from .. import mongo_client_factory from ..metric import ( CacheInfo, Metric, @@ -38,13 +39,13 @@ class MongoMetricStore(MetricStore): def __init__(self, config=None): - host = config.get("host", "localhost") - port = config.get("port", "27017") + uri = config.get("uri", "mongodb://localhost:27017") metric_collection = config.get("collection", "metrics") - endpoint = "{}:{}".format(host, port) + username = config.get("username") + password = config.get("password") - self.mongo_client = pymongo.MongoClient(endpoint, journal=True, connect=False) + self.mongo_client = mongo_client_factory.create_client(uri, username, password) self.database = self.mongo_client.metric_store self.store = self.database[metric_collection] @@ -58,10 +59,10 @@ def __init__(self, config=None): } self.storage_metric_collector = MongoStorageMetricCollector( - endpoint, self.mongo_client, "metric_store", metric_collection + uri, self.mongo_client, "metric_store", metric_collection ) - logging.info("MongoClient configured to open at {}".format(endpoint)) + logging.info("MongoClient configured to open at {}".format(uri)) def get_type(self): return "mongodb" @@ -85,7 +86,6 @@ def get_metric(self, uuid): return None def get_metrics(self, ascending=None, descending=None, limit=None, **kwargs): - all_slots = [] found_type = None diff --git a/polytope_server/common/mongo_client_factory.py b/polytope_server/common/mongo_client_factory.py new file mode 100644 index 0000000..7e67327 --- /dev/null +++ b/polytope_server/common/mongo_client_factory.py @@ -0,0 +1,14 @@ +import typing + +import pymongo + + +def create_client( + uri: str, + username: typing.Optional[str] = None, + password: typing.Optional[str] = None, +) -> pymongo.MongoClient: + if username and password: + return pymongo.MongoClient(host=uri, journal=True, connect=False, username=username, password=password) + else: + return pymongo.MongoClient(host=uri, journal=True, connect=False) diff --git a/polytope_server/common/queue/queue.py b/polytope_server/common/queue/queue.py index 8e8ed85..9cb1e22 100644 --- a/polytope_server/common/queue/queue.py +++ b/polytope_server/common/queue/queue.py @@ -80,7 +80,7 @@ def collect_metric_info( """Collect dictionary of metrics""" -queue_dict = {"rabbitmq": "RabbitmqQueue"} +queue_dict = {"rabbitmq": "RabbitmqQueue", "sqs": "SQSQueue"} def create_queue(queue_config): diff --git a/polytope_server/common/queue/sqs_queue.py b/polytope_server/common/queue/sqs_queue.py new file mode 100644 index 0000000..dd1a392 --- /dev/null +++ b/polytope_server/common/queue/sqs_queue.py @@ -0,0 +1,89 @@ +import json +import logging +from uuid import uuid4 + +import boto3 + +from ..metric_collector import SQSQueueMetricCollector +from . import queue + + +class SQSQueue(queue.Queue): + def __init__(self, config): + queue_name = config.get("queue_name") + region = config.get("region") + self.keep_alive_interval = config.get("keep_alive_interval", 60) + self.visibility_timeout = config.get("visibility_timeout", 120) + + logging.getLogger("sqs").setLevel(logging.WARNING) + logging.getLogger("boto3").setLevel(logging.WARNING) + logging.getLogger("botocore").setLevel(logging.WARNING) + + self.client = boto3.client("sqs", region_name=region) + + self.queue_url = self.client.get_queue_url(QueueName=queue_name).get("QueueUrl") + self.check_connection() + self.queue_metric_collector = SQSQueueMetricCollector(self.queue_url, self.client) + + def enqueue(self, message): + # Messages need to have different a `MessageGroupId` so that they can be processed in parallel. + self.client.send_message( + QueueUrl=self.queue_url, + MessageBody=json.dumps(message.body), + MessageGroupId=message.body.get("id", uuid4()), + ) + + def dequeue(self): + response = self.client.receive_message( + QueueUrl=self.queue_url, + VisibilityTimeout=self.visibility_timeout, # If processing takes more seconds, message will be read twice + MaxNumberOfMessages=1, + WaitTimeSeconds=20, + ) + if "Messages" not in response: + return None + + msg, *remainder = response["Messages"] + for item in remainder: + self.client.change_message_visibility( + QueueUrl=self.queue_url, ReceiptHandle=item["ReceiptHandle"], VisibilityTimeout=0 + ) + body = msg["Body"] + receipt_handle = msg["ReceiptHandle"] + + return queue.Message(json.loads(body), context=receipt_handle) + + def ack(self, message): + self.client.delete_message(QueueUrl=self.queue_url, ReceiptHandle=message.context) + + def nack(self, message): + self.client.change_message_visibility( + QueueUrl=self.queue_url, ReceiptHandle=message.context, VisibilityTimeout=0 + ) + + def keep_alive(self): + # Implemented for compatibility, disabled because each request to SQS is billed + pass + # return self.check_connection() + + def check_connection(self): + response = self.client.get_queue_attributes(QueueUrl=self.queue_url, AttributeNames=["CreatedTimestamp"]) + # Tries to parse response + return "Attributes" in response and "CreatedTimestamp" in response["Attributes"] + + def close_connection(self): + self.client.close() + + def count(self): + response = self.client.get_queue_attributes( + QueueUrl=self.queue_url, AttributeNames=["ApproximateNumberOfMessages"] + ) + num_messages = response["Attributes"]["ApproximateNumberOfMessages"] + + return int(num_messages) + + def get_type(self): + return "sqs" + + def collect_metric_info(self): + return self.queue_metric_collector.collect().serialize() diff --git a/polytope_server/common/request_store/mongodb_request_store.py b/polytope_server/common/request_store/mongodb_request_store.py index 1734e2d..1f43561 100644 --- a/polytope_server/common/request_store/mongodb_request_store.py +++ b/polytope_server/common/request_store/mongodb_request_store.py @@ -23,7 +23,7 @@ import pymongo -from .. import metric_store +from .. import metric_store, mongo_client_factory from ..metric import MetricType, RequestStatusChange from ..metric_collector import ( MongoRequestStoreMetricCollector, @@ -35,13 +35,12 @@ class MongoRequestStore(request_store.RequestStore): def __init__(self, config=None, metric_store_config=None): - host = config.get("host", "localhost") - port = config.get("port", "27017") + uri = config.get("uri", "mongodb://localhost:27017") request_collection = config.get("collection", "requests") + username = config.get("username") + password = config.get("password") - endpoint = "{}:{}".format(host, port) - - self.mongo_client = pymongo.MongoClient(endpoint, journal=True, connect=False) + self.mongo_client = mongo_client_factory.create_client(uri, username, password) self.database = self.mongo_client.request_store self.store = self.database[request_collection] @@ -50,11 +49,11 @@ def __init__(self, config=None, metric_store_config=None): self.metric_store = metric_store.create_metric_store(metric_store_config) self.storage_metric_collector = MongoStorageMetricCollector( - endpoint, self.mongo_client, "request_store", request_collection + uri, self.mongo_client, "request_store", request_collection ) self.request_store_metric_collector = MongoRequestStoreMetricCollector() - logging.info("MongoClient configured to open at {}".format(endpoint)) + logging.info("MongoClient configured to open at {}".format(uri)) def get_type(self): return "mongodb" @@ -87,7 +86,6 @@ def get_request(self, id): return None def get_requests(self, ascending=None, descending=None, limit=None, **kwargs): - if ascending: if ascending not in Request.__slots__: raise KeyError("Request has no key {}".format(ascending)) @@ -98,7 +96,6 @@ def get_requests(self, ascending=None, descending=None, limit=None, **kwargs): query = {} for k, v in kwargs.items(): - if k not in Request.__slots__: raise KeyError("Request has no key {}".format(k)) @@ -152,7 +149,6 @@ def update_request(self, request): return res def wipe(self): - if self.metric_store: res = self.get_requests() for i in res: diff --git a/polytope_server/common/staging/s3_staging.py b/polytope_server/common/staging/s3_staging.py index 628ce5e..c056145 100644 --- a/polytope_server/common/staging/s3_staging.py +++ b/polytope_server/common/staging/s3_staging.py @@ -47,6 +47,7 @@ def __init__(self, config): access_key = config.get("access_key", "") secret_key = config.get("secret_key", "") self.bucket = config.get("bucket", "default") + secure = config.get("secure", False) == True self.url = config.get("url", None) internal_url = "{}:{}".format(self.host, self.port) secure = config.get("use_ssl", False) @@ -56,17 +57,16 @@ def __init__(self, config): secret_key=secret_key, secure=secure, ) - self.internal_url = "http://" + internal_url + self.internal_url = ("https://" if secure else "http://") + internal_url try: self.client.make_bucket(self.bucket) + self.client.set_bucket_policy(self.bucket, self.bucket_policy()) except BucketAlreadyExists: pass except BucketAlreadyOwnedByYou: pass - self.client.set_bucket_policy(self.bucket, self.bucket_policy()) - self.storage_metric_collector = S3StorageMetricCollector(endpoint, self.client, self.bucket, self.get_type()) logging.info( diff --git a/polytope_server/common/user.py b/polytope_server/common/user.py index 9f522cc..60518cb 100644 --- a/polytope_server/common/user.py +++ b/polytope_server/common/user.py @@ -68,3 +68,6 @@ def serialize(self): v = self.__getattribute__(k) result[k] = v return result + + def __str__(self): + return f"User({self.realm}:{self.username})" diff --git a/polytope_server/frontend/common/flask_decorators.py b/polytope_server/frontend/common/flask_decorators.py index 2a5280e..f94d68f 100644 --- a/polytope_server/frontend/common/flask_decorators.py +++ b/polytope_server/frontend/common/flask_decorators.py @@ -18,7 +18,7 @@ # does it submit to any jurisdiction. # -import collections +import collections.abc import json from flask import Response @@ -31,13 +31,13 @@ def RequestSucceeded(response): - if not isinstance(response, collections.Mapping): + if not isinstance(response, collections.abc.Mapping): response = {"message": response} return Response(response=json.dumps(response), status=200, mimetype="application/json") def RequestAccepted(response): - if not isinstance(response, collections.Mapping): + if not isinstance(response, collections.abc.Mapping): response = {"message": response} if response["message"] == "": response["message"] = "Request {}".format(response["status"]) diff --git a/polytope_server/frontend/flask_handler.py b/polytope_server/frontend/flask_handler.py index b6fb5e1..5fb82be 100644 --- a/polytope_server/frontend/flask_handler.py +++ b/polytope_server/frontend/flask_handler.py @@ -29,6 +29,7 @@ from flask import Flask, request from flask_swagger_ui import get_swaggerui_blueprint from werkzeug.exceptions import default_exceptions +from werkzeug.middleware.proxy_fix import ProxyFix from ..common.exceptions import BadRequest, ForbiddenRequest, HTTPException, NotFound from ..version import __version__ @@ -47,9 +48,13 @@ def create_handler( collections, identity, apikeygenerator, + proxy_support: bool, ): handler = Flask(__name__) + if proxy_support: + handler.wsgi_app = ProxyFix(handler.wsgi_app, x_for=1, x_proto=1, x_host=1) + openapi_spec = "static/openapi.yaml" spec_path = pathlib.Path(__file__).parent.absolute() / openapi_spec with spec_path.open("r+", encoding="utf8") as f: @@ -63,6 +68,7 @@ def create_handler( SWAGGER_URL, tmp.name, config={"app_name": "Polytope", "spec": spec} ) handler.register_blueprint(SWAGGERUI_BLUEPRINT, url_prefix=SWAGGER_URL) + handler.register_blueprint(SWAGGERUI_BLUEPRINT, url_prefix='/') data_transfer = DataTransfer(request_store, staging) @@ -87,13 +93,6 @@ def handle_error(error): for code, ex in default_exceptions.items(): handler.errorhandler(code)(handle_error) - @handler.route("/", methods=["GET"]) - def root(): - this_dir = os.path.dirname(os.path.abspath(__file__)) + "/" - with open(this_dir + "web/index.html") as fh: - content = fh.read() - return content - def get_auth_header(request): return request.headers.get("Authorization", "") @@ -262,7 +261,6 @@ def only_json(): return handler def run_server(self, handler, server_type, host, port): - if server_type == "flask": # flask internal server for non-production environments # should only be used for testing and debugging diff --git a/polytope_server/frontend/frontend.py b/polytope_server/frontend/frontend.py index 2c4d7e1..6481118 100644 --- a/polytope_server/frontend/frontend.py +++ b/polytope_server/frontend/frontend.py @@ -59,7 +59,6 @@ def __init__(self, config): self.port = frontend_config.get("port", "5000") def run(self): - # create instances of authentication, request_store & staging request_store = create_request_store(self.config.get("request_store"), self.config.get("metric_store")) @@ -72,7 +71,15 @@ def run(self): handler_module = importlib.import_module("polytope_server.frontend." + self.handler_type + "_handler") handler_class = getattr(handler_module, self.handler_dict[self.handler_type])() - handler = handler_class.create_handler(request_store, auth, staging, collections, identity, apikeygenerator) + handler = handler_class.create_handler( + request_store, + auth, + staging, + collections, + identity, + apikeygenerator, + self.config.get("frontend", {}).get("proxy_support", False), + ) logging.info("Starting frontend...") handler_class.run_server(handler, self.server_type, self.host, self.port) diff --git a/polytope_server/frontend/web/css/style.css b/polytope_server/frontend/web/css/style.css deleted file mode 100644 index 5629028..0000000 --- a/polytope_server/frontend/web/css/style.css +++ /dev/null @@ -1,27 +0,0 @@ -/* -Copyright 2022 European Centre for Medium-Range Weather Forecasts (ECMWF) - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -In applying this licence, ECMWF does not waive the privileges and immunities -granted to it by virtue of its status as an intergovernmental organisation nor -does it submit to any jurisdiction. -*/ - -.wide { - text-align: center; -} - -.wide h1 { - display: inline-block; -} diff --git a/polytope_server/frontend/web/index.html b/polytope_server/frontend/web/index.html deleted file mode 100644 index b578842..0000000 --- a/polytope_server/frontend/web/index.html +++ /dev/null @@ -1,38 +0,0 @@ - - - - - - - Polytope - - - -
-

Polytope

-

- This is an experimental service for EU projects, see - https://github.com/ecmwf-projects/polytope-server. -

-
- - diff --git a/polytope_server/worker/worker.py b/polytope_server/worker/worker.py index 5bfac20..cc7b8b0 100644 --- a/polytope_server/worker/worker.py +++ b/polytope_server/worker/worker.py @@ -132,7 +132,6 @@ def update_metric(self): self.metric_store.update_metric(self.metric) def run(self): - self.queue = polytope_queue.create_queue(self.config.get("queue")) self.thread_pool = ThreadPoolExecutor(1) @@ -141,7 +140,6 @@ def run(self): self.update_metric() while not time.sleep(self.poll_interval): - self.queue.keep_alive() # No active request: try to pop from queue and process request in future thread @@ -247,6 +245,7 @@ def process_request(self, request): except Exception: request.user_message += "Failed to finalize request" + logging.info(request.user_message, extra={"request_id": id}) logging.exception("Failed to finalize request", extra={"request_id": id}) raise @@ -257,6 +256,7 @@ def process_request(self, request): if datasource is None: request.user_message += "Failed to process request." + logging.info(request.user_message, extra={"request_id": id}) raise Exception("Failed to process request.") else: request.user_message += "Success" @@ -304,7 +304,7 @@ def on_request_fail(self, request, exception): logging.exception("Request failed with exception.", extra={"request_id": request.id}) self.requests_failed += 1 - def on_process_terminated(self): + def on_process_terminated(self, signumm=None, frame=None): """Called when the worker is asked to exit whilst processing a request, and we want to reschedule the request""" if self.request is not None: diff --git a/pyproject.toml b/pyproject.toml index e34796e..85c3b07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,5 @@ [tool.black] -line-length = 120 \ No newline at end of file +line-length = 120 + +[tool.isort] +profile = "black" diff --git a/requirements.txt b/requirements.txt index ab5ce1a..dda9ffd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,7 @@ flask-wtf==0.14.2 werkzeug==2.0 gunicorn==19.9.0 ecmwf-api-client==1.5.4 -pymongo==3.10.1 +pymongo==4.6.0 pymemcache==3.0.0 redis==3.4.1 markdown==3.2.1 @@ -30,4 +30,5 @@ flask-swagger-ui==3.25.0 ldap3==2.7 docker==4.2.0 python-keycloak==0.24.0 -boto3==1.17.108 \ No newline at end of file +python-jose +boto3==1.28.80 diff --git a/tests/unit/test_mongo_client_factory.py b/tests/unit/test_mongo_client_factory.py new file mode 100644 index 0000000..81ffe81 --- /dev/null +++ b/tests/unit/test_mongo_client_factory.py @@ -0,0 +1,44 @@ +import typing +from unittest import mock + +from polytope_server.common import mongo_client_factory + + +@mock.patch("polytope_server.common.mongo_client_factory.pymongo.MongoClient", autospec=True) +def test_create_without_credentials(mock_mongo: mock.Mock): + mongo_client_factory.create_client("mongodb://host:123") + + _verify(mock_mongo, "mongodb://host:123", None, None) + + +@mock.patch("polytope_server.common.mongo_client_factory.pymongo.MongoClient", autospec=True) +def test_create_without_password_credentials(mock_mongo: mock.Mock): + mongo_client_factory.create_client("mongodb+srv://host:123", username="admin") + + _verify(mock_mongo, "mongodb+srv://host:123", None, None) + + +@mock.patch("polytope_server.common.mongo_client_factory.pymongo.MongoClient", autospec=True) +def test_create_without_username_credentials(mock_mongo: mock.Mock): + mongo_client_factory.create_client("host:123", password="password") + + _verify(mock_mongo, "host:123", None, None) + + +@mock.patch("polytope_server.common.mongo_client_factory.pymongo.MongoClient", autospec=True) +def test_create_with_credentials(mock_mongo: mock.Mock): + mongo_client_factory.create_client("mongodb+srv://host", username="admin", password="est123123") + + _verify(mock_mongo, "mongodb+srv://host", "admin", "est123123") + + +def _verify( + mock_mongo: mock.Mock, endpoint: str, username: typing.Optional[str] = None, password: typing.Optional[str] = None +): + mock_mongo.assert_called_once() + args, kwargs = mock_mongo.call_args + assert args[0] == endpoint + if username: + assert kwargs["username"] == username + if password: + assert kwargs["password"] == password diff --git a/tests/unit/test_s3_staging.py b/tests/unit/test_s3_staging.py new file mode 100644 index 0000000..41407e3 --- /dev/null +++ b/tests/unit/test_s3_staging.py @@ -0,0 +1,49 @@ +from unittest import mock +from polytope_server.common.staging.s3_staging import S3Staging + + +class DummyMinioClient: + def __init__(self) -> None: + self._region = None + + def make_bucket(self, bucket, region): + return "Dummy make bucket" + + def set_bucket_policy(self, bucket, policy): + return "Dummy set bucket policy" + + +class Test: + @mock.patch("polytope_server.common.staging.s3_staging.Minio", autospec=True) + def test_s3_staging_secure_false(self, mock_minio: mock.Mock): + mock_minio.return_value = DummyMinioClient() + s3Staging = S3Staging(config={"secure": False}) + + self.verify_secure_flag_and_internal_url(mock_minio, s3Staging, False) + + @mock.patch("polytope_server.common.staging.s3_staging.Minio", autospec=True) + def test_s3_staging_secure_any_value_false(self, mock_minio: mock.Mock): + mock_minio.return_value = DummyMinioClient() + s3Staging = S3Staging(config={"secure": "sdafsdfs"}) + + self.verify_secure_flag_and_internal_url(mock_minio, s3Staging, False) + + @mock.patch("polytope_server.common.staging.s3_staging.Minio", autospec=True) + def test_s3_staging_secure_default(self, mock_minio: mock.Mock): + mock_minio.return_value = DummyMinioClient() + s3Staging = S3Staging(config={}) + + self.verify_secure_flag_and_internal_url(mock_minio, s3Staging, False) + + @mock.patch("polytope_server.common.staging.s3_staging.Minio", autospec=True) + def test_s3_staging_secure_true(self, mock_minio: mock.Mock): + mock_minio.return_value = DummyMinioClient() + s3Staging = S3Staging(config={"secure": True}) + + self.verify_secure_flag_and_internal_url(mock_minio, s3Staging, True) + + def verify_secure_flag_and_internal_url(self, mock_minio: mock.Mock, s3Staging: S3Staging, secure: bool): + mock_minio.assert_called_once() + _, kwargs = mock_minio.call_args + assert kwargs["secure"] == secure + assert s3Staging.get_internal_url("test").startswith("https" if secure else "http")