Skip to content

Commit

Permalink
Merge branch 'develop' into feature/s3-boto
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshawkes authored Apr 29, 2024
2 parents 2c30eb3 + 7e41214 commit cc868cf
Show file tree
Hide file tree
Showing 35 changed files with 622 additions and 204 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ htmlcov
validated.yaml
merged.yaml
polytope_server.egg-info
**/build
**/build
.venv
69 changes: 31 additions & 38 deletions polytope_server/broker/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,10 @@ def __init__(self, config):

self.collections = collection.create_collections(config.get("collections"))

self.user_limit = self.broker_config.get("user_limit", None)

def run(self):

logging.info("Starting broker...")
logging.info("Maximum Queue Size: {}".format(self.max_queue_size))
logging.info("User Request Limit: {}".format(self.user_limit))

while not time.sleep(self.scheduling_interval):
self.check_requests()
Expand Down Expand Up @@ -87,45 +84,41 @@ def check_requests(self):

def check_limits(self, active_requests, request):

logging.debug("Checking limits for request {}".format(request.id))

# User limits
if self.user_limit is not None:
user_active_requests = sum(qr.user == request.user for qr in active_requests)
if user_active_requests >= self.user_limit:
logging.debug("User has {} of {} active requests".format(user_active_requests, self.user_limit))
return False

# Collection limits
collection_total_limit = self.collections[request.collection].limits.get("total", None)
if collection_total_limit is not None:
collection_active_requests = sum(qr.collection == request.collection for qr in active_requests)
if collection_active_requests >= collection_total_limit:
logging.debug(
"Collection has {} of {} total active requests".format(
collection_active_requests, collection_total_limit
)
)
return False

# Collection-user limits
collection_user_limit = self.collections[request.collection].limits.get("per-user", None)
if collection_user_limit is not None:
collection_user_active_requests = sum(
(qr.collection == request.collection and qr.user == request.user) for qr in active_requests
)
if collection_user_active_requests >= collection_user_limit:
logging.debug(
"User has {} of {} active requests in collection {}".format(
collection_user_active_requests,
collection_user_limit,
request.collection,
)
)
logging.debug(f"Checking limits for request {request.id}")

# Get collection limits and calculate active requests
collection = self.collections[request.collection]
collection_limits = collection.limits
collection_total_limit = collection_limits.get("total")
collection_active_requests = sum(qr.collection == request.collection for qr in active_requests)
logging.debug(f"Collection {request.collection} has {collection_active_requests} active requests")

# Check collection total limit
if collection_total_limit is not None and collection_active_requests >= collection_total_limit:
logging.debug(f"Collection has {collection_active_requests} of {collection_total_limit} total active requests")
return False

# Determine the effective limit based on role or per-user setting
role_limits = collection_limits.get("per-role", {}).get(request.user.realm, {})
limit = max((role_limits.get(role, 0) for role in request.user.roles), default=0)
if limit == 0: # Use collection per-user limit if no role-specific limit
limit = collection_limits.get("per-user", 0)

# Check if user exceeds the effective limit
if limit > 0:
user_active_requests = sum(qr.collection == request.collection and qr.user == request.user for qr in active_requests)
if user_active_requests >= limit:
logging.debug(f"User {request.user} has {user_active_requests} of {limit} active requests in collection {request.collection}")
return False
else:
logging.debug(f"User {request.user} has {user_active_requests} of {limit} active requests in collection {request.collection}")
return True

# Allow if no limits are exceeded
logging.debug(f"No limit for user {request.user} in collection {request.collection}")
return True


def enqueue(self, request):

logging.info("Queuing request", extra={"request_id": request.id})
Expand Down
2 changes: 1 addition & 1 deletion polytope_server/common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def authenticate(self, auth_header) -> User:
www_authenticate=self.auth_info,
)

user.roles = ["default"]
user.roles.append("default")

# Visit all authorizers to append additional roles and attributes
for authorizer in self.authorizers:
Expand Down
3 changes: 2 additions & 1 deletion polytope_server/common/authentication/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def name(self) -> str:
"plain": "PlainAuthentication",
"keycloak": "KeycloakAuthentication",
"federation": "FederationAuthentication",
"jwt": "JWTAuthentication",
"jwt" : "JWTAuthentication",
"openid_offline_access" : "OpenIDOfflineAuthentication",
}


Expand Down
7 changes: 7 additions & 0 deletions polytope_server/common/authentication/jwt_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, name, realm, config):
self.config = config

self.certs_url = config["cert_url"]
self.client_id = config["client_id"]

super().__init__(name, realm, config)

Expand All @@ -57,8 +58,14 @@ def authenticate(self, credentials: str) -> User:
token=credentials, algorithms=jwt.get_unverified_header(credentials).get("alg"), key=certs
)

logging.info("Decoded JWT: {}".format(decoded_token))


user = User(decoded_token["sub"], self.realm())

roles = decoded_token.get("resource_access", {}).get(self.client_id, {}).get("roles", [])
user.roles.extend(roles)

logging.info("Found user {} from decoded JWT".format(user))
except Exception as e:
logging.info("Failed to authenticate user from JWT")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

from datetime import datetime

import pymongo

from .. import mongo_client_factory
from ..auth import User
from ..exceptions import ForbiddenRequest
from ..metric_collector import MongoStorageMetricCollector
Expand All @@ -39,19 +38,18 @@ class ApiKeyMongoAuthentication(authentication.Authentication):
"""

def __init__(self, name, realm, config):

self.config = config
host = config.get("host", "localhost")
port = config.get("port", "27017")
uri = config.get("uri", "mongodb://localhost:27017")
collection = config.get("collection", "keys")
username = config.get("username")
password = config.get("password")

endpoint = "{}:{}".format(host, port)
self.mongo_client = pymongo.MongoClient(endpoint, journal=True, connect=False)
self.mongo_client = mongo_client_factory.create_client(uri, username, password)
self.database = self.mongo_client.keys
self.keys = self.database[collection]
assert realm == "polytope"

self.storage_metric_collector = MongoStorageMetricCollector(endpoint, self.mongo_client, "keys", collection)
self.storage_metric_collector = MongoStorageMetricCollector(uri, self.mongo_client, "keys", collection)

super().__init__(name, realm, config)

Expand All @@ -62,7 +60,6 @@ def authentication_info(self):
return "Authenticate with Polytope API Key from ../auth/keys"

def authenticate(self, credentials: str) -> User:

# credentials should be of the form '<ApiKey>'
res = self.keys.find_one({"key.key": credentials})
if res is None:
Expand Down
15 changes: 6 additions & 9 deletions polytope_server/common/authentication/mongodb_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
import binascii
import hashlib

import pymongo

from .. import mongo_client_factory
from ..auth import User
from ..exceptions import ForbiddenRequest
from ..metric_collector import MongoStorageMetricCollector
Expand All @@ -32,19 +31,18 @@

class MongoAuthentication(authentication.Authentication):
def __init__(self, name, realm, config):

self.config = config
host = config.get("host", "localhost")
port = config.get("port", "27017")
uri = config.get("uri", "mongodb://localhost:27017")
collection = config.get("collection", "users")
username = config.get("username")
password = config.get("password")

endpoint = "{}:{}".format(host, port)
self.mongo_client = pymongo.MongoClient(endpoint, journal=True, connect=False)
self.mongo_client = mongo_client_factory.create_client(uri, username, password)
self.database = self.mongo_client.authentication
self.users = self.database[collection]

self.storage_metric_collector = MongoStorageMetricCollector(
endpoint, self.mongo_client, "authentication", collection
uri, self.mongo_client, "authentication", collection
)

super().__init__(name, realm, config)
Expand All @@ -59,7 +57,6 @@ def authentication_info(self):
return "Authenticate with username and password"

def authenticate(self, credentials: str) -> User:

# credentials should be of the form 'base64(<username>:<API_key>)'
try:
decoded = base64.b64decode(credentials).decode("utf-8")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#
# Copyright 2022 European Centre for Medium-Range Weather Forecasts (ECMWF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation nor
# does it submit to any jurisdiction.
#

import logging
import os
import requests
from jose import jwt

from ..auth import User
from ..caching import cache
from . import authentication
from ..exceptions import ForbiddenRequest


class OpenIDOfflineAuthentication(authentication.Authentication):
def __init__(self, name, realm, config):
self.config = config

self.certs_url = config["cert_url"]
self.public_client_id = config["public_client_id"]
self.private_client_id = config["private_client_id"]
self.private_client_secret = config["private_client_secret"]
self.iam_url = config["iam_url"]
self.iam_realm = config["iam_realm"]


super().__init__(name, realm, config)

def authentication_type(self):
return "Bearer"

def authentication_info(self):
return "Authenticate with OpenID offline_access token"

@cache(lifetime=120)
def get_certs(self):
return requests.get(self.certs_url).json()

@cache(lifetime=120)
def check_offline_access_token(self, token: str) -> bool:
"""
We check if the token is recognised by the IAM service, and we cache this result.
We cannot simply try to get the access token because we would spam the IAM server with invalid tokens, and the
failure at that point would not be cached.
"""
keycloak_token_introspection = self.iam_url + "/realms/" + self.iam_realm + "/protocol/openid-connect/token/introspect"
introspection_data = {
"token": token
}
b_auth = requests.auth.HTTPBasicAuth(self.private_client_id, self.private_client_secret)
resp = requests.post(url=keycloak_token_introspection, data=introspection_data, auth=b_auth).json()
if resp["active"] and resp["token_type"] == "Offline":
return True
else:
return False

@cache(lifetime=120)
def authenticate(self, credentials: str) -> User:

try:

# Check if this is a valid offline_access token
if not self.check_offline_access_token(credentials):
raise ForbiddenRequest("Not a valid offline_access token")

# Generate an access token from the offline_access token (like a refresh token)
refresh_data = {
"client_id": self.public_client_id,
"grant_type": "refresh_token",
"refresh_token": credentials
}
keycloak_token_endpoint = self.iam_url + "/realms/" + self.iam_realm + "/protocol/openid-connect/token"
resp = requests.post(url=keycloak_token_endpoint, data=refresh_data)
token = resp.json()['access_token']

certs = self.get_certs()
decoded_token = jwt.decode(token=token,
algorithms=jwt.get_unverified_header(token).get('alg'),
key=certs
)

logging.info("Decoded JWT: {}".format(decoded_token))

user = User(decoded_token["sub"], self.realm())

roles = decoded_token.get("resource_access", {}).get(self.public_client_id, {}).get("roles", [])
user.roles.extend(roles)

logging.info("Found user {} from openid offline_access token".format(user))

except Exception as e:
logging.info("Failed to authenticate user from openid offline_access token")
logging.info(e)
raise ForbiddenRequest("Could not authenticate user from openid offline_access token")
return user


def collect_metric_info(self):
return {}
14 changes: 6 additions & 8 deletions polytope_server/common/authorization/mongodb_authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
# does it submit to any jurisdiction.
#

import pymongo

from .. import mongo_client_factory
from ..auth import User
from ..metric_collector import MongoStorageMetricCollector
from . import authorization
Expand All @@ -29,23 +28,22 @@ class MongoDBAuthorization(authorization.Authorization):
def __init__(self, name, realm, config):
self.config = config
assert self.config["type"] == "mongodb"
self.host = config.get("host", "localhost")
self.port = config.get("port", "27017")
self.uri = config.get("uri", "mongodb://localhost:27017")
self.collection = config.get("collection", "users")
username = config.get("username")
password = config.get("password")

endpoint = "{}:{}".format(self.host, self.port)
self.mongo_client = pymongo.MongoClient(endpoint, journal=True, connect=False)
self.mongo_client = mongo_client_factory.create_client(self.uri, username, password)
self.database = self.mongo_client.authentication
self.users = self.database[self.collection]

self.storage_metric_collector = MongoStorageMetricCollector(
endpoint, self.mongo_client, "authentication", self.collection
self.uri, self.mongo_client, "authentication", self.collection
)

super().__init__(name, realm, config)

def get_roles(self, user: User) -> list:

if user.realm != self.realm():
raise ValueError(
"Trying to authorize a user in the wrong realm, expected {}, got {}".format(self.realm(), user.realm)
Expand Down
Loading

0 comments on commit cc868cf

Please sign in to comment.