Skip to content

Commit

Permalink
refactor series route
Browse files Browse the repository at this point in the history
  • Loading branch information
gleasonw committed Sep 29, 2024
1 parent 7a8ed2c commit 1eb1456
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 77 deletions.
111 changes: 111 additions & 0 deletions app/gallicagram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from datetime import datetime
from io import StringIO
import time
from typing import Dict, Literal, Optional

import aiohttp
from fastapi import HTTPException
import pandas as pd
from pydantic import BaseModel, validator


class GallicagramInput(BaseModel):
term: str
start_date: Optional[int] = 1789
end_date: Optional[int] = 1950
grouping: Literal["mois", "annee"] = "mois"
source: Literal["livres", "presse", "lemonde"] = "presse"
link_term: Optional[str] = None

class Config:
validate_assignment = True

@validator("start_date")
def set_start_date(cls, date):
return date or 1789

@validator("end_date")
def set_end_date(cls, date):
return date or 1950

def getCacheKey(self):
return f"{self.term}-{self.start_date}-{self.end_date}-{self.grouping}-{self.source}-{self.link_term}"


async def fetch_series(input: GallicagramInput):
return await do_dataframe_fetch(
"https://shiny.ens-paris-saclay.fr/guni/query",
{
"corpus": input.source,
"mot": input.term.lower(),
"from": input.start_date,
"to": input.end_date,
},
)


async def fetch_series_linked_term(input: GallicagramInput):
if input.link_term is None:
print("No link term")
return
return await do_dataframe_fetch(
"https://shiny.ens-paris-saclay.fr/guni/contain",
{
"corpus": input.source,
"mot1": input.term.lower(),
"mot2": input.link_term.lower(),
"from": input.start_date,
"to": input.end_date,
},
)


async def do_dataframe_fetch(url: str, params: Dict):
print(f"Fetching {url} with params {params}")
async with aiohttp.ClientSession() as session:
start = time.time()
async with session.get(url, params=params) as response:
print(f"Fetched {response.url}")
print(f"Took {time.time() - start} seconds")
if response.status != 200:
raise HTTPException(
status_code=503, detail="Could not connect to Gallicagram! Egads!"
)
return pd.read_csv(StringIO(await response.text()))


def transform_series(series_dataframe: pd.DataFrame, input: GallicagramInput):
if input.grouping == "mois" and input.source != "livres":
series_dataframe = (
series_dataframe.groupby(["annee", "mois", "gram"])
.agg({"n": "sum", "total": "sum"})
.reset_index()
)
if input.grouping == "annee":
series_dataframe = (
series_dataframe.groupby(["annee", "gram"])
.agg({"n": "sum", "total": "sum"})
.reset_index()
)

def calc_ratio(row):
if row.total == 0:
return 0
return row.n / row.total

series_dataframe["ratio"] = series_dataframe.apply(
lambda row: calc_ratio(row), axis=1
)
if all(series_dataframe.ratio == 0):
raise HTTPException(status_code=404, detail="No occurrences of the term found")

def get_unix_timestamp(row) -> int:
year = int(row.get("annee", 0))
month = int(row.get("mois", 1))

dt = datetime(year, month, 1)
return int(dt.timestamp() * 1000)

return series_dataframe.apply(
lambda row: (get_unix_timestamp(row), row["ratio"]), axis=1
).tolist()
115 changes: 38 additions & 77 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
ContextSnippets,
ExtractRoot,
)
from app.gallicagram import (
GallicagramInput,
fetch_series,
fetch_series_linked_term,
transform_series,
)
from app.mostFrequent import get_gallica_core
from app.fetch import APIRequest, fetch_queries_concurrently
from app.imageSnippet import ImageQuery, ImageSnippet
Expand Down Expand Up @@ -52,7 +58,7 @@
dotenv.load_dotenv()


logfire.configure(token=os.getenv("LOGFIRE_TOKEN"))
# logfire.configure(token=os.getenv("LOGFIRE_TOKEN"))


MAX_PAPERS_TO_SEARCH = 600
Expand All @@ -61,7 +67,6 @@
gallica_session: aiohttp.ClientSession

# todo
# setup logfire
# fix multi-term search for gallicagram


Expand Down Expand Up @@ -702,7 +707,9 @@ class Series(BaseModel):
name: str


series_cache: Dict[str, Series] = {}
series_cache: Dict[str, Tuple[Series, datetime]] = {}

CACHE_TIMEOUT = 60 * 60 * 24 * 2 # 2 days


@app.get("/api/series")
Expand All @@ -714,89 +721,43 @@ async def get(
source: Literal["livres", "presse", "lemonde"] = "presse",
link_term: Optional[str] = None,
) -> Series:
key = f"{term}-{start_date}-{end_date}-{grouping}-{source}-{link_term}"
if key in series_cache:
return series_cache[key]
debut = start_date or 1789
fin = end_date or 1950
if link_term:
series_dataframe = await fetch_series_dataframe(
"https://shiny.ens-paris-saclay.fr/guni/contain",
{
"corpus": source,
"mot1": term.lower(),
"mot2": link_term.lower(),
"from": debut,
"to": fin,
},
)
else:
series_dataframe = await fetch_series_dataframe(
"https://shiny.ens-paris-saclay.fr/guni/query",
{
"corpus": source,
"mot": term.lower(),
"from": debut,
"to": fin,
},
)

if grouping == "mois" and source != "livres":
series_dataframe = (
series_dataframe.groupby(["annee", "mois", "gram"])
.agg({"n": "sum", "total": "sum"})
.reset_index()
)
if grouping == "annee":
series_dataframe = (
series_dataframe.groupby(["annee", "gram"])
.agg({"n": "sum", "total": "sum"})
.reset_index()
)

def calc_ratio(row):
if row.total == 0:
return 0
return row.n / row.total

series_dataframe["ratio"] = series_dataframe.apply(
lambda row: calc_ratio(row), axis=1
args_object = GallicagramInput(
term=term,
start_date=start_date,
end_date=end_date,
grouping=grouping,
source=source,
link_term=link_term,
)
if all(series_dataframe.ratio == 0):
raise HTTPException(status_code=404, detail="No records found")
key = args_object.getCacheKey()

def get_unix_timestamp(row) -> int:
year = int(row.get("annee", 0))
month = int(row.get("mois", 1))
if key in series_cache:
cache_time = series_cache[key][1]
now = datetime.now()
if (now - cache_time).total_seconds() < CACHE_TIMEOUT:
print("cache hit")
return series_cache[key][0]
del series_cache[key]

fetched_dataframe = (
await fetch_series(args_object)
if link_term is None
else await fetch_series_linked_term(args_object)
)

dt = datetime(year, month, 1)
return int(dt.timestamp() * 1000)
if fetched_dataframe is None:
raise HTTPException(status_code=404, detail="No records found")

data = series_dataframe.apply(
lambda row: (get_unix_timestamp(row), row["ratio"]), axis=1
).tolist()
transformed_dataframe = transform_series(fetched_dataframe, args_object)

series = Series(
data=data,
fetched_dataframe = Series(
data=transformed_dataframe,
name=term,
)
series_cache[key] = series
return series


async def fetch_series_dataframe(url: str, params: Dict):
async with aiohttp.ClientSession() as session:
start = time.time()
async with session.get(url, params=params) as response:
print(f"Fetched {response.url}")
print(f"Took {time.time() - start} seconds")
if response.status != 200:
raise HTTPException(
status_code=503, detail="Could not connect to Gallicagram! Egads!"
)
return pd.read_csv(StringIO(await response.text()))
series_cache[key] = fetched_dataframe, datetime.now()
return fetched_dataframe


if __name__ == "__main__":
print("hi fly")
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 8080)))

0 comments on commit 1eb1456

Please sign in to comment.