diff --git a/app/gallicagram.py b/app/gallicagram.py new file mode 100644 index 0000000..93ec0a2 --- /dev/null +++ b/app/gallicagram.py @@ -0,0 +1,111 @@ +from datetime import datetime +from io import StringIO +import time +from typing import Dict, Literal, Optional + +import aiohttp +from fastapi import HTTPException +import pandas as pd +from pydantic import BaseModel, validator + + +class GallicagramInput(BaseModel): + term: str + start_date: Optional[int] = 1789 + end_date: Optional[int] = 1950 + grouping: Literal["mois", "annee"] = "mois" + source: Literal["livres", "presse", "lemonde"] = "presse" + link_term: Optional[str] = None + + class Config: + validate_assignment = True + + @validator("start_date") + def set_start_date(cls, date): + return date or 1789 + + @validator("end_date") + def set_end_date(cls, date): + return date or 1950 + + def getCacheKey(self): + return f"{self.term}-{self.start_date}-{self.end_date}-{self.grouping}-{self.source}-{self.link_term}" + + +async def fetch_series(input: GallicagramInput): + return await do_dataframe_fetch( + "https://shiny.ens-paris-saclay.fr/guni/query", + { + "corpus": input.source, + "mot": input.term.lower(), + "from": input.start_date, + "to": input.end_date, + }, + ) + + +async def fetch_series_linked_term(input: GallicagramInput): + if input.link_term is None: + print("No link term") + return + return await do_dataframe_fetch( + "https://shiny.ens-paris-saclay.fr/guni/contain", + { + "corpus": input.source, + "mot1": input.term.lower(), + "mot2": input.link_term.lower(), + "from": input.start_date, + "to": input.end_date, + }, + ) + + +async def do_dataframe_fetch(url: str, params: Dict): + print(f"Fetching {url} with params {params}") + async with aiohttp.ClientSession() as session: + start = time.time() + async with session.get(url, params=params) as response: + print(f"Fetched {response.url}") + print(f"Took {time.time() - start} seconds") + if response.status != 200: + raise HTTPException( + status_code=503, detail="Could not connect to Gallicagram! Egads!" + ) + return pd.read_csv(StringIO(await response.text())) + + +def transform_series(series_dataframe: pd.DataFrame, input: GallicagramInput): + if input.grouping == "mois" and input.source != "livres": + series_dataframe = ( + series_dataframe.groupby(["annee", "mois", "gram"]) + .agg({"n": "sum", "total": "sum"}) + .reset_index() + ) + if input.grouping == "annee": + series_dataframe = ( + series_dataframe.groupby(["annee", "gram"]) + .agg({"n": "sum", "total": "sum"}) + .reset_index() + ) + + def calc_ratio(row): + if row.total == 0: + return 0 + return row.n / row.total + + series_dataframe["ratio"] = series_dataframe.apply( + lambda row: calc_ratio(row), axis=1 + ) + if all(series_dataframe.ratio == 0): + raise HTTPException(status_code=404, detail="No occurrences of the term found") + + def get_unix_timestamp(row) -> int: + year = int(row.get("annee", 0)) + month = int(row.get("mois", 1)) + + dt = datetime(year, month, 1) + return int(dt.timestamp() * 1000) + + return series_dataframe.apply( + lambda row: (get_unix_timestamp(row), row["ratio"]), axis=1 + ).tolist() diff --git a/app/main.py b/app/main.py index 61df815..bf40878 100644 --- a/app/main.py +++ b/app/main.py @@ -13,6 +13,12 @@ ContextSnippets, ExtractRoot, ) +from app.gallicagram import ( + GallicagramInput, + fetch_series, + fetch_series_linked_term, + transform_series, +) from app.mostFrequent import get_gallica_core from app.fetch import APIRequest, fetch_queries_concurrently from app.imageSnippet import ImageQuery, ImageSnippet @@ -52,7 +58,7 @@ dotenv.load_dotenv() -logfire.configure(token=os.getenv("LOGFIRE_TOKEN")) +# logfire.configure(token=os.getenv("LOGFIRE_TOKEN")) MAX_PAPERS_TO_SEARCH = 600 @@ -61,7 +67,6 @@ gallica_session: aiohttp.ClientSession # todo -# setup logfire # fix multi-term search for gallicagram @@ -702,7 +707,9 @@ class Series(BaseModel): name: str -series_cache: Dict[str, Series] = {} +series_cache: Dict[str, Tuple[Series, datetime]] = {} + +CACHE_TIMEOUT = 60 * 60 * 24 * 2 # 2 days @app.get("/api/series") @@ -714,89 +721,43 @@ async def get( source: Literal["livres", "presse", "lemonde"] = "presse", link_term: Optional[str] = None, ) -> Series: - key = f"{term}-{start_date}-{end_date}-{grouping}-{source}-{link_term}" - if key in series_cache: - return series_cache[key] - debut = start_date or 1789 - fin = end_date or 1950 - if link_term: - series_dataframe = await fetch_series_dataframe( - "https://shiny.ens-paris-saclay.fr/guni/contain", - { - "corpus": source, - "mot1": term.lower(), - "mot2": link_term.lower(), - "from": debut, - "to": fin, - }, - ) - else: - series_dataframe = await fetch_series_dataframe( - "https://shiny.ens-paris-saclay.fr/guni/query", - { - "corpus": source, - "mot": term.lower(), - "from": debut, - "to": fin, - }, - ) - if grouping == "mois" and source != "livres": - series_dataframe = ( - series_dataframe.groupby(["annee", "mois", "gram"]) - .agg({"n": "sum", "total": "sum"}) - .reset_index() - ) - if grouping == "annee": - series_dataframe = ( - series_dataframe.groupby(["annee", "gram"]) - .agg({"n": "sum", "total": "sum"}) - .reset_index() - ) - - def calc_ratio(row): - if row.total == 0: - return 0 - return row.n / row.total - - series_dataframe["ratio"] = series_dataframe.apply( - lambda row: calc_ratio(row), axis=1 + args_object = GallicagramInput( + term=term, + start_date=start_date, + end_date=end_date, + grouping=grouping, + source=source, + link_term=link_term, ) - if all(series_dataframe.ratio == 0): - raise HTTPException(status_code=404, detail="No records found") + key = args_object.getCacheKey() - def get_unix_timestamp(row) -> int: - year = int(row.get("annee", 0)) - month = int(row.get("mois", 1)) + if key in series_cache: + cache_time = series_cache[key][1] + now = datetime.now() + if (now - cache_time).total_seconds() < CACHE_TIMEOUT: + print("cache hit") + return series_cache[key][0] + del series_cache[key] + + fetched_dataframe = ( + await fetch_series(args_object) + if link_term is None + else await fetch_series_linked_term(args_object) + ) - dt = datetime(year, month, 1) - return int(dt.timestamp() * 1000) + if fetched_dataframe is None: + raise HTTPException(status_code=404, detail="No records found") - data = series_dataframe.apply( - lambda row: (get_unix_timestamp(row), row["ratio"]), axis=1 - ).tolist() + transformed_dataframe = transform_series(fetched_dataframe, args_object) - series = Series( - data=data, + fetched_dataframe = Series( + data=transformed_dataframe, name=term, ) - series_cache[key] = series - return series - - -async def fetch_series_dataframe(url: str, params: Dict): - async with aiohttp.ClientSession() as session: - start = time.time() - async with session.get(url, params=params) as response: - print(f"Fetched {response.url}") - print(f"Took {time.time() - start} seconds") - if response.status != 200: - raise HTTPException( - status_code=503, detail="Could not connect to Gallicagram! Egads!" - ) - return pd.read_csv(StringIO(await response.text())) + series_cache[key] = fetched_dataframe, datetime.now() + return fetched_dataframe if __name__ == "__main__": - print("hi fly") uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 8080)))