Skip to content

Commit

Permalink
chore(data-warehouse): optimize databaseschemaquery (#27977)
Browse files Browse the repository at this point in the history
  • Loading branch information
EDsCODE authored and adamleithp committed Jan 29, 2025
1 parent 5981434 commit dbda5b2
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 42 deletions.
84 changes: 42 additions & 42 deletions posthog/hogql/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, TypeAlias, Union, cast
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError

from django.db.models import Q
from django.db.models import Q, Prefetch
from pydantic import BaseModel, ConfigDict
from sentry_sdk import capture_exception

Expand Down Expand Up @@ -91,8 +91,6 @@
SessionTableVersion,
)
from posthog.warehouse.models.external_data_job import ExternalDataJob
from posthog.warehouse.models.external_data_schema import ExternalDataSchema
from posthog.warehouse.models.external_data_source import ExternalDataSource
from posthog.warehouse.models.table import (
DataWarehouseTable,
DataWarehouseTableColumns,
Expand Down Expand Up @@ -572,27 +570,34 @@ def serialize_database(
fields_dict = {field.name: field for field in fields}
tables[table_key] = DatabaseSchemaPostHogTable(fields=fields_dict, id=table_key, name=table_key)

# Data Warehouse Tables
# Data Warehouse Tables and Views - Fetch all related data in one go
warehouse_table_names = context.database.get_warehouse_tables()
warehouse_tables = (
list(
DataWarehouseTable.objects.select_related("credential", "external_data_source")
.filter(Q(deleted=False) | Q(deleted__isnull=True), team_id=context.team_id, name__in=warehouse_table_names)
.all()
)
if len(warehouse_table_names) > 0
else []
)
warehouse_schemas = (
list(
ExternalDataSchema.objects.exclude(deleted=True)
.filter(team_id=context.team_id, table_id__in=[table.id for table in warehouse_tables])
.all()
views = context.database.get_views()

# Fetch warehouse tables with related data in a single query
warehouse_tables_with_data = (
DataWarehouseTable.objects.select_related("credential", "external_data_source")
.prefetch_related(
"externaldataschema_set",
Prefetch(
"external_data_source__jobs",
queryset=ExternalDataJob.objects.filter(status="Completed", team_id=context.team_id).order_by(
"-created_at"
)[:1],
to_attr="latest_completed_job",
),
)
if len(warehouse_tables) > 0
.filter(Q(deleted=False) | Q(deleted__isnull=True), team_id=context.team_id, name__in=warehouse_table_names)
.all()
if warehouse_table_names
else []
)
for warehouse_table in warehouse_tables:

# Fetch all views in a single query
all_views = DataWarehouseSavedQuery.objects.filter(team_id=context.team_id, deleted=False).all() if views else []

# Process warehouse tables
for warehouse_table in warehouse_tables_with_data:
table_key = warehouse_table.name

field_input = {}
Expand All @@ -603,14 +608,12 @@ def serialize_database(
fields = serialize_fields(field_input, context, table_key, warehouse_table.columns, table_type="external")
fields_dict = {field.name: field for field in fields}

# Schema
schema_filter: list[ExternalDataSchema] = list(
filter(lambda schema: schema.table_id == warehouse_table.id, warehouse_schemas)
)
if len(schema_filter) == 0:
schema: DatabaseSchemaSchema | None = None
# Get schema from prefetched data
schema_data = list(warehouse_table.externaldataschema_set.all())
if not schema_data:
schema = None
else:
db_schema = schema_filter[0]
db_schema = schema_data[0]
schema = DatabaseSchemaSchema(
id=str(db_schema.id),
name=db_schema.name,
Expand All @@ -620,15 +623,15 @@ def serialize_database(
last_synced_at=str(db_schema.last_synced_at),
)

# Source
# Get source from prefetched data
if warehouse_table.external_data_source is None:
source: DatabaseSchemaSource | None = None
source = None
else:
db_source: ExternalDataSource = warehouse_table.external_data_source
db_source = warehouse_table.external_data_source
latest_completed_run = (
ExternalDataJob.objects.filter(pipeline_id=db_source.pk, status="Completed", team_id=context.team_id)
.order_by("-created_at")
.first()
db_source.latest_completed_job[0]
if hasattr(db_source, "latest_completed_job") and db_source.latest_completed_job
else None
)
source = DatabaseSchemaSource(
id=str(db_source.source_id),
Expand All @@ -648,9 +651,8 @@ def serialize_database(
source=source,
)

# Views
views = context.database.get_views()
all_views = list(DataWarehouseSavedQuery.objects.filter(team_id=context.team_id).exclude(deleted=True))
# Process views using prefetched data
views_dict = {view.name: view for view in all_views}
for view_name in views:
view: SavedQuery | None = getattr(context.database, view_name, None)
if view is None:
Expand All @@ -659,15 +661,13 @@ def serialize_database(
fields = serialize_fields(view.fields, context, view_name, table_type="external")
fields_dict = {field.name: field for field in fields}

saved_query: list[DataWarehouseSavedQuery] = list(
filter(lambda saved_query: saved_query.name == view_name, all_views)
)
if len(saved_query) != 0:
saved_query = views_dict.get(view_name)
if saved_query:
tables[view_name] = DatabaseSchemaViewTable(
fields=fields_dict,
id=str(saved_query[0].pk),
id=str(saved_query.pk),
name=view.name,
query=HogQLQuery(query=saved_query[0].query["query"]),
query=HogQLQuery(query=saved_query.query["query"]),
)

return tables
Expand Down
67 changes: 67 additions & 0 deletions posthog/hogql/database/test/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,73 @@ def test_serialize_database_warehouse_table_source(self):
assert field.type == "string"
assert field.schema_valid is True

def test_serialize_database_warehouse_table_source_query_count(self):
source = ExternalDataSource.objects.create(
team=self.team,
source_id="source_id_1",
connection_id="connection_id_1",
status=ExternalDataSource.Status.COMPLETED,
source_type=ExternalDataSource.Type.STRIPE,
)
credentials = DataWarehouseCredential.objects.create(access_key="blah", access_secret="blah", team=self.team)
warehouse_table = DataWarehouseTable.objects.create(
name="table_1",
format="Parquet",
team=self.team,
external_data_source=source,
external_data_source_id=source.id,
credential=credentials,
url_pattern="https://bucket.s3/data/*",
columns={"id": {"hogql": "StringDatabaseField", "clickhouse": "Nullable(String)", "schema_valid": True}},
)
ExternalDataSchema.objects.create(
team=self.team,
name="table_1",
source=source,
table=warehouse_table,
should_sync=True,
last_synced_at="2024-01-01",
)

database = create_hogql_database(team_id=self.team.pk)

with self.assertNumQueries(3):
serialize_database(HogQLContext(team_id=self.team.pk, database=database))

for i in range(5):
source = ExternalDataSource.objects.create(
team=self.team,
source_id=f"source_id_{i+2}",
connection_id=f"connection_id_{i+2}",
status=ExternalDataSource.Status.COMPLETED,
source_type=ExternalDataSource.Type.STRIPE,
)
warehouse_table = DataWarehouseTable.objects.create(
name=f"table_{i+2}",
format="Parquet",
team=self.team,
external_data_source=source,
external_data_source_id=source.id,
credential=credentials,
url_pattern="https://bucket.s3/data/*",
columns={
"id": {"hogql": "StringDatabaseField", "clickhouse": "Nullable(String)", "schema_valid": True}
},
)
ExternalDataSchema.objects.create(
team=self.team,
name=f"table_{i+2}",
source=source,
table=warehouse_table,
should_sync=True,
last_synced_at="2024-01-01",
)

database = create_hogql_database(team_id=self.team.pk)

with self.assertNumQueries(3):
serialize_database(HogQLContext(team_id=self.team.pk, database=database))

@patch("posthog.hogql.query.sync_execute", return_value=([], []))
@pytest.mark.usefixtures("unittest_snapshot")
def test_database_with_warehouse_tables(self, patch_execute):
Expand Down

0 comments on commit dbda5b2

Please sign in to comment.