From 996c59f09967e8b03f8625a53fe7a3f9b7a93531 Mon Sep 17 00:00:00 2001 From: eric <eeoneric@gmail.com> Date: Tue, 28 Jan 2025 12:48:05 -0500 Subject: [PATCH] optimize --- posthog/hogql/database/database.py | 84 ++++++++++---------- posthog/hogql/database/test/test_database.py | 67 ++++++++++++++++ 2 files changed, 109 insertions(+), 42 deletions(-) diff --git a/posthog/hogql/database/database.py b/posthog/hogql/database/database.py index c014099691818..cd47ba88baac8 100644 --- a/posthog/hogql/database/database.py +++ b/posthog/hogql/database/database.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, TypeAlias, Union, cast from zoneinfo import ZoneInfo, ZoneInfoNotFoundError -from django.db.models import Q +from django.db.models import Q, Prefetch from pydantic import BaseModel, ConfigDict from sentry_sdk import capture_exception @@ -91,8 +91,6 @@ SessionTableVersion, ) from posthog.warehouse.models.external_data_job import ExternalDataJob -from posthog.warehouse.models.external_data_schema import ExternalDataSchema -from posthog.warehouse.models.external_data_source import ExternalDataSource from posthog.warehouse.models.table import ( DataWarehouseTable, DataWarehouseTableColumns, @@ -572,27 +570,34 @@ def serialize_database( fields_dict = {field.name: field for field in fields} tables[table_key] = DatabaseSchemaPostHogTable(fields=fields_dict, id=table_key, name=table_key) - # Data Warehouse Tables + # Data Warehouse Tables and Views - Fetch all related data in one go warehouse_table_names = context.database.get_warehouse_tables() - warehouse_tables = ( - list( - DataWarehouseTable.objects.select_related("credential", "external_data_source") - .filter(Q(deleted=False) | Q(deleted__isnull=True), team_id=context.team_id, name__in=warehouse_table_names) - .all() - ) - if len(warehouse_table_names) > 0 - else [] - ) - warehouse_schemas = ( - list( - ExternalDataSchema.objects.exclude(deleted=True) - .filter(team_id=context.team_id, table_id__in=[table.id for table in warehouse_tables]) - .all() + views = context.database.get_views() + + # Fetch warehouse tables with related data in a single query + warehouse_tables_with_data = ( + DataWarehouseTable.objects.select_related("credential", "external_data_source") + .prefetch_related( + "externaldataschema_set", + Prefetch( + "external_data_source__jobs", + queryset=ExternalDataJob.objects.filter(status="Completed", team_id=context.team_id).order_by( + "-created_at" + )[:1], + to_attr="latest_completed_job", + ), ) - if len(warehouse_tables) > 0 + .filter(Q(deleted=False) | Q(deleted__isnull=True), team_id=context.team_id, name__in=warehouse_table_names) + .all() + if warehouse_table_names else [] ) - for warehouse_table in warehouse_tables: + + # Fetch all views in a single query + all_views = DataWarehouseSavedQuery.objects.filter(team_id=context.team_id, deleted=False).all() if views else [] + + # Process warehouse tables + for warehouse_table in warehouse_tables_with_data: table_key = warehouse_table.name field_input = {} @@ -603,14 +608,12 @@ def serialize_database( fields = serialize_fields(field_input, context, table_key, warehouse_table.columns, table_type="external") fields_dict = {field.name: field for field in fields} - # Schema - schema_filter: list[ExternalDataSchema] = list( - filter(lambda schema: schema.table_id == warehouse_table.id, warehouse_schemas) - ) - if len(schema_filter) == 0: - schema: DatabaseSchemaSchema | None = None + # Get schema from prefetched data + schema_data = list(warehouse_table.externaldataschema_set.all()) + if not schema_data: + schema = None else: - db_schema = schema_filter[0] + db_schema = schema_data[0] schema = DatabaseSchemaSchema( id=str(db_schema.id), name=db_schema.name, @@ -620,15 +623,15 @@ def serialize_database( last_synced_at=str(db_schema.last_synced_at), ) - # Source + # Get source from prefetched data if warehouse_table.external_data_source is None: - source: DatabaseSchemaSource | None = None + source = None else: - db_source: ExternalDataSource = warehouse_table.external_data_source + db_source = warehouse_table.external_data_source latest_completed_run = ( - ExternalDataJob.objects.filter(pipeline_id=db_source.pk, status="Completed", team_id=context.team_id) - .order_by("-created_at") - .first() + db_source.latest_completed_job[0] + if hasattr(db_source, "latest_completed_job") and db_source.latest_completed_job + else None ) source = DatabaseSchemaSource( id=str(db_source.source_id), @@ -648,9 +651,8 @@ def serialize_database( source=source, ) - # Views - views = context.database.get_views() - all_views = list(DataWarehouseSavedQuery.objects.filter(team_id=context.team_id).exclude(deleted=True)) + # Process views using prefetched data + views_dict = {view.name: view for view in all_views} for view_name in views: view: SavedQuery | None = getattr(context.database, view_name, None) if view is None: @@ -659,15 +661,13 @@ def serialize_database( fields = serialize_fields(view.fields, context, view_name, table_type="external") fields_dict = {field.name: field for field in fields} - saved_query: list[DataWarehouseSavedQuery] = list( - filter(lambda saved_query: saved_query.name == view_name, all_views) - ) - if len(saved_query) != 0: + saved_query = views_dict.get(view_name) + if saved_query: tables[view_name] = DatabaseSchemaViewTable( fields=fields_dict, - id=str(saved_query[0].pk), + id=str(saved_query.pk), name=view.name, - query=HogQLQuery(query=saved_query[0].query["query"]), + query=HogQLQuery(query=saved_query.query["query"]), ) return tables diff --git a/posthog/hogql/database/test/test_database.py b/posthog/hogql/database/test/test_database.py index c4da114c178d5..41d9a80f3501c 100644 --- a/posthog/hogql/database/test/test_database.py +++ b/posthog/hogql/database/test/test_database.py @@ -214,6 +214,73 @@ def test_serialize_database_warehouse_table_source(self): assert field.type == "string" assert field.schema_valid is True + def test_serialize_database_warehouse_table_source_query_count(self): + source = ExternalDataSource.objects.create( + team=self.team, + source_id="source_id_1", + connection_id="connection_id_1", + status=ExternalDataSource.Status.COMPLETED, + source_type=ExternalDataSource.Type.STRIPE, + ) + credentials = DataWarehouseCredential.objects.create(access_key="blah", access_secret="blah", team=self.team) + warehouse_table = DataWarehouseTable.objects.create( + name="table_1", + format="Parquet", + team=self.team, + external_data_source=source, + external_data_source_id=source.id, + credential=credentials, + url_pattern="https://bucket.s3/data/*", + columns={"id": {"hogql": "StringDatabaseField", "clickhouse": "Nullable(String)", "schema_valid": True}}, + ) + ExternalDataSchema.objects.create( + team=self.team, + name="table_1", + source=source, + table=warehouse_table, + should_sync=True, + last_synced_at="2024-01-01", + ) + + database = create_hogql_database(team_id=self.team.pk) + + with self.assertNumQueries(3): + serialize_database(HogQLContext(team_id=self.team.pk, database=database)) + + for i in range(5): + source = ExternalDataSource.objects.create( + team=self.team, + source_id=f"source_id_{i+2}", + connection_id=f"connection_id_{i+2}", + status=ExternalDataSource.Status.COMPLETED, + source_type=ExternalDataSource.Type.STRIPE, + ) + warehouse_table = DataWarehouseTable.objects.create( + name=f"table_{i+2}", + format="Parquet", + team=self.team, + external_data_source=source, + external_data_source_id=source.id, + credential=credentials, + url_pattern="https://bucket.s3/data/*", + columns={ + "id": {"hogql": "StringDatabaseField", "clickhouse": "Nullable(String)", "schema_valid": True} + }, + ) + ExternalDataSchema.objects.create( + team=self.team, + name=f"table_{i+2}", + source=source, + table=warehouse_table, + should_sync=True, + last_synced_at="2024-01-01", + ) + + database = create_hogql_database(team_id=self.team.pk) + + with self.assertNumQueries(3): + serialize_database(HogQLContext(team_id=self.team.pk, database=database)) + @patch("posthog.hogql.query.sync_execute", return_value=([], [])) @pytest.mark.usefixtures("unittest_snapshot") def test_database_with_warehouse_tables(self, patch_execute):