Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add connection pooling #27

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion space2stats_api/src/app/routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ def get_summary(request: SummaryRequest):

@router.get("/fields", response_model=List[str])
def fields():
return get_available_fields()
return get_available_fields()
4 changes: 4 additions & 0 deletions space2stats_api/src/app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ class Settings(BaseSettings):
DB_USER: str
DB_PASSWORD: str
DB_TABLE_NAME: str

@property
def DB_CONNECTION_STRING(self) -> str:
return f"host={self.DB_HOST} port={self.DB_PORT} dbname={self.DB_NAME} user={self.DB_USER} password={self.DB_PASSWORD}"
69 changes: 26 additions & 43 deletions space2stats_api/src/app/utils/db_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import psycopg as pg
from psycopg_pool import ConnectionPool
from ..settings import Settings

settings = Settings()

DB_HOST = settings.DB_HOST
DB_PORT = settings.DB_PORT
DB_NAME = settings.DB_NAME
DB_USER = settings.DB_USER
DB_PASSWORD = settings.DB_PASSWORD
DB_TABLE_NAME = settings.DB_TABLE_NAME or "space2stats"
conninfo = settings.DB_CONNECTION_STRING
pool = ConnectionPool(conninfo=conninfo, min_size=1, max_size=10, open=True)


def get_summaries(fields, h3_ids):
Expand All @@ -20,26 +17,20 @@ def get_summaries(fields, h3_ids):
FROM {1}
WHERE hex_id = ANY (%s)
"""
).format(pg.sql.SQL(", ").join(cols), pg.sql.Identifier(DB_TABLE_NAME))
).format(pg.sql.SQL(", ").join(cols), pg.sql.Identifier(settings.DB_TABLE_NAME))
try:
conn = pg.connect(
host=DB_HOST,
port=DB_PORT,
dbname=DB_NAME,
user=DB_USER,
password=DB_PASSWORD,
)
cur = conn.cursor()
cur.execute(
sql_query,
[
h3_ids,
],
)
rows = cur.fetchall()
colnames = [desc[0] for desc in cur.description]
cur.close()
conn.close()
# Convert h3_ids to a list to ensure compatibility with psycopg
h3_ids = list(h3_ids)
with pool.connection() as conn:
zacdezgeo marked this conversation as resolved.
Show resolved Hide resolved
with conn.cursor() as cur:
cur.execute(
sql_query,
[
h3_ids,
],
)
rows = cur.fetchall()
colnames = [desc[0] for desc in cur.description]
except Exception as e:
raise e

Expand All @@ -53,24 +44,16 @@ def get_available_fields():
WHERE table_name = %s
"""
try:
conn = pg.connect(
host=DB_HOST,
port=DB_PORT,
dbname=DB_NAME,
user=DB_USER,
password=DB_PASSWORD,
)
cur = conn.cursor()
cur.execute(
sql_query,
[
DB_TABLE_NAME,
],
)
columns = [row[0] for row in cur.fetchall() if row[0] != "hex_id"]
cur.close()
conn.close()
with pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(
sql_query,
[
settings.DB_TABLE_NAME,
],
)
columns = [row[0] for row in cur.fetchall() if row[0] != "hex_id"]
except Exception as e:
raise e

return columns
return columns
1 change: 1 addition & 0 deletions space2stats_api/src/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ python-dotenv
shapely
h3
psycopg[binary]
psycopg[pool]
httpx
geojson-pydantic
shapely
Expand Down
61 changes: 23 additions & 38 deletions space2stats_api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,20 @@ def test_read_root():
assert response.json() == {"message": "Welcome to Space2Stats!"}


@patch("psycopg.connect")
def test_get_summary(mock_connect):
mock_cursor = mock_connect.return_value.cursor.return_value
mock_cursor.description = [("hex_id",), ("field1",), ("field2",)]
mock_cursor.fetchall.return_value = [("hex_1", 100, 200)]
@patch("src.app.routers.api.get_summaries")
def test_get_summary(mock_get_summaries):
mock_get_summaries.return_value = [("hex_1", 100, 200)], ["hex_id", "sum_pop_2020", "sum_pop_f_10_2020"]

request_payload = {
"aoi": aoi,
"spatial_join_method": "touches",
"fields": ["field1", "field2"],
"fields": ["sum_pop_2020", "sum_pop_f_10_2020"],
}

response = client.post("/summary", json=request_payload)

assert response.status_code == 200
response_json = response.json()
print(response_json)
assert isinstance(response_json, list)

for summary in response_json:
Expand All @@ -56,24 +53,21 @@ def test_get_summary(mock_connect):
assert len(summary) == len(request_payload["fields"]) + 1 # +1 for the 'hex_id'


@patch("psycopg.connect")
def test_get_summary_with_geometry_polygon(mock_connect):
mock_cursor = mock_connect.return_value.cursor.return_value
mock_cursor.description = [("hex_id",), ("field1",), ("field2",)]
mock_cursor.fetchall.return_value = [("hex_1", 100, 200)]
@patch("src.app.routers.api.get_summaries")
def test_get_summary_with_geometry_polygon(mock_get_summaries):
mock_get_summaries.return_value = [("hex_1", 100, 200)], ["hex_id", "sum_pop_2020", "sum_pop_f_10_2020"]

request_payload = {
"aoi": aoi,
"spatial_join_method": "touches",
"fields": ["field1", "field2"],
"fields": ["sum_pop_2020", "sum_pop_f_10_2020"],
"geometry": "polygon",
}

response = client.post("/summary", json=request_payload)

assert response.status_code == 200
response_json = response.json()
print(response_json)
assert isinstance(response_json, list)

for summary in response_json:
Expand All @@ -82,29 +76,24 @@ def test_get_summary_with_geometry_polygon(mock_connect):
assert summary["geometry"]["type"] == "Polygon"
for field in request_payload["fields"]:
assert field in summary
assert (
len(summary) == len(request_payload["fields"]) + 2
) # +1 for the 'hex_id' and +1 for 'geometry'
assert len(summary) == len(request_payload["fields"]) + 2 # +1 for the 'hex_id' and +1 for 'geometry'


@patch("psycopg.connect")
def test_get_summary_with_geometry_point(mock_connect):
mock_cursor = mock_connect.return_value.cursor.return_value
mock_cursor.description = [("hex_id",), ("field1",), ("field2",)]
mock_cursor.fetchall.return_value = [("hex_1", 100, 200)]
@patch("src.app.routers.api.get_summaries")
def test_get_summary_with_geometry_point(mock_get_summaries):
mock_get_summaries.return_value = [("hex_1", 100, 200)], ["hex_id", "sum_pop_2020", "sum_pop_f_10_2020"]

request_payload = {
"aoi": aoi,
"spatial_join_method": "touches",
"fields": ["field1", "field2"],
"fields": ["sum_pop_2020", "sum_pop_f_10_2020"],
"geometry": "point",
}

response = client.post("/summary", json=request_payload)

assert response.status_code == 200
response_json = response.json()
print(response_json)
assert isinstance(response_json, list)

for summary in response_json:
Expand All @@ -113,26 +102,22 @@ def test_get_summary_with_geometry_point(mock_connect):
assert summary["geometry"]["type"] == "Point"
for field in request_payload["fields"]:
assert field in summary
assert (
len(summary) == len(request_payload["fields"]) + 2
) # +1 for the 'hex_id' and +1 for 'geometry'
assert len(summary) == len(request_payload["fields"]) + 2 # +1 for the 'hex_id' and +1 for 'geometry'


@patch("psycopg.connect")
def test_get_fields(mock_connect):
mock_cursor = mock_connect.return_value.cursor.return_value
mock_cursor.fetchall.return_value = [
("hex_id",),
("field1",),
("field2",),
("field3",),
]
@patch("src.app.routers.api.get_available_fields")
def test_get_fields(mock_get_available_fields):
mock_get_available_fields.return_value = ["sum_pop_2020", "sum_pop_f_10_2020", "field3"]

response = client.get("/fields")

assert response.status_code == 200
assert response.json() == ["field1", "field2", "field3"]
response_json = response.json()

expected_fields = ["sum_pop_2020", "sum_pop_f_10_2020", "field3"]
for field in expected_fields:
assert field in response_json


if __name__ == "__main__":
pytest.main()
pytest.main()
124 changes: 64 additions & 60 deletions space2stats_api/tests/test_db_utils.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,67 @@
import unittest
from unittest.mock import patch, Mock
from src.app.utils.db_utils import get_summaries, get_available_fields
from psycopg.sql import SQL, Identifier


@patch("psycopg.connect")
def test_get_summaries(mock_connect):
mock_conn = Mock()
mock_cursor = Mock()
mock_connect.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor

mock_cursor.description = [("hex_id",), ("field1",), ("field2",)]
mock_cursor.fetchall.return_value = [("hex_1", 100, 200)]

fields = ["field1", "field2"]
h3_ids = ["hex_1"]
rows, colnames = get_summaries(fields, h3_ids)

mock_connect.assert_called_once()
sql_query = SQL(
"""
SELECT {0}
FROM {1}
WHERE hex_id = ANY (%s)
"""
).format(
SQL(", ").join([Identifier(c) for c in ["hex_id"] + fields]),
Identifier("space2stats"),
)
mock_cursor.execute.assert_called_once_with(sql_query, [h3_ids])

assert rows == [("hex_1", 100, 200)]
assert colnames == ["hex_id", "field1", "field2"]


@patch("psycopg.connect")
def test_get_available_fields(mock_connect):
mock_conn = Mock()
mock_cursor = Mock()
mock_connect.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor

mock_cursor.fetchall.return_value = [("field1",), ("field2",), ("field3",)]

columns = get_available_fields()

mock_connect.assert_called_once()
mock_cursor.execute.assert_called_once_with(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = %s
""",
["space2stats"],
)

assert columns == ["field1", "field2", "field3"]
import pytest
from shapely.geometry import Polygon, mapping
from src.app.utils.h3_utils import generate_h3_ids, generate_h3_geometries

polygon_coords = [
[-74.3, 40.5],
[-73.7, 40.5],
[-73.7, 40.9],
[-74.3, 40.9],
[-74.3, 40.5],
]
polygon = Polygon(polygon_coords)
aoi_geojson = mapping(polygon)
resolution = 6


def test_generate_h3_ids_within():
h3_ids = generate_h3_ids(aoi_geojson, resolution, "within")
print(f"Test 'within' - Generated H3 IDs: {h3_ids}")
assert len(h3_ids) > 0, "Expected at least one H3 ID"


def test_generate_h3_ids_touches():
h3_ids = generate_h3_ids(aoi_geojson, resolution, "touches")
print(f"Test 'touches' - Generated H3 IDs: {h3_ids}")
assert len(h3_ids) > 0, "Expected at least one H3 ID"


def test_generate_h3_ids_centroid():
h3_ids = generate_h3_ids(aoi_geojson, resolution, "centroid")
print(f"Test 'centroid' - Generated H3 IDs: {h3_ids}")
assert len(h3_ids) > 0, "Expected at least one H3 ID for centroid"


def test_generate_h3_ids_invalid_method():
with pytest.raises(ValueError, match="Invalid spatial join method"):
generate_h3_ids(aoi_geojson, resolution, "invalid_method")


def test_generate_h3_geometries_polygon():
h3_ids = generate_h3_ids(aoi_geojson, resolution, "touches")
geometries = generate_h3_geometries(h3_ids, "polygon")
assert len(geometries) == len(
h3_ids
), "Expected the same number of geometries as H3 IDs"
for geom in geometries:
assert geom["type"] == "Polygon", "Expected Polygon geometry"


def test_generate_h3_geometries_point():
h3_ids = generate_h3_ids(aoi_geojson, resolution, "touches")
geometries = generate_h3_geometries(h3_ids, "point")
assert len(geometries) == len(
h3_ids
), "Expected the same number of geometries as H3 IDs"
for geom in geometries:
assert geom["type"] == "Point", "Expected Point geometry"


def test_generate_h3_geometries_invalid_type():
h3_ids = generate_h3_ids(aoi_geojson, resolution, "touches")
with pytest.raises(ValueError, match="Invalid geometry type"):
generate_h3_geometries(h3_ids, "invalid_type")


if __name__ == "__main__":
unittest.main()
pytest.main()
Loading