Skip to content

Commit

Permalink
Merge pull request #9 from untitledaiproject/ethan/simple-authorization
Browse files Browse the repository at this point in the history
Basic single user token authorization
  • Loading branch information
etown authored Feb 6, 2024
2 parents 8737b67 + 0b0bfe6 commit b8a4a85
Show file tree
Hide file tree
Showing 13 changed files with 81 additions and 23 deletions.
1 change: 1 addition & 0 deletions clients/ios/Shared/AppConstants.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import Foundation
struct AppConstants {
static let apiBaseURL = ""
static let clientToken = ""
static let bleServiceUUID = "03d5d5c4-a86c-11ee-9d89-8f2089a49e7e"
static let bleAudioCharacteristicUUID = "b189a505-a86c-11ee-a5fb-8f2089a49e7e"

Expand Down
14 changes: 14 additions & 0 deletions clients/ios/Shared/Extensions/URLRequest+Extensions.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//
// URLRequest+Extensions.swift
// UntitledAI
//
// Created by ethan on 2/3/24.
//

import Foundation

extension URLRequest {
mutating func addCommonHeaders() {
self.addValue("Bearer \(AppConstants.clientToken)", forHTTPHeaderField: "Authorization")
}
}
2 changes: 2 additions & 0 deletions clients/ios/Shared/Files/FileUploadTask.swift
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ fileprivate func uploadFile(_ url: URL, contentType: String) async -> Bool {
// Request type
let url = URL(string: "\(AppConstants.apiBaseURL)/capture/upload_chunk")!
var request = URLRequest(url: url)
request.addCommonHeaders()
request.httpMethod = "POST"
request.setValue("multipart/form-data;boundary=\(form.boundary)", forHTTPHeaderField: "Content-Type")

Expand Down Expand Up @@ -244,6 +245,7 @@ fileprivate func processCapture(_ captureUUID: String) async -> Bool {
// Request type
let url = URL(string: "\(AppConstants.apiBaseURL)/capture/process_capture")!
var request = URLRequest(url: url)
request.addCommonHeaders()
request.httpMethod = "POST"
request.setValue("multipart/form-data;boundary=\(form.boundary)", forHTTPHeaderField: "Content-Type")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class NetworkManager : NSObject, URLSessionDataDelegate {
fatalError("Invalid URL")
}
var request = URLRequest(url: url)
request.addCommonHeaders()
request.httpMethod = "POST"
request.setValue("application/octet-stream", forHTTPHeaderField: "Content-Type")
let task = self.urlSession.uploadTask(withStreamedRequest: request)
Expand Down Expand Up @@ -64,6 +65,7 @@ class NetworkManager : NSObject, URLSessionDataDelegate {

// Signal end
var request = URLRequest(url: url)
request.addCommonHeaders()
request.httpMethod = "POST"
let config = URLSessionConfiguration.default
let session = URLSession(configuration: config)
Expand Down
6 changes: 6 additions & 0 deletions clients/ios/UntitledAI.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
AA84B1C92B604FB000FD654C /* ConversationsViewModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = AA84B1C82B604FB000FD654C /* ConversationsViewModel.swift */; };
AA84B1CC2B606AED00FD654C /* ConversationDetailView.swift in Sources */ = {isa = PBXBuildFile; fileRef = AA84B1CB2B606AED00FD654C /* ConversationDetailView.swift */; };
AA84B1CE2B606E4400FD654C /* ConversationsView.swift in Sources */ = {isa = PBXBuildFile; fileRef = AA84B1CD2B606E4400FD654C /* ConversationsView.swift */; };
AAA6006E2B6ED40C004200FF /* URLRequest+Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = AAA6006D2B6ED40C004200FF /* URLRequest+Extensions.swift */; };
AAA6006F2B6ED40C004200FF /* URLRequest+Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = AAA6006D2B6ED40C004200FF /* URLRequest+Extensions.swift */; };
AAB582132B530FD400CB72B8 /* UntitledAIApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = AAB582122B530FD400CB72B8 /* UntitledAIApp.swift */; };
AAB582152B530FD400CB72B8 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = AAB582142B530FD400CB72B8 /* ContentView.swift */; };
AAB582172B530FD600CB72B8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = AAB582162B530FD600CB72B8 /* Assets.xcassets */; };
Expand Down Expand Up @@ -152,6 +154,7 @@
AA84B1C82B604FB000FD654C /* ConversationsViewModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConversationsViewModel.swift; sourceTree = "<group>"; };
AA84B1CB2B606AED00FD654C /* ConversationDetailView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConversationDetailView.swift; sourceTree = "<group>"; };
AA84B1CD2B606E4400FD654C /* ConversationsView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConversationsView.swift; sourceTree = "<group>"; };
AAA6006D2B6ED40C004200FF /* URLRequest+Extensions.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "URLRequest+Extensions.swift"; sourceTree = "<group>"; };
AAB5820F2B530FD400CB72B8 /* UntitledAI.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = UntitledAI.app; sourceTree = BUILT_PRODUCTS_DIR; };
AAB582122B530FD400CB72B8 /* UntitledAIApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = UntitledAIApp.swift; sourceTree = "<group>"; };
AAB582142B530FD400CB72B8 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
Expand Down Expand Up @@ -484,6 +487,7 @@
children = (
CCD7A5C42B6217830079D9D4 /* URL+Extensions.swift */,
CCD7A5C72B6219AD0079D9D4 /* UUID+Extensions.swift */,
AAA6006D2B6ED40C004200FF /* URLRequest+Extensions.swift */,
);
path = Extensions;
sourceTree = "<group>";
Expand Down Expand Up @@ -779,6 +783,7 @@
AAB7147F2B64162800412DB7 /* Capture.swift in Sources */,
AAB582132B530FD400CB72B8 /* UntitledAIApp.swift in Sources */,
AA5DC9CB2B629E770017376F /* LocationManager.swift in Sources */,
AAA6006E2B6ED40C004200FF /* URLRequest+Extensions.swift in Sources */,
AAB714832B65AC1A00412DB7 /* WatchConnectivityManager.swift in Sources */,
CCD7A5C22B62173E0079D9D4 /* AudioFileWriter.swift in Sources */,
);
Expand All @@ -805,6 +810,7 @@
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
AAA6006F2B6ED40C004200FF /* URLRequest+Extensions.swift in Sources */,
CCD7A5C82B6219AD0079D9D4 /* UUID+Extensions.swift in Sources */,
CCD7A5BD2B6216060079D9D4 /* MultipartForm.swift in Sources */,
AAB5826F2B53108D00CB72B8 /* NetworkManager.swift in Sources */,
Expand Down
7 changes: 5 additions & 2 deletions clients/ios/UntitledAI/Services/API.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ class API {

func fetchConversations(completionHandler: @escaping (ConversationsResponse) -> Void) {
guard let url = URL(string: "\(AppConstants.apiBaseURL)/conversations/") else { return }

let task = URLSession.shared.dataTask(with: url) { (data, response, error) in
var request = URLRequest(url: url)
request.addCommonHeaders()
let task = URLSession.shared.dataTask(with: request) { (data, response, error) in
if let error = error {
print("Error: \(error)")
} else if let data = data {
Expand All @@ -33,6 +34,7 @@ class API {
func deleteConversation(_ id: Int, completion: @escaping (Bool) -> Void) {
let url = URL(string: "\(AppConstants.apiBaseURL)/conversations/\(id)")!
var request = URLRequest(url: url)
request.addCommonHeaders()
request.httpMethod = "DELETE"
URLSession.shared.dataTask(with: request) { data, response, error in
if let error = error {
Expand All @@ -55,6 +57,7 @@ class API {
}

var request = URLRequest(url: url)
request.addCommonHeaders()
request.httpMethod = "POST"
request.addValue("application/json", forHTTPHeaderField: "Content-Type")

Expand Down
3 changes: 2 additions & 1 deletion clients/ios/UntitledAI/Services/SocketManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class SocketManager: ObservableObject {
.reconnects(true),
.reconnectAttempts(-1),
.reconnectWait(1),
.reconnectWaitMax(5)
.reconnectWaitMax(5),
.extraHeaders(["Authorization": "Bearer \(AppConstants.clientToken)"])
])
socket = socketManager.defaultSocket

Expand Down
1 change: 1 addition & 0 deletions untitledai/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class CapturesConfiguration(BaseModel):

class UserConfiguration(BaseModel):
name: str
client_token: str
voice_sample_filepath: Optional[str] = None

class DeepgramConfiguration(BaseModel):
Expand Down
38 changes: 29 additions & 9 deletions untitledai/server/app_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from dataclasses import dataclass, field
import os
from typing import Dict

from fastapi import FastAPI, Request

from fastapi import FastAPI, HTTPException, Request, Depends, Header
from typing import Optional
from ..core.config import Configuration
from ..services import ConversationService, LLMService, NotificationService
from .streaming_capture_handler import StreamingCaptureHandler
Expand Down Expand Up @@ -38,14 +37,35 @@ def get(from_obj: FastAPI | Request) -> AppState:
return from_obj.app.state._app_state
else:
raise TypeError("`from_obj` must be of type `FastAPI` or `Request`")

@staticmethod
def get_from_request(request: Request) -> AppState:
return request.app.state._app_state


@staticmethod
def get_db(request: Request):
app_state: AppState = AppState.get(request)
return next(app_state.database.get_db())


@staticmethod
async def _parse_and_verify_token(authorization: str, expected_token: str):
if not authorization:
raise HTTPException(status_code=401, detail="Authorization header missing")

parts = authorization.split()
if len(parts) != 2 or parts[0].lower() != 'bearer':
raise HTTPException(status_code=401, detail="Invalid token type")

token = parts[1]
if token != expected_token:
raise HTTPException(status_code=403, detail="Invalid or expired token")

@staticmethod
async def authenticate_request(request: Request, authorization: Optional[str] = Header(None)):
app_state: AppState = AppState.get(request)
await AppState._parse_and_verify_token(authorization, app_state.config.user.client_token)
return app_state

@staticmethod
async def authenticate_socket(environ: dict):
headers = {k.decode('utf-8').lower(): v.decode('utf-8') for k, v in environ.get('asgi.scope', {}).get('headers', [])}
authorization = headers.get('authorization')
app_state: AppState = AppState.get(environ['asgi.scope']['app'])
await AppState._parse_and_verify_token(authorization, app_state.config.user.client_token)
return app_state
8 changes: 7 additions & 1 deletion untitledai/server/capture_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,14 @@ def __init__(self, app_state):
def mount_to(self, app: FastAPI, at_path: str):
app.mount(path=at_path, app=self._app)

async def on_connect(self, path, sid, *args):
async def on_connect(self, path, sid, environ):
logger.info(f'Connected: {sid}')
try:
await self._app_state.authenticate_socket(environ)
except ValueError as e:
logger.error(f"Authentication failed for {sid}: {e}")
await self._sio.disconnect(sid)
return False

async def on_disconnect(self, path, sid, *args):
logger.info(f'Disconnected: {sid}')
Expand Down
4 changes: 2 additions & 2 deletions untitledai/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import logging
import asyncio
from colorama import init, Fore, Style, Back
from fastapi import Depends

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -79,7 +80,6 @@ def create_server_app(config: Configuration) -> FastAPI:
transcription_service = AsyncTranscriptionServiceFactory.get_service(config)
notification_service = NotificationService(config.notification)
conversation_service = ConversationService(config, database, transcription_service, notification_service)

# Create server app
app = FastAPI()
app.state._app_state = AppState(
Expand Down Expand Up @@ -110,7 +110,7 @@ async def shutdown_event():

# Base routing
@app.get("/")
def read_root():
async def read_root(app_state: AppState = Depends(AppState.authenticate_request)):
return "UntitledAI is running!"

return app
10 changes: 5 additions & 5 deletions untitledai/server/routes/capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def find_audio_filepath(audio_directory: str, capture_uuid: str) -> str | None:
supported_upload_file_extensions = set([ "pcm", "wav", "aac", "m4a" ])

@router.post("/capture/streaming_post/{capture_uuid}")
async def streaming_post(request: Request, capture_uuid: str, device_type: str, app_state = Depends(AppState.get_from_request)):
async def streaming_post(request: Request, capture_uuid: str, device_type: str, app_state: AppState = Depends(AppState.authenticate_request)):
logger.info('Client connected')
try:
if capture_uuid not in app_state.capture_handlers:
Expand All @@ -62,7 +62,7 @@ async def streaming_post(request: Request, capture_uuid: str, device_type: str,


@router.post("/capture/streaming_post/{capture_uuid}/complete")
async def complete_audio(request: Request, background_tasks: BackgroundTasks, capture_uuid: str, app_state = Depends(AppState.get_from_request)):
async def complete_audio(request: Request, background_tasks: BackgroundTasks, capture_uuid: str, app_state: AppState = Depends(AppState.authenticate_request)):
logger.info(f"Completing audio capture for {capture_uuid}")
if capture_uuid not in app_state.capture_handlers:
logger.error(f"Capture session not found: {capture_uuid}")
Expand All @@ -78,7 +78,7 @@ async def upload_chunk(request: Request,
capture_uuid: Annotated[str, Form()],
timestamp: Annotated[str, Form()],
device_type: Annotated[str, Form()],
app_state = Depends(AppState.get_from_request)):
app_state: AppState = Depends(AppState.authenticate_request)):
try:
# Validate file format
file_extension = os.path.splitext(file.filename)[1].lstrip(".")
Expand Down Expand Up @@ -129,7 +129,7 @@ async def upload_chunk(request: Request,


@router.post("/capture/process_capture")
async def process_capture(request: Request, capture_uuid: Annotated[str, Form()], app_state = Depends(AppState.get_from_request)):
async def process_capture(request: Request, capture_uuid: Annotated[str, Form()], app_state: AppState = Depends(AppState.authenticate_request)):
try:
# Get capture file
filepath = find_audio_filepath(audio_directory=app_state.config.captures.capture_dir, capture_uuid=capture_uuid)
Expand Down Expand Up @@ -161,7 +161,7 @@ async def process_capture(request: Request, capture_uuid: Annotated[str, Form()]


@router.post("/capture/location")
async def receive_location(location: Location, db: Session = Depends(AppState.get_db)):
async def receive_location(location: Location, db: Session = Depends(AppState.get_db), app_state: AppState = Depends(AppState.authenticate_request)):
try:
logger.info(f"Received location: {location}")
new_location = create_location(db, location)
Expand Down
8 changes: 5 additions & 3 deletions untitledai/server/routes/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@
def read_conversations(
offset: int = 0,
limit: int = Query(default=100),
db: Session = Depends(AppState.get_db)
db: Session = Depends(AppState.get_db),
app_state: AppState = Depends(AppState.authenticate_request)
):
conversations = get_all_conversations(db, offset, limit)
return ConversationsResponse(conversations=conversations)

@router.delete("/conversations/{conversation_id}/")
@router.delete("/conversations/{conversation_id}")
def delete_conversation_endpoint(
conversation_id: int,
db: Session = Depends(AppState.get_db)
db: Session = Depends(AppState.get_db),
app_state: AppState = Depends(AppState.authenticate_request)
):
success = delete_conversation(db, conversation_id)
if not success:
Expand Down

0 comments on commit b8a4a85

Please sign in to comment.