From 22af360dcea8e867d4a842d749185329b15e003e Mon Sep 17 00:00:00 2001 From: Ethan Sutin Date: Sat, 3 Feb 2024 12:17:07 -0800 Subject: [PATCH 1/3] Basic single user token authorization --- clients/ios/Shared/AppConstants.swift | 3 +- .../Extensions/URLRequest+Extensions.swift | 14 +++++++ clients/ios/Shared/Files/FileUploadTask.swift | 2 + .../Services/NetworkManager.swift | 2 + .../ios/UntitledAI.xcodeproj/project.pbxproj | 6 +++ clients/ios/UntitledAI/Services/API.swift | 7 +++- .../UntitledAI/Services/SocketManager.swift | 3 +- untitledai/core/config.py | 1 + untitledai/server/app_state.py | 38 ++++++++++++++----- untitledai/server/capture_socket.py | 8 +++- untitledai/server/main.py | 4 +- untitledai/server/routes/capture.py | 10 ++--- untitledai/server/routes/conversations.py | 8 ++-- 13 files changed, 82 insertions(+), 24 deletions(-) create mode 100644 clients/ios/Shared/Extensions/URLRequest+Extensions.swift diff --git a/clients/ios/Shared/AppConstants.swift b/clients/ios/Shared/AppConstants.swift index a815b5ee..722efa1f 100644 --- a/clients/ios/Shared/AppConstants.swift +++ b/clients/ios/Shared/AppConstants.swift @@ -7,7 +7,8 @@ import Foundation struct AppConstants { - static let apiBaseURL = "" + static let apiBaseURL = "https://3093adab5e6f.ngrok.app" + static let clientToken = "" static let bleServiceUUID = "03d5d5c4-a86c-11ee-9d89-8f2089a49e7e" static let bleAudioCharacteristicUUID = "b189a505-a86c-11ee-a5fb-8f2089a49e7e" diff --git a/clients/ios/Shared/Extensions/URLRequest+Extensions.swift b/clients/ios/Shared/Extensions/URLRequest+Extensions.swift new file mode 100644 index 00000000..3427d3cd --- /dev/null +++ b/clients/ios/Shared/Extensions/URLRequest+Extensions.swift @@ -0,0 +1,14 @@ +// +// URLRequest+Extensions.swift +// UntitledAI +// +// Created by ethan on 2/3/24. +// + +import Foundation + +extension URLRequest { + mutating func addCommonHeaders() { + self.addValue("Bearer \(AppConstants.clientToken)", forHTTPHeaderField: "Authorization") + } +} diff --git a/clients/ios/Shared/Files/FileUploadTask.swift b/clients/ios/Shared/Files/FileUploadTask.swift index 52eb7aa6..2125871b 100644 --- a/clients/ios/Shared/Files/FileUploadTask.swift +++ b/clients/ios/Shared/Files/FileUploadTask.swift @@ -205,6 +205,7 @@ fileprivate func uploadFile(_ url: URL, contentType: String) async -> Bool { // Request type let url = URL(string: "\(AppConstants.apiBaseURL)/capture/upload_chunk")! var request = URLRequest(url: url) + request.addCommonHeaders() request.httpMethod = "POST" request.setValue("multipart/form-data;boundary=\(form.boundary)", forHTTPHeaderField: "Content-Type") @@ -244,6 +245,7 @@ fileprivate func processCapture(_ captureUUID: String) async -> Bool { // Request type let url = URL(string: "\(AppConstants.apiBaseURL)/capture/process_capture")! var request = URLRequest(url: url) + request.addCommonHeaders() request.httpMethod = "POST" request.setValue("multipart/form-data;boundary=\(form.boundary)", forHTTPHeaderField: "Content-Type") diff --git a/clients/ios/UntitledAI Watch App/Services/NetworkManager.swift b/clients/ios/UntitledAI Watch App/Services/NetworkManager.swift index b9c5a050..f7a49c2c 100644 --- a/clients/ios/UntitledAI Watch App/Services/NetworkManager.swift +++ b/clients/ios/UntitledAI Watch App/Services/NetworkManager.swift @@ -32,6 +32,7 @@ class NetworkManager : NSObject, URLSessionDataDelegate { fatalError("Invalid URL") } var request = URLRequest(url: url) + request.addCommonHeaders() request.httpMethod = "POST" request.setValue("application/octet-stream", forHTTPHeaderField: "Content-Type") let task = self.urlSession.uploadTask(withStreamedRequest: request) @@ -64,6 +65,7 @@ class NetworkManager : NSObject, URLSessionDataDelegate { // Signal end var request = URLRequest(url: url) + request.addCommonHeaders() request.httpMethod = "POST" let config = URLSessionConfiguration.default let session = URLSession(configuration: config) diff --git a/clients/ios/UntitledAI.xcodeproj/project.pbxproj b/clients/ios/UntitledAI.xcodeproj/project.pbxproj index c56134f1..b863db0f 100644 --- a/clients/ios/UntitledAI.xcodeproj/project.pbxproj +++ b/clients/ios/UntitledAI.xcodeproj/project.pbxproj @@ -24,6 +24,8 @@ AA84B1C92B604FB000FD654C /* ConversationsViewModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = AA84B1C82B604FB000FD654C /* ConversationsViewModel.swift */; }; AA84B1CC2B606AED00FD654C /* ConversationDetailView.swift in Sources */ = {isa = PBXBuildFile; fileRef = AA84B1CB2B606AED00FD654C /* ConversationDetailView.swift */; }; AA84B1CE2B606E4400FD654C /* ConversationsView.swift in Sources */ = {isa = PBXBuildFile; fileRef = AA84B1CD2B606E4400FD654C /* ConversationsView.swift */; }; + AAA6006E2B6ED40C004200FF /* URLRequest+Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = AAA6006D2B6ED40C004200FF /* URLRequest+Extensions.swift */; }; + AAA6006F2B6ED40C004200FF /* URLRequest+Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = AAA6006D2B6ED40C004200FF /* URLRequest+Extensions.swift */; }; AAB582132B530FD400CB72B8 /* UntitledAIApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = AAB582122B530FD400CB72B8 /* UntitledAIApp.swift */; }; AAB582152B530FD400CB72B8 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = AAB582142B530FD400CB72B8 /* ContentView.swift */; }; AAB582172B530FD600CB72B8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = AAB582162B530FD600CB72B8 /* Assets.xcassets */; }; @@ -152,6 +154,7 @@ AA84B1C82B604FB000FD654C /* ConversationsViewModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConversationsViewModel.swift; sourceTree = ""; }; AA84B1CB2B606AED00FD654C /* ConversationDetailView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConversationDetailView.swift; sourceTree = ""; }; AA84B1CD2B606E4400FD654C /* ConversationsView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConversationsView.swift; sourceTree = ""; }; + AAA6006D2B6ED40C004200FF /* URLRequest+Extensions.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "URLRequest+Extensions.swift"; sourceTree = ""; }; AAB5820F2B530FD400CB72B8 /* UntitledAI.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = UntitledAI.app; sourceTree = BUILT_PRODUCTS_DIR; }; AAB582122B530FD400CB72B8 /* UntitledAIApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = UntitledAIApp.swift; sourceTree = ""; }; AAB582142B530FD400CB72B8 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; @@ -484,6 +487,7 @@ children = ( CCD7A5C42B6217830079D9D4 /* URL+Extensions.swift */, CCD7A5C72B6219AD0079D9D4 /* UUID+Extensions.swift */, + AAA6006D2B6ED40C004200FF /* URLRequest+Extensions.swift */, ); path = Extensions; sourceTree = ""; @@ -779,6 +783,7 @@ AAB7147F2B64162800412DB7 /* Capture.swift in Sources */, AAB582132B530FD400CB72B8 /* UntitledAIApp.swift in Sources */, AA5DC9CB2B629E770017376F /* LocationManager.swift in Sources */, + AAA6006E2B6ED40C004200FF /* URLRequest+Extensions.swift in Sources */, AAB714832B65AC1A00412DB7 /* WatchConnectivityManager.swift in Sources */, CCD7A5C22B62173E0079D9D4 /* AudioFileWriter.swift in Sources */, ); @@ -805,6 +810,7 @@ isa = PBXSourcesBuildPhase; buildActionMask = 2147483647; files = ( + AAA6006F2B6ED40C004200FF /* URLRequest+Extensions.swift in Sources */, CCD7A5C82B6219AD0079D9D4 /* UUID+Extensions.swift in Sources */, CCD7A5BD2B6216060079D9D4 /* MultipartForm.swift in Sources */, AAB5826F2B53108D00CB72B8 /* NetworkManager.swift in Sources */, diff --git a/clients/ios/UntitledAI/Services/API.swift b/clients/ios/UntitledAI/Services/API.swift index a6343d35..9fd74ed6 100644 --- a/clients/ios/UntitledAI/Services/API.swift +++ b/clients/ios/UntitledAI/Services/API.swift @@ -13,8 +13,9 @@ class API { func fetchConversations(completionHandler: @escaping (ConversationsResponse) -> Void) { guard let url = URL(string: "\(AppConstants.apiBaseURL)/conversations/") else { return } - - let task = URLSession.shared.dataTask(with: url) { (data, response, error) in + var request = URLRequest(url: url) + request.addCommonHeaders() + let task = URLSession.shared.dataTask(with: request) { (data, response, error) in if let error = error { print("Error: \(error)") } else if let data = data { @@ -33,6 +34,7 @@ class API { func deleteConversation(_ id: Int, completion: @escaping (Bool) -> Void) { let url = URL(string: "\(AppConstants.apiBaseURL)/conversations/\(id)")! var request = URLRequest(url: url) + request.addCommonHeaders() request.httpMethod = "DELETE" URLSession.shared.dataTask(with: request) { data, response, error in if let error = error { @@ -55,6 +57,7 @@ class API { } var request = URLRequest(url: url) + request.addCommonHeaders() request.httpMethod = "POST" request.addValue("application/json", forHTTPHeaderField: "Content-Type") diff --git a/clients/ios/UntitledAI/Services/SocketManager.swift b/clients/ios/UntitledAI/Services/SocketManager.swift index bf727bfa..8bc302b6 100644 --- a/clients/ios/UntitledAI/Services/SocketManager.swift +++ b/clients/ios/UntitledAI/Services/SocketManager.swift @@ -22,7 +22,8 @@ class SocketManager: ObservableObject { .reconnects(true), .reconnectAttempts(-1), .reconnectWait(1), - .reconnectWaitMax(5) + .reconnectWaitMax(5), + .extraHeaders(["Authorization": "Bearer \(AppConstants.clientToken)"]) ]) socket = socketManager.defaultSocket diff --git a/untitledai/core/config.py b/untitledai/core/config.py index c578a1dd..3f2c132c 100644 --- a/untitledai/core/config.py +++ b/untitledai/core/config.py @@ -23,6 +23,7 @@ class CapturesConfiguration(BaseModel): class UserConfiguration(BaseModel): name: str + client_token: str voice_sample_filepath: Optional[str] = None class DeepgramConfiguration(BaseModel): diff --git a/untitledai/server/app_state.py b/untitledai/server/app_state.py index 9315f65c..afd04ecf 100644 --- a/untitledai/server/app_state.py +++ b/untitledai/server/app_state.py @@ -2,9 +2,8 @@ from dataclasses import dataclass, field import os from typing import Dict - -from fastapi import FastAPI, Request - +from fastapi import FastAPI, HTTPException, Request, Depends, Header +from typing import Optional from ..core.config import Configuration from ..services import ConversationService, LLMService, NotificationService from .streaming_capture_handler import StreamingCaptureHandler @@ -38,14 +37,35 @@ def get(from_obj: FastAPI | Request) -> AppState: return from_obj.app.state._app_state else: raise TypeError("`from_obj` must be of type `FastAPI` or `Request`") - - @staticmethod - def get_from_request(request: Request) -> AppState: - return request.app.state._app_state - + @staticmethod def get_db(request: Request): app_state: AppState = AppState.get(request) return next(app_state.database.get_db()) - \ No newline at end of file + @staticmethod + async def _parse_and_verify_token(authorization: str, expected_token: str): + if not authorization: + raise HTTPException(status_code=401, detail="Authorization header missing") + + parts = authorization.split() + if len(parts) != 2 or parts[0].lower() != 'bearer': + raise HTTPException(status_code=401, detail="Invalid token type") + + token = parts[1] + if token != expected_token: + raise HTTPException(status_code=403, detail="Invalid or expired token") + + @staticmethod + async def authenticate_request(request: Request, authorization: Optional[str] = Header(None)): + app_state = AppState.get(request) + await AppState._parse_and_verify_token(authorization, app_state.config.user.client_token) + return app_state + + @staticmethod + async def authenticate_socket(environ: dict): + headers = {k.decode('utf-8').lower(): v.decode('utf-8') for k, v in environ.get('asgi.scope', {}).get('headers', [])} + authorization = headers.get('authorization') + app_state = AppState.get(environ['asgi.scope']['app']) + await AppState._parse_and_verify_token(authorization, app_state.config.user.client_token) + return app_state \ No newline at end of file diff --git a/untitledai/server/capture_socket.py b/untitledai/server/capture_socket.py index 93ce6621..92b9f375 100644 --- a/untitledai/server/capture_socket.py +++ b/untitledai/server/capture_socket.py @@ -39,8 +39,14 @@ def __init__(self, app_state): def mount_to(self, app: FastAPI, at_path: str): app.mount(path=at_path, app=self._app) - async def on_connect(self, path, sid, *args): + async def on_connect(self, path, sid, environ): logger.info(f'Connected: {sid}') + try: + await self._app_state.authenticate_socket(environ) + except ValueError as e: + logger.error(f"Authentication failed for {sid}: {e}") + await self._sio.disconnect(sid) + return False async def on_disconnect(self, path, sid, *args): logger.info(f'Disconnected: {sid}') diff --git a/untitledai/server/main.py b/untitledai/server/main.py index 2fcc4083..15160b66 100644 --- a/untitledai/server/main.py +++ b/untitledai/server/main.py @@ -24,6 +24,7 @@ import logging import asyncio from colorama import init, Fore, Style, Back +from fastapi import Depends logger = logging.getLogger(__name__) @@ -79,7 +80,6 @@ def create_server_app(config: Configuration) -> FastAPI: transcription_service = AsyncTranscriptionServiceFactory.get_service(config) notification_service = NotificationService(config.notification) conversation_service = ConversationService(config, database, transcription_service, notification_service) - # Create server app app = FastAPI() app.state._app_state = AppState( @@ -110,7 +110,7 @@ async def shutdown_event(): # Base routing @app.get("/") - def read_root(): + async def read_root(app_state: AppState = Depends(AppState.authenticate_request)): return "UntitledAI is running!" return app diff --git a/untitledai/server/routes/capture.py b/untitledai/server/routes/capture.py index 6cae73b8..64435439 100644 --- a/untitledai/server/routes/capture.py +++ b/untitledai/server/routes/capture.py @@ -39,7 +39,7 @@ def find_audio_filepath(audio_directory: str, capture_uuid: str) -> str | None: supported_upload_file_extensions = set([ "pcm", "wav", "aac", "m4a" ]) @router.post("/capture/streaming_post/{capture_uuid}") -async def streaming_post(request: Request, capture_uuid: str, device_type: str, app_state = Depends(AppState.get_from_request)): +async def streaming_post(request: Request, capture_uuid: str, device_type: str, app_state: AppState = Depends(AppState.authenticate_request)): logger.info('Client connected') try: if capture_uuid not in app_state.capture_handlers: @@ -62,7 +62,7 @@ async def streaming_post(request: Request, capture_uuid: str, device_type: str, @router.post("/capture/streaming_post/{capture_uuid}/complete") -async def complete_audio(request: Request, background_tasks: BackgroundTasks, capture_uuid: str, app_state = Depends(AppState.get_from_request)): +async def complete_audio(request: Request, background_tasks: BackgroundTasks, capture_uuid: str, app_state: AppState = Depends(AppState.authenticate_request)): logger.info(f"Completing audio capture for {capture_uuid}") if capture_uuid not in app_state.capture_handlers: logger.error(f"Capture session not found: {capture_uuid}") @@ -78,7 +78,7 @@ async def upload_chunk(request: Request, capture_uuid: Annotated[str, Form()], timestamp: Annotated[str, Form()], device_type: Annotated[str, Form()], - app_state = Depends(AppState.get_from_request)): + app_state: AppState = Depends(AppState.authenticate_request)): try: # Validate file format file_extension = os.path.splitext(file.filename)[1].lstrip(".") @@ -129,7 +129,7 @@ async def upload_chunk(request: Request, @router.post("/capture/process_capture") -async def process_capture(request: Request, capture_uuid: Annotated[str, Form()], app_state = Depends(AppState.get_from_request)): +async def process_capture(request: Request, capture_uuid: Annotated[str, Form()], app_state: AppState = Depends(AppState.authenticate_request)): try: # Get capture file filepath = find_audio_filepath(audio_directory=app_state.config.captures.capture_dir, capture_uuid=capture_uuid) @@ -161,7 +161,7 @@ async def process_capture(request: Request, capture_uuid: Annotated[str, Form()] @router.post("/capture/location") -async def receive_location(location: Location, db: Session = Depends(AppState.get_db)): +async def receive_location(location: Location, db: Session = Depends(AppState.get_db), app_state: AppState = Depends(AppState.authenticate_request)): try: logger.info(f"Received location: {location}") new_location = create_location(db, location) diff --git a/untitledai/server/routes/conversations.py b/untitledai/server/routes/conversations.py index 42640bc3..947aa93b 100644 --- a/untitledai/server/routes/conversations.py +++ b/untitledai/server/routes/conversations.py @@ -14,15 +14,17 @@ def read_conversations( offset: int = 0, limit: int = Query(default=100), - db: Session = Depends(AppState.get_db) + db: Session = Depends(AppState.get_db), + app_state: AppState = Depends(AppState.authenticate_request) ): conversations = get_all_conversations(db, offset, limit) return ConversationsResponse(conversations=conversations) -@router.delete("/conversations/{conversation_id}/") +@router.delete("/conversations/{conversation_id}") def delete_conversation_endpoint( conversation_id: int, - db: Session = Depends(AppState.get_db) + db: Session = Depends(AppState.get_db), + app_state: AppState = Depends(AppState.authenticate_request) ): success = delete_conversation(db, conversation_id) if not success: From a2a2e7d555dc0cdb8d799f5767f81591a4c84224 Mon Sep 17 00:00:00 2001 From: Ethan Sutin Date: Sat, 3 Feb 2024 16:00:22 -0800 Subject: [PATCH 2/3] Remove debugging base url --- clients/ios/Shared/AppConstants.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/ios/Shared/AppConstants.swift b/clients/ios/Shared/AppConstants.swift index 722efa1f..c3eac812 100644 --- a/clients/ios/Shared/AppConstants.swift +++ b/clients/ios/Shared/AppConstants.swift @@ -7,7 +7,7 @@ import Foundation struct AppConstants { - static let apiBaseURL = "https://3093adab5e6f.ngrok.app" + static let apiBaseURL = "" static let clientToken = "" static let bleServiceUUID = "03d5d5c4-a86c-11ee-9d89-8f2089a49e7e" static let bleAudioCharacteristicUUID = "b189a505-a86c-11ee-a5fb-8f2089a49e7e" From 0b0bfe6460b3aed5f305e71bace2c66313721e2e Mon Sep 17 00:00:00 2001 From: Ethan Sutin Date: Mon, 5 Feb 2024 17:00:45 -0800 Subject: [PATCH 3/3] Add app state type hints --- untitledai/server/app_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/untitledai/server/app_state.py b/untitledai/server/app_state.py index afd04ecf..392f33a0 100644 --- a/untitledai/server/app_state.py +++ b/untitledai/server/app_state.py @@ -58,7 +58,7 @@ async def _parse_and_verify_token(authorization: str, expected_token: str): @staticmethod async def authenticate_request(request: Request, authorization: Optional[str] = Header(None)): - app_state = AppState.get(request) + app_state: AppState = AppState.get(request) await AppState._parse_and_verify_token(authorization, app_state.config.user.client_token) return app_state @@ -66,6 +66,6 @@ async def authenticate_request(request: Request, authorization: Optional[str] = async def authenticate_socket(environ: dict): headers = {k.decode('utf-8').lower(): v.decode('utf-8') for k, v in environ.get('asgi.scope', {}).get('headers', [])} authorization = headers.get('authorization') - app_state = AppState.get(environ['asgi.scope']['app']) + app_state: AppState = AppState.get(environ['asgi.scope']['app']) await AppState._parse_and_verify_token(authorization, app_state.config.user.client_token) return app_state \ No newline at end of file