From 996c59f09967e8b03f8625a53fe7a3f9b7a93531 Mon Sep 17 00:00:00 2001
From: eric <eeoneric@gmail.com>
Date: Tue, 28 Jan 2025 12:48:05 -0500
Subject: [PATCH] optimize

---
 posthog/hogql/database/database.py           | 84 ++++++++++----------
 posthog/hogql/database/test/test_database.py | 67 ++++++++++++++++
 2 files changed, 109 insertions(+), 42 deletions(-)

diff --git a/posthog/hogql/database/database.py b/posthog/hogql/database/database.py
index c014099691818..cd47ba88baac8 100644
--- a/posthog/hogql/database/database.py
+++ b/posthog/hogql/database/database.py
@@ -3,7 +3,7 @@
 from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, TypeAlias, Union, cast
 from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
 
-from django.db.models import Q
+from django.db.models import Q, Prefetch
 from pydantic import BaseModel, ConfigDict
 from sentry_sdk import capture_exception
 
@@ -91,8 +91,6 @@
     SessionTableVersion,
 )
 from posthog.warehouse.models.external_data_job import ExternalDataJob
-from posthog.warehouse.models.external_data_schema import ExternalDataSchema
-from posthog.warehouse.models.external_data_source import ExternalDataSource
 from posthog.warehouse.models.table import (
     DataWarehouseTable,
     DataWarehouseTableColumns,
@@ -572,27 +570,34 @@ def serialize_database(
         fields_dict = {field.name: field for field in fields}
         tables[table_key] = DatabaseSchemaPostHogTable(fields=fields_dict, id=table_key, name=table_key)
 
-    # Data Warehouse Tables
+    # Data Warehouse Tables and Views - Fetch all related data in one go
     warehouse_table_names = context.database.get_warehouse_tables()
-    warehouse_tables = (
-        list(
-            DataWarehouseTable.objects.select_related("credential", "external_data_source")
-            .filter(Q(deleted=False) | Q(deleted__isnull=True), team_id=context.team_id, name__in=warehouse_table_names)
-            .all()
-        )
-        if len(warehouse_table_names) > 0
-        else []
-    )
-    warehouse_schemas = (
-        list(
-            ExternalDataSchema.objects.exclude(deleted=True)
-            .filter(team_id=context.team_id, table_id__in=[table.id for table in warehouse_tables])
-            .all()
+    views = context.database.get_views()
+
+    # Fetch warehouse tables with related data in a single query
+    warehouse_tables_with_data = (
+        DataWarehouseTable.objects.select_related("credential", "external_data_source")
+        .prefetch_related(
+            "externaldataschema_set",
+            Prefetch(
+                "external_data_source__jobs",
+                queryset=ExternalDataJob.objects.filter(status="Completed", team_id=context.team_id).order_by(
+                    "-created_at"
+                )[:1],
+                to_attr="latest_completed_job",
+            ),
         )
-        if len(warehouse_tables) > 0
+        .filter(Q(deleted=False) | Q(deleted__isnull=True), team_id=context.team_id, name__in=warehouse_table_names)
+        .all()
+        if warehouse_table_names
         else []
     )
-    for warehouse_table in warehouse_tables:
+
+    # Fetch all views in a single query
+    all_views = DataWarehouseSavedQuery.objects.filter(team_id=context.team_id, deleted=False).all() if views else []
+
+    # Process warehouse tables
+    for warehouse_table in warehouse_tables_with_data:
         table_key = warehouse_table.name
 
         field_input = {}
@@ -603,14 +608,12 @@ def serialize_database(
         fields = serialize_fields(field_input, context, table_key, warehouse_table.columns, table_type="external")
         fields_dict = {field.name: field for field in fields}
 
-        # Schema
-        schema_filter: list[ExternalDataSchema] = list(
-            filter(lambda schema: schema.table_id == warehouse_table.id, warehouse_schemas)
-        )
-        if len(schema_filter) == 0:
-            schema: DatabaseSchemaSchema | None = None
+        # Get schema from prefetched data
+        schema_data = list(warehouse_table.externaldataschema_set.all())
+        if not schema_data:
+            schema = None
         else:
-            db_schema = schema_filter[0]
+            db_schema = schema_data[0]
             schema = DatabaseSchemaSchema(
                 id=str(db_schema.id),
                 name=db_schema.name,
@@ -620,15 +623,15 @@ def serialize_database(
                 last_synced_at=str(db_schema.last_synced_at),
             )
 
-        # Source
+        # Get source from prefetched data
         if warehouse_table.external_data_source is None:
-            source: DatabaseSchemaSource | None = None
+            source = None
         else:
-            db_source: ExternalDataSource = warehouse_table.external_data_source
+            db_source = warehouse_table.external_data_source
             latest_completed_run = (
-                ExternalDataJob.objects.filter(pipeline_id=db_source.pk, status="Completed", team_id=context.team_id)
-                .order_by("-created_at")
-                .first()
+                db_source.latest_completed_job[0]
+                if hasattr(db_source, "latest_completed_job") and db_source.latest_completed_job
+                else None
             )
             source = DatabaseSchemaSource(
                 id=str(db_source.source_id),
@@ -648,9 +651,8 @@ def serialize_database(
             source=source,
         )
 
-    # Views
-    views = context.database.get_views()
-    all_views = list(DataWarehouseSavedQuery.objects.filter(team_id=context.team_id).exclude(deleted=True))
+    # Process views using prefetched data
+    views_dict = {view.name: view for view in all_views}
     for view_name in views:
         view: SavedQuery | None = getattr(context.database, view_name, None)
         if view is None:
@@ -659,15 +661,13 @@ def serialize_database(
         fields = serialize_fields(view.fields, context, view_name, table_type="external")
         fields_dict = {field.name: field for field in fields}
 
-        saved_query: list[DataWarehouseSavedQuery] = list(
-            filter(lambda saved_query: saved_query.name == view_name, all_views)
-        )
-        if len(saved_query) != 0:
+        saved_query = views_dict.get(view_name)
+        if saved_query:
             tables[view_name] = DatabaseSchemaViewTable(
                 fields=fields_dict,
-                id=str(saved_query[0].pk),
+                id=str(saved_query.pk),
                 name=view.name,
-                query=HogQLQuery(query=saved_query[0].query["query"]),
+                query=HogQLQuery(query=saved_query.query["query"]),
             )
 
     return tables
diff --git a/posthog/hogql/database/test/test_database.py b/posthog/hogql/database/test/test_database.py
index c4da114c178d5..41d9a80f3501c 100644
--- a/posthog/hogql/database/test/test_database.py
+++ b/posthog/hogql/database/test/test_database.py
@@ -214,6 +214,73 @@ def test_serialize_database_warehouse_table_source(self):
         assert field.type == "string"
         assert field.schema_valid is True
 
+    def test_serialize_database_warehouse_table_source_query_count(self):
+        source = ExternalDataSource.objects.create(
+            team=self.team,
+            source_id="source_id_1",
+            connection_id="connection_id_1",
+            status=ExternalDataSource.Status.COMPLETED,
+            source_type=ExternalDataSource.Type.STRIPE,
+        )
+        credentials = DataWarehouseCredential.objects.create(access_key="blah", access_secret="blah", team=self.team)
+        warehouse_table = DataWarehouseTable.objects.create(
+            name="table_1",
+            format="Parquet",
+            team=self.team,
+            external_data_source=source,
+            external_data_source_id=source.id,
+            credential=credentials,
+            url_pattern="https://bucket.s3/data/*",
+            columns={"id": {"hogql": "StringDatabaseField", "clickhouse": "Nullable(String)", "schema_valid": True}},
+        )
+        ExternalDataSchema.objects.create(
+            team=self.team,
+            name="table_1",
+            source=source,
+            table=warehouse_table,
+            should_sync=True,
+            last_synced_at="2024-01-01",
+        )
+
+        database = create_hogql_database(team_id=self.team.pk)
+
+        with self.assertNumQueries(3):
+            serialize_database(HogQLContext(team_id=self.team.pk, database=database))
+
+        for i in range(5):
+            source = ExternalDataSource.objects.create(
+                team=self.team,
+                source_id=f"source_id_{i+2}",
+                connection_id=f"connection_id_{i+2}",
+                status=ExternalDataSource.Status.COMPLETED,
+                source_type=ExternalDataSource.Type.STRIPE,
+            )
+            warehouse_table = DataWarehouseTable.objects.create(
+                name=f"table_{i+2}",
+                format="Parquet",
+                team=self.team,
+                external_data_source=source,
+                external_data_source_id=source.id,
+                credential=credentials,
+                url_pattern="https://bucket.s3/data/*",
+                columns={
+                    "id": {"hogql": "StringDatabaseField", "clickhouse": "Nullable(String)", "schema_valid": True}
+                },
+            )
+            ExternalDataSchema.objects.create(
+                team=self.team,
+                name=f"table_{i+2}",
+                source=source,
+                table=warehouse_table,
+                should_sync=True,
+                last_synced_at="2024-01-01",
+            )
+
+        database = create_hogql_database(team_id=self.team.pk)
+
+        with self.assertNumQueries(3):
+            serialize_database(HogQLContext(team_id=self.team.pk, database=database))
+
     @patch("posthog.hogql.query.sync_execute", return_value=([], []))
     @pytest.mark.usefixtures("unittest_snapshot")
     def test_database_with_warehouse_tables(self, patch_execute):