From ca15798f0e79e556efcc3a27241f9329c597e839 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Mon, 13 Jan 2025 20:40:27 +0100 Subject: [PATCH 01/15] feat: query --- frontend/src/queries/schema.json | 363 +++++++++++++++++- frontend/src/queries/schema.ts | 33 ++ .../ai/test/test_traces_query_runner.py | 90 +++++ .../hogql_queries/ai/traces_query_runner.py | 127 ++++++ posthog/schema.py | 201 +++++++++- 5 files changed, 802 insertions(+), 12 deletions(-) create mode 100644 posthog/hogql_queries/ai/test/test_traces_query_runner.py create mode 100644 posthog/hogql_queries/ai/traces_query_runner.py diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index 6339746e595ef..a4bc84efa716f 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -1,6 +1,40 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "definitions": { + "AISpan": { + "additionalProperties": false, + "properties": { + "id": { + "type": "string" + }, + "input": { + "items": {}, + "type": "array" + }, + "output": { + "items": {}, + "type": "array" + } + }, + "required": ["id", "input", "output"], + "type": "object" + }, + "AITrace": { + "additionalProperties": false, + "properties": { + "id": { + "type": "string" + }, + "spans": { + "items": { + "$ref": "#/definitions/AISpan" + }, + "type": "array" + } + }, + "required": ["id", "spans"], + "type": "object" + }, "ActionConversionGoal": { "additionalProperties": false, "properties": { @@ -369,6 +403,9 @@ }, { "$ref": "#/definitions/RecordingsQuery" + }, + { + "$ref": "#/definitions/TracesQuery" } ] }, @@ -486,6 +523,9 @@ }, { "$ref": "#/definitions/ErrorTrackingQueryResponse" + }, + { + "$ref": "#/definitions/TracesQueryResponse" } ] }, @@ -3321,6 +3361,89 @@ ], "type": "object" }, + "CachedTracesQueryResponse": { + "additionalProperties": false, + "properties": { + "cache_key": { + "type": "string" + }, + "cache_target_age": { + "format": "date-time", + "type": "string" + }, + "calculation_trigger": { + "description": "What triggered the calculation of the query, leave empty if user/immediate", + "type": "string" + }, + "columns": { + "items": { + "type": "string" + }, + "type": "array" + }, + "error": { + "description": "Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + "type": "string" + }, + "hasMore": { + "type": "boolean" + }, + "hogql": { + "description": "Generated HogQL query.", + "type": "string" + }, + "is_cached": { + "type": "boolean" + }, + "last_refresh": { + "format": "date-time", + "type": "string" + }, + "limit": { + "type": "integer" + }, + "modifiers": { + "$ref": "#/definitions/HogQLQueryModifiers", + "description": "Modifiers used when performing the query" + }, + "next_allowed_client_refresh": { + "format": "date-time", + "type": "string" + }, + "offset": { + "type": "integer" + }, + "query_status": { + "$ref": "#/definitions/QueryStatus", + "description": "Query status indicates whether next to the provided data, a query is still running." + }, + "results": { + "items": { + "$ref": "#/definitions/AITrace" + }, + "type": "array" + }, + "timezone": { + "type": "string" + }, + "timings": { + "description": "Measured timings for different parts of the query generation process", + "items": { + "$ref": "#/definitions/QueryTiming" + }, + "type": "array" + } + }, + "required": [ + "cache_key", + "is_cached", + "last_refresh", + "next_allowed_client_refresh", + "results", + "timezone" + ], + "type": "object" + }, "CachedTrendsQueryResponse": { "additionalProperties": false, "properties": { @@ -4738,6 +4861,57 @@ "variants" ], "type": "object" + }, + { + "additionalProperties": false, + "properties": { + "columns": { + "items": { + "type": "string" + }, + "type": "array" + }, + "error": { + "description": "Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + "type": "string" + }, + "hasMore": { + "type": "boolean" + }, + "hogql": { + "description": "Generated HogQL query.", + "type": "string" + }, + "limit": { + "type": "integer" + }, + "modifiers": { + "$ref": "#/definitions/HogQLQueryModifiers", + "description": "Modifiers used when performing the query" + }, + "offset": { + "type": "integer" + }, + "query_status": { + "$ref": "#/definitions/QueryStatus", + "description": "Query status indicates whether next to the provided data, a query is still running." + }, + "results": { + "items": { + "$ref": "#/definitions/AITrace" + }, + "type": "array" + }, + "timings": { + "description": "Measured timings for different parts of the query generation process", + "items": { + "$ref": "#/definitions/QueryTiming" + }, + "type": "array" + } + }, + "required": ["results"], + "type": "object" } ] }, @@ -4855,6 +5029,9 @@ }, { "$ref": "#/definitions/ExperimentTrendsQuery" + }, + { + "$ref": "#/definitions/TracesQuery" } ], "description": "Source of the events" @@ -8686,7 +8863,8 @@ "SuggestedQuestionsQuery", "TeamTaxonomyQuery", "EventTaxonomyQuery", - "ActorsPropertyTaxonomyQuery" + "ActorsPropertyTaxonomyQuery", + "TracesQuery" ], "type": "string" }, @@ -10698,6 +10876,57 @@ ], "type": "object" }, + { + "additionalProperties": false, + "properties": { + "columns": { + "items": { + "type": "string" + }, + "type": "array" + }, + "error": { + "description": "Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + "type": "string" + }, + "hasMore": { + "type": "boolean" + }, + "hogql": { + "description": "Generated HogQL query.", + "type": "string" + }, + "limit": { + "type": "integer" + }, + "modifiers": { + "$ref": "#/definitions/HogQLQueryModifiers", + "description": "Modifiers used when performing the query" + }, + "offset": { + "type": "integer" + }, + "query_status": { + "$ref": "#/definitions/QueryStatus", + "description": "Query status indicates whether next to the provided data, a query is still running." + }, + "results": { + "items": { + "$ref": "#/definitions/AITrace" + }, + "type": "array" + }, + "timings": { + "description": "Measured timings for different parts of the query generation process", + "items": { + "$ref": "#/definitions/QueryTiming" + }, + "type": "array" + } + }, + "required": ["results"], + "type": "object" + }, { "additionalProperties": false, "properties": { @@ -11105,6 +11334,57 @@ }, "required": ["results"], "type": "object" + }, + { + "additionalProperties": false, + "properties": { + "columns": { + "items": { + "type": "string" + }, + "type": "array" + }, + "error": { + "description": "Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + "type": "string" + }, + "hasMore": { + "type": "boolean" + }, + "hogql": { + "description": "Generated HogQL query.", + "type": "string" + }, + "limit": { + "type": "integer" + }, + "modifiers": { + "$ref": "#/definitions/HogQLQueryModifiers", + "description": "Modifiers used when performing the query" + }, + "offset": { + "type": "integer" + }, + "query_status": { + "$ref": "#/definitions/QueryStatus", + "description": "Query status indicates whether next to the provided data, a query is still running." + }, + "results": { + "items": { + "$ref": "#/definitions/AITrace" + }, + "type": "array" + }, + "timings": { + "description": "Measured timings for different parts of the query generation process", + "items": { + "$ref": "#/definitions/QueryTiming" + }, + "type": "array" + } + }, + "required": ["results"], + "type": "object" } ] }, @@ -11223,6 +11503,9 @@ }, { "$ref": "#/definitions/ActorsPropertyTaxonomyQuery" + }, + { + "$ref": "#/definitions/TracesQuery" } ], "required": ["kind"], @@ -12754,6 +13037,84 @@ "required": ["events"], "type": "object" }, + "TracesQuery": { + "additionalProperties": false, + "properties": { + "kind": { + "const": "TracesQuery", + "type": "string" + }, + "limit": { + "type": "integer" + }, + "modifiers": { + "$ref": "#/definitions/HogQLQueryModifiers", + "description": "Modifiers used when performing the query" + }, + "offset": { + "type": "integer" + }, + "response": { + "$ref": "#/definitions/TracesQueryResponse" + }, + "trace_id": { + "type": "string" + } + }, + "required": ["kind"], + "type": "object" + }, + "TracesQueryResponse": { + "additionalProperties": false, + "properties": { + "columns": { + "items": { + "type": "string" + }, + "type": "array" + }, + "error": { + "description": "Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + "type": "string" + }, + "hasMore": { + "type": "boolean" + }, + "hogql": { + "description": "Generated HogQL query.", + "type": "string" + }, + "limit": { + "type": "integer" + }, + "modifiers": { + "$ref": "#/definitions/HogQLQueryModifiers", + "description": "Modifiers used when performing the query" + }, + "offset": { + "type": "integer" + }, + "query_status": { + "$ref": "#/definitions/QueryStatus", + "description": "Query status indicates whether next to the provided data, a query is still running." + }, + "results": { + "items": { + "$ref": "#/definitions/AITrace" + }, + "type": "array" + }, + "timings": { + "description": "Measured timings for different parts of the query generation process", + "items": { + "$ref": "#/definitions/QueryTiming" + }, + "type": "array" + } + }, + "required": ["results"], + "type": "object" + }, "TrendsAlertConfig": { "additionalProperties": false, "properties": { diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index e13773cbcc23b..5e87d0c2953bc 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -113,6 +113,7 @@ export enum NodeKind { TeamTaxonomyQuery = 'TeamTaxonomyQuery', EventTaxonomyQuery = 'EventTaxonomyQuery', ActorsPropertyTaxonomyQuery = 'ActorsPropertyTaxonomyQuery', + TracesQuery = 'TracesQuery', } export type AnyDataNode = @@ -137,6 +138,7 @@ export type AnyDataNode = | ExperimentFunnelsQuery | ExperimentTrendsQuery | RecordingsQuery + | TracesQuery /** * @discriminator kind @@ -188,6 +190,7 @@ export type QuerySchema = | TeamTaxonomyQuery | EventTaxonomyQuery | ActorsPropertyTaxonomyQuery + | TracesQuery // Keep this, because QuerySchema itself will be collapsed as it is used in other models export type QuerySchemaRoot = QuerySchema @@ -218,6 +221,7 @@ export type AnyResponseType = | EventsNode['response'] | EventsQueryResponse | ErrorTrackingQueryResponse + | TracesQueryResponse /** @internal - no need to emit to schema.json. */ export interface DataNode = Record> extends Node { @@ -637,6 +641,7 @@ export interface DataTableNode | ErrorTrackingQuery | ExperimentFunnelsQuery | ExperimentTrendsQuery + | TracesQuery )['response'] > >, @@ -657,6 +662,7 @@ export interface DataTableNode | ErrorTrackingQuery | ExperimentFunnelsQuery | ExperimentTrendsQuery + | TracesQuery /** Columns shown in the table, unless the `source` provides them. */ columns?: HogQLExpression[] /** Columns that aren't shown in the table, even if in columns or returned data */ @@ -2523,6 +2529,33 @@ export type ActorsPropertyTaxonomyQueryResponse = AnalyticsQueryResponseBase +export interface AISpan { + id: string + input: any[] + output: any[] +} + +export interface AITrace { + id: string + spans: AISpan[] +} + +export interface TracesQueryResponse extends AnalyticsQueryResponseBase { + hasMore?: boolean + limit?: integer + offset?: integer + columns?: string[] +} + +export interface TracesQuery extends DataNode { + kind: NodeKind.TracesQuery + trace_id?: string + limit?: integer + offset?: integer +} + +export type CachedTracesQueryResponse = CachedQueryResponse + export enum AssistantMessageType { Human = 'human', Assistant = 'ai', diff --git a/posthog/hogql_queries/ai/test/test_traces_query_runner.py b/posthog/hogql_queries/ai/test/test_traces_query_runner.py new file mode 100644 index 0000000000000..9759a4f036ace --- /dev/null +++ b/posthog/hogql_queries/ai/test/test_traces_query_runner.py @@ -0,0 +1,90 @@ +from typing import Any, Literal, TypedDict + +from posthog.hogql_queries.ai.traces_query_runner import TracesQueryRunner +from posthog.models import Team +from posthog.schema import TracesQuery +from posthog.test.base import ( + BaseTest, + ClickhouseTestMixin, + _create_event, + _create_person, + snapshot_clickhouse_queries, +) + + +class InputMessage(TypedDict): + role: Literal["user", "assistant"] + content: str + + +class OutputMessage(TypedDict): + role: Literal["user", "assistant", "tool"] + content: str + + +def _calculate_tokens(messages: str | list[InputMessage] | list[OutputMessage]) -> int: + if isinstance(messages, str): + message = messages + else: + message = "".join([message["content"] for message in messages]) + return len(message) + + +def _create_ai_generation_event( + input: str | list[InputMessage] | None = "What is the capital of Spain?", + output: str | list[OutputMessage] | None = "Madrid", + team: Team | None = None, + distinct_id: str | None = None, + trace_id: str | None = None, + properties: dict[str, Any] | None = None, +): + input_tokens = _calculate_tokens(input) + output_tokens = _calculate_tokens(output) + props = { + "$ai_trace_id": trace_id, + "$ai_latency": 1, + "$ai_input_tokens": input_tokens, + "$ai_output_tokens": output_tokens, + "$ai_input_cost_usd": input_tokens, + "$ai_output_cost_usd": output_tokens, + "$ai_total_cost_usd": input_tokens + output_tokens, + } + if properties: + props.update(properties) + _create_event( + event="$ai_generation", + distinct_id=distinct_id, + properties=props, + team=team, + ) + + +class TestTracesQueryRunner(ClickhouseTestMixin, BaseTest): + @snapshot_clickhouse_queries + def test_traces_query_runner(self): + _create_person(distinct_ids=["person1"], team=self.team) + _create_person(distinct_ids=["person2"], team=self.team) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + input="Foo", + output="Bar", + team=self.team, + ) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + input="Foo", + output="Bar", + team=self.team, + ) + _create_ai_generation_event( + distinct_id="person2", + trace_id="trace2", + input="Foo", + output="Bar", + team=self.team, + ) + + results = TracesQueryRunner(team=self.team, query=TracesQuery()).calculate() + self.assertEqual(len(results.results), 2) diff --git a/posthog/hogql_queries/ai/traces_query_runner.py b/posthog/hogql_queries/ai/traces_query_runner.py new file mode 100644 index 0000000000000..72d7f23778208 --- /dev/null +++ b/posthog/hogql_queries/ai/traces_query_runner.py @@ -0,0 +1,127 @@ +import structlog + +from posthog.hogql import ast +from posthog.hogql.constants import LimitContext +from posthog.hogql_queries.insights.paginators import HogQLHasMorePaginator +from posthog.hogql_queries.query_runner import QueryRunner +from posthog.schema import ( + CachedTracesQueryResponse, + NodeKind, + TracesQuery, + TracesQueryResponse, +) + +logger = structlog.get_logger(__name__) + + +""" +select + properties.$ai_trace_id as trace_id, + min(timestamp) as trace_timestamp, + max(person.properties) as person, + sum(properties.$ai_latency) as total_latency, + sum(properties.$ai_input_tokens) as input_tokens, + sum(properties.$ai_output_tokens) as output_tokens, + sum(properties.$ai_input_cost_usd) as input_cost, + sum(properties.$ai_output_cost_usd) as output_cost, + sum(properties.$ai_total_cost_usd) as total_cost, + arraySort(x -> x.1, groupArray(tuple(timestamp, properties))) as spans +from events +where + event = '$ai_generation' +group by + trace_id +order by + trace_timestamp desc +""" + + +class TracesQueryRunner(QueryRunner): + query: TracesQuery + response: TracesQueryResponse + cached_response: CachedTracesQueryResponse + paginator: HogQLHasMorePaginator + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.paginator = HogQLHasMorePaginator.from_limit_context( + limit_context=LimitContext.QUERY, + limit=self.query.limit if self.query.limit else None, + offset=self.query.offset, + ) + + def to_query(self) -> ast.SelectQuery: + return ast.SelectQuery( + select=self._get_select_fields(), + select_from=ast.JoinExpr(table=ast.Field(chain=["events"])), + where=self._get_where_clause(), + order_by=self._get_order_by_clause(), + group_by=[ast.Field(chain=["trace_id"])], + ) + + def calculate(self): + with self.timings.measure("error_tracking_query_hogql_execute"): + query_result = self.paginator.execute_hogql_query( + query=self.to_query(), + team=self.team, + query_type=NodeKind.TRACES_QUERY, + timings=self.timings, + modifiers=self.modifiers, + limit_context=self.limit_context, + ) + + columns: list[str] = query_result.columns or [] + results = self._map_results(columns, query_result.results) + + return TracesQueryResponse( + columns=columns, + results=results, + timings=query_result.timings, + hogql=query_result.hogql, + modifiers=self.modifiers, + **self.paginator.response_params(), + ) + + def _map_results(self, columns: list[str], query_results: list): + mapped_results = [dict(zip(columns, value)) for value in query_results] + return mapped_results + + def _get_select_fields(self) -> list[ast.Expr]: + return [ + ast.Field(chain=["properties", "$ai_trace_id"]), + ast.Call(name="min", args=[ast.Field(chain=["timestamp"])]), + ast.Call(name="max", args=[ast.Field(chain=["person", "properties"])]), + ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_latency"])]), + ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_input_tokens"])]), + ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_output_tokens"])]), + ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_input_cost_usd"])]), + ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_output_cost_usd"])]), + ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_total_cost_usd"])]), + ast.Call( + name="arraySort", + args=[ + ast.Lambda(args=["x"], expr=ast.Field(chain=["x", "1"])), + ast.Call( + name="groupArray", + args=[ + ast.Tuple( + exprs=[ + ast.Field(chain=["timestamp"]), + ast.Field(chain=["properties"]), + ] + ) + ], + ), + ], + ), + ] + + def _get_where_clause(self): + return ast.CompareOperation( + left=ast.Field(chain=["event"]), + op=ast.CompareOperationOp.Eq, + right=ast.Constant(value="$ai_generation"), + ) + + def _get_order_by_clause(self): + return [ast.OrderExpr(expr=ast.Field(chain=["timestamp"]), order="DESC")] diff --git a/posthog/schema.py b/posthog/schema.py index 325004e8cab1b..a407a43950206 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -12,6 +12,23 @@ class SchemaRoot(RootModel[Any]): root: Any +class AISpan(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + id: str + input: list + output: list + + +class AITrace(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + id: str + spans: list[AISpan] + + class ActionConversionGoal(BaseModel): model_config = ConfigDict( extra="forbid", @@ -1196,6 +1213,7 @@ class NodeKind(StrEnum): TEAM_TAXONOMY_QUERY = "TeamTaxonomyQuery" EVENT_TAXONOMY_QUERY = "EventTaxonomyQuery" ACTORS_PROPERTY_TAXONOMY_QUERY = "ActorsPropertyTaxonomyQuery" + TRACES_QUERY = "TracesQuery" class PathCleaningFilter(BaseModel): @@ -1343,7 +1361,7 @@ class QueryResponseAlternative7(BaseModel): warnings: list[HogQLNotice] -class QueryResponseAlternative36(BaseModel): +class QueryResponseAlternative37(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -2874,6 +2892,31 @@ class QueryResponseAlternative25(BaseModel): class QueryResponseAlternative28(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + columns: Optional[list[str]] = None + error: Optional[str] = Field( + default=None, + description="Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + ) + hasMore: Optional[bool] = None + hogql: Optional[str] = Field(default=None, description="Generated HogQL query.") + limit: Optional[int] = None + modifiers: Optional[HogQLQueryModifiers] = Field( + default=None, description="Modifiers used when performing the query" + ) + offset: Optional[int] = None + query_status: Optional[QueryStatus] = Field( + default=None, description="Query status indicates whether next to the provided data, a query is still running." + ) + results: list[AITrace] + timings: Optional[list[QueryTiming]] = Field( + default=None, description="Measured timings for different parts of the query generation process" + ) + + +class QueryResponseAlternative29(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -2895,7 +2938,7 @@ class QueryResponseAlternative28(BaseModel): ) -class QueryResponseAlternative29(BaseModel): +class QueryResponseAlternative30(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -2917,7 +2960,7 @@ class QueryResponseAlternative29(BaseModel): ) -class QueryResponseAlternative31(BaseModel): +class QueryResponseAlternative32(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -2938,7 +2981,7 @@ class QueryResponseAlternative31(BaseModel): ) -class QueryResponseAlternative34(BaseModel): +class QueryResponseAlternative35(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -2964,7 +3007,7 @@ class QueryResponseAlternative34(BaseModel): types: Optional[list] = None -class QueryResponseAlternative37(BaseModel): +class QueryResponseAlternative38(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -2985,7 +3028,7 @@ class QueryResponseAlternative37(BaseModel): ) -class QueryResponseAlternative38(BaseModel): +class QueryResponseAlternative39(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -3006,7 +3049,7 @@ class QueryResponseAlternative38(BaseModel): ) -class QueryResponseAlternative39(BaseModel): +class QueryResponseAlternative40(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -3027,6 +3070,31 @@ class QueryResponseAlternative39(BaseModel): ) +class QueryResponseAlternative41(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + columns: Optional[list[str]] = None + error: Optional[str] = Field( + default=None, + description="Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + ) + hasMore: Optional[bool] = None + hogql: Optional[str] = Field(default=None, description="Generated HogQL query.") + limit: Optional[int] = None + modifiers: Optional[HogQLQueryModifiers] = Field( + default=None, description="Modifiers used when performing the query" + ) + offset: Optional[int] = None + query_status: Optional[QueryStatus] = Field( + default=None, description="Query status indicates whether next to the provided data, a query is still running." + ) + results: list[AITrace] + timings: Optional[list[QueryTiming]] = Field( + default=None, description="Measured timings for different parts of the query generation process" + ) + + class ResultCustomization(RootModel[Union[ResultCustomizationByValue, ResultCustomizationByPosition]]): root: Union[ResultCustomizationByValue, ResultCustomizationByPosition] @@ -3368,6 +3436,31 @@ class TestCachedBasicQueryResponse(BaseModel): ) +class TracesQueryResponse(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + columns: Optional[list[str]] = None + error: Optional[str] = Field( + default=None, + description="Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + ) + hasMore: Optional[bool] = None + hogql: Optional[str] = Field(default=None, description="Generated HogQL query.") + limit: Optional[int] = None + modifiers: Optional[HogQLQueryModifiers] = Field( + default=None, description="Modifiers used when performing the query" + ) + offset: Optional[int] = None + query_status: Optional[QueryStatus] = Field( + default=None, description="Query status indicates whether next to the provided data, a query is still running." + ) + results: list[AITrace] + timings: Optional[list[QueryTiming]] = Field( + default=None, description="Measured timings for different parts of the query generation process" + ) + + class TrendsQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", @@ -4379,6 +4472,40 @@ class CachedTeamTaxonomyQueryResponse(BaseModel): ) +class CachedTracesQueryResponse(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + cache_key: str + cache_target_age: Optional[AwareDatetime] = None + calculation_trigger: Optional[str] = Field( + default=None, description="What triggered the calculation of the query, leave empty if user/immediate" + ) + columns: Optional[list[str]] = None + error: Optional[str] = Field( + default=None, + description="Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + ) + hasMore: Optional[bool] = None + hogql: Optional[str] = Field(default=None, description="Generated HogQL query.") + is_cached: bool + last_refresh: AwareDatetime + limit: Optional[int] = None + modifiers: Optional[HogQLQueryModifiers] = Field( + default=None, description="Modifiers used when performing the query" + ) + next_allowed_client_refresh: AwareDatetime + offset: Optional[int] = None + query_status: Optional[QueryStatus] = Field( + default=None, description="Query status indicates whether next to the provided data, a query is still running." + ) + results: list[AITrace] + timezone: str + timings: Optional[list[QueryTiming]] = Field( + default=None, description="Measured timings for different parts of the query generation process" + ) + + class CachedTrendsQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", @@ -4763,6 +4890,31 @@ class Response8(BaseModel): ) +class Response11(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + columns: Optional[list[str]] = None + error: Optional[str] = Field( + default=None, + description="Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + ) + hasMore: Optional[bool] = None + hogql: Optional[str] = Field(default=None, description="Generated HogQL query.") + limit: Optional[int] = None + modifiers: Optional[HogQLQueryModifiers] = Field( + default=None, description="Modifiers used when performing the query" + ) + offset: Optional[int] = None + query_status: Optional[QueryStatus] = Field( + default=None, description="Query status indicates whether next to the provided data, a query is still running." + ) + results: list[AITrace] + timings: Optional[list[QueryTiming]] = Field( + default=None, description="Measured timings for different parts of the query generation process" + ) + + class DataWarehouseNode(BaseModel): model_config = ConfigDict( extra="forbid", @@ -5428,7 +5580,7 @@ class QueryResponseAlternative9(BaseModel): ) -class QueryResponseAlternative30(BaseModel): +class QueryResponseAlternative31(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -5522,6 +5674,20 @@ class TeamTaxonomyQuery(BaseModel): response: Optional[TeamTaxonomyQueryResponse] = None +class TracesQuery(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + kind: Literal["TracesQuery"] = "TracesQuery" + limit: Optional[int] = None + modifiers: Optional[HogQLQueryModifiers] = Field( + default=None, description="Modifiers used when performing the query" + ) + offset: Optional[int] = None + response: Optional[TracesQueryResponse] = None + trace_id: Optional[str] = None + + class VisualizationMessage(BaseModel): model_config = ConfigDict( extra="forbid", @@ -5704,6 +5870,7 @@ class AnyResponseType( Any, EventsQueryResponse, ErrorTrackingQueryResponse, + TracesQueryResponse, ] ] ): @@ -5716,6 +5883,7 @@ class AnyResponseType( Any, EventsQueryResponse, ErrorTrackingQueryResponse, + TracesQueryResponse, ] @@ -6540,7 +6708,7 @@ class QueryResponseAlternative27(BaseModel): variants: list[ExperimentVariantTrendsBaseStats] -class QueryResponseAlternative35(BaseModel): +class QueryResponseAlternative36(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -6587,12 +6755,14 @@ class QueryResponseAlternative( QueryResponseAlternative29, QueryResponseAlternative30, QueryResponseAlternative31, - QueryResponseAlternative34, + QueryResponseAlternative32, QueryResponseAlternative35, QueryResponseAlternative36, QueryResponseAlternative37, QueryResponseAlternative38, QueryResponseAlternative39, + QueryResponseAlternative40, + QueryResponseAlternative41, ] ] ): @@ -6625,12 +6795,14 @@ class QueryResponseAlternative( QueryResponseAlternative29, QueryResponseAlternative30, QueryResponseAlternative31, - QueryResponseAlternative34, + QueryResponseAlternative32, QueryResponseAlternative35, QueryResponseAlternative36, QueryResponseAlternative37, QueryResponseAlternative38, QueryResponseAlternative39, + QueryResponseAlternative40, + QueryResponseAlternative41, ] @@ -7013,6 +7185,7 @@ class DataTableNode(BaseModel): Response8, Response9, Response10, + Response11, ] ] = None showActions: Optional[bool] = Field(default=None, description="Show the kebab menu at the end of the row") @@ -7055,6 +7228,7 @@ class DataTableNode(BaseModel): ErrorTrackingQuery, ExperimentFunnelsQuery, ExperimentTrendsQuery, + TracesQuery, ] = Field(..., description="Source of the events") @@ -7095,6 +7269,7 @@ class HogQLAutocomplete(BaseModel): ExperimentFunnelsQuery, ExperimentTrendsQuery, RecordingsQuery, + TracesQuery, ] ] = Field(default=None, description="Query in whose context to validate.") startPosition: int = Field(..., description="Start position of the editor word") @@ -7139,6 +7314,7 @@ class HogQLMetadata(BaseModel): ExperimentFunnelsQuery, ExperimentTrendsQuery, RecordingsQuery, + TracesQuery, ] ] = Field( default=None, @@ -7196,6 +7372,7 @@ class QueryRequest(BaseModel): TeamTaxonomyQuery, EventTaxonomyQuery, ActorsPropertyTaxonomyQuery, + TracesQuery, ] = Field( ..., description=( @@ -7263,6 +7440,7 @@ class QuerySchemaRoot( TeamTaxonomyQuery, EventTaxonomyQuery, ActorsPropertyTaxonomyQuery, + TracesQuery, ] ] ): @@ -7304,6 +7482,7 @@ class QuerySchemaRoot( TeamTaxonomyQuery, EventTaxonomyQuery, ActorsPropertyTaxonomyQuery, + TracesQuery, ] = Field(..., discriminator="kind") From d76093b6734d7cc791e3b534ef91f7a2621e8637 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 14 Jan 2025 10:27:52 +0100 Subject: [PATCH 02/15] fix: query aliases --- .../hogql_queries/ai/traces_query_runner.py | 69 ++++++++++++------- 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/posthog/hogql_queries/ai/traces_query_runner.py b/posthog/hogql_queries/ai/traces_query_runner.py index 72d7f23778208..e574c35ba59cc 100644 --- a/posthog/hogql_queries/ai/traces_query_runner.py +++ b/posthog/hogql_queries/ai/traces_query_runner.py @@ -88,31 +88,48 @@ def _map_results(self, columns: list[str], query_results: list): def _get_select_fields(self) -> list[ast.Expr]: return [ - ast.Field(chain=["properties", "$ai_trace_id"]), - ast.Call(name="min", args=[ast.Field(chain=["timestamp"])]), - ast.Call(name="max", args=[ast.Field(chain=["person", "properties"])]), - ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_latency"])]), - ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_input_tokens"])]), - ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_output_tokens"])]), - ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_input_cost_usd"])]), - ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_output_cost_usd"])]), - ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_total_cost_usd"])]), - ast.Call( - name="arraySort", - args=[ - ast.Lambda(args=["x"], expr=ast.Field(chain=["x", "1"])), - ast.Call( - name="groupArray", - args=[ - ast.Tuple( - exprs=[ - ast.Field(chain=["timestamp"]), - ast.Field(chain=["properties"]), - ] - ) - ], - ), - ], + ast.Alias(expr=ast.Field(chain=["properties", "$ai_trace_id"]), alias="trace_id"), + ast.Alias(expr=ast.Call(name="min", args=[ast.Field(chain=["timestamp"])]), alias="trace_timestamp"), + ast.Alias(expr=ast.Call(name="max", args=[ast.Field(chain=["person", "properties"])]), alias="person"), + ast.Alias( + expr=ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_latency"])]), + alias="total_latency", + ), + ast.Alias( + expr=ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_input_tokens"])]), + alias="input_tokens", + ), + ast.Alias( + expr=ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_output_tokens"])]), + alias="output_tokens", + ), + ast.Alias( + expr=ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_input_cost_usd"])]), + alias="input_cost", + ), + ast.Alias( + expr=ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_output_cost_usd"])]), + alias="output_cost", + ), + ast.Alias( + expr=ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_total_cost_usd"])]), + alias="total_cost", + ), + ast.Alias( + expr=ast.Call( + name="arraySort", + args=[ + ast.Lambda( + args=["x"], + expr=ast.Call(name="tupleElement", args=[ast.Field(chain=["x"]), ast.Constant(value=1)]), + ), + ast.Call( + name="groupArray", + args=[ast.Tuple(exprs=[ast.Field(chain=["timestamp"]), ast.Field(chain=["properties"])])], + ), + ], + ), + alias="spans", ), ] @@ -124,4 +141,4 @@ def _get_where_clause(self): ) def _get_order_by_clause(self): - return [ast.OrderExpr(expr=ast.Field(chain=["timestamp"]), order="DESC")] + return [ast.OrderExpr(expr=ast.Field(chain=["trace_timestamp"]), order="DESC")] From 3922f9d0f60ceac6a0d0e266ec32ca84cfc8fbe5 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 14 Jan 2025 13:27:20 +0100 Subject: [PATCH 03/15] feat: trace and generation models --- frontend/src/queries/schema.json | 181 ++++++++++++++++-- frontend/src/queries/schema.ts | 26 ++- .../test_traces_query_runner.ambr | 44 +++++ .../ai/test/test_traces_query_runner.py | 46 ++++- .../hogql_queries/ai/traces_query_runner.py | 87 ++++++++- posthog/schema.py | 57 ++++-- 6 files changed, 388 insertions(+), 53 deletions(-) create mode 100644 posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index a4bc84efa716f..6bdf20ffbe24c 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -1,9 +1,18 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "definitions": { - "AISpan": { + "AIGeneration": { "additionalProperties": false, "properties": { + "base_url": { + "type": "string" + }, + "created_at": { + "type": "string" + }, + "http_status": { + "type": "number" + }, "id": { "type": "string" }, @@ -11,28 +20,87 @@ "items": {}, "type": "array" }, + "input_cost": { + "type": "number" + }, + "input_tokens": { + "type": "number" + }, + "latency": { + "type": "number" + }, + "model": { + "type": "string" + }, "output": { "items": {}, "type": "array" + }, + "output_cost": { + "type": "number" + }, + "output_tokens": { + "type": "number" + }, + "provider": { + "type": "string" + }, + "total_cost": { + "type": "number" } }, - "required": ["id", "input", "output"], + "required": ["id", "created_at", "input", "latency"], "type": "object" }, "AITrace": { "additionalProperties": false, "properties": { - "id": { + "created_at": { "type": "string" }, - "spans": { + "events": { "items": { - "$ref": "#/definitions/AISpan" + "$ref": "#/definitions/AIGeneration" }, "type": "array" + }, + "id": { + "type": "string" + }, + "input_cost": { + "type": "number" + }, + "input_tokens": { + "type": "number" + }, + "output_cost": { + "type": "number" + }, + "output_tokens": { + "type": "number" + }, + "person": { + "type": "object" + }, + "total_cost": { + "type": "number" + }, + "total_latency": { + "type": "number" } }, - "required": ["id", "spans"], + "required": [ + "id", + "created_at", + "person", + "total_latency", + "input_tokens", + "output_tokens", + "input_cost", + "output_cost", + "total_cost", + "events" + ], "type": "object" }, "ActionConversionGoal": { @@ -877,8 +945,27 @@ "description": "`icontains` - case insensitive contains. `not_icontains` - case insensitive does not contain. `regex` - matches the regex pattern. `not_regex` - does not match the regex pattern." }, "type": { - "enum": ["event", "person", "session", "feature"], - "type": "string" + "anyOf": [ + { + "const": "event", + "description": "Event properties", + "type": "string" + }, + { + "const": "person", + "description": "Person properties", + "type": "string" + }, + { + "const": "session", + "type": "string" + }, + { + "const": "feature", + "description": "Event property with \"$feature/\" prepended", + "type": "string" + } + ] }, "value": { "description": "Only use property values from the plan. If the operator is `regex` or `not_regex`, the value must be a valid ClickHouse regex pattern to match against. Otherwise, the value must be a substring that will be matched against the property value.", @@ -900,8 +987,27 @@ "description": "`exact` - exact match of any of the values. `is_not` - does not match any of the values." }, "type": { - "enum": ["event", "person", "session", "feature"], - "type": "string" + "anyOf": [ + { + "const": "event", + "description": "Event properties", + "type": "string" + }, + { + "const": "person", + "description": "Person properties", + "type": "string" + }, + { + "const": "session", + "type": "string" + }, + { + "const": "feature", + "description": "Event property with \"$feature/\" prepended", + "type": "string" + } + ] }, "value": { "description": "Only use property values from the plan. Always use strings as values. If you have a number, convert it to a string first. If you have a boolean, convert it to a string \"true\" or \"false\".", @@ -925,8 +1031,27 @@ "$ref": "#/definitions/AssistantDateTimePropertyFilterOperator" }, "type": { - "enum": ["event", "person", "session", "feature"], - "type": "string" + "anyOf": [ + { + "const": "event", + "description": "Event properties", + "type": "string" + }, + { + "const": "person", + "description": "Person properties", + "type": "string" + }, + { + "const": "session", + "type": "string" + }, + { + "const": "feature", + "description": "Event property with \"$feature/\" prepended", + "type": "string" + } + ] }, "value": { "description": "Value must be a date in ISO 8601 format.", @@ -948,8 +1073,27 @@ "description": "`is_set` - the property has any value. `is_not_set` - the property doesn't have a value or wasn't collected." }, "type": { - "enum": ["event", "person", "session", "feature"], - "type": "string" + "anyOf": [ + { + "const": "event", + "description": "Event properties", + "type": "string" + }, + { + "const": "person", + "description": "Person properties", + "type": "string" + }, + { + "const": "session", + "type": "string" + }, + { + "const": "feature", + "description": "Event property with \"$feature/\" prepended", + "type": "string" + } + ] } }, "required": ["key", "operator", "type"], @@ -11526,8 +11670,7 @@ "type": "integer" }, "end_time": { - "description": "When did the query execution task finish (whether successfully or not).", - "format": "date-time", + "description": "When did the query execution task finish (whether successfully or not). @format date-time", "type": "string" }, "error": { @@ -11556,8 +11699,7 @@ "type": "array" }, "pickup_time": { - "description": "When was the query execution task picked up by a worker.", - "format": "date-time", + "description": "When was the query execution task picked up by a worker. @format date-time", "type": "string" }, "query_async": { @@ -11571,8 +11713,7 @@ }, "results": {}, "start_time": { - "description": "When was query execution task enqueued.", - "format": "date-time", + "description": "When was query execution task enqueued. @format date-time", "type": "string" }, "task_id": { diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index 5e87d0c2953bc..30995197f787e 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -568,6 +568,7 @@ export interface EventsQueryPersonColumn { } distinct_id: string } + export interface EventsQuery extends DataNode { kind: NodeKind.EventsQuery /** Return a limited set of data. Required. */ @@ -2529,15 +2530,34 @@ export type ActorsPropertyTaxonomyQueryResponse = AnalyticsQueryResponseBase -export interface AISpan { +export interface AIGeneration { id: string + created_at: string input: any[] - output: any[] + latency: number + output?: any[] + provider?: string + model?: string + input_tokens?: number + output_tokens?: number + input_cost?: number + output_cost?: number + total_cost?: number + http_status?: number + base_url?: string } export interface AITrace { id: string - spans: AISpan[] + created_at: string + person: Record + total_latency: number + input_tokens: number + output_tokens: number + input_cost: number + output_cost: number + total_cost: number + events: AIGeneration[] } export interface TracesQueryResponse extends AnalyticsQueryResponseBase { diff --git a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr new file mode 100644 index 0000000000000..dfa8a3c411bea --- /dev/null +++ b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr @@ -0,0 +1,44 @@ +# serializer version: 1 +# name: TestTracesQueryRunner.test_traces_query_runner + ''' + SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, + min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, + max(events__person.properties) AS person, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, + arraySort(x -> tupleElement(x, 1), groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS spans + FROM events + LEFT OUTER JOIN + (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, + person_distinct_id_overrides.distinct_id AS distinct_id + FROM person_distinct_id_overrides + WHERE equals(person_distinct_id_overrides.team_id, 99999) + GROUP BY person_distinct_id_overrides.distinct_id + HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) + LEFT JOIN + (SELECT person.id AS id, + person.properties AS properties + FROM person + WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), + (SELECT person.id AS id, max(person.version) AS version + FROM person + WHERE equals(person.team_id, 99999) + GROUP BY person.id + HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation')) + GROUP BY id + ORDER BY trace_timestamp DESC + LIMIT 101 + OFFSET 0 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=1, + format_csv_allow_double_quotes=0, + max_ast_elements=4000000, + max_expanded_ast_elements=4000000, + max_bytes_before_external_group_by=0 + ''' +# --- diff --git a/posthog/hogql_queries/ai/test/test_traces_query_runner.py b/posthog/hogql_queries/ai/test/test_traces_query_runner.py index 9759a4f036ace..3278c35bf1db0 100644 --- a/posthog/hogql_queries/ai/test/test_traces_query_runner.py +++ b/posthog/hogql_queries/ai/test/test_traces_query_runner.py @@ -1,7 +1,9 @@ +import json from typing import Any, Literal, TypedDict from posthog.hogql_queries.ai.traces_query_runner import TracesQueryRunner -from posthog.models import Team +from posthog.models import PropertyDefinition, Team +from posthog.models.property_definition import PropertyType from posthog.schema import TracesQuery from posthog.test.base import ( BaseTest, @@ -31,8 +33,8 @@ def _calculate_tokens(messages: str | list[InputMessage] | list[OutputMessage]) def _create_ai_generation_event( - input: str | list[InputMessage] | None = "What is the capital of Spain?", - output: str | list[OutputMessage] | None = "Madrid", + input: str | list[InputMessage] | None = "Foo", + output: str | list[OutputMessage] | None = "Bar", team: Team | None = None, distinct_id: str | None = None, trace_id: str | None = None, @@ -40,9 +42,22 @@ def _create_ai_generation_event( ): input_tokens = _calculate_tokens(input) output_tokens = _calculate_tokens(output) + + if isinstance(input, str): + input_messages = [{"role": "user", "content": input}] + else: + input_messages = input + + if isinstance(output, str): + output_messages = [{"role": "assistant", "content": output}] + else: + output_messages = output + props = { "$ai_trace_id": trace_id, "$ai_latency": 1, + "$ai_input": json.dumps(input_messages), + "$ai_output": json.dumps(output_messages), "$ai_input_tokens": input_tokens, "$ai_output_tokens": output_tokens, "$ai_input_cost_usd": input_tokens, @@ -51,6 +66,7 @@ def _create_ai_generation_event( } if properties: props.update(properties) + _create_event( event="$ai_generation", distinct_id=distinct_id, @@ -60,6 +76,30 @@ def _create_ai_generation_event( class TestTracesQueryRunner(ClickhouseTestMixin, BaseTest): + def setUp(self): + super().setUp() + self._create_properties() + + def _create_properties(self): + numeric_props = { + "$ai_latency", + "$ai_input_tokens", + "$ai_output_tokens", + "$ai_input_cost_usd", + "$ai_output_cost_usd", + "$ai_total_cost_usd", + } + models_to_create = [] + for prop in numeric_props: + prop_model = PropertyDefinition( + team=self.team, + name=prop, + type=PropertyDefinition.Type.EVENT, + property_type=PropertyType.Numeric, + ) + models_to_create.append(prop_model) + PropertyDefinition.objects.bulk_create(models_to_create) + @snapshot_clickhouse_queries def test_traces_query_runner(self): _create_person(distinct_ids=["person1"], team=self.team) diff --git a/posthog/hogql_queries/ai/traces_query_runner.py b/posthog/hogql_queries/ai/traces_query_runner.py index e574c35ba59cc..a6a980f36e2e2 100644 --- a/posthog/hogql_queries/ai/traces_query_runner.py +++ b/posthog/hogql_queries/ai/traces_query_runner.py @@ -1,3 +1,8 @@ +from datetime import datetime +from typing import cast +from uuid import UUID + +import orjson import structlog from posthog.hogql import ast @@ -5,6 +10,8 @@ from posthog.hogql_queries.insights.paginators import HogQLHasMorePaginator from posthog.hogql_queries.query_runner import QueryRunner from posthog.schema import ( + AIGeneration, + AITrace, CachedTracesQueryResponse, NodeKind, TracesQuery, @@ -25,7 +32,7 @@ sum(properties.$ai_input_cost_usd) as input_cost, sum(properties.$ai_output_cost_usd) as output_cost, sum(properties.$ai_total_cost_usd) as total_cost, - arraySort(x -> x.1, groupArray(tuple(timestamp, properties))) as spans + arraySort(x -> x.1, groupArray(tuple(timestamp, properties))) as events from events where event = '$ai_generation' @@ -56,7 +63,7 @@ def to_query(self) -> ast.SelectQuery: select_from=ast.JoinExpr(table=ast.Field(chain=["events"])), where=self._get_where_clause(), order_by=self._get_order_by_clause(), - group_by=[ast.Field(chain=["trace_id"])], + group_by=[ast.Field(chain=["id"])], ) def calculate(self): @@ -83,12 +90,72 @@ def calculate(self): ) def _map_results(self, columns: list[str], query_results: list): + TRACE_FIELDS = { + "id", + "created_at", + "person", + "total_latency", + "input_tokens", + "output_tokens", + "input_cost", + "output_cost", + "total_cost", + "events", + } mapped_results = [dict(zip(columns, value)) for value in query_results] - return mapped_results + traces = [] + + for result in mapped_results: + generations = [] + for uuid, timestamp, properties in result["events"]: + generations.append(self._map_generation(uuid, timestamp, properties)) + trace_dict = { + **result, + "created_at": cast(datetime, result["trace_timestamp"]).isoformat(), + "person": orjson.loads(result["person"]), + "events": generations, + } + trace = AITrace.model_validate({key: value for key, value in trace_dict.items() if key in TRACE_FIELDS}) + traces.append(trace) + + return traces + + def _map_generation(self, event_uuid: UUID, event_timestamp: datetime, event_properties: str) -> AIGeneration: + properties: dict = orjson.loads(event_properties) + + GENERATION_MAPPING = { + "$ai_input": "input", + "$ai_latency": "latency", + "$ai_output": "output", + "$ai_provider": "provider", + "$ai_model": "model", + "$ai_input_tokens": "input_tokens", + "$ai_output_tokens": "output_tokens", + "$ai_input_cost_usd": "input_cost", + "$ai_output_cost_usd": "output_cost", + "$ai_total_cost_usd": "total_cost", + "$ai_http_status": "http_status", + "$ai_base_url": "base_url", + } + GENERATION_JSON_FIELDS = {"$ai_input", "$ai_output"} + + generation = { + "id": str(event_uuid), + "created_at": event_timestamp.isoformat(), + } + + for event_prop, model_prop in GENERATION_MAPPING.items(): + if event_prop in properties: + if event_prop in GENERATION_JSON_FIELDS: + generation[model_prop] = orjson.loads(properties[event_prop]) + else: + generation[model_prop] = properties[event_prop] + + return AIGeneration.model_validate(generation) def _get_select_fields(self) -> list[ast.Expr]: return [ - ast.Alias(expr=ast.Field(chain=["properties", "$ai_trace_id"]), alias="trace_id"), + ast.Alias(expr=ast.Field(chain=["properties", "$ai_trace_id"]), alias="id"), ast.Alias(expr=ast.Call(name="min", args=[ast.Field(chain=["timestamp"])]), alias="trace_timestamp"), ast.Alias(expr=ast.Call(name="max", args=[ast.Field(chain=["person", "properties"])]), alias="person"), ast.Alias( @@ -125,11 +192,19 @@ def _get_select_fields(self) -> list[ast.Expr]: ), ast.Call( name="groupArray", - args=[ast.Tuple(exprs=[ast.Field(chain=["timestamp"]), ast.Field(chain=["properties"])])], + args=[ + ast.Tuple( + exprs=[ + ast.Field(chain=["uuid"]), + ast.Field(chain=["timestamp"]), + ast.Field(chain=["properties"]), + ] + ) + ], ), ], ), - alias="spans", + alias="events", ), ] diff --git a/posthog/schema.py b/posthog/schema.py index a407a43950206..3b53d495f75de 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -12,21 +12,40 @@ class SchemaRoot(RootModel[Any]): root: Any -class AISpan(BaseModel): +class AIGeneration(BaseModel): model_config = ConfigDict( extra="forbid", ) + base_url: Optional[str] = None + created_at: str + http_status: Optional[float] = None id: str input: list - output: list + input_cost: Optional[float] = None + input_tokens: Optional[float] = None + latency: float + model: Optional[str] = None + output: Optional[list] = None + output_cost: Optional[float] = None + output_tokens: Optional[float] = None + provider: Optional[str] = None + total_cost: Optional[float] = None class AITrace(BaseModel): model_config = ConfigDict( extra="forbid", ) + created_at: str + events: list[AIGeneration] id: str - spans: list[AISpan] + input_cost: float + input_tokens: float + output_cost: float + output_tokens: float + person: dict[str, Any] + total_cost: float + total_latency: float class ActionConversionGoal(BaseModel): @@ -156,13 +175,6 @@ class AssistantGenericMultipleBreakdownFilter(BaseModel): type: AssistantEventMultipleBreakdownFilterType -class Type(StrEnum): - EVENT = "event" - PERSON = "person" - SESSION = "session" - FEATURE = "feature" - - class AssistantGenericPropertyFilter2(BaseModel): model_config = ConfigDict( extra="forbid", @@ -171,7 +183,7 @@ class AssistantGenericPropertyFilter2(BaseModel): operator: AssistantArrayPropertyFilterOperator = Field( ..., description="`exact` - exact match of any of the values. `is_not` - does not match any of the values." ) - type: Type + type: str value: list[str] = Field( ..., description=( @@ -187,7 +199,7 @@ class AssistantGenericPropertyFilter3(BaseModel): ) key: str = Field(..., description="Use one of the properties the user has provided in the plan.") operator: AssistantDateTimePropertyFilterOperator - type: Type + type: str value: str = Field(..., description="Value must be a date in ISO 8601 format.") @@ -644,7 +656,7 @@ class DatabaseSchemaSource(BaseModel): status: str -class Type4(StrEnum): +class Type(StrEnum): POSTHOG = "posthog" DATA_WAREHOUSE = "data_warehouse" VIEW = "view" @@ -1380,8 +1392,9 @@ class QueryStatus(BaseModel): ), ) dashboard_id: Optional[int] = None - end_time: Optional[AwareDatetime] = Field( - default=None, description="When did the query execution task finish (whether successfully or not)." + end_time: Optional[str] = Field( + default=None, + description="When did the query execution task finish (whether successfully or not). @format date-time", ) error: Optional[bool] = Field( default=False, @@ -1394,13 +1407,15 @@ class QueryStatus(BaseModel): id: str insight_id: Optional[int] = None labels: Optional[list[str]] = None - pickup_time: Optional[AwareDatetime] = Field( - default=None, description="When was the query execution task picked up by a worker." + pickup_time: Optional[str] = Field( + default=None, description="When was the query execution task picked up by a worker. @format date-time" ) query_async: Literal[True] = Field(default=True, description="ONLY async queries use QueryStatus.") query_progress: Optional[ClickhouseQueryProgress] = None results: Optional[Any] = None - start_time: Optional[AwareDatetime] = Field(default=None, description="When was query execution task enqueued.") + start_time: Optional[str] = Field( + default=None, description="When was query execution task enqueued. @format date-time" + ) task_id: Optional[str] = None team_id: int @@ -1927,7 +1942,7 @@ class AssistantGenericPropertyFilter1(BaseModel): " matches the regex pattern. `not_regex` - does not match the regex pattern." ), ) - type: Type + type: str value: str = Field( ..., description=( @@ -1950,7 +1965,7 @@ class AssistantGenericPropertyFilter4(BaseModel): " collected." ), ) - type: Type + type: str class AssistantGroupPropertyFilter1(BaseModel): @@ -2238,7 +2253,7 @@ class DatabaseSchemaTableCommon(BaseModel): fields: dict[str, DatabaseSchemaField] id: str name: str - type: Type4 + type: Type class ElementPropertyFilter(BaseModel): From 7c4bb87f317fc82cc4b2508131438b0d8900e38b Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 14 Jan 2025 15:12:16 +0100 Subject: [PATCH 04/15] test: retrieving a list --- frontend/src/queries/schema.json | 5 +- frontend/src/queries/schema.ts | 2 +- .../test_traces_query_runner.ambr | 4 +- .../ai/test/test_traces_query_runner.py | 130 ++++++++++++++++-- .../hogql_queries/ai/traces_query_runner.py | 2 +- posthog/schema.py | 2 +- 6 files changed, 125 insertions(+), 20 deletions(-) diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index 6bdf20ffbe24c..da4fa05f52693 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -32,10 +32,7 @@ "model": { "type": "string" }, - "output": { - "items": {}, - "type": "array" - }, + "output": {}, "output_cost": { "type": "number" }, diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index 30995197f787e..f36bcfbc98d1d 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -2535,7 +2535,7 @@ export interface AIGeneration { created_at: string input: any[] latency: number - output?: any[] + output?: any provider?: string model?: string input_tokens?: number diff --git a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr index dfa8a3c411bea..d157dd29397a9 100644 --- a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr +++ b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr @@ -1,5 +1,5 @@ # serializer version: 1 -# name: TestTracesQueryRunner.test_traces_query_runner +# name: TestTracesQueryRunner.test_field_mapping ''' SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, @@ -10,7 +10,7 @@ sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, - arraySort(x -> tupleElement(x, 1), groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS spans + arraySort(x -> tupleElement(x, 2), groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events FROM events LEFT OUTER JOIN (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, diff --git a/posthog/hogql_queries/ai/test/test_traces_query_runner.py b/posthog/hogql_queries/ai/test/test_traces_query_runner.py index 3278c35bf1db0..ec1d40fc78073 100644 --- a/posthog/hogql_queries/ai/test/test_traces_query_runner.py +++ b/posthog/hogql_queries/ai/test/test_traces_query_runner.py @@ -1,10 +1,12 @@ import json +from datetime import UTC, datetime from typing import Any, Literal, TypedDict +from uuid import UUID from posthog.hogql_queries.ai.traces_query_runner import TracesQueryRunner from posthog.models import PropertyDefinition, Team from posthog.models.property_definition import PropertyType -from posthog.schema import TracesQuery +from posthog.schema import AIGeneration, AITrace, TracesQuery from posthog.test.base import ( BaseTest, ClickhouseTestMixin, @@ -33,23 +35,25 @@ def _calculate_tokens(messages: str | list[InputMessage] | list[OutputMessage]) def _create_ai_generation_event( - input: str | list[InputMessage] | None = "Foo", - output: str | list[OutputMessage] | None = "Bar", + input: str | list[InputMessage] = "Foo", + output: str | list[OutputMessage] = "Bar", team: Team | None = None, distinct_id: str | None = None, trace_id: str | None = None, properties: dict[str, Any] | None = None, + timestamp: datetime | None = None, + event_uuid: str | UUID | None = None, ): input_tokens = _calculate_tokens(input) output_tokens = _calculate_tokens(output) if isinstance(input, str): - input_messages = [{"role": "user", "content": input}] + input_messages: list[InputMessage] = [{"role": "user", "content": input}] else: input_messages = input if isinstance(output, str): - output_messages = [{"role": "assistant", "content": output}] + output_messages: list[OutputMessage] = [{"role": "assistant", "content": output}] else: output_messages = output @@ -57,7 +61,7 @@ def _create_ai_generation_event( "$ai_trace_id": trace_id, "$ai_latency": 1, "$ai_input": json.dumps(input_messages), - "$ai_output": json.dumps(output_messages), + "$ai_output": json.dumps({"choices": output_messages}), "$ai_input_tokens": input_tokens, "$ai_output_tokens": output_tokens, "$ai_input_cost_usd": input_tokens, @@ -72,6 +76,8 @@ def _create_ai_generation_event( distinct_id=distinct_id, properties=props, team=team, + timestamp=timestamp, + event_uuid=str(event_uuid) if event_uuid else None, ) @@ -100,8 +106,18 @@ def _create_properties(self): models_to_create.append(prop_model) PropertyDefinition.objects.bulk_create(models_to_create) + def assertTraceEqual(self, trace: AITrace, expected_trace: dict): + trace_dict = trace.model_dump() + for key, value in expected_trace.items(): + self.assertEqual(trace_dict[key], value, f"Field {key} does not match") + + def assertEventEqual(self, event: AIGeneration, expected_event: dict): + event_dict = event.model_dump() + for key, value in expected_event.items(): + self.assertEqual(event_dict[key], value, f"Field {key} does not match") + @snapshot_clickhouse_queries - def test_traces_query_runner(self): + def test_field_mapping(self): _create_person(distinct_ids=["person1"], team=self.team) _create_person(distinct_ids=["person2"], team=self.team) _create_ai_generation_event( @@ -110,13 +126,15 @@ def test_traces_query_runner(self): input="Foo", output="Bar", team=self.team, + timestamp=datetime(2025, 1, 15, 0), ) _create_ai_generation_event( distinct_id="person1", trace_id="trace1", - input="Foo", - output="Bar", + input="Bar", + output="Baz", team=self.team, + timestamp=datetime(2025, 1, 15, 1), ) _create_ai_generation_event( distinct_id="person2", @@ -124,7 +142,97 @@ def test_traces_query_runner(self): input="Foo", output="Bar", team=self.team, + timestamp=datetime(2025, 1, 14), + ) + + response = TracesQueryRunner(team=self.team, query=TracesQuery()).calculate() + self.assertEqual(len(response.results), 2) + + trace = response.results[0] + + self.assertTraceEqual( + trace, + { + "id": "trace1", + "created_at": datetime(2025, 1, 15, 0, tzinfo=UTC).isoformat(), + "total_latency": 2.0, + "input_tokens": 6.0, + "output_tokens": 6.0, + "input_cost": 6.0, + "output_cost": 6.0, + "total_cost": 12.0, + "person": {}, + }, ) - results = TracesQueryRunner(team=self.team, query=TracesQuery()).calculate() - self.assertEqual(len(results.results), 2) + self.assertEqual(len(trace.events), 2) + event = trace.events[0] + self.assertIsNotNone(event.id) + self.assertEventEqual( + event, + { + "created_at": datetime(2025, 1, 15, 0, tzinfo=UTC).isoformat(), + "input": [{"role": "user", "content": "Foo"}], + "output": {"choices": [{"role": "assistant", "content": "Bar"}]}, + "latency": 1, + "input_tokens": 3, + "output_tokens": 3, + "input_cost": 3, + "output_cost": 3, + "total_cost": 6, + }, + ) + + event = trace.events[1] + self.assertIsNotNone(event.id) + self.assertEventEqual( + event, + { + "created_at": datetime(2025, 1, 15, 1, tzinfo=UTC).isoformat(), + "input": [{"role": "user", "content": "Bar"}], + "output": {"choices": [{"role": "assistant", "content": "Baz"}]}, + "latency": 1, + "input_tokens": 3, + "output_tokens": 3, + "input_cost": 3, + "output_cost": 3, + "total_cost": 6, + "base_url": None, + "http_status": None, + }, + ) + + trace = response.results[1] + self.assertTraceEqual( + trace, + { + "id": "trace2", + "created_at": datetime(2025, 1, 14, tzinfo=UTC).isoformat(), + "total_latency": 1, + "input_tokens": 3, + "output_tokens": 3, + "input_cost": 3, + "output_cost": 3, + "total_cost": 6, + "person": {}, + }, + ) + self.assertEqual(len(trace.events), 1) + event = trace.events[0] + self.assertIsNotNone(event.id) + self.assertEventEqual( + event, + { + "created_at": datetime(2025, 1, 14, tzinfo=UTC).isoformat(), + "input": [{"role": "user", "content": "Foo"}], + "output": {"choices": [{"role": "assistant", "content": "Bar"}]}, + "latency": 1, + "input_tokens": 3, + "output_tokens": 3, + "input_cost": 3, + "output_cost": 3, + "total_cost": 6, + "base_url": None, + "http_status": None, + }, + ) diff --git a/posthog/hogql_queries/ai/traces_query_runner.py b/posthog/hogql_queries/ai/traces_query_runner.py index a6a980f36e2e2..f94bdb10479b8 100644 --- a/posthog/hogql_queries/ai/traces_query_runner.py +++ b/posthog/hogql_queries/ai/traces_query_runner.py @@ -188,7 +188,7 @@ def _get_select_fields(self) -> list[ast.Expr]: args=[ ast.Lambda( args=["x"], - expr=ast.Call(name="tupleElement", args=[ast.Field(chain=["x"]), ast.Constant(value=1)]), + expr=ast.Call(name="tupleElement", args=[ast.Field(chain=["x"]), ast.Constant(value=2)]), ), ast.Call( name="groupArray", diff --git a/posthog/schema.py b/posthog/schema.py index 3b53d495f75de..2c93cab56082c 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -25,7 +25,7 @@ class AIGeneration(BaseModel): input_tokens: Optional[float] = None latency: float model: Optional[str] = None - output: Optional[list] = None + output: Optional[Any] = None output_cost: Optional[float] = None output_tokens: Optional[float] = None provider: Optional[str] = None From ac5da15c784bbd972ffb11d06e690b4e9a818bb4 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 14 Jan 2025 15:16:23 +0100 Subject: [PATCH 05/15] feat: add a filter by traceId --- frontend/src/queries/schema.json | 2 +- frontend/src/queries/schema.ts | 2 +- .../test_traces_query_runner.ambr | 43 +++++++++++++++++++ .../ai/test/test_traces_query_runner.py | 23 ++++++++++ .../hogql_queries/ai/traces_query_runner.py | 14 +++++- posthog/schema.py | 2 +- 6 files changed, 82 insertions(+), 4 deletions(-) diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index da4fa05f52693..e7cb80a395c74 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -13195,7 +13195,7 @@ "response": { "$ref": "#/definitions/TracesQueryResponse" }, - "trace_id": { + "traceId": { "type": "string" } }, diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index f36bcfbc98d1d..5435fad9c040d 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -2569,7 +2569,7 @@ export interface TracesQueryResponse extends AnalyticsQueryResponseBase { kind: NodeKind.TracesQuery - trace_id?: string + traceId?: string limit?: integer offset?: integer } diff --git a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr index d157dd29397a9..62a1471e22c0a 100644 --- a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr +++ b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr @@ -42,3 +42,46 @@ max_bytes_before_external_group_by=0 ''' # --- +# name: TestTracesQueryRunner.test_trace_id_filter + ''' + SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, + min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, + max(events__person.properties) AS person, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, + arraySort(x -> tupleElement(x, 2), groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events + FROM events + LEFT OUTER JOIN + (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, + person_distinct_id_overrides.distinct_id AS distinct_id + FROM person_distinct_id_overrides + WHERE equals(person_distinct_id_overrides.team_id, 99999) + GROUP BY person_distinct_id_overrides.distinct_id + HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) + LEFT JOIN + (SELECT person.id AS id, + person.properties AS properties + FROM person + WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), + (SELECT person.id AS id, max(person.version) AS version + FROM person + WHERE equals(person.team_id, 99999) + GROUP BY person.id + HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), ifNull(equals(id, 'trace1'), 0)) + GROUP BY id + ORDER BY trace_timestamp DESC + LIMIT 101 + OFFSET 0 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=1, + format_csv_allow_double_quotes=0, + max_ast_elements=4000000, + max_expanded_ast_elements=4000000, + max_bytes_before_external_group_by=0 + ''' +# --- diff --git a/posthog/hogql_queries/ai/test/test_traces_query_runner.py b/posthog/hogql_queries/ai/test/test_traces_query_runner.py index ec1d40fc78073..c4cc16bdbf381 100644 --- a/posthog/hogql_queries/ai/test/test_traces_query_runner.py +++ b/posthog/hogql_queries/ai/test/test_traces_query_runner.py @@ -236,3 +236,26 @@ def test_field_mapping(self): "http_status": None, }, ) + + @snapshot_clickhouse_queries + def test_trace_id_filter(self): + _create_person(distinct_ids=["person1"], team=self.team) + _create_person(distinct_ids=["person2"], team=self.team) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + input="Foo", + output="Bar", + team=self.team, + ) + _create_ai_generation_event( + distinct_id="person2", + trace_id="trace2", + input="Foo", + output="Bar", + team=self.team, + ) + + response = TracesQueryRunner(team=self.team, query=TracesQuery(traceId="trace1")).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace1") diff --git a/posthog/hogql_queries/ai/traces_query_runner.py b/posthog/hogql_queries/ai/traces_query_runner.py index f94bdb10479b8..1d0f61f7467d9 100644 --- a/posthog/hogql_queries/ai/traces_query_runner.py +++ b/posthog/hogql_queries/ai/traces_query_runner.py @@ -209,11 +209,23 @@ def _get_select_fields(self) -> list[ast.Expr]: ] def _get_where_clause(self): - return ast.CompareOperation( + event_expr = ast.CompareOperation( left=ast.Field(chain=["event"]), op=ast.CompareOperationOp.Eq, right=ast.Constant(value="$ai_generation"), ) + if self.query.traceId is not None: + return ast.And( + exprs=[ + event_expr, + ast.CompareOperation( + left=ast.Field(chain=["id"]), + op=ast.CompareOperationOp.Eq, + right=ast.Constant(value=self.query.traceId), + ), + ] + ) + return event_expr def _get_order_by_clause(self): return [ast.OrderExpr(expr=ast.Field(chain=["trace_timestamp"]), order="DESC")] diff --git a/posthog/schema.py b/posthog/schema.py index 2c93cab56082c..6ce1e205b39bb 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -5700,7 +5700,7 @@ class TracesQuery(BaseModel): ) offset: Optional[int] = None response: Optional[TracesQueryResponse] = None - trace_id: Optional[str] = None + traceId: Optional[str] = None class VisualizationMessage(BaseModel): From f59e7ea340029e69b0358a564585a6d86b210138 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 14 Jan 2025 15:24:11 +0100 Subject: [PATCH 06/15] test: pagination --- .../test_traces_query_runner.ambr | 129 ++++++++++++++++++ .../ai/test/test_traces_query_runner.py | 53 +++++-- 2 files changed, 167 insertions(+), 15 deletions(-) diff --git a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr index 62a1471e22c0a..acb39bc86f57d 100644 --- a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr +++ b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr @@ -42,6 +42,135 @@ max_bytes_before_external_group_by=0 ''' # --- +# name: TestTracesQueryRunner.test_pagination + ''' + SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, + min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, + max(events__person.properties) AS person, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, + arraySort(x -> tupleElement(x, 2), groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events + FROM events + LEFT OUTER JOIN + (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, + person_distinct_id_overrides.distinct_id AS distinct_id + FROM person_distinct_id_overrides + WHERE equals(person_distinct_id_overrides.team_id, 99999) + GROUP BY person_distinct_id_overrides.distinct_id + HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) + LEFT JOIN + (SELECT person.id AS id, + person.properties AS properties + FROM person + WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), + (SELECT person.id AS id, max(person.version) AS version + FROM person + WHERE equals(person.team_id, 99999) + GROUP BY person.id + HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation')) + GROUP BY id + ORDER BY trace_timestamp DESC + LIMIT 5 + OFFSET 0 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=1, + format_csv_allow_double_quotes=0, + max_ast_elements=4000000, + max_expanded_ast_elements=4000000, + max_bytes_before_external_group_by=0 + ''' +# --- +# name: TestTracesQueryRunner.test_pagination.1 + ''' + SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, + min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, + max(events__person.properties) AS person, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, + arraySort(x -> tupleElement(x, 2), groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events + FROM events + LEFT OUTER JOIN + (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, + person_distinct_id_overrides.distinct_id AS distinct_id + FROM person_distinct_id_overrides + WHERE equals(person_distinct_id_overrides.team_id, 99999) + GROUP BY person_distinct_id_overrides.distinct_id + HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) + LEFT JOIN + (SELECT person.id AS id, + person.properties AS properties + FROM person + WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), + (SELECT person.id AS id, max(person.version) AS version + FROM person + WHERE equals(person.team_id, 99999) + GROUP BY person.id + HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation')) + GROUP BY id + ORDER BY trace_timestamp DESC + LIMIT 5 + OFFSET 5 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=1, + format_csv_allow_double_quotes=0, + max_ast_elements=4000000, + max_expanded_ast_elements=4000000, + max_bytes_before_external_group_by=0 + ''' +# --- +# name: TestTracesQueryRunner.test_pagination.2 + ''' + SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, + min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, + max(events__person.properties) AS person, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, + arraySort(x -> tupleElement(x, 2), groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events + FROM events + LEFT OUTER JOIN + (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, + person_distinct_id_overrides.distinct_id AS distinct_id + FROM person_distinct_id_overrides + WHERE equals(person_distinct_id_overrides.team_id, 99999) + GROUP BY person_distinct_id_overrides.distinct_id + HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) + LEFT JOIN + (SELECT person.id AS id, + person.properties AS properties + FROM person + WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), + (SELECT person.id AS id, max(person.version) AS version + FROM person + WHERE equals(person.team_id, 99999) + GROUP BY person.id + HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation')) + GROUP BY id + ORDER BY trace_timestamp DESC + LIMIT 5 + OFFSET 10 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=1, + format_csv_allow_double_quotes=0, + max_ast_elements=4000000, + max_expanded_ast_elements=4000000, + max_bytes_before_external_group_by=0 + ''' +# --- # name: TestTracesQueryRunner.test_trace_id_filter ''' SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, diff --git a/posthog/hogql_queries/ai/test/test_traces_query_runner.py b/posthog/hogql_queries/ai/test/test_traces_query_runner.py index c4cc16bdbf381..aa745b88b640d 100644 --- a/posthog/hogql_queries/ai/test/test_traces_query_runner.py +++ b/posthog/hogql_queries/ai/test/test_traces_query_runner.py @@ -1,4 +1,5 @@ import json +import uuid from datetime import UTC, datetime from typing import Any, Literal, TypedDict from uuid import UUID @@ -58,7 +59,7 @@ def _create_ai_generation_event( output_messages = output props = { - "$ai_trace_id": trace_id, + "$ai_trace_id": trace_id or str(uuid.uuid4()), "$ai_latency": 1, "$ai_input": json.dumps(input_messages), "$ai_output": json.dumps({"choices": output_messages}), @@ -241,21 +242,43 @@ def test_field_mapping(self): def test_trace_id_filter(self): _create_person(distinct_ids=["person1"], team=self.team) _create_person(distinct_ids=["person2"], team=self.team) - _create_ai_generation_event( - distinct_id="person1", - trace_id="trace1", - input="Foo", - output="Bar", - team=self.team, - ) - _create_ai_generation_event( - distinct_id="person2", - trace_id="trace2", - input="Foo", - output="Bar", - team=self.team, - ) + _create_ai_generation_event(distinct_id="person1", trace_id="trace1", team=self.team) + _create_ai_generation_event(distinct_id="person2", trace_id="trace2", team=self.team) response = TracesQueryRunner(team=self.team, query=TracesQuery(traceId="trace1")).calculate() self.assertEqual(len(response.results), 1) self.assertEqual(response.results[0].id, "trace1") + + @snapshot_clickhouse_queries + def test_pagination(self): + _create_person(distinct_ids=["person1"], team=self.team) + _create_person(distinct_ids=["person2"], team=self.team) + for i in range(11): + _create_ai_generation_event( + distinct_id="person1" if i % 2 == 0 else "person2", + team=self.team, + trace_id=f"trace_{i}", + ) + + response = TracesQueryRunner(team=self.team, query=TracesQuery(limit=4, offset=0)).calculate() + self.assertEqual(response.hasMore, True) + self.assertEqual(len(response.results), 5) + self.assertEqual(response.results[0].id, "trace_10") + self.assertEqual(response.results[1].id, "trace_9") + self.assertEqual(response.results[2].id, "trace_8") + self.assertEqual(response.results[3].id, "trace_7") + self.assertEqual(response.results[4].id, "trace_6") + + response = TracesQueryRunner(team=self.team, query=TracesQuery(limit=4, offset=5)).calculate() + self.assertEqual(response.hasMore, True) + self.assertEqual(len(response.results), 5) + self.assertEqual(response.results[0].id, "trace_5") + self.assertEqual(response.results[1].id, "trace_4") + self.assertEqual(response.results[2].id, "trace_3") + self.assertEqual(response.results[3].id, "trace_2") + self.assertEqual(response.results[4].id, "trace_1") + + response = TracesQueryRunner(team=self.team, query=TracesQuery(limit=4, offset=10)).calculate() + self.assertEqual(response.hasMore, False) + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace_0") From 678cd6ccb576a14ee86ef52423dc3d4249d5e1dd Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 14 Jan 2025 15:27:43 +0100 Subject: [PATCH 07/15] test: more field tests --- .../ai/test/test_traces_query_runner.py | 42 +++++++++++++++++++ .../hogql_queries/ai/traces_query_runner.py | 2 +- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/posthog/hogql_queries/ai/test/test_traces_query_runner.py b/posthog/hogql_queries/ai/test/test_traces_query_runner.py index aa745b88b640d..c09108ef120c8 100644 --- a/posthog/hogql_queries/ai/test/test_traces_query_runner.py +++ b/posthog/hogql_queries/ai/test/test_traces_query_runner.py @@ -282,3 +282,45 @@ def test_pagination(self): self.assertEqual(response.hasMore, False) self.assertEqual(len(response.results), 1) self.assertEqual(response.results[0].id, "trace_0") + + def test_maps_all_fields(self): + _create_person(distinct_ids=["person1"], team=self.team) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + team=self.team, + properties={ + "$ai_latency": 10.5, + "$ai_provider": "posthog", + "$ai_model": "hog-destroyer", + "$ai_http_status": 200, + "$ai_base_url": "https://us.posthog.com", + }, + ) + + response = TracesQueryRunner(team=self.team, query=TracesQuery()).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace1") + self.assertEqual(response.results[0].total_latency, 10.5) + self.assertEventEqual( + response.results[0].events[0], + { + "latency": 10.5, + "provider": "posthog", + "model": "hog-destroyer", + "http_status": 200, + "base_url": "https://us.posthog.com", + }, + ) + + def test_person_properties(self): + _create_person(distinct_ids=["person1"], team=self.team, properties={"email": "test@posthog.com"}) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + team=self.team, + ) + + response = TracesQueryRunner(team=self.team, query=TracesQuery()).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].person, {"email": "test@posthog.com"}) diff --git a/posthog/hogql_queries/ai/traces_query_runner.py b/posthog/hogql_queries/ai/traces_query_runner.py index 1d0f61f7467d9..54264ccf633f5 100644 --- a/posthog/hogql_queries/ai/traces_query_runner.py +++ b/posthog/hogql_queries/ai/traces_query_runner.py @@ -67,7 +67,7 @@ def to_query(self) -> ast.SelectQuery: ) def calculate(self): - with self.timings.measure("error_tracking_query_hogql_execute"): + with self.timings.measure("traces_query_hogql_execute"): query_result = self.paginator.execute_hogql_query( query=self.to_query(), team=self.team, From a5b278999757f84ef94c7bc84f84ec8e83f33456 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 14 Jan 2025 16:54:43 +0100 Subject: [PATCH 08/15] feat: mirror events query for persons --- frontend/src/queries/schema.json | 21 ++++++++++++- frontend/src/queries/schema.ts | 9 +++++- .../test_traces_query_runner.ambr | 15 ++++++---- .../ai/test/test_traces_query_runner.py | 23 ++++++++------ .../hogql_queries/ai/traces_query_runner.py | 27 +++++++++++++++-- posthog/schema.py | 30 ++++++++++++------- 6 files changed, 97 insertions(+), 28 deletions(-) diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index e7cb80a395c74..0ba7fc18e4b8d 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -77,7 +77,7 @@ "type": "number" }, "person": { - "type": "object" + "$ref": "#/definitions/AITracePerson" }, "total_cost": { "type": "number" @@ -100,6 +100,25 @@ ], "type": "object" }, + "AITracePerson": { + "additionalProperties": false, + "properties": { + "created_at": { + "type": "string" + }, + "distinct_id": { + "type": "string" + }, + "properties": { + "type": "object" + }, + "uuid": { + "type": "string" + } + }, + "required": ["uuid", "created_at", "properties", "distinct_id"], + "type": "object" + }, "ActionConversionGoal": { "additionalProperties": false, "properties": { diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index 5435fad9c040d..a58641d6e0bbc 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -2547,10 +2547,17 @@ export interface AIGeneration { base_url?: string } +export interface AITracePerson { + uuid: string + created_at: string + properties: Record + distinct_id: string +} + export interface AITrace { id: string created_at: string - person: Record + person: AITracePerson total_latency: number input_tokens: number output_tokens: number diff --git a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr index acb39bc86f57d..ded095025fea7 100644 --- a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr +++ b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr @@ -3,7 +3,7 @@ ''' SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, - max(events__person.properties) AS person, + tuple(max(events__person.id), max(events.distinct_id), max(events__person.created_at), max(events__person.properties)) AS person, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, @@ -21,6 +21,7 @@ HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) LEFT JOIN (SELECT person.id AS id, + toTimeZone(person.created_at, 'UTC') AS created_at, person.properties AS properties FROM person WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), @@ -46,7 +47,7 @@ ''' SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, - max(events__person.properties) AS person, + tuple(max(events__person.id), max(events.distinct_id), max(events__person.created_at), max(events__person.properties)) AS person, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, @@ -64,6 +65,7 @@ HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) LEFT JOIN (SELECT person.id AS id, + toTimeZone(person.created_at, 'UTC') AS created_at, person.properties AS properties FROM person WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), @@ -89,7 +91,7 @@ ''' SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, - max(events__person.properties) AS person, + tuple(max(events__person.id), max(events.distinct_id), max(events__person.created_at), max(events__person.properties)) AS person, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, @@ -107,6 +109,7 @@ HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) LEFT JOIN (SELECT person.id AS id, + toTimeZone(person.created_at, 'UTC') AS created_at, person.properties AS properties FROM person WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), @@ -132,7 +135,7 @@ ''' SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, - max(events__person.properties) AS person, + tuple(max(events__person.id), max(events.distinct_id), max(events__person.created_at), max(events__person.properties)) AS person, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, @@ -150,6 +153,7 @@ HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) LEFT JOIN (SELECT person.id AS id, + toTimeZone(person.created_at, 'UTC') AS created_at, person.properties AS properties FROM person WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), @@ -175,7 +179,7 @@ ''' SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, - max(events__person.properties) AS person, + tuple(max(events__person.id), max(events.distinct_id), max(events__person.created_at), max(events__person.properties)) AS person, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, @@ -193,6 +197,7 @@ HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) LEFT JOIN (SELECT person.id AS id, + toTimeZone(person.created_at, 'UTC') AS created_at, person.properties AS properties FROM person WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), diff --git a/posthog/hogql_queries/ai/test/test_traces_query_runner.py b/posthog/hogql_queries/ai/test/test_traces_query_runner.py index c09108ef120c8..259a15abcba69 100644 --- a/posthog/hogql_queries/ai/test/test_traces_query_runner.py +++ b/posthog/hogql_queries/ai/test/test_traces_query_runner.py @@ -4,6 +4,8 @@ from typing import Any, Literal, TypedDict from uuid import UUID +from freezegun import freeze_time + from posthog.hogql_queries.ai.traces_query_runner import TracesQueryRunner from posthog.models import PropertyDefinition, Team from posthog.models.property_definition import PropertyType @@ -162,9 +164,9 @@ def test_field_mapping(self): "input_cost": 6.0, "output_cost": 6.0, "total_cost": 12.0, - "person": {}, }, ) + self.assertEqual(trace.person.distinct_id, "person1") self.assertEqual(len(trace.events), 2) event = trace.events[0] @@ -215,9 +217,9 @@ def test_field_mapping(self): "input_cost": 3, "output_cost": 3, "total_cost": 6, - "person": {}, }, ) + self.assertEqual(trace.person.distinct_id, "person2") self.assertEqual(len(trace.events), 1) event = trace.events[0] self.assertIsNotNone(event.id) @@ -314,13 +316,16 @@ def test_maps_all_fields(self): ) def test_person_properties(self): - _create_person(distinct_ids=["person1"], team=self.team, properties={"email": "test@posthog.com"}) - _create_ai_generation_event( - distinct_id="person1", - trace_id="trace1", - team=self.team, - ) + with freeze_time("2025-01-01T00:00:00Z"): + _create_person(distinct_ids=["person1"], team=self.team, properties={"email": "test@posthog.com"}) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + team=self.team, + ) response = TracesQueryRunner(team=self.team, query=TracesQuery()).calculate() self.assertEqual(len(response.results), 1) - self.assertEqual(response.results[0].person, {"email": "test@posthog.com"}) + self.assertEqual(response.results[0].person.created_at, "2025-01-01T00:00:00+00:00") + self.assertEqual(response.results[0].person.properties, {"email": "test@posthog.com"}) + self.assertEqual(response.results[0].person.distinct_id, "person1") diff --git a/posthog/hogql_queries/ai/traces_query_runner.py b/posthog/hogql_queries/ai/traces_query_runner.py index 54264ccf633f5..3ec9c9d732d63 100644 --- a/posthog/hogql_queries/ai/traces_query_runner.py +++ b/posthog/hogql_queries/ai/traces_query_runner.py @@ -12,6 +12,7 @@ from posthog.schema import ( AIGeneration, AITrace, + AITracePerson, CachedTracesQueryResponse, NodeKind, TracesQuery, @@ -112,7 +113,7 @@ def _map_results(self, columns: list[str], query_results: list): trace_dict = { **result, "created_at": cast(datetime, result["trace_timestamp"]).isoformat(), - "person": orjson.loads(result["person"]), + "person": self._map_person(result["person"]), "events": generations, } trace = AITrace.model_validate({key: value for key, value in trace_dict.items() if key in TRACE_FIELDS}) @@ -153,11 +154,20 @@ def _map_generation(self, event_uuid: UUID, event_timestamp: datetime, event_pro return AIGeneration.model_validate(generation) + def _map_person(self, person: tuple[UUID, UUID, datetime, str]) -> AITracePerson: + uuid, distinct_id, created_at, properties = person + return AITracePerson( + uuid=str(uuid), + distinct_id=str(distinct_id), + created_at=created_at.isoformat(), + properties=orjson.loads(properties), + ) + def _get_select_fields(self) -> list[ast.Expr]: return [ ast.Alias(expr=ast.Field(chain=["properties", "$ai_trace_id"]), alias="id"), ast.Alias(expr=ast.Call(name="min", args=[ast.Field(chain=["timestamp"])]), alias="trace_timestamp"), - ast.Alias(expr=ast.Call(name="max", args=[ast.Field(chain=["person", "properties"])]), alias="person"), + self._get_person_field(), ast.Alias( expr=ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_latency"])]), alias="total_latency", @@ -208,6 +218,19 @@ def _get_select_fields(self) -> list[ast.Expr]: ), ] + def _get_person_field(self): + return ast.Alias( + expr=ast.Tuple( + exprs=[ + ast.Call(name="max", args=[ast.Field(chain=["person", "id"])]), + ast.Call(name="max", args=[ast.Field(chain=["distinct_id"])]), + ast.Call(name="max", args=[ast.Field(chain=["person", "created_at"])]), + ast.Call(name="max", args=[ast.Field(chain=["person", "properties"])]), + ], + ), + alias="person", + ) + def _get_where_clause(self): event_expr = ast.CompareOperation( left=ast.Field(chain=["event"]), diff --git a/posthog/schema.py b/posthog/schema.py index 6ce1e205b39bb..1a1d33e0a5854 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -32,20 +32,14 @@ class AIGeneration(BaseModel): total_cost: Optional[float] = None -class AITrace(BaseModel): +class AITracePerson(BaseModel): model_config = ConfigDict( extra="forbid", ) created_at: str - events: list[AIGeneration] - id: str - input_cost: float - input_tokens: float - output_cost: float - output_tokens: float - person: dict[str, Any] - total_cost: float - total_latency: float + distinct_id: str + properties: dict[str, Any] + uuid: str class ActionConversionGoal(BaseModel): @@ -1792,6 +1786,22 @@ class NumericalKey(RootModel[str]): root: str +class AITrace(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + created_at: str + events: list[AIGeneration] + id: str + input_cost: float + input_tokens: float + output_cost: float + output_tokens: float + person: AITracePerson + total_cost: float + total_latency: float + + class AlertCondition(BaseModel): model_config = ConfigDict( extra="forbid", From 6438dcd912f343e10c26c1885c2404a954c5b440 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 14 Jan 2025 16:56:09 +0100 Subject: [PATCH 09/15] chore: register a query runner --- posthog/hogql_queries/query_runner.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/posthog/hogql_queries/query_runner.py b/posthog/hogql_queries/query_runner.py index abcaddde2d3cf..608acc282968f 100644 --- a/posthog/hogql_queries/query_runner.py +++ b/posthog/hogql_queries/query_runner.py @@ -54,6 +54,7 @@ StickinessQuery, SuggestedQuestionsQuery, TeamTaxonomyQuery, + TracesQuery, TrendsQuery, WebGoalsQuery, WebOverviewQuery, @@ -421,6 +422,16 @@ def get_query_runner( limit_context=limit_context, modifiers=modifiers, ) + if kind == "TracesQuery": + from .ai.traces_query_runner import TracesQueryRunner + + return TracesQueryRunner( + query=cast(TracesQuery | dict[str, Any], query), + team=team, + timings=timings, + limit_context=limit_context, + modifiers=modifiers, + ) raise ValueError(f"Can't get a runner for an unknown query kind: {kind}") From b3af14b727d28f751166c615c45708a0251da3ff Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 14 Jan 2025 17:15:19 +0100 Subject: [PATCH 10/15] fix: AI-prefixed names broke the schema --- frontend/src/queries/schema.json | 347 +++++++----------- frontend/src/queries/schema.ts | 12 +- .../hogql_queries/ai/traces_query_runner.py | 16 +- posthog/schema.py | 134 +++---- 4 files changed, 220 insertions(+), 289 deletions(-) diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index 0ba7fc18e4b8d..f8543ca643384 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -1,124 +1,6 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "definitions": { - "AIGeneration": { - "additionalProperties": false, - "properties": { - "base_url": { - "type": "string" - }, - "created_at": { - "type": "string" - }, - "http_status": { - "type": "number" - }, - "id": { - "type": "string" - }, - "input": { - "items": {}, - "type": "array" - }, - "input_cost": { - "type": "number" - }, - "input_tokens": { - "type": "number" - }, - "latency": { - "type": "number" - }, - "model": { - "type": "string" - }, - "output": {}, - "output_cost": { - "type": "number" - }, - "output_tokens": { - "type": "number" - }, - "provider": { - "type": "string" - }, - "total_cost": { - "type": "number" - } - }, - "required": ["id", "created_at", "input", "latency"], - "type": "object" - }, - "AITrace": { - "additionalProperties": false, - "properties": { - "created_at": { - "type": "string" - }, - "events": { - "items": { - "$ref": "#/definitions/AIGeneration" - }, - "type": "array" - }, - "id": { - "type": "string" - }, - "input_cost": { - "type": "number" - }, - "input_tokens": { - "type": "number" - }, - "output_cost": { - "type": "number" - }, - "output_tokens": { - "type": "number" - }, - "person": { - "$ref": "#/definitions/AITracePerson" - }, - "total_cost": { - "type": "number" - }, - "total_latency": { - "type": "number" - } - }, - "required": [ - "id", - "created_at", - "person", - "total_latency", - "input_tokens", - "output_tokens", - "input_cost", - "output_cost", - "total_cost", - "events" - ], - "type": "object" - }, - "AITracePerson": { - "additionalProperties": false, - "properties": { - "created_at": { - "type": "string" - }, - "distinct_id": { - "type": "string" - }, - "properties": { - "type": "object" - }, - "uuid": { - "type": "string" - } - }, - "required": ["uuid", "created_at", "properties", "distinct_id"], - "type": "object" - }, "ActionConversionGoal": { "additionalProperties": false, "properties": { @@ -961,27 +843,8 @@ "description": "`icontains` - case insensitive contains. `not_icontains` - case insensitive does not contain. `regex` - matches the regex pattern. `not_regex` - does not match the regex pattern." }, "type": { - "anyOf": [ - { - "const": "event", - "description": "Event properties", - "type": "string" - }, - { - "const": "person", - "description": "Person properties", - "type": "string" - }, - { - "const": "session", - "type": "string" - }, - { - "const": "feature", - "description": "Event property with \"$feature/\" prepended", - "type": "string" - } - ] + "enum": ["event", "person", "session", "feature"], + "type": "string" }, "value": { "description": "Only use property values from the plan. If the operator is `regex` or `not_regex`, the value must be a valid ClickHouse regex pattern to match against. Otherwise, the value must be a substring that will be matched against the property value.", @@ -1003,27 +866,8 @@ "description": "`exact` - exact match of any of the values. `is_not` - does not match any of the values." }, "type": { - "anyOf": [ - { - "const": "event", - "description": "Event properties", - "type": "string" - }, - { - "const": "person", - "description": "Person properties", - "type": "string" - }, - { - "const": "session", - "type": "string" - }, - { - "const": "feature", - "description": "Event property with \"$feature/\" prepended", - "type": "string" - } - ] + "enum": ["event", "person", "session", "feature"], + "type": "string" }, "value": { "description": "Only use property values from the plan. Always use strings as values. If you have a number, convert it to a string first. If you have a boolean, convert it to a string \"true\" or \"false\".", @@ -1047,27 +891,8 @@ "$ref": "#/definitions/AssistantDateTimePropertyFilterOperator" }, "type": { - "anyOf": [ - { - "const": "event", - "description": "Event properties", - "type": "string" - }, - { - "const": "person", - "description": "Person properties", - "type": "string" - }, - { - "const": "session", - "type": "string" - }, - { - "const": "feature", - "description": "Event property with \"$feature/\" prepended", - "type": "string" - } - ] + "enum": ["event", "person", "session", "feature"], + "type": "string" }, "value": { "description": "Value must be a date in ISO 8601 format.", @@ -1089,27 +914,8 @@ "description": "`is_set` - the property has any value. `is_not_set` - the property doesn't have a value or wasn't collected." }, "type": { - "anyOf": [ - { - "const": "event", - "description": "Event properties", - "type": "string" - }, - { - "const": "person", - "description": "Person properties", - "type": "string" - }, - { - "const": "session", - "type": "string" - }, - { - "const": "feature", - "description": "Event property with \"$feature/\" prepended", - "type": "string" - } - ] + "enum": ["event", "person", "session", "feature"], + "type": "string" } }, "required": ["key", "operator", "type"], @@ -3579,7 +3385,7 @@ }, "results": { "items": { - "$ref": "#/definitions/AITrace" + "$ref": "#/definitions/LLMTrace" }, "type": "array" }, @@ -5058,7 +4864,7 @@ }, "results": { "items": { - "$ref": "#/definitions/AITrace" + "$ref": "#/definitions/LLMTrace" }, "type": "array" }, @@ -8738,6 +8544,124 @@ "enum": ["minute", "hour", "day", "week", "month"], "type": "string" }, + "LLMGeneration": { + "additionalProperties": false, + "properties": { + "base_url": { + "type": "string" + }, + "created_at": { + "type": "string" + }, + "http_status": { + "type": "number" + }, + "id": { + "type": "string" + }, + "input": { + "items": {}, + "type": "array" + }, + "input_cost": { + "type": "number" + }, + "input_tokens": { + "type": "number" + }, + "latency": { + "type": "number" + }, + "model": { + "type": "string" + }, + "output": {}, + "output_cost": { + "type": "number" + }, + "output_tokens": { + "type": "number" + }, + "provider": { + "type": "string" + }, + "total_cost": { + "type": "number" + } + }, + "required": ["id", "created_at", "input", "latency"], + "type": "object" + }, + "LLMTrace": { + "additionalProperties": false, + "properties": { + "created_at": { + "type": "string" + }, + "events": { + "items": { + "$ref": "#/definitions/LLMGeneration" + }, + "type": "array" + }, + "id": { + "type": "string" + }, + "input_cost": { + "type": "number" + }, + "input_tokens": { + "type": "number" + }, + "output_cost": { + "type": "number" + }, + "output_tokens": { + "type": "number" + }, + "person": { + "$ref": "#/definitions/LLMTracePerson" + }, + "total_cost": { + "type": "number" + }, + "total_latency": { + "type": "number" + } + }, + "required": [ + "id", + "created_at", + "person", + "total_latency", + "input_tokens", + "output_tokens", + "input_cost", + "output_cost", + "total_cost", + "events" + ], + "type": "object" + }, + "LLMTracePerson": { + "additionalProperties": false, + "properties": { + "created_at": { + "type": "string" + }, + "distinct_id": { + "type": "string" + }, + "properties": { + "type": "object" + }, + "uuid": { + "type": "string" + } + }, + "required": ["uuid", "created_at", "properties", "distinct_id"], + "type": "object" + }, "LifecycleFilter": { "additionalProperties": false, "properties": { @@ -11072,7 +10996,7 @@ }, "results": { "items": { - "$ref": "#/definitions/AITrace" + "$ref": "#/definitions/LLMTrace" }, "type": "array" }, @@ -11531,7 +11455,7 @@ }, "results": { "items": { - "$ref": "#/definitions/AITrace" + "$ref": "#/definitions/LLMTrace" }, "type": "array" }, @@ -11686,7 +11610,8 @@ "type": "integer" }, "end_time": { - "description": "When did the query execution task finish (whether successfully or not). @format date-time", + "description": "When did the query execution task finish (whether successfully or not).", + "format": "date-time", "type": "string" }, "error": { @@ -11715,7 +11640,8 @@ "type": "array" }, "pickup_time": { - "description": "When was the query execution task picked up by a worker. @format date-time", + "description": "When was the query execution task picked up by a worker.", + "format": "date-time", "type": "string" }, "query_async": { @@ -11729,7 +11655,8 @@ }, "results": {}, "start_time": { - "description": "When was query execution task enqueued. @format date-time", + "description": "When was query execution task enqueued.", + "format": "date-time", "type": "string" }, "task_id": { @@ -13257,7 +13184,7 @@ }, "results": { "items": { - "$ref": "#/definitions/AITrace" + "$ref": "#/definitions/LLMTrace" }, "type": "array" }, diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index a58641d6e0bbc..37a7487a5f68d 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -2530,7 +2530,7 @@ export type ActorsPropertyTaxonomyQueryResponse = AnalyticsQueryResponseBase -export interface AIGeneration { +export interface LLMGeneration { id: string created_at: string input: any[] @@ -2547,27 +2547,27 @@ export interface AIGeneration { base_url?: string } -export interface AITracePerson { +export interface LLMTracePerson { uuid: string created_at: string properties: Record distinct_id: string } -export interface AITrace { +export interface LLMTrace { id: string created_at: string - person: AITracePerson + person: LLMTracePerson total_latency: number input_tokens: number output_tokens: number input_cost: number output_cost: number total_cost: number - events: AIGeneration[] + events: LLMGeneration[] } -export interface TracesQueryResponse extends AnalyticsQueryResponseBase { +export interface TracesQueryResponse extends AnalyticsQueryResponseBase { hasMore?: boolean limit?: integer offset?: integer diff --git a/posthog/hogql_queries/ai/traces_query_runner.py b/posthog/hogql_queries/ai/traces_query_runner.py index 3ec9c9d732d63..f611cf62ef6a6 100644 --- a/posthog/hogql_queries/ai/traces_query_runner.py +++ b/posthog/hogql_queries/ai/traces_query_runner.py @@ -10,10 +10,10 @@ from posthog.hogql_queries.insights.paginators import HogQLHasMorePaginator from posthog.hogql_queries.query_runner import QueryRunner from posthog.schema import ( - AIGeneration, - AITrace, - AITracePerson, CachedTracesQueryResponse, + LLMGeneration, + LLMTrace, + LLMTracePerson, NodeKind, TracesQuery, TracesQueryResponse, @@ -116,12 +116,12 @@ def _map_results(self, columns: list[str], query_results: list): "person": self._map_person(result["person"]), "events": generations, } - trace = AITrace.model_validate({key: value for key, value in trace_dict.items() if key in TRACE_FIELDS}) + trace = LLMTrace.model_validate({key: value for key, value in trace_dict.items() if key in TRACE_FIELDS}) traces.append(trace) return traces - def _map_generation(self, event_uuid: UUID, event_timestamp: datetime, event_properties: str) -> AIGeneration: + def _map_generation(self, event_uuid: UUID, event_timestamp: datetime, event_properties: str) -> LLMGeneration: properties: dict = orjson.loads(event_properties) GENERATION_MAPPING = { @@ -152,11 +152,11 @@ def _map_generation(self, event_uuid: UUID, event_timestamp: datetime, event_pro else: generation[model_prop] = properties[event_prop] - return AIGeneration.model_validate(generation) + return LLMGeneration.model_validate(generation) - def _map_person(self, person: tuple[UUID, UUID, datetime, str]) -> AITracePerson: + def _map_person(self, person: tuple[UUID, UUID, datetime, str]) -> LLMTracePerson: uuid, distinct_id, created_at, properties = person - return AITracePerson( + return LLMTracePerson( uuid=str(uuid), distinct_id=str(distinct_id), created_at=created_at.isoformat(), diff --git a/posthog/schema.py b/posthog/schema.py index 1a1d33e0a5854..8fa1b6e6d6120 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -12,36 +12,6 @@ class SchemaRoot(RootModel[Any]): root: Any -class AIGeneration(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) - base_url: Optional[str] = None - created_at: str - http_status: Optional[float] = None - id: str - input: list - input_cost: Optional[float] = None - input_tokens: Optional[float] = None - latency: float - model: Optional[str] = None - output: Optional[Any] = None - output_cost: Optional[float] = None - output_tokens: Optional[float] = None - provider: Optional[str] = None - total_cost: Optional[float] = None - - -class AITracePerson(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) - created_at: str - distinct_id: str - properties: dict[str, Any] - uuid: str - - class ActionConversionGoal(BaseModel): model_config = ConfigDict( extra="forbid", @@ -169,6 +139,13 @@ class AssistantGenericMultipleBreakdownFilter(BaseModel): type: AssistantEventMultipleBreakdownFilterType +class Type(StrEnum): + EVENT = "event" + PERSON = "person" + SESSION = "session" + FEATURE = "feature" + + class AssistantGenericPropertyFilter2(BaseModel): model_config = ConfigDict( extra="forbid", @@ -177,7 +154,7 @@ class AssistantGenericPropertyFilter2(BaseModel): operator: AssistantArrayPropertyFilterOperator = Field( ..., description="`exact` - exact match of any of the values. `is_not` - does not match any of the values." ) - type: str + type: Type value: list[str] = Field( ..., description=( @@ -193,7 +170,7 @@ class AssistantGenericPropertyFilter3(BaseModel): ) key: str = Field(..., description="Use one of the properties the user has provided in the plan.") operator: AssistantDateTimePropertyFilterOperator - type: str + type: Type value: str = Field(..., description="Value must be a date in ISO 8601 format.") @@ -650,7 +627,7 @@ class DatabaseSchemaSource(BaseModel): status: str -class Type(StrEnum): +class Type4(StrEnum): POSTHOG = "posthog" DATA_WAREHOUSE = "data_warehouse" VIEW = "view" @@ -1156,6 +1133,36 @@ class IntervalType(StrEnum): MONTH = "month" +class LLMGeneration(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + base_url: Optional[str] = None + created_at: str + http_status: Optional[float] = None + id: str + input: list + input_cost: Optional[float] = None + input_tokens: Optional[float] = None + latency: float + model: Optional[str] = None + output: Optional[Any] = None + output_cost: Optional[float] = None + output_tokens: Optional[float] = None + provider: Optional[str] = None + total_cost: Optional[float] = None + + +class LLMTracePerson(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + created_at: str + distinct_id: str + properties: dict[str, Any] + uuid: str + + class LifecycleToggle(StrEnum): NEW = "new" RESURRECTING = "resurrecting" @@ -1386,9 +1393,8 @@ class QueryStatus(BaseModel): ), ) dashboard_id: Optional[int] = None - end_time: Optional[str] = Field( - default=None, - description="When did the query execution task finish (whether successfully or not). @format date-time", + end_time: Optional[AwareDatetime] = Field( + default=None, description="When did the query execution task finish (whether successfully or not)." ) error: Optional[bool] = Field( default=False, @@ -1401,15 +1407,13 @@ class QueryStatus(BaseModel): id: str insight_id: Optional[int] = None labels: Optional[list[str]] = None - pickup_time: Optional[str] = Field( - default=None, description="When was the query execution task picked up by a worker. @format date-time" + pickup_time: Optional[AwareDatetime] = Field( + default=None, description="When was the query execution task picked up by a worker." ) query_async: Literal[True] = Field(default=True, description="ONLY async queries use QueryStatus.") query_progress: Optional[ClickhouseQueryProgress] = None results: Optional[Any] = None - start_time: Optional[str] = Field( - default=None, description="When was query execution task enqueued. @format date-time" - ) + start_time: Optional[AwareDatetime] = Field(default=None, description="When was query execution task enqueued.") task_id: Optional[str] = None team_id: int @@ -1786,22 +1790,6 @@ class NumericalKey(RootModel[str]): root: str -class AITrace(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) - created_at: str - events: list[AIGeneration] - id: str - input_cost: float - input_tokens: float - output_cost: float - output_tokens: float - person: AITracePerson - total_cost: float - total_latency: float - - class AlertCondition(BaseModel): model_config = ConfigDict( extra="forbid", @@ -1952,7 +1940,7 @@ class AssistantGenericPropertyFilter1(BaseModel): " matches the regex pattern. `not_regex` - does not match the regex pattern." ), ) - type: str + type: Type value: str = Field( ..., description=( @@ -1975,7 +1963,7 @@ class AssistantGenericPropertyFilter4(BaseModel): " collected." ), ) - type: str + type: Type class AssistantGroupPropertyFilter1(BaseModel): @@ -2263,7 +2251,7 @@ class DatabaseSchemaTableCommon(BaseModel): fields: dict[str, DatabaseSchemaField] id: str name: str - type: Type + type: Type4 class ElementPropertyFilter(BaseModel): @@ -2451,6 +2439,22 @@ class InsightThreshold(BaseModel): type: InsightThresholdType +class LLMTrace(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + created_at: str + events: list[LLMGeneration] + id: str + input_cost: float + input_tokens: float + output_cost: float + output_tokens: float + person: LLMTracePerson + total_cost: float + total_latency: float + + class LifecycleFilter(BaseModel): model_config = ConfigDict( extra="forbid", @@ -2935,7 +2939,7 @@ class QueryResponseAlternative28(BaseModel): query_status: Optional[QueryStatus] = Field( default=None, description="Query status indicates whether next to the provided data, a query is still running." ) - results: list[AITrace] + results: list[LLMTrace] timings: Optional[list[QueryTiming]] = Field( default=None, description="Measured timings for different parts of the query generation process" ) @@ -3114,7 +3118,7 @@ class QueryResponseAlternative41(BaseModel): query_status: Optional[QueryStatus] = Field( default=None, description="Query status indicates whether next to the provided data, a query is still running." ) - results: list[AITrace] + results: list[LLMTrace] timings: Optional[list[QueryTiming]] = Field( default=None, description="Measured timings for different parts of the query generation process" ) @@ -3480,7 +3484,7 @@ class TracesQueryResponse(BaseModel): query_status: Optional[QueryStatus] = Field( default=None, description="Query status indicates whether next to the provided data, a query is still running." ) - results: list[AITrace] + results: list[LLMTrace] timings: Optional[list[QueryTiming]] = Field( default=None, description="Measured timings for different parts of the query generation process" ) @@ -4524,7 +4528,7 @@ class CachedTracesQueryResponse(BaseModel): query_status: Optional[QueryStatus] = Field( default=None, description="Query status indicates whether next to the provided data, a query is still running." ) - results: list[AITrace] + results: list[LLMTrace] timezone: str timings: Optional[list[QueryTiming]] = Field( default=None, description="Measured timings for different parts of the query generation process" @@ -4934,7 +4938,7 @@ class Response11(BaseModel): query_status: Optional[QueryStatus] = Field( default=None, description="Query status indicates whether next to the provided data, a query is still running." ) - results: list[AITrace] + results: list[LLMTrace] timings: Optional[list[QueryTiming]] = Field( default=None, description="Measured timings for different parts of the query generation process" ) From 117a9d19d9c9416033650dfabbf8b65a88cfe165 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 14 Jan 2025 17:28:53 +0100 Subject: [PATCH 11/15] fix: schema renaming --- posthog/hogql_queries/ai/test/test_traces_query_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/posthog/hogql_queries/ai/test/test_traces_query_runner.py b/posthog/hogql_queries/ai/test/test_traces_query_runner.py index 259a15abcba69..f8b8938c76dcb 100644 --- a/posthog/hogql_queries/ai/test/test_traces_query_runner.py +++ b/posthog/hogql_queries/ai/test/test_traces_query_runner.py @@ -9,7 +9,7 @@ from posthog.hogql_queries.ai.traces_query_runner import TracesQueryRunner from posthog.models import PropertyDefinition, Team from posthog.models.property_definition import PropertyType -from posthog.schema import AIGeneration, AITrace, TracesQuery +from posthog.schema import LLMGeneration, LLMTrace, TracesQuery from posthog.test.base import ( BaseTest, ClickhouseTestMixin, @@ -109,12 +109,12 @@ def _create_properties(self): models_to_create.append(prop_model) PropertyDefinition.objects.bulk_create(models_to_create) - def assertTraceEqual(self, trace: AITrace, expected_trace: dict): + def assertTraceEqual(self, trace: LLMTrace, expected_trace: dict): trace_dict = trace.model_dump() for key, value in expected_trace.items(): self.assertEqual(trace_dict[key], value, f"Field {key} does not match") - def assertEventEqual(self, event: AIGeneration, expected_event: dict): + def assertEventEqual(self, event: LLMGeneration, expected_event: dict): event_dict = event.model_dump() for key, value in expected_event.items(): self.assertEqual(event_dict[key], value, f"Field {key} does not match") From 3a12ca58557dd75230e188a2fc16dd828edc9cab Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 14 Jan 2025 17:48:35 +0100 Subject: [PATCH 12/15] fix: used the new query in wrong places --- frontend/src/queries/schema.json | 111 ------------------ frontend/src/queries/schema/schema-general.ts | 4 - .../scenes/saved-insights/SavedInsights.tsx | 5 + posthog/schema.py | 82 ++----------- 4 files changed, 17 insertions(+), 185 deletions(-) diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index 28225460762de..dba8034877c1e 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -369,9 +369,6 @@ }, { "$ref": "#/definitions/RecordingsQuery" - }, - { - "$ref": "#/definitions/TracesQuery" } ] }, @@ -489,9 +486,6 @@ }, { "$ref": "#/definitions/ErrorTrackingQueryResponse" - }, - { - "$ref": "#/definitions/TracesQueryResponse" } ] }, @@ -4907,57 +4901,6 @@ "variants" ], "type": "object" - }, - { - "additionalProperties": false, - "properties": { - "columns": { - "items": { - "type": "string" - }, - "type": "array" - }, - "error": { - "description": "Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", - "type": "string" - }, - "hasMore": { - "type": "boolean" - }, - "hogql": { - "description": "Generated HogQL query.", - "type": "string" - }, - "limit": { - "$ref": "#/definitions/integer" - }, - "modifiers": { - "$ref": "#/definitions/HogQLQueryModifiers", - "description": "Modifiers used when performing the query" - }, - "offset": { - "$ref": "#/definitions/integer" - }, - "query_status": { - "$ref": "#/definitions/QueryStatus", - "description": "Query status indicates whether next to the provided data, a query is still running." - }, - "results": { - "items": { - "$ref": "#/definitions/LLMTrace" - }, - "type": "array" - }, - "timings": { - "description": "Measured timings for different parts of the query generation process", - "items": { - "$ref": "#/definitions/QueryTiming" - }, - "type": "array" - } - }, - "required": ["results"], - "type": "object" } ] }, @@ -5075,9 +5018,6 @@ }, { "$ref": "#/definitions/ExperimentTrendsQuery" - }, - { - "$ref": "#/definitions/TracesQuery" } ], "description": "Source of the events" @@ -11039,57 +10979,6 @@ ], "type": "object" }, - { - "additionalProperties": false, - "properties": { - "columns": { - "items": { - "type": "string" - }, - "type": "array" - }, - "error": { - "description": "Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", - "type": "string" - }, - "hasMore": { - "type": "boolean" - }, - "hogql": { - "description": "Generated HogQL query.", - "type": "string" - }, - "limit": { - "$ref": "#/definitions/integer" - }, - "modifiers": { - "$ref": "#/definitions/HogQLQueryModifiers", - "description": "Modifiers used when performing the query" - }, - "offset": { - "$ref": "#/definitions/integer" - }, - "query_status": { - "$ref": "#/definitions/QueryStatus", - "description": "Query status indicates whether next to the provided data, a query is still running." - }, - "results": { - "items": { - "$ref": "#/definitions/LLMTrace" - }, - "type": "array" - }, - "timings": { - "description": "Measured timings for different parts of the query generation process", - "items": { - "$ref": "#/definitions/QueryTiming" - }, - "type": "array" - } - }, - "required": ["results"], - "type": "object" - }, { "additionalProperties": false, "properties": { diff --git a/frontend/src/queries/schema/schema-general.ts b/frontend/src/queries/schema/schema-general.ts index b13b83919a018..dba260d4ec2be 100644 --- a/frontend/src/queries/schema/schema-general.ts +++ b/frontend/src/queries/schema/schema-general.ts @@ -131,7 +131,6 @@ export type AnyDataNode = | ExperimentFunnelsQuery | ExperimentTrendsQuery | RecordingsQuery - | TracesQuery /** * @discriminator kind @@ -214,7 +213,6 @@ export type AnyResponseType = | EventsNode['response'] | EventsQueryResponse | ErrorTrackingQueryResponse - | TracesQueryResponse /** @internal - no need to emit to schema.json. */ export interface DataNode = Record> extends Node { @@ -635,7 +633,6 @@ export interface DataTableNode | ErrorTrackingQuery | ExperimentFunnelsQuery | ExperimentTrendsQuery - | TracesQuery )['response'] > >, @@ -656,7 +653,6 @@ export interface DataTableNode | ErrorTrackingQuery | ExperimentFunnelsQuery | ExperimentTrendsQuery - | TracesQuery /** Columns shown in the table, unless the `source` provides them. */ columns?: HogQLExpression[] /** Columns that aren't shown in the table, even if in columns or returned data */ diff --git a/frontend/src/scenes/saved-insights/SavedInsights.tsx b/frontend/src/scenes/saved-insights/SavedInsights.tsx index 8ebaa5a400c8d..aae326b16ab11 100644 --- a/frontend/src/scenes/saved-insights/SavedInsights.tsx +++ b/frontend/src/scenes/saved-insights/SavedInsights.tsx @@ -313,6 +313,11 @@ export const QUERY_TYPES_METADATA: Record = { icon: IconHogQL, inMenu: false, }, + [NodeKind.TracesQuery]: { + name: 'LLM Observability Traces', + icon: IconHogQL, + inMenu: false, + }, } export const INSIGHT_TYPES_METADATA: Record = { diff --git a/posthog/schema.py b/posthog/schema.py index 15a1ced9c5b0e..df66c17872058 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -1173,7 +1173,7 @@ class QueryResponseAlternative5(BaseModel): stdout: Optional[str] = None -class QueryResponseAlternative37(BaseModel): +class QueryResponseAlternative36(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -4096,31 +4096,6 @@ class Response8(BaseModel): ) -class Response11(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) - columns: Optional[list[str]] = None - error: Optional[str] = Field( - default=None, - description="Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", - ) - hasMore: Optional[bool] = None - hogql: Optional[str] = Field(default=None, description="Generated HogQL query.") - limit: Optional[int] = None - modifiers: Optional[HogQLQueryModifiers] = Field( - default=None, description="Modifiers used when performing the query" - ) - offset: Optional[int] = None - query_status: Optional[QueryStatus] = Field( - default=None, description="Query status indicates whether next to the provided data, a query is still running." - ) - results: list[LLMTrace] - timings: Optional[list[QueryTiming]] = Field( - default=None, description="Measured timings for different parts of the query generation process" - ) - - class DataWarehouseNode(BaseModel): model_config = ConfigDict( extra="forbid", @@ -5242,31 +5217,6 @@ class QueryResponseAlternative25(BaseModel): class QueryResponseAlternative28(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) - columns: Optional[list[str]] = None - error: Optional[str] = Field( - default=None, - description="Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", - ) - hasMore: Optional[bool] = None - hogql: Optional[str] = Field(default=None, description="Generated HogQL query.") - limit: Optional[int] = None - modifiers: Optional[HogQLQueryModifiers] = Field( - default=None, description="Modifiers used when performing the query" - ) - offset: Optional[int] = None - query_status: Optional[QueryStatus] = Field( - default=None, description="Query status indicates whether next to the provided data, a query is still running." - ) - results: list[LLMTrace] - timings: Optional[list[QueryTiming]] = Field( - default=None, description="Measured timings for different parts of the query generation process" - ) - - -class QueryResponseAlternative29(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -5288,7 +5238,7 @@ class QueryResponseAlternative29(BaseModel): ) -class QueryResponseAlternative30(BaseModel): +class QueryResponseAlternative29(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -5310,7 +5260,7 @@ class QueryResponseAlternative30(BaseModel): ) -class QueryResponseAlternative32(BaseModel): +class QueryResponseAlternative31(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -5331,7 +5281,7 @@ class QueryResponseAlternative32(BaseModel): ) -class QueryResponseAlternative35(BaseModel): +class QueryResponseAlternative34(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -5357,7 +5307,7 @@ class QueryResponseAlternative35(BaseModel): types: Optional[list] = None -class QueryResponseAlternative38(BaseModel): +class QueryResponseAlternative37(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -5378,7 +5328,7 @@ class QueryResponseAlternative38(BaseModel): ) -class QueryResponseAlternative39(BaseModel): +class QueryResponseAlternative38(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -5399,7 +5349,7 @@ class QueryResponseAlternative39(BaseModel): ) -class QueryResponseAlternative40(BaseModel): +class QueryResponseAlternative39(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -5420,7 +5370,7 @@ class QueryResponseAlternative40(BaseModel): ) -class QueryResponseAlternative41(BaseModel): +class QueryResponseAlternative40(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -5711,7 +5661,6 @@ class AnyResponseType( Any, EventsQueryResponse, ErrorTrackingQueryResponse, - TracesQueryResponse, ] ] ): @@ -5724,7 +5673,6 @@ class AnyResponseType( Any, EventsQueryResponse, ErrorTrackingQueryResponse, - TracesQueryResponse, ] @@ -5936,7 +5884,7 @@ class PropertyGroupFilter(BaseModel): values: list[PropertyGroupFilterValue] -class QueryResponseAlternative31(BaseModel): +class QueryResponseAlternative30(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -6925,7 +6873,7 @@ class PathsQuery(BaseModel): samplingFactor: Optional[float] = Field(default=None, description="Sampling rate") -class QueryResponseAlternative36(BaseModel): +class QueryResponseAlternative35(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -6973,14 +6921,13 @@ class QueryResponseAlternative( QueryResponseAlternative29, QueryResponseAlternative30, QueryResponseAlternative31, - QueryResponseAlternative32, + QueryResponseAlternative34, QueryResponseAlternative35, QueryResponseAlternative36, QueryResponseAlternative37, QueryResponseAlternative38, QueryResponseAlternative39, QueryResponseAlternative40, - QueryResponseAlternative41, ] ] ): @@ -7014,14 +6961,13 @@ class QueryResponseAlternative( QueryResponseAlternative29, QueryResponseAlternative30, QueryResponseAlternative31, - QueryResponseAlternative32, + QueryResponseAlternative34, QueryResponseAlternative35, QueryResponseAlternative36, QueryResponseAlternative37, QueryResponseAlternative38, QueryResponseAlternative39, QueryResponseAlternative40, - QueryResponseAlternative41, ] @@ -7238,7 +7184,6 @@ class DataTableNode(BaseModel): Response8, Response9, Response10, - Response11, ] ] = None showActions: Optional[bool] = Field(default=None, description="Show the kebab menu at the end of the row") @@ -7281,7 +7226,6 @@ class DataTableNode(BaseModel): ErrorTrackingQuery, ExperimentFunnelsQuery, ExperimentTrendsQuery, - TracesQuery, ] = Field(..., description="Source of the events") @@ -7322,7 +7266,6 @@ class HogQLAutocomplete(BaseModel): ExperimentFunnelsQuery, ExperimentTrendsQuery, RecordingsQuery, - TracesQuery, ] ] = Field(default=None, description="Query in whose context to validate.") startPosition: int = Field(..., description="Start position of the editor word") @@ -7367,7 +7310,6 @@ class HogQLMetadata(BaseModel): ExperimentFunnelsQuery, ExperimentTrendsQuery, RecordingsQuery, - TracesQuery, ] ] = Field( default=None, From 87cba508c28df6fa639a7a0223db6170bcfa7831 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 14 Jan 2025 21:39:49 +0100 Subject: [PATCH 13/15] fix: correct date ranges --- frontend/src/queries/schema.json | 3 + frontend/src/queries/schema/schema-general.ts | 1 + .../test_traces_query_runner.ambr | 20 +- .../ai/test/test_traces_query_runner.py | 121 ++++++++++- .../hogql_queries/ai/traces_query_runner.py | 200 ++++++++---------- posthog/schema.py | 1 + 6 files changed, 218 insertions(+), 128 deletions(-) diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index dba8034877c1e..448441ad8851e 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -13092,6 +13092,9 @@ "TracesQuery": { "additionalProperties": false, "properties": { + "dateRange": { + "$ref": "#/definitions/DateRange" + }, "kind": { "const": "TracesQuery", "type": "string" diff --git a/frontend/src/queries/schema/schema-general.ts b/frontend/src/queries/schema/schema-general.ts index dba260d4ec2be..8ef6b62fe8df2 100644 --- a/frontend/src/queries/schema/schema-general.ts +++ b/frontend/src/queries/schema/schema-general.ts @@ -2199,6 +2199,7 @@ export interface TracesQueryResponse extends AnalyticsQueryResponseBase { kind: NodeKind.TracesQuery traceId?: string + dateRange?: DateRange limit?: integer offset?: integer } diff --git a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr index ded095025fea7..c8e34eaef2b75 100644 --- a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr +++ b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr @@ -10,7 +10,7 @@ sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, - arraySort(x -> tupleElement(x, 2), groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events + arraySort(x -> x.2, groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events FROM events LEFT OUTER JOIN (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, @@ -30,7 +30,7 @@ WHERE equals(person.team_id, 99999) GROUP BY person.id HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) - WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation')) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-08 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-17 00:09:59', 6, 'UTC')))) GROUP BY id ORDER BY trace_timestamp DESC LIMIT 101 @@ -54,7 +54,7 @@ sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, - arraySort(x -> tupleElement(x, 2), groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events + arraySort(x -> x.2, groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events FROM events LEFT OUTER JOIN (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, @@ -74,7 +74,7 @@ WHERE equals(person.team_id, 99999) GROUP BY person.id HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) - WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation')) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-15 00:09:59', 6, 'UTC')))) GROUP BY id ORDER BY trace_timestamp DESC LIMIT 5 @@ -98,7 +98,7 @@ sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, - arraySort(x -> tupleElement(x, 2), groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events + arraySort(x -> x.2, groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events FROM events LEFT OUTER JOIN (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, @@ -118,7 +118,7 @@ WHERE equals(person.team_id, 99999) GROUP BY person.id HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) - WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation')) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-15 00:09:59', 6, 'UTC')))) GROUP BY id ORDER BY trace_timestamp DESC LIMIT 5 @@ -142,7 +142,7 @@ sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, - arraySort(x -> tupleElement(x, 2), groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events + arraySort(x -> x.2, groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events FROM events LEFT OUTER JOIN (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, @@ -162,7 +162,7 @@ WHERE equals(person.team_id, 99999) GROUP BY person.id HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) - WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation')) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-15 00:09:59', 6, 'UTC')))) GROUP BY id ORDER BY trace_timestamp DESC LIMIT 5 @@ -186,7 +186,7 @@ sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, - arraySort(x -> tupleElement(x, 2), groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events + arraySort(x -> x.2, groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events FROM events LEFT OUTER JOIN (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, @@ -206,7 +206,7 @@ WHERE equals(person.team_id, 99999) GROUP BY person.id HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) - WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), ifNull(equals(id, 'trace1'), 0)) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-15 00:09:59', 6, 'UTC'))), ifNull(equals(id, 'trace1'), 0)) GROUP BY id ORDER BY trace_timestamp DESC LIMIT 101 diff --git a/posthog/hogql_queries/ai/test/test_traces_query_runner.py b/posthog/hogql_queries/ai/test/test_traces_query_runner.py index f8b8938c76dcb..57704397abf0f 100644 --- a/posthog/hogql_queries/ai/test/test_traces_query_runner.py +++ b/posthog/hogql_queries/ai/test/test_traces_query_runner.py @@ -9,7 +9,7 @@ from posthog.hogql_queries.ai.traces_query_runner import TracesQueryRunner from posthog.models import PropertyDefinition, Team from posthog.models.property_definition import PropertyType -from posthog.schema import LLMGeneration, LLMTrace, TracesQuery +from posthog.schema import DateRange, LLMGeneration, LLMTrace, TracesQuery from posthog.test.base import ( BaseTest, ClickhouseTestMixin, @@ -119,6 +119,7 @@ def assertEventEqual(self, event: LLMGeneration, expected_event: dict): for key, value in expected_event.items(): self.assertEqual(event_dict[key], value, f"Field {key} does not match") + @freeze_time("2025-01-16T00:00:00Z") @snapshot_clickhouse_queries def test_field_mapping(self): _create_person(distinct_ids=["person1"], team=self.team) @@ -315,17 +316,119 @@ def test_maps_all_fields(self): }, ) + @freeze_time("2025-01-01T00:00:00Z") def test_person_properties(self): - with freeze_time("2025-01-01T00:00:00Z"): - _create_person(distinct_ids=["person1"], team=self.team, properties={"email": "test@posthog.com"}) - _create_ai_generation_event( - distinct_id="person1", - trace_id="trace1", - team=self.team, - ) - + _create_person(distinct_ids=["person1"], team=self.team, properties={"email": "test@posthog.com"}) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + team=self.team, + ) response = TracesQueryRunner(team=self.team, query=TracesQuery()).calculate() self.assertEqual(len(response.results), 1) self.assertEqual(response.results[0].person.created_at, "2025-01-01T00:00:00+00:00") self.assertEqual(response.results[0].person.properties, {"email": "test@posthog.com"}) self.assertEqual(response.results[0].person.distinct_id, "person1") + + @freeze_time("2025-01-16T00:00:00Z") + def test_date_range(self): + _create_person(distinct_ids=["person1"], team=self.team) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + team=self.team, + timestamp=datetime(2025, 1, 15), + ) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace2", + team=self.team, + timestamp=datetime(2024, 12, 1), + ) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace3", + team=self.team, + timestamp=datetime(2024, 11, 1), + ) + + response = TracesQueryRunner( + team=self.team, query=TracesQuery(dateRange=DateRange(date_from="-1m")) + ).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace1") + + response = TracesQueryRunner( + team=self.team, query=TracesQuery(dateRange=DateRange(date_from="-2m")) + ).calculate() + self.assertEqual(len(response.results), 2) + self.assertEqual(response.results[0].id, "trace1") + self.assertEqual(response.results[1].id, "trace2") + + response = TracesQueryRunner( + team=self.team, query=TracesQuery(dateRange=DateRange(date_from="-3m", date_to="-2m")) + ).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace3") + + def test_capture_range(self): + _create_person(distinct_ids=["person1"], team=self.team) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + team=self.team, + timestamp=datetime(2024, 12, 1, 0, 0), + ) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + team=self.team, + timestamp=datetime(2024, 12, 1, 0, 10), + ) + + response = TracesQueryRunner( + team=self.team, + query=TracesQuery(dateRange=DateRange(date_from="2024-12-01T00:00:00Z", date_to="2024-12-01T00:10:00Z")), + ).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace1") + + # Date is after the capture range + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace2", + team=self.team, + timestamp=datetime(2024, 12, 1, 0, 11), + ) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace2", + team=self.team, + timestamp=datetime(2024, 12, 1, 0, 12), + ) + response = TracesQueryRunner( + team=self.team, + query=TracesQuery(dateRange=DateRange(date_from="2024-12-01T00:00:00Z", date_to="2024-12-01T00:10:00Z")), + ).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace1") + + # Date is before the capture range + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace3", + team=self.team, + timestamp=datetime(2024, 11, 30, 23, 59), + ) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace3", + team=self.team, + timestamp=datetime(2024, 12, 1, 0, 0), + ) + response = TracesQueryRunner( + team=self.team, + query=TracesQuery(dateRange=DateRange(date_from="2024-12-01T00:00:00Z", date_to="2024-12-01T00:10:00Z")), + ).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace1") diff --git a/posthog/hogql_queries/ai/traces_query_runner.py b/posthog/hogql_queries/ai/traces_query_runner.py index f611cf62ef6a6..1588aa927fd5c 100644 --- a/posthog/hogql_queries/ai/traces_query_runner.py +++ b/posthog/hogql_queries/ai/traces_query_runner.py @@ -1,4 +1,5 @@ -from datetime import datetime +from datetime import datetime, timedelta +from functools import cached_property from typing import cast from uuid import UUID @@ -7,10 +8,13 @@ from posthog.hogql import ast from posthog.hogql.constants import LimitContext +from posthog.hogql.parser import parse_select from posthog.hogql_queries.insights.paginators import HogQLHasMorePaginator from posthog.hogql_queries.query_runner import QueryRunner +from posthog.hogql_queries.utils.query_date_range import QueryDateRange from posthog.schema import ( CachedTracesQueryResponse, + IntervalType, LLMGeneration, LLMTrace, LLMTracePerson, @@ -22,26 +26,25 @@ logger = structlog.get_logger(__name__) -""" -select - properties.$ai_trace_id as trace_id, - min(timestamp) as trace_timestamp, - max(person.properties) as person, - sum(properties.$ai_latency) as total_latency, - sum(properties.$ai_input_tokens) as input_tokens, - sum(properties.$ai_output_tokens) as output_tokens, - sum(properties.$ai_input_cost_usd) as input_cost, - sum(properties.$ai_output_cost_usd) as output_cost, - sum(properties.$ai_total_cost_usd) as total_cost, - arraySort(x -> x.1, groupArray(tuple(timestamp, properties))) as events -from events -where - event = '$ai_generation' -group by - trace_id -order by - trace_timestamp desc -""" +class TracesQueryDateRange(QueryDateRange): + """ + Extends the QueryDateRange to include a capture range of 10 minutes before and after the date range. + It's a naive assumption that a trace finishes generating within 10 minutes of the first event so we can apply the date filters. + """ + + CAPTURE_RANGE_MINUTES = 10 + + def date_from_for_filtering(self) -> datetime: + return super().date_from() + + def date_to_for_filtering(self) -> datetime: + return super().date_to() + + def date_from(self) -> datetime: + return super().date_from() - timedelta(minutes=self.CAPTURE_RANGE_MINUTES) + + def date_to(self) -> datetime: + return super().date_to() + timedelta(minutes=self.CAPTURE_RANGE_MINUTES) class TracesQueryRunner(QueryRunner): @@ -59,13 +62,9 @@ def __init__(self, *args, **kwargs): ) def to_query(self) -> ast.SelectQuery: - return ast.SelectQuery( - select=self._get_select_fields(), - select_from=ast.JoinExpr(table=ast.Field(chain=["events"])), - where=self._get_where_clause(), - order_by=self._get_order_by_clause(), - group_by=[ast.Field(chain=["id"])], - ) + query = self._get_event_query() + query.where = self._get_where_clause() + return query def calculate(self): with self.timings.measure("traces_query_hogql_execute"): @@ -90,6 +89,11 @@ def calculate(self): **self.paginator.response_params(), ) + @cached_property + def _date_range(self): + # Minute-level precision for 10m capture range + return TracesQueryDateRange(self.query.dateRange, self.team, IntervalType.MINUTE, datetime.now()) + def _map_results(self, columns: list[str], query_results: list): TRACE_FIELDS = { "id", @@ -107,12 +111,21 @@ def _map_results(self, columns: list[str], query_results: list): traces = [] for result in mapped_results: + # Exclude traces that are outside of the capture range. + timestamp_dt = cast(datetime, result["trace_timestamp"]) + if ( + timestamp_dt < self._date_range.date_from_for_filtering() + or timestamp_dt > self._date_range.date_to_for_filtering() + ): + continue + generations = [] for uuid, timestamp, properties in result["events"]: generations.append(self._map_generation(uuid, timestamp, properties)) + trace_dict = { **result, - "created_at": cast(datetime, result["trace_timestamp"]).isoformat(), + "created_at": timestamp_dt.isoformat(), "person": self._map_person(result["person"]), "events": generations, } @@ -160,95 +173,64 @@ def _map_person(self, person: tuple[UUID, UUID, datetime, str]) -> LLMTracePerso uuid=str(uuid), distinct_id=str(distinct_id), created_at=created_at.isoformat(), - properties=orjson.loads(properties), + properties=orjson.loads(properties) if properties else {}, ) - def _get_select_fields(self) -> list[ast.Expr]: - return [ - ast.Alias(expr=ast.Field(chain=["properties", "$ai_trace_id"]), alias="id"), - ast.Alias(expr=ast.Call(name="min", args=[ast.Field(chain=["timestamp"])]), alias="trace_timestamp"), - self._get_person_field(), - ast.Alias( - expr=ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_latency"])]), - alias="total_latency", - ), - ast.Alias( - expr=ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_input_tokens"])]), - alias="input_tokens", - ), - ast.Alias( - expr=ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_output_tokens"])]), - alias="output_tokens", - ), - ast.Alias( - expr=ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_input_cost_usd"])]), - alias="input_cost", - ), - ast.Alias( - expr=ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_output_cost_usd"])]), - alias="output_cost", + def _get_event_query(self) -> ast.SelectQuery: + query = parse_select( + """ + SELECT + properties.$ai_trace_id as id, + min(timestamp) as trace_timestamp, + tuple(max(person.id), max(distinct_id), max(person.created_at), max(person.properties)) as person, + sum(properties.$ai_latency) as total_latency, + sum(properties.$ai_input_tokens) as input_tokens, + sum(properties.$ai_output_tokens) as output_tokens, + sum(properties.$ai_input_cost_usd) as input_cost, + sum(properties.$ai_output_cost_usd) as output_cost, + sum(properties.$ai_total_cost_usd) as total_cost, + arraySort(x -> x.2, groupArray(tuple(uuid, timestamp, properties))) as events + FROM + events + GROUP BY + id + ORDER BY + trace_timestamp DESC + """ + ) + return cast(ast.SelectQuery, query) + + def _get_where_clause(self): + timestamp_field = ast.Field(chain=["events", "timestamp"]) + + exprs: list[ast.Expr] = [ + ast.CompareOperation( + left=ast.Field(chain=["event"]), + op=ast.CompareOperationOp.Eq, + right=ast.Constant(value="$ai_generation"), ), - ast.Alias( - expr=ast.Call(name="sum", args=[ast.Field(chain=["properties", "$ai_total_cost_usd"])]), - alias="total_cost", + ast.CompareOperation( + op=ast.CompareOperationOp.GtEq, + left=timestamp_field, + right=self._date_range.date_from_as_hogql(), ), - ast.Alias( - expr=ast.Call( - name="arraySort", - args=[ - ast.Lambda( - args=["x"], - expr=ast.Call(name="tupleElement", args=[ast.Field(chain=["x"]), ast.Constant(value=2)]), - ), - ast.Call( - name="groupArray", - args=[ - ast.Tuple( - exprs=[ - ast.Field(chain=["uuid"]), - ast.Field(chain=["timestamp"]), - ast.Field(chain=["properties"]), - ] - ) - ], - ), - ], - ), - alias="events", + ast.CompareOperation( + op=ast.CompareOperationOp.LtEq, + left=timestamp_field, + right=self._date_range.date_to_as_hogql(), ), ] - def _get_person_field(self): - return ast.Alias( - expr=ast.Tuple( - exprs=[ - ast.Call(name="max", args=[ast.Field(chain=["person", "id"])]), - ast.Call(name="max", args=[ast.Field(chain=["distinct_id"])]), - ast.Call(name="max", args=[ast.Field(chain=["person", "created_at"])]), - ast.Call(name="max", args=[ast.Field(chain=["person", "properties"])]), - ], - ), - alias="person", - ) - - def _get_where_clause(self): - event_expr = ast.CompareOperation( - left=ast.Field(chain=["event"]), - op=ast.CompareOperationOp.Eq, - right=ast.Constant(value="$ai_generation"), - ) if self.query.traceId is not None: - return ast.And( - exprs=[ - event_expr, - ast.CompareOperation( - left=ast.Field(chain=["id"]), - op=ast.CompareOperationOp.Eq, - right=ast.Constant(value=self.query.traceId), - ), - ] + exprs.append( + ast.CompareOperation( + left=ast.Field(chain=["id"]), + op=ast.CompareOperationOp.Eq, + right=ast.Constant(value=self.query.traceId), + ), ) - return event_expr + + return ast.And(exprs=exprs) def _get_order_by_clause(self): return [ast.OrderExpr(expr=ast.Field(chain=["trace_timestamp"]), order="DESC")] diff --git a/posthog/schema.py b/posthog/schema.py index df66c17872058..25968784ec90f 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -5470,6 +5470,7 @@ class TracesQuery(BaseModel): model_config = ConfigDict( extra="forbid", ) + dateRange: Optional[DateRange] = None kind: Literal["TracesQuery"] = "TracesQuery" limit: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = Field( From 335ea2370e68b4981b4f220ba7a3babbdc2c6561 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 14 Jan 2025 21:49:17 +0100 Subject: [PATCH 14/15] fix: camel case in query schemas --- frontend/src/queries/schema.json | 52 +++++------ frontend/src/queries/schema/schema-general.ts | 34 ++++---- .../scenes/saved-insights/SavedInsights.tsx | 3 +- .../test_traces_query_runner.ambr | 10 +-- .../ai/test/test_traces_query_runner.py | 86 +++++++++---------- .../hogql_queries/ai/traces_query_runner.py | 47 +++++----- posthog/schema.py | 34 ++++---- 7 files changed, 135 insertions(+), 131 deletions(-) diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index 448441ad8851e..f2fedede548fb 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -8566,13 +8566,13 @@ "LLMGeneration": { "additionalProperties": false, "properties": { - "base_url": { + "baseUrl": { "type": "string" }, - "created_at": { + "createdAt": { "type": "string" }, - "http_status": { + "httpStatus": { "type": "number" }, "id": { @@ -8582,10 +8582,10 @@ "items": {}, "type": "array" }, - "input_cost": { + "inputCost": { "type": "number" }, - "input_tokens": { + "inputTokens": { "type": "number" }, "latency": { @@ -8595,26 +8595,26 @@ "type": "string" }, "output": {}, - "output_cost": { + "outputCost": { "type": "number" }, - "output_tokens": { + "outputTokens": { "type": "number" }, "provider": { "type": "string" }, - "total_cost": { + "totalCost": { "type": "number" } }, - "required": ["id", "created_at", "input", "latency"], + "required": ["id", "createdAt", "input", "latency"], "type": "object" }, "LLMTrace": { "additionalProperties": false, "properties": { - "created_at": { + "createdAt": { "type": "string" }, "events": { @@ -8626,38 +8626,38 @@ "id": { "type": "string" }, - "input_cost": { + "inputCost": { "type": "number" }, - "input_tokens": { + "inputTokens": { "type": "number" }, - "output_cost": { + "outputCost": { "type": "number" }, - "output_tokens": { + "outputTokens": { "type": "number" }, "person": { "$ref": "#/definitions/LLMTracePerson" }, - "total_cost": { + "totalCost": { "type": "number" }, - "total_latency": { + "totalLatency": { "type": "number" } }, "required": [ "id", - "created_at", + "createdAt", "person", - "total_latency", - "input_tokens", - "output_tokens", - "input_cost", - "output_cost", - "total_cost", + "totalLatency", + "inputTokens", + "outputTokens", + "inputCost", + "outputCost", + "totalCost", "events" ], "type": "object" @@ -8665,10 +8665,10 @@ "LLMTracePerson": { "additionalProperties": false, "properties": { - "created_at": { + "createdAt": { "type": "string" }, - "distinct_id": { + "distinctId": { "type": "string" }, "properties": { @@ -8678,7 +8678,7 @@ "type": "string" } }, - "required": ["uuid", "created_at", "properties", "distinct_id"], + "required": ["uuid", "createdAt", "properties", "distinctId"], "type": "object" }, "LifecycleFilter": { diff --git a/frontend/src/queries/schema/schema-general.ts b/frontend/src/queries/schema/schema-general.ts index 8ef6b62fe8df2..2148a6ef0b3e8 100644 --- a/frontend/src/queries/schema/schema-general.ts +++ b/frontend/src/queries/schema/schema-general.ts @@ -2154,38 +2154,38 @@ export enum DefaultChannelTypes { export interface LLMGeneration { id: string - created_at: string + createdAt: string input: any[] latency: number output?: any provider?: string model?: string - input_tokens?: number - output_tokens?: number - input_cost?: number - output_cost?: number - total_cost?: number - http_status?: number - base_url?: string + inputTokens?: number + outputTokens?: number + inputCost?: number + outputCost?: number + totalCost?: number + httpStatus?: number + baseUrl?: string } export interface LLMTracePerson { uuid: string - created_at: string + createdAt: string properties: Record - distinct_id: string + distinctId: string } export interface LLMTrace { id: string - created_at: string + createdAt: string person: LLMTracePerson - total_latency: number - input_tokens: number - output_tokens: number - input_cost: number - output_cost: number - total_cost: number + totalLatency: number + inputTokens: number + outputTokens: number + inputCost: number + outputCost: number + totalCost: number events: LLMGeneration[] } diff --git a/frontend/src/scenes/saved-insights/SavedInsights.tsx b/frontend/src/scenes/saved-insights/SavedInsights.tsx index aae326b16ab11..3e910396fd7c9 100644 --- a/frontend/src/scenes/saved-insights/SavedInsights.tsx +++ b/frontend/src/scenes/saved-insights/SavedInsights.tsx @@ -1,6 +1,7 @@ import './SavedInsights.scss' import { + IconAI, IconBrackets, IconCorrelationAnalysis, IconCursor, @@ -315,7 +316,7 @@ export const QUERY_TYPES_METADATA: Record = { }, [NodeKind.TracesQuery]: { name: 'LLM Observability Traces', - icon: IconHogQL, + icon: IconAI, inMenu: false, }, } diff --git a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr index c8e34eaef2b75..72156bc23bfdc 100644 --- a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr +++ b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr @@ -30,7 +30,7 @@ WHERE equals(person.team_id, 99999) GROUP BY person.id HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) - WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-08 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-17 00:09:59', 6, 'UTC')))) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-08 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-16 00:10:59', 6, 'UTC')))) GROUP BY id ORDER BY trace_timestamp DESC LIMIT 101 @@ -74,7 +74,7 @@ WHERE equals(person.team_id, 99999) GROUP BY person.id HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) - WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-15 00:09:59', 6, 'UTC')))) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-14 20:58:59', 6, 'UTC')))) GROUP BY id ORDER BY trace_timestamp DESC LIMIT 5 @@ -118,7 +118,7 @@ WHERE equals(person.team_id, 99999) GROUP BY person.id HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) - WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-15 00:09:59', 6, 'UTC')))) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-14 20:58:59', 6, 'UTC')))) GROUP BY id ORDER BY trace_timestamp DESC LIMIT 5 @@ -162,7 +162,7 @@ WHERE equals(person.team_id, 99999) GROUP BY person.id HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) - WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-15 00:09:59', 6, 'UTC')))) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-14 20:58:59', 6, 'UTC')))) GROUP BY id ORDER BY trace_timestamp DESC LIMIT 5 @@ -206,7 +206,7 @@ WHERE equals(person.team_id, 99999) GROUP BY person.id HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) - WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-15 00:09:59', 6, 'UTC'))), ifNull(equals(id, 'trace1'), 0)) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-14 20:58:59', 6, 'UTC'))), ifNull(equals(id, 'trace1'), 0)) GROUP BY id ORDER BY trace_timestamp DESC LIMIT 101 diff --git a/posthog/hogql_queries/ai/test/test_traces_query_runner.py b/posthog/hogql_queries/ai/test/test_traces_query_runner.py index 57704397abf0f..6a03217e85529 100644 --- a/posthog/hogql_queries/ai/test/test_traces_query_runner.py +++ b/posthog/hogql_queries/ai/test/test_traces_query_runner.py @@ -158,16 +158,16 @@ def test_field_mapping(self): trace, { "id": "trace1", - "created_at": datetime(2025, 1, 15, 0, tzinfo=UTC).isoformat(), - "total_latency": 2.0, - "input_tokens": 6.0, - "output_tokens": 6.0, - "input_cost": 6.0, - "output_cost": 6.0, - "total_cost": 12.0, + "createdAt": datetime(2025, 1, 15, 0, tzinfo=UTC).isoformat(), + "totalLatency": 2.0, + "inputTokens": 6.0, + "outputTokens": 6.0, + "inputCost": 6.0, + "outputCost": 6.0, + "totalCost": 12.0, }, ) - self.assertEqual(trace.person.distinct_id, "person1") + self.assertEqual(trace.person.distinctId, "person1") self.assertEqual(len(trace.events), 2) event = trace.events[0] @@ -175,15 +175,15 @@ def test_field_mapping(self): self.assertEventEqual( event, { - "created_at": datetime(2025, 1, 15, 0, tzinfo=UTC).isoformat(), + "createdAt": datetime(2025, 1, 15, 0, tzinfo=UTC).isoformat(), "input": [{"role": "user", "content": "Foo"}], "output": {"choices": [{"role": "assistant", "content": "Bar"}]}, "latency": 1, - "input_tokens": 3, - "output_tokens": 3, - "input_cost": 3, - "output_cost": 3, - "total_cost": 6, + "inputTokens": 3, + "outputTokens": 3, + "inputCost": 3, + "outputCost": 3, + "totalCost": 6, }, ) @@ -192,17 +192,17 @@ def test_field_mapping(self): self.assertEventEqual( event, { - "created_at": datetime(2025, 1, 15, 1, tzinfo=UTC).isoformat(), + "createdAt": datetime(2025, 1, 15, 1, tzinfo=UTC).isoformat(), "input": [{"role": "user", "content": "Bar"}], "output": {"choices": [{"role": "assistant", "content": "Baz"}]}, "latency": 1, - "input_tokens": 3, - "output_tokens": 3, - "input_cost": 3, - "output_cost": 3, - "total_cost": 6, - "base_url": None, - "http_status": None, + "inputTokens": 3, + "outputTokens": 3, + "inputCost": 3, + "outputCost": 3, + "totalCost": 6, + "baseUrl": None, + "httpStatus": None, }, ) @@ -211,33 +211,33 @@ def test_field_mapping(self): trace, { "id": "trace2", - "created_at": datetime(2025, 1, 14, tzinfo=UTC).isoformat(), - "total_latency": 1, - "input_tokens": 3, - "output_tokens": 3, - "input_cost": 3, - "output_cost": 3, - "total_cost": 6, + "createdAt": datetime(2025, 1, 14, tzinfo=UTC).isoformat(), + "totalLatency": 1, + "inputTokens": 3, + "outputTokens": 3, + "inputCost": 3, + "outputCost": 3, + "totalCost": 6, }, ) - self.assertEqual(trace.person.distinct_id, "person2") + self.assertEqual(trace.person.distinctId, "person2") self.assertEqual(len(trace.events), 1) event = trace.events[0] self.assertIsNotNone(event.id) self.assertEventEqual( event, { - "created_at": datetime(2025, 1, 14, tzinfo=UTC).isoformat(), + "createdAt": datetime(2025, 1, 14, tzinfo=UTC).isoformat(), "input": [{"role": "user", "content": "Foo"}], "output": {"choices": [{"role": "assistant", "content": "Bar"}]}, "latency": 1, - "input_tokens": 3, - "output_tokens": 3, - "input_cost": 3, - "output_cost": 3, - "total_cost": 6, - "base_url": None, - "http_status": None, + "inputTokens": 3, + "outputTokens": 3, + "inputCost": 3, + "outputCost": 3, + "totalCost": 6, + "baseUrl": None, + "httpStatus": None, }, ) @@ -304,15 +304,15 @@ def test_maps_all_fields(self): response = TracesQueryRunner(team=self.team, query=TracesQuery()).calculate() self.assertEqual(len(response.results), 1) self.assertEqual(response.results[0].id, "trace1") - self.assertEqual(response.results[0].total_latency, 10.5) + self.assertEqual(response.results[0].totalLatency, 10.5) self.assertEventEqual( response.results[0].events[0], { "latency": 10.5, "provider": "posthog", "model": "hog-destroyer", - "http_status": 200, - "base_url": "https://us.posthog.com", + "httpStatus": 200, + "baseUrl": "https://us.posthog.com", }, ) @@ -326,9 +326,9 @@ def test_person_properties(self): ) response = TracesQueryRunner(team=self.team, query=TracesQuery()).calculate() self.assertEqual(len(response.results), 1) - self.assertEqual(response.results[0].person.created_at, "2025-01-01T00:00:00+00:00") + self.assertEqual(response.results[0].person.createdAt, "2025-01-01T00:00:00+00:00") self.assertEqual(response.results[0].person.properties, {"email": "test@posthog.com"}) - self.assertEqual(response.results[0].person.distinct_id, "person1") + self.assertEqual(response.results[0].person.distinctId, "person1") @freeze_time("2025-01-16T00:00:00Z") def test_date_range(self): diff --git a/posthog/hogql_queries/ai/traces_query_runner.py b/posthog/hogql_queries/ai/traces_query_runner.py index 1588aa927fd5c..df9a706029688 100644 --- a/posthog/hogql_queries/ai/traces_query_runner.py +++ b/posthog/hogql_queries/ai/traces_query_runner.py @@ -95,17 +95,17 @@ def _date_range(self): return TracesQueryDateRange(self.query.dateRange, self.team, IntervalType.MINUTE, datetime.now()) def _map_results(self, columns: list[str], query_results: list): - TRACE_FIELDS = { - "id", - "created_at", - "person", - "total_latency", - "input_tokens", - "output_tokens", - "input_cost", - "output_cost", - "total_cost", - "events", + TRACE_FIELDS_MAPPING = { + "id": "id", + "created_at": "createdAt", + "person": "person", + "total_latency": "totalLatency", + "input_tokens": "inputTokens", + "output_tokens": "outputTokens", + "input_cost": "inputCost", + "output_cost": "outputCost", + "total_cost": "totalCost", + "events": "events", } mapped_results = [dict(zip(columns, value)) for value in query_results] traces = [] @@ -129,7 +129,10 @@ def _map_results(self, columns: list[str], query_results: list): "person": self._map_person(result["person"]), "events": generations, } - trace = LLMTrace.model_validate({key: value for key, value in trace_dict.items() if key in TRACE_FIELDS}) + # Remap keys from snake case to camel case + trace = LLMTrace.model_validate( + {TRACE_FIELDS_MAPPING[key]: value for key, value in trace_dict.items() if key in TRACE_FIELDS_MAPPING} + ) traces.append(trace) return traces @@ -143,19 +146,19 @@ def _map_generation(self, event_uuid: UUID, event_timestamp: datetime, event_pro "$ai_output": "output", "$ai_provider": "provider", "$ai_model": "model", - "$ai_input_tokens": "input_tokens", - "$ai_output_tokens": "output_tokens", - "$ai_input_cost_usd": "input_cost", - "$ai_output_cost_usd": "output_cost", - "$ai_total_cost_usd": "total_cost", - "$ai_http_status": "http_status", - "$ai_base_url": "base_url", + "$ai_input_tokens": "inputTokens", + "$ai_output_tokens": "outputTokens", + "$ai_input_cost_usd": "inputCost", + "$ai_output_cost_usd": "outputCost", + "$ai_total_cost_usd": "totalCost", + "$ai_http_status": "httpStatus", + "$ai_base_url": "baseUrl", } GENERATION_JSON_FIELDS = {"$ai_input", "$ai_output"} generation = { "id": str(event_uuid), - "created_at": event_timestamp.isoformat(), + "createdAt": event_timestamp.isoformat(), } for event_prop, model_prop in GENERATION_MAPPING.items(): @@ -171,8 +174,8 @@ def _map_person(self, person: tuple[UUID, UUID, datetime, str]) -> LLMTracePerso uuid, distinct_id, created_at, properties = person return LLMTracePerson( uuid=str(uuid), - distinct_id=str(distinct_id), - created_at=created_at.isoformat(), + distinctId=str(distinct_id), + createdAt=created_at.isoformat(), properties=orjson.loads(properties) if properties else {}, ) diff --git a/posthog/schema.py b/posthog/schema.py index 25968784ec90f..ce9dc633f47a1 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -970,28 +970,28 @@ class LLMGeneration(BaseModel): model_config = ConfigDict( extra="forbid", ) - base_url: Optional[str] = None - created_at: str - http_status: Optional[float] = None + baseUrl: Optional[str] = None + createdAt: str + httpStatus: Optional[float] = None id: str input: list - input_cost: Optional[float] = None - input_tokens: Optional[float] = None + inputCost: Optional[float] = None + inputTokens: Optional[float] = None latency: float model: Optional[str] = None output: Optional[Any] = None - output_cost: Optional[float] = None - output_tokens: Optional[float] = None + outputCost: Optional[float] = None + outputTokens: Optional[float] = None provider: Optional[str] = None - total_cost: Optional[float] = None + totalCost: Optional[float] = None class LLMTracePerson(BaseModel): model_config = ConfigDict( extra="forbid", ) - created_at: str - distinct_id: str + createdAt: str + distinctId: str properties: dict[str, Any] uuid: str @@ -2219,16 +2219,16 @@ class LLMTrace(BaseModel): model_config = ConfigDict( extra="forbid", ) - created_at: str + createdAt: str events: list[LLMGeneration] id: str - input_cost: float - input_tokens: float - output_cost: float - output_tokens: float + inputCost: float + inputTokens: float + outputCost: float + outputTokens: float person: LLMTracePerson - total_cost: float - total_latency: float + totalCost: float + totalLatency: float class LifecycleFilter(BaseModel): From 42ea56d0dfb3d61f0d1029edf464495e0107b4f6 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Wed, 15 Jan 2025 10:00:51 +0100 Subject: [PATCH 15/15] fix: set snapshot dates --- .../ai/test/__snapshots__/test_traces_query_runner.ambr | 8 ++++---- posthog/hogql_queries/ai/test/test_traces_query_runner.py | 5 ++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr index 72156bc23bfdc..c5931f8c340fd 100644 --- a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr +++ b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr @@ -74,7 +74,7 @@ WHERE equals(person.team_id, 99999) GROUP BY person.id HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) - WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-14 20:58:59', 6, 'UTC')))) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-08 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-16 00:10:59', 6, 'UTC')))) GROUP BY id ORDER BY trace_timestamp DESC LIMIT 5 @@ -118,7 +118,7 @@ WHERE equals(person.team_id, 99999) GROUP BY person.id HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) - WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-14 20:58:59', 6, 'UTC')))) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-08 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-16 00:10:59', 6, 'UTC')))) GROUP BY id ORDER BY trace_timestamp DESC LIMIT 5 @@ -162,7 +162,7 @@ WHERE equals(person.team_id, 99999) GROUP BY person.id HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) - WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-14 20:58:59', 6, 'UTC')))) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-08 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-16 00:10:59', 6, 'UTC')))) GROUP BY id ORDER BY trace_timestamp DESC LIMIT 5 @@ -206,7 +206,7 @@ WHERE equals(person.team_id, 99999) GROUP BY person.id HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) - WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-06 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-14 20:58:59', 6, 'UTC'))), ifNull(equals(id, 'trace1'), 0)) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-08 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-16 00:10:59', 6, 'UTC'))), ifNull(equals(id, 'trace1'), 0)) GROUP BY id ORDER BY trace_timestamp DESC LIMIT 101 diff --git a/posthog/hogql_queries/ai/test/test_traces_query_runner.py b/posthog/hogql_queries/ai/test/test_traces_query_runner.py index 6a03217e85529..5956ffefc8441 100644 --- a/posthog/hogql_queries/ai/test/test_traces_query_runner.py +++ b/posthog/hogql_queries/ai/test/test_traces_query_runner.py @@ -241,6 +241,7 @@ def test_field_mapping(self): }, ) + @freeze_time("2025-01-16T00:00:00Z") @snapshot_clickhouse_queries def test_trace_id_filter(self): _create_person(distinct_ids=["person1"], team=self.team) @@ -252,6 +253,7 @@ def test_trace_id_filter(self): self.assertEqual(len(response.results), 1) self.assertEqual(response.results[0].id, "trace1") + @freeze_time("2025-01-16T00:00:00Z") @snapshot_clickhouse_queries def test_pagination(self): _create_person(distinct_ids=["person1"], team=self.team) @@ -261,8 +263,8 @@ def test_pagination(self): distinct_id="person1" if i % 2 == 0 else "person2", team=self.team, trace_id=f"trace_{i}", + timestamp=datetime(2025, 1, 15, i), ) - response = TracesQueryRunner(team=self.team, query=TracesQuery(limit=4, offset=0)).calculate() self.assertEqual(response.hasMore, True) self.assertEqual(len(response.results), 5) @@ -286,6 +288,7 @@ def test_pagination(self): self.assertEqual(len(response.results), 1) self.assertEqual(response.results[0].id, "trace_0") + @freeze_time("2025-01-16T00:00:00Z") def test_maps_all_fields(self): _create_person(distinct_ids=["person1"], team=self.team) _create_ai_generation_event(