diff --git a/toolbox/src/access/db/__init__.py b/toolbox/src/access/db/__init__.py index c0f8cf84..366ead8b 100644 --- a/toolbox/src/access/db/__init__.py +++ b/toolbox/src/access/db/__init__.py @@ -62,11 +62,3 @@ class Transformations(Base): SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base.metadata.create_all(bind=engine) - - -def populate(): - """ - Populate the database with basic information (purposes, exceptions, transformations, ...) provided by the config file - """ - - pass diff --git a/toolbox/src/access/pap.py b/toolbox/src/access/pap.py index d6ce8c35..f47be277 100644 --- a/toolbox/src/access/pap.py +++ b/toolbox/src/access/pap.py @@ -7,7 +7,7 @@ logger = logging.getLogger(__name__) -pap_router = APIRouter() +router = APIRouter() class PurposeIn(BaseModel): @@ -26,7 +26,7 @@ def emit(self, record): self.log_list.append(record.msg) -@pap_router.post("/purpose/{purpose_name}", status_code=200) +@router.post("/purpose/{purpose_name}", status_code=200) async def add_purpose(purpose: PurposeIn): purpose_dict = purpose.dict() with session_scope() as session: @@ -40,14 +40,14 @@ async def add_purpose(purpose: PurposeIn): return {**purpose_dict, "id": 0} -@pap_router.get("/purposes") +@router.get("/purposes") def list_purposes(): with session_scope() as session: purposes = session.query(Purposes).all() return purposes -@pap_router.get("/policy") +@router.get("/policy") def get_policy(): log_list = [] handler = ListHandler(log_list) @@ -57,7 +57,7 @@ def get_policy(): return log_list -@pap_router.put("/exception") +@router.put("/exception") def add_exception(item): items_dict = item.dict() # Request body purpose_name = items_dict["purpose"] diff --git a/toolbox/src/auth/__init__.py b/toolbox/src/auth/__init__.py new file mode 100644 index 00000000..9be5ecbc --- /dev/null +++ b/toolbox/src/auth/__init__.py @@ -0,0 +1,24 @@ +from fastapi import Request, HTTPException +from jose import jwt, JWTError +import requests +import os + + +project_nb = os.getenv("GCP_PROJECT_NB") +project_name = os.getenv("GCP_PROJECT_NAME") +AUDIENCE = f"/projects/{project_nb}/apps/{project_name}" +GOOGLE_CERTS_URL = "https://www.googleapis.com/oauth2/v1/certs" + + +async def iap_jwt_middleware(request: Request, call_next): + token = request.headers.get('x-goog-iap-jwt-assertion') + if token: + try: + certs = requests.get(GOOGLE_CERTS_URL).json() + payload = jwt.decode(token, certs, algorithms=['RS256'], audience=AUDIENCE) + request.state.user = payload.get("email") + except JWTError as e: + raise HTTPException(status_code=401, detail="Invalid token") + response = await call_next(request) + return response + diff --git a/toolbox/src/cloud/gcp/__init__.py b/toolbox/src/cloud/gcp/__init__.py index 7df3eab0..47c71418 100644 --- a/toolbox/src/cloud/gcp/__init__.py +++ b/toolbox/src/cloud/gcp/__init__.py @@ -6,6 +6,6 @@ from fastapi import APIRouter -gcp_router = APIRouter() +router = APIRouter() import cloud.gcp.bq import cloud.gcp.storage diff --git a/toolbox/src/cloud/gcp/bq.py b/toolbox/src/cloud/gcp/bq.py index 52b37ae8..fec611ab 100644 --- a/toolbox/src/cloud/gcp/bq.py +++ b/toolbox/src/cloud/gcp/bq.py @@ -4,21 +4,21 @@ from google.cloud.exceptions import GoogleCloudError import logging -from . import gcp_router, credentials +from . import router, credentials logger = logging.getLogger(__name__) client = bigquery.Client(credentials=credentials, project=os.getenv("GCP_PROJECT_NAME")) -@gcp_router.get("/bq/{query}") +@router.get("/bq/{query}") def query_bq(query: str): """Executes a client query in BigQuery and returns the result.""" query_job = client.query(query) return list(query_job) -@gcp_router.post("/bq") +@router.post("/bq") def create_bq_dataset(dataset_id: str, location: str) -> str: """ Create a new dataset in BigQuery based on the query params. @@ -32,7 +32,7 @@ def create_bq_dataset(dataset_id: str, location: str) -> str: return {"result": "success"} -@gcp_router.delete("/bq") +@router.delete("/bq") def delete_bq_dataset(dataset_id: str) -> str: """Deletes a dataset in BigQuery.""" client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True) diff --git a/toolbox/src/cloud/gcp/storage.py b/toolbox/src/cloud/gcp/storage.py index c114c7d3..9fc191df 100644 --- a/toolbox/src/cloud/gcp/storage.py +++ b/toolbox/src/cloud/gcp/storage.py @@ -13,7 +13,7 @@ import transformations from access.pep import tag_content, control_access from utils import calculate_image_hash -from . import gcp_router, credentials +from . import router, credentials logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ bucket_name = "company_directory" # TODO: convert to a config variable -@gcp_router.post("/blob") +@router.post("/blob") async def upload_object(request: Request, file: UploadFile = File(...)): """ Uploads a file to a Cloud Storage bucket. @@ -54,7 +54,7 @@ async def upload_object(request: Request, file: UploadFile = File(...)): return {"result": "success"} -@gcp_router.get("/blob/{source_blob_name}", dependencies=[Depends(control_access)]) +@router.get("/blob/{source_blob_name}", dependencies=[Depends(control_access)]) def download_object(source_blob_name: str): """Return a blob from a bucket in Google Cloud Storage.""" # TODO: purpose must match, this is done in the PEP @@ -82,7 +82,7 @@ def create_bucket(bucket_name: str): # TODO: would be interesting to limit access to this endpoint -@gcp_router.get("/bucket") +@router.get("/bucket") def list_bucket(): """ Returns all the blobs in a bucket Google Cloud Storage. diff --git a/toolbox/src/main.py b/toolbox/src/main.py index 9712156b..41346750 100644 --- a/toolbox/src/main.py +++ b/toolbox/src/main.py @@ -3,12 +3,11 @@ load_dotenv() import logging -from fastapi import FastAPI, Request, Depends -# Middleware -from access.pep import control_access +from fastapi import FastAPI, Request # Router -from access.pap import pap_router -from cloud.gcp import gcp_router +from access.pap import router as pap_router +from cloud.gcp import router as cloud_router +from auth import iap_jwt_middleware # Configure app-wide logging # N.B.: logs are automatically handle by the built-in interface of the cloud provider @@ -20,10 +19,9 @@ # Our main process app = FastAPI(debug=True) - +app.add_middleware(iap_jwt_middleware) app.include_router(pap_router, prefix="/api/v1/pap") -app.include_router(gcp_router, prefix="/api/v1/gcp") -# app.include_router(gcp_router, prefix="/api/v1/gcp", dependencies=[Depends(control_access())]) +app.include_router(cloud_router, prefix="/api/v1/gcp") @app.on_event("startup")