-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathserver.py
106 lines (84 loc) · 2.92 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# server.py
import os
import sys
import traceback
from typing import Tuple, List
from fastapi import FastAPI, Depends, HTTPException
from fastapi.logger import logger
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseSettings, BaseModel
from model import Model, get_model
class Settings(BaseSettings):
BASE_URL = "http://localhost:8000"
USE_NGROK = os.environ.get("USE_NGROK", "False") == "True"
settings = Settings()
# Initialize the FastAPI app for a simple web server
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if settings.USE_NGROK:
# pyngrok should only ever be installed or initialized in a dev environment when this flag is set
from pyngrok import ngrok
# Get the dev server port (defaults to 8000 for Uvicorn, can be overridden with `--port`
# when starting the server
port = sys.argv[sys.argv.index("--port") + 1] if "--port" in sys.argv else 8000
# Open a ngrok tunnel to the dev server
public_url = ngrok.connect(port)
logger.warn("ngrok tunnel \"{}\" -> \"http://127.0.0.1:{}/\"".format(public_url, port))
# Update any base URLs or webhooks to use the public ngrok URL
settings.BASE_URL = public_url
class DistanceRequest(BaseModel):
sents: Tuple[str, str]
class DistanceResponse(BaseModel):
cosine: float
manhattan: float
euclidean: float
class KmeansRequest(BaseModel):
corpus: List[str]
n_clusters: int
class KmeansResponse(BaseModel):
plot_html: str
@app.get("/")
def read_root():
return {"Hello": "World"}
@app.post("/measure_dis", response_model=DistanceResponse)
def measure_sim(request: DistanceRequest, model:Model = Depends(get_model)):
try:
cosine, manhattan, euclidean = model.measure_distance(request.sents)
except Exception as e:
raise HTTPException(
status_code=422,
detail=traceback.format_exc(),
)
return DistanceResponse(
cosine = cosine,
manhattan = manhattan,
euclidean = euclidean
)
@app.post("/cluster", response_model=KmeansResponse)
def cluster(request: KmeansRequest, model:Model = Depends(get_model)):
print(request.n_clusters, len(request.corpus))
if (request.n_clusters < 1):
raise HTTPException(status_code=400,
detail="Number of clusters must be greater than 0.")
elif (len(request.corpus) < 2):
raise HTTPException(status_code=400,
detail="Corpus must have at least 2 sentences.")
elif (len(request.corpus) < request.n_clusters):
raise HTTPException(status_code=400,
detail="Number of sentences must be greater than number of clusters.")
try:
result = KmeansResponse(
plot_html = model.fit_kmeans(request.corpus, request.n_clusters)
)
except Exception as e:
raise HTTPException(
status_code=422,
detail=traceback.format_exc(),
)
return result