diff --git a/examples/api_node_examples/api_client.py b/examples/api_node_examples/api_client.py index 763dd76..9cdeca3 100644 --- a/examples/api_node_examples/api_client.py +++ b/examples/api_node_examples/api_client.py @@ -36,7 +36,7 @@ def __init__( self.input_tick_channel = input_tick_channel self.input_response_channel = input_response_channel self.output_channel = output_channel - self.output_message_type: type[Message[RestRequest]] = Message[request_class] # type: ignore[valid-type, assignment] + self.output_message_type: type[Message[RestRequest]] = Message[request_class] # type: ignore[valid-type] async def event_handler( self, channel: str, message: Message[RestResponse | Tick] diff --git a/src/aact/nodes/base.py b/src/aact/nodes/base.py index 8e3dd95..c2dbcbf 100644 --- a/src/aact/nodes/base.py +++ b/src/aact/nodes/base.py @@ -210,10 +210,8 @@ async def _wait_for_input( if message["type"] == "message" and channel in self.input_channel_types: try: data = Message[ - self.input_channel_types[channel] - ].model_validate_json( # type: ignore - message["data"] - ) + self.input_channel_types[channel] # type: ignore[name-defined] + ].model_validate_json(message["data"]) except ValidationError as e: self.logger.error( f"Failed to validate message from {channel}: {message['data']}. Error: {e}" diff --git a/tests/messages/test_rest.py b/tests/messages/test_rest.py index c0532f7..4d4f43a 100644 --- a/tests/messages/test_rest.py +++ b/tests/messages/test_rest.py @@ -1,7 +1,7 @@ from aact.messages import get_rest_request_class, get_rest_response_class, Text -def test_get_rest_request_class(): +def test_get_rest_request_class() -> None: request_class = get_rest_request_class(Text) assert request_class.__name__ == "RestRequest[Text]" diff --git a/tests/nodes/main.py b/tests/nodes/main.py new file mode 100644 index 0000000..51bf021 --- /dev/null +++ b/tests/nodes/main.py @@ -0,0 +1,125 @@ +from fastapi import FastAPI, Response, HTTPException, File, UploadFile +from fastapi.responses import PlainTextResponse, HTMLResponse, StreamingResponse +from typing import Generator, Dict, Any +from pydantic import BaseModel +import json + +app = FastAPI(title="HTTP Test Server") + + +@app.get("/") +async def root() -> Dict[str, str]: + return {"message": "Welcome to the HTTP test server"} + + +# Basic REST endpoints +class Item(BaseModel): + item_id: int + name: str + + +class ItemResponse(BaseModel): + item: Item + message: str + + +@app.get("/items/{item_id}") +async def get_item(item_id: int) -> Item: + return Item.model_validate({"item_id": item_id, "name": f"Test Item {item_id}"}) + + +@app.post("/items") +async def create_item(item: Item) -> ItemResponse: + return ItemResponse.model_validate({"message": "Item created", "item": item}) + + +@app.put("/items/{item_id}") +async def update_item(item_id: int, item: Item) -> ItemResponse: + return ItemResponse.model_validate( + {"message": f"Item {item_id} updated", "item": item} + ) + + +@app.delete("/items/{item_id}") +async def delete_item(item_id: int) -> Dict[str, str]: + return {"message": f"Item {item_id} deleted"} + + +# Different content types +@app.get("/text", response_class=PlainTextResponse) +async def get_text() -> str: + return "This is a plain text response" + + +@app.get("/html", response_class=HTMLResponse) +async def get_html() -> str: + return """ + + +
+This is a test HTML response
+ + + """ + + +@app.get("/binary") +async def get_binary() -> Response: + content = b"Binary data response" + return Response(content=content, media_type="application/octet-stream") + + +# Status codes +@app.get("/error/404") +async def error_404() -> HTTPException: + raise HTTPException(status_code=404, detail="Item not found") + + +@app.get("/error/500") +async def error_500() -> HTTPException: + raise HTTPException(status_code=500, detail="Internal server error") + + +class FileResponse(BaseModel): + filename: str + content_type: str + size: int + + +# File upload +@app.post("/upload") +async def upload_file(file: UploadFile = File(...)) -> FileResponse: + contents = await file.read() + return FileResponse.model_validate( + { + "filename": file.filename, + "content_type": file.content_type, + "size": len(contents), + } + ) + + +# Streaming response +@app.get("/stream") +async def stream_data() -> StreamingResponse: + def generate() -> Generator[str, None, None]: + for i in range(5): + yield json.dumps({"chunk": i}) + "\n" + + return StreamingResponse(generate(), media_type="application/x-ndjson") + + +# Echo endpoint +@app.post("/echo") +async def echo(request_data: Dict[str, Any]) -> Dict[str, Any]: + return request_data + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/tests/nodes/test_api.py b/tests/nodes/test_api.py new file mode 100644 index 0000000..e69de29