diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index 50b305a185003..f2fedede548fb 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -3397,6 +3397,89 @@ ], "type": "object" }, + "CachedTracesQueryResponse": { + "additionalProperties": false, + "properties": { + "cache_key": { + "type": "string" + }, + "cache_target_age": { + "format": "date-time", + "type": "string" + }, + "calculation_trigger": { + "description": "What triggered the calculation of the query, leave empty if user/immediate", + "type": "string" + }, + "columns": { + "items": { + "type": "string" + }, + "type": "array" + }, + "error": { + "description": "Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + "type": "string" + }, + "hasMore": { + "type": "boolean" + }, + "hogql": { + "description": "Generated HogQL query.", + "type": "string" + }, + "is_cached": { + "type": "boolean" + }, + "last_refresh": { + "format": "date-time", + "type": "string" + }, + "limit": { + "$ref": "#/definitions/integer" + }, + "modifiers": { + "$ref": "#/definitions/HogQLQueryModifiers", + "description": "Modifiers used when performing the query" + }, + "next_allowed_client_refresh": { + "format": "date-time", + "type": "string" + }, + "offset": { + "$ref": "#/definitions/integer" + }, + "query_status": { + "$ref": "#/definitions/QueryStatus", + "description": "Query status indicates whether next to the provided data, a query is still running." + }, + "results": { + "items": { + "$ref": "#/definitions/LLMTrace" + }, + "type": "array" + }, + "timezone": { + "type": "string" + }, + "timings": { + "description": "Measured timings for different parts of the query generation process", + "items": { + "$ref": "#/definitions/QueryTiming" + }, + "type": "array" + } + }, + "required": [ + "cache_key", + "is_cached", + "last_refresh", + "next_allowed_client_refresh", + "results", + "timezone" + ], + "type": "object" + }, "CachedTrendsQueryResponse": { "additionalProperties": false, "properties": { @@ -8480,6 +8563,124 @@ "enum": ["minute", "hour", "day", "week", "month"], "type": "string" }, + "LLMGeneration": { + "additionalProperties": false, + "properties": { + "baseUrl": { + "type": "string" + }, + "createdAt": { + "type": "string" + }, + "httpStatus": { + "type": "number" + }, + "id": { + "type": "string" + }, + "input": { + "items": {}, + "type": "array" + }, + "inputCost": { + "type": "number" + }, + "inputTokens": { + "type": "number" + }, + "latency": { + "type": "number" + }, + "model": { + "type": "string" + }, + "output": {}, + "outputCost": { + "type": "number" + }, + "outputTokens": { + "type": "number" + }, + "provider": { + "type": "string" + }, + "totalCost": { + "type": "number" + } + }, + "required": ["id", "createdAt", "input", "latency"], + "type": "object" + }, + "LLMTrace": { + "additionalProperties": false, + "properties": { + "createdAt": { + "type": "string" + }, + "events": { + "items": { + "$ref": "#/definitions/LLMGeneration" + }, + "type": "array" + }, + "id": { + "type": "string" + }, + "inputCost": { + "type": "number" + }, + "inputTokens": { + "type": "number" + }, + "outputCost": { + "type": "number" + }, + "outputTokens": { + "type": "number" + }, + "person": { + "$ref": "#/definitions/LLMTracePerson" + }, + "totalCost": { + "type": "number" + }, + "totalLatency": { + "type": "number" + } + }, + "required": [ + "id", + "createdAt", + "person", + "totalLatency", + "inputTokens", + "outputTokens", + "inputCost", + "outputCost", + "totalCost", + "events" + ], + "type": "object" + }, + "LLMTracePerson": { + "additionalProperties": false, + "properties": { + "createdAt": { + "type": "string" + }, + "distinctId": { + "type": "string" + }, + "properties": { + "type": "object" + }, + "uuid": { + "type": "string" + } + }, + "required": ["uuid", "createdAt", "properties", "distinctId"], + "type": "object" + }, "LifecycleFilter": { "additionalProperties": false, "properties": { @@ -8765,7 +8966,8 @@ "SuggestedQuestionsQuery", "TeamTaxonomyQuery", "EventTaxonomyQuery", - "ActorsPropertyTaxonomyQuery" + "ActorsPropertyTaxonomyQuery", + "TracesQuery" ], "type": "string" }, @@ -11184,6 +11386,57 @@ }, "required": ["results"], "type": "object" + }, + { + "additionalProperties": false, + "properties": { + "columns": { + "items": { + "type": "string" + }, + "type": "array" + }, + "error": { + "description": "Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + "type": "string" + }, + "hasMore": { + "type": "boolean" + }, + "hogql": { + "description": "Generated HogQL query.", + "type": "string" + }, + "limit": { + "$ref": "#/definitions/integer" + }, + "modifiers": { + "$ref": "#/definitions/HogQLQueryModifiers", + "description": "Modifiers used when performing the query" + }, + "offset": { + "$ref": "#/definitions/integer" + }, + "query_status": { + "$ref": "#/definitions/QueryStatus", + "description": "Query status indicates whether next to the provided data, a query is still running." + }, + "results": { + "items": { + "$ref": "#/definitions/LLMTrace" + }, + "type": "array" + }, + "timings": { + "description": "Measured timings for different parts of the query generation process", + "items": { + "$ref": "#/definitions/QueryTiming" + }, + "type": "array" + } + }, + "required": ["results"], + "type": "object" } ] }, @@ -11302,6 +11555,9 @@ }, { "$ref": "#/definitions/ActorsPropertyTaxonomyQuery" + }, + { + "$ref": "#/definitions/TracesQuery" } ], "required": ["kind"], @@ -12833,6 +13089,87 @@ "required": ["events"], "type": "object" }, + "TracesQuery": { + "additionalProperties": false, + "properties": { + "dateRange": { + "$ref": "#/definitions/DateRange" + }, + "kind": { + "const": "TracesQuery", + "type": "string" + }, + "limit": { + "$ref": "#/definitions/integer" + }, + "modifiers": { + "$ref": "#/definitions/HogQLQueryModifiers", + "description": "Modifiers used when performing the query" + }, + "offset": { + "$ref": "#/definitions/integer" + }, + "response": { + "$ref": "#/definitions/TracesQueryResponse" + }, + "traceId": { + "type": "string" + } + }, + "required": ["kind"], + "type": "object" + }, + "TracesQueryResponse": { + "additionalProperties": false, + "properties": { + "columns": { + "items": { + "type": "string" + }, + "type": "array" + }, + "error": { + "description": "Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + "type": "string" + }, + "hasMore": { + "type": "boolean" + }, + "hogql": { + "description": "Generated HogQL query.", + "type": "string" + }, + "limit": { + "$ref": "#/definitions/integer" + }, + "modifiers": { + "$ref": "#/definitions/HogQLQueryModifiers", + "description": "Modifiers used when performing the query" + }, + "offset": { + "$ref": "#/definitions/integer" + }, + "query_status": { + "$ref": "#/definitions/QueryStatus", + "description": "Query status indicates whether next to the provided data, a query is still running." + }, + "results": { + "items": { + "$ref": "#/definitions/LLMTrace" + }, + "type": "array" + }, + "timings": { + "description": "Measured timings for different parts of the query generation process", + "items": { + "$ref": "#/definitions/QueryTiming" + }, + "type": "array" + } + }, + "required": ["results"], + "type": "object" + }, "TrendsAlertConfig": { "additionalProperties": false, "properties": { diff --git a/frontend/src/queries/schema/schema-general.ts b/frontend/src/queries/schema/schema-general.ts index de6d62e4bdfa0..2148a6ef0b3e8 100644 --- a/frontend/src/queries/schema/schema-general.ts +++ b/frontend/src/queries/schema/schema-general.ts @@ -106,6 +106,7 @@ export enum NodeKind { TeamTaxonomyQuery = 'TeamTaxonomyQuery', EventTaxonomyQuery = 'EventTaxonomyQuery', ActorsPropertyTaxonomyQuery = 'ActorsPropertyTaxonomyQuery', + TracesQuery = 'TracesQuery', } export type AnyDataNode = @@ -181,6 +182,7 @@ export type QuerySchema = | TeamTaxonomyQuery | EventTaxonomyQuery | ActorsPropertyTaxonomyQuery + | TracesQuery // Keep this, because QuerySchema itself will be collapsed as it is used in other models export type QuerySchemaRoot = QuerySchema @@ -557,6 +559,7 @@ export interface EventsQueryPersonColumn { } distinct_id: string } + export interface EventsQuery extends DataNode { kind: NodeKind.EventsQuery /** Return a limited set of data. Required. */ @@ -2148,3 +2151,57 @@ export enum DefaultChannelTypes { Affiliate = 'Affiliate', Unknown = 'Unknown', } + +export interface LLMGeneration { + id: string + createdAt: string + input: any[] + latency: number + output?: any + provider?: string + model?: string + inputTokens?: number + outputTokens?: number + inputCost?: number + outputCost?: number + totalCost?: number + httpStatus?: number + baseUrl?: string +} + +export interface LLMTracePerson { + uuid: string + createdAt: string + properties: Record + distinctId: string +} + +export interface LLMTrace { + id: string + createdAt: string + person: LLMTracePerson + totalLatency: number + inputTokens: number + outputTokens: number + inputCost: number + outputCost: number + totalCost: number + events: LLMGeneration[] +} + +export interface TracesQueryResponse extends AnalyticsQueryResponseBase { + hasMore?: boolean + limit?: integer + offset?: integer + columns?: string[] +} + +export interface TracesQuery extends DataNode { + kind: NodeKind.TracesQuery + traceId?: string + dateRange?: DateRange + limit?: integer + offset?: integer +} + +export type CachedTracesQueryResponse = CachedQueryResponse diff --git a/frontend/src/scenes/saved-insights/SavedInsights.tsx b/frontend/src/scenes/saved-insights/SavedInsights.tsx index 8ebaa5a400c8d..3e910396fd7c9 100644 --- a/frontend/src/scenes/saved-insights/SavedInsights.tsx +++ b/frontend/src/scenes/saved-insights/SavedInsights.tsx @@ -1,6 +1,7 @@ import './SavedInsights.scss' import { + IconAI, IconBrackets, IconCorrelationAnalysis, IconCursor, @@ -313,6 +314,11 @@ export const QUERY_TYPES_METADATA: Record = { icon: IconHogQL, inMenu: false, }, + [NodeKind.TracesQuery]: { + name: 'LLM Observability Traces', + icon: IconAI, + inMenu: false, + }, } export const INSIGHT_TYPES_METADATA: Record = { diff --git a/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr new file mode 100644 index 0000000000000..c5931f8c340fd --- /dev/null +++ b/posthog/hogql_queries/ai/test/__snapshots__/test_traces_query_runner.ambr @@ -0,0 +1,221 @@ +# serializer version: 1 +# name: TestTracesQueryRunner.test_field_mapping + ''' + SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, + min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, + tuple(max(events__person.id), max(events.distinct_id), max(events__person.created_at), max(events__person.properties)) AS person, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, + arraySort(x -> x.2, groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events + FROM events + LEFT OUTER JOIN + (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, + person_distinct_id_overrides.distinct_id AS distinct_id + FROM person_distinct_id_overrides + WHERE equals(person_distinct_id_overrides.team_id, 99999) + GROUP BY person_distinct_id_overrides.distinct_id + HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) + LEFT JOIN + (SELECT person.id AS id, + toTimeZone(person.created_at, 'UTC') AS created_at, + person.properties AS properties + FROM person + WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), + (SELECT person.id AS id, max(person.version) AS version + FROM person + WHERE equals(person.team_id, 99999) + GROUP BY person.id + HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-08 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-16 00:10:59', 6, 'UTC')))) + GROUP BY id + ORDER BY trace_timestamp DESC + LIMIT 101 + OFFSET 0 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=1, + format_csv_allow_double_quotes=0, + max_ast_elements=4000000, + max_expanded_ast_elements=4000000, + max_bytes_before_external_group_by=0 + ''' +# --- +# name: TestTracesQueryRunner.test_pagination + ''' + SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, + min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, + tuple(max(events__person.id), max(events.distinct_id), max(events__person.created_at), max(events__person.properties)) AS person, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, + arraySort(x -> x.2, groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events + FROM events + LEFT OUTER JOIN + (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, + person_distinct_id_overrides.distinct_id AS distinct_id + FROM person_distinct_id_overrides + WHERE equals(person_distinct_id_overrides.team_id, 99999) + GROUP BY person_distinct_id_overrides.distinct_id + HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) + LEFT JOIN + (SELECT person.id AS id, + toTimeZone(person.created_at, 'UTC') AS created_at, + person.properties AS properties + FROM person + WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), + (SELECT person.id AS id, max(person.version) AS version + FROM person + WHERE equals(person.team_id, 99999) + GROUP BY person.id + HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-08 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-16 00:10:59', 6, 'UTC')))) + GROUP BY id + ORDER BY trace_timestamp DESC + LIMIT 5 + OFFSET 0 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=1, + format_csv_allow_double_quotes=0, + max_ast_elements=4000000, + max_expanded_ast_elements=4000000, + max_bytes_before_external_group_by=0 + ''' +# --- +# name: TestTracesQueryRunner.test_pagination.1 + ''' + SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, + min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, + tuple(max(events__person.id), max(events.distinct_id), max(events__person.created_at), max(events__person.properties)) AS person, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, + arraySort(x -> x.2, groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events + FROM events + LEFT OUTER JOIN + (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, + person_distinct_id_overrides.distinct_id AS distinct_id + FROM person_distinct_id_overrides + WHERE equals(person_distinct_id_overrides.team_id, 99999) + GROUP BY person_distinct_id_overrides.distinct_id + HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) + LEFT JOIN + (SELECT person.id AS id, + toTimeZone(person.created_at, 'UTC') AS created_at, + person.properties AS properties + FROM person + WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), + (SELECT person.id AS id, max(person.version) AS version + FROM person + WHERE equals(person.team_id, 99999) + GROUP BY person.id + HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-08 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-16 00:10:59', 6, 'UTC')))) + GROUP BY id + ORDER BY trace_timestamp DESC + LIMIT 5 + OFFSET 5 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=1, + format_csv_allow_double_quotes=0, + max_ast_elements=4000000, + max_expanded_ast_elements=4000000, + max_bytes_before_external_group_by=0 + ''' +# --- +# name: TestTracesQueryRunner.test_pagination.2 + ''' + SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, + min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, + tuple(max(events__person.id), max(events.distinct_id), max(events__person.created_at), max(events__person.properties)) AS person, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, + arraySort(x -> x.2, groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events + FROM events + LEFT OUTER JOIN + (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, + person_distinct_id_overrides.distinct_id AS distinct_id + FROM person_distinct_id_overrides + WHERE equals(person_distinct_id_overrides.team_id, 99999) + GROUP BY person_distinct_id_overrides.distinct_id + HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) + LEFT JOIN + (SELECT person.id AS id, + toTimeZone(person.created_at, 'UTC') AS created_at, + person.properties AS properties + FROM person + WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), + (SELECT person.id AS id, max(person.version) AS version + FROM person + WHERE equals(person.team_id, 99999) + GROUP BY person.id + HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-08 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-16 00:10:59', 6, 'UTC')))) + GROUP BY id + ORDER BY trace_timestamp DESC + LIMIT 5 + OFFSET 10 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=1, + format_csv_allow_double_quotes=0, + max_ast_elements=4000000, + max_expanded_ast_elements=4000000, + max_bytes_before_external_group_by=0 + ''' +# --- +# name: TestTracesQueryRunner.test_trace_id_filter + ''' + SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_trace_id'), ''), 'null'), '^"|"$', '') AS id, + min(toTimeZone(events.timestamp, 'UTC')) AS trace_timestamp, + tuple(max(events__person.id), max(events.distinct_id), max(events__person.created_at), max(events__person.properties)) AS person, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_latency'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_latency, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_tokens'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_tokens, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_input_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS input_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_output_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS output_cost, + sum(accurateCastOrNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, '$ai_total_cost_usd'), ''), 'null'), '^"|"$', ''), 'Float64')) AS total_cost, + arraySort(x -> x.2, groupArray(tuple(events.uuid, toTimeZone(events.timestamp, 'UTC'), events.properties))) AS events + FROM events + LEFT OUTER JOIN + (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, + person_distinct_id_overrides.distinct_id AS distinct_id + FROM person_distinct_id_overrides + WHERE equals(person_distinct_id_overrides.team_id, 99999) + GROUP BY person_distinct_id_overrides.distinct_id + HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__override ON equals(events.distinct_id, events__override.distinct_id) + LEFT JOIN + (SELECT person.id AS id, + toTimeZone(person.created_at, 'UTC') AS created_at, + person.properties AS properties + FROM person + WHERE and(equals(person.team_id, 99999), ifNull(in(tuple(person.id, person.version), + (SELECT person.id AS id, max(person.version) AS version + FROM person + WHERE equals(person.team_id, 99999) + GROUP BY person.id + HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(toTimeZone(person.created_at, 'UTC'), person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)))), 0)) SETTINGS optimize_aggregation_in_order=1) AS events__person ON equals(if(not(empty(events__override.distinct_id)), events__override.person_id, events.person_id), events__person.id) + WHERE and(equals(events.team_id, 99999), equals(events.event, '$ai_generation'), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-08 23:50:00', 6, 'UTC'))), lessOrEquals(toTimeZone(events.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2025-01-16 00:10:59', 6, 'UTC'))), ifNull(equals(id, 'trace1'), 0)) + GROUP BY id + ORDER BY trace_timestamp DESC + LIMIT 101 + OFFSET 0 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=1, + format_csv_allow_double_quotes=0, + max_ast_elements=4000000, + max_expanded_ast_elements=4000000, + max_bytes_before_external_group_by=0 + ''' +# --- diff --git a/posthog/hogql_queries/ai/test/test_traces_query_runner.py b/posthog/hogql_queries/ai/test/test_traces_query_runner.py new file mode 100644 index 0000000000000..5956ffefc8441 --- /dev/null +++ b/posthog/hogql_queries/ai/test/test_traces_query_runner.py @@ -0,0 +1,437 @@ +import json +import uuid +from datetime import UTC, datetime +from typing import Any, Literal, TypedDict +from uuid import UUID + +from freezegun import freeze_time + +from posthog.hogql_queries.ai.traces_query_runner import TracesQueryRunner +from posthog.models import PropertyDefinition, Team +from posthog.models.property_definition import PropertyType +from posthog.schema import DateRange, LLMGeneration, LLMTrace, TracesQuery +from posthog.test.base import ( + BaseTest, + ClickhouseTestMixin, + _create_event, + _create_person, + snapshot_clickhouse_queries, +) + + +class InputMessage(TypedDict): + role: Literal["user", "assistant"] + content: str + + +class OutputMessage(TypedDict): + role: Literal["user", "assistant", "tool"] + content: str + + +def _calculate_tokens(messages: str | list[InputMessage] | list[OutputMessage]) -> int: + if isinstance(messages, str): + message = messages + else: + message = "".join([message["content"] for message in messages]) + return len(message) + + +def _create_ai_generation_event( + input: str | list[InputMessage] = "Foo", + output: str | list[OutputMessage] = "Bar", + team: Team | None = None, + distinct_id: str | None = None, + trace_id: str | None = None, + properties: dict[str, Any] | None = None, + timestamp: datetime | None = None, + event_uuid: str | UUID | None = None, +): + input_tokens = _calculate_tokens(input) + output_tokens = _calculate_tokens(output) + + if isinstance(input, str): + input_messages: list[InputMessage] = [{"role": "user", "content": input}] + else: + input_messages = input + + if isinstance(output, str): + output_messages: list[OutputMessage] = [{"role": "assistant", "content": output}] + else: + output_messages = output + + props = { + "$ai_trace_id": trace_id or str(uuid.uuid4()), + "$ai_latency": 1, + "$ai_input": json.dumps(input_messages), + "$ai_output": json.dumps({"choices": output_messages}), + "$ai_input_tokens": input_tokens, + "$ai_output_tokens": output_tokens, + "$ai_input_cost_usd": input_tokens, + "$ai_output_cost_usd": output_tokens, + "$ai_total_cost_usd": input_tokens + output_tokens, + } + if properties: + props.update(properties) + + _create_event( + event="$ai_generation", + distinct_id=distinct_id, + properties=props, + team=team, + timestamp=timestamp, + event_uuid=str(event_uuid) if event_uuid else None, + ) + + +class TestTracesQueryRunner(ClickhouseTestMixin, BaseTest): + def setUp(self): + super().setUp() + self._create_properties() + + def _create_properties(self): + numeric_props = { + "$ai_latency", + "$ai_input_tokens", + "$ai_output_tokens", + "$ai_input_cost_usd", + "$ai_output_cost_usd", + "$ai_total_cost_usd", + } + models_to_create = [] + for prop in numeric_props: + prop_model = PropertyDefinition( + team=self.team, + name=prop, + type=PropertyDefinition.Type.EVENT, + property_type=PropertyType.Numeric, + ) + models_to_create.append(prop_model) + PropertyDefinition.objects.bulk_create(models_to_create) + + def assertTraceEqual(self, trace: LLMTrace, expected_trace: dict): + trace_dict = trace.model_dump() + for key, value in expected_trace.items(): + self.assertEqual(trace_dict[key], value, f"Field {key} does not match") + + def assertEventEqual(self, event: LLMGeneration, expected_event: dict): + event_dict = event.model_dump() + for key, value in expected_event.items(): + self.assertEqual(event_dict[key], value, f"Field {key} does not match") + + @freeze_time("2025-01-16T00:00:00Z") + @snapshot_clickhouse_queries + def test_field_mapping(self): + _create_person(distinct_ids=["person1"], team=self.team) + _create_person(distinct_ids=["person2"], team=self.team) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + input="Foo", + output="Bar", + team=self.team, + timestamp=datetime(2025, 1, 15, 0), + ) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + input="Bar", + output="Baz", + team=self.team, + timestamp=datetime(2025, 1, 15, 1), + ) + _create_ai_generation_event( + distinct_id="person2", + trace_id="trace2", + input="Foo", + output="Bar", + team=self.team, + timestamp=datetime(2025, 1, 14), + ) + + response = TracesQueryRunner(team=self.team, query=TracesQuery()).calculate() + self.assertEqual(len(response.results), 2) + + trace = response.results[0] + + self.assertTraceEqual( + trace, + { + "id": "trace1", + "createdAt": datetime(2025, 1, 15, 0, tzinfo=UTC).isoformat(), + "totalLatency": 2.0, + "inputTokens": 6.0, + "outputTokens": 6.0, + "inputCost": 6.0, + "outputCost": 6.0, + "totalCost": 12.0, + }, + ) + self.assertEqual(trace.person.distinctId, "person1") + + self.assertEqual(len(trace.events), 2) + event = trace.events[0] + self.assertIsNotNone(event.id) + self.assertEventEqual( + event, + { + "createdAt": datetime(2025, 1, 15, 0, tzinfo=UTC).isoformat(), + "input": [{"role": "user", "content": "Foo"}], + "output": {"choices": [{"role": "assistant", "content": "Bar"}]}, + "latency": 1, + "inputTokens": 3, + "outputTokens": 3, + "inputCost": 3, + "outputCost": 3, + "totalCost": 6, + }, + ) + + event = trace.events[1] + self.assertIsNotNone(event.id) + self.assertEventEqual( + event, + { + "createdAt": datetime(2025, 1, 15, 1, tzinfo=UTC).isoformat(), + "input": [{"role": "user", "content": "Bar"}], + "output": {"choices": [{"role": "assistant", "content": "Baz"}]}, + "latency": 1, + "inputTokens": 3, + "outputTokens": 3, + "inputCost": 3, + "outputCost": 3, + "totalCost": 6, + "baseUrl": None, + "httpStatus": None, + }, + ) + + trace = response.results[1] + self.assertTraceEqual( + trace, + { + "id": "trace2", + "createdAt": datetime(2025, 1, 14, tzinfo=UTC).isoformat(), + "totalLatency": 1, + "inputTokens": 3, + "outputTokens": 3, + "inputCost": 3, + "outputCost": 3, + "totalCost": 6, + }, + ) + self.assertEqual(trace.person.distinctId, "person2") + self.assertEqual(len(trace.events), 1) + event = trace.events[0] + self.assertIsNotNone(event.id) + self.assertEventEqual( + event, + { + "createdAt": datetime(2025, 1, 14, tzinfo=UTC).isoformat(), + "input": [{"role": "user", "content": "Foo"}], + "output": {"choices": [{"role": "assistant", "content": "Bar"}]}, + "latency": 1, + "inputTokens": 3, + "outputTokens": 3, + "inputCost": 3, + "outputCost": 3, + "totalCost": 6, + "baseUrl": None, + "httpStatus": None, + }, + ) + + @freeze_time("2025-01-16T00:00:00Z") + @snapshot_clickhouse_queries + def test_trace_id_filter(self): + _create_person(distinct_ids=["person1"], team=self.team) + _create_person(distinct_ids=["person2"], team=self.team) + _create_ai_generation_event(distinct_id="person1", trace_id="trace1", team=self.team) + _create_ai_generation_event(distinct_id="person2", trace_id="trace2", team=self.team) + + response = TracesQueryRunner(team=self.team, query=TracesQuery(traceId="trace1")).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace1") + + @freeze_time("2025-01-16T00:00:00Z") + @snapshot_clickhouse_queries + def test_pagination(self): + _create_person(distinct_ids=["person1"], team=self.team) + _create_person(distinct_ids=["person2"], team=self.team) + for i in range(11): + _create_ai_generation_event( + distinct_id="person1" if i % 2 == 0 else "person2", + team=self.team, + trace_id=f"trace_{i}", + timestamp=datetime(2025, 1, 15, i), + ) + response = TracesQueryRunner(team=self.team, query=TracesQuery(limit=4, offset=0)).calculate() + self.assertEqual(response.hasMore, True) + self.assertEqual(len(response.results), 5) + self.assertEqual(response.results[0].id, "trace_10") + self.assertEqual(response.results[1].id, "trace_9") + self.assertEqual(response.results[2].id, "trace_8") + self.assertEqual(response.results[3].id, "trace_7") + self.assertEqual(response.results[4].id, "trace_6") + + response = TracesQueryRunner(team=self.team, query=TracesQuery(limit=4, offset=5)).calculate() + self.assertEqual(response.hasMore, True) + self.assertEqual(len(response.results), 5) + self.assertEqual(response.results[0].id, "trace_5") + self.assertEqual(response.results[1].id, "trace_4") + self.assertEqual(response.results[2].id, "trace_3") + self.assertEqual(response.results[3].id, "trace_2") + self.assertEqual(response.results[4].id, "trace_1") + + response = TracesQueryRunner(team=self.team, query=TracesQuery(limit=4, offset=10)).calculate() + self.assertEqual(response.hasMore, False) + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace_0") + + @freeze_time("2025-01-16T00:00:00Z") + def test_maps_all_fields(self): + _create_person(distinct_ids=["person1"], team=self.team) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + team=self.team, + properties={ + "$ai_latency": 10.5, + "$ai_provider": "posthog", + "$ai_model": "hog-destroyer", + "$ai_http_status": 200, + "$ai_base_url": "https://us.posthog.com", + }, + ) + + response = TracesQueryRunner(team=self.team, query=TracesQuery()).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace1") + self.assertEqual(response.results[0].totalLatency, 10.5) + self.assertEventEqual( + response.results[0].events[0], + { + "latency": 10.5, + "provider": "posthog", + "model": "hog-destroyer", + "httpStatus": 200, + "baseUrl": "https://us.posthog.com", + }, + ) + + @freeze_time("2025-01-01T00:00:00Z") + def test_person_properties(self): + _create_person(distinct_ids=["person1"], team=self.team, properties={"email": "test@posthog.com"}) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + team=self.team, + ) + response = TracesQueryRunner(team=self.team, query=TracesQuery()).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].person.createdAt, "2025-01-01T00:00:00+00:00") + self.assertEqual(response.results[0].person.properties, {"email": "test@posthog.com"}) + self.assertEqual(response.results[0].person.distinctId, "person1") + + @freeze_time("2025-01-16T00:00:00Z") + def test_date_range(self): + _create_person(distinct_ids=["person1"], team=self.team) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + team=self.team, + timestamp=datetime(2025, 1, 15), + ) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace2", + team=self.team, + timestamp=datetime(2024, 12, 1), + ) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace3", + team=self.team, + timestamp=datetime(2024, 11, 1), + ) + + response = TracesQueryRunner( + team=self.team, query=TracesQuery(dateRange=DateRange(date_from="-1m")) + ).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace1") + + response = TracesQueryRunner( + team=self.team, query=TracesQuery(dateRange=DateRange(date_from="-2m")) + ).calculate() + self.assertEqual(len(response.results), 2) + self.assertEqual(response.results[0].id, "trace1") + self.assertEqual(response.results[1].id, "trace2") + + response = TracesQueryRunner( + team=self.team, query=TracesQuery(dateRange=DateRange(date_from="-3m", date_to="-2m")) + ).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace3") + + def test_capture_range(self): + _create_person(distinct_ids=["person1"], team=self.team) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + team=self.team, + timestamp=datetime(2024, 12, 1, 0, 0), + ) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace1", + team=self.team, + timestamp=datetime(2024, 12, 1, 0, 10), + ) + + response = TracesQueryRunner( + team=self.team, + query=TracesQuery(dateRange=DateRange(date_from="2024-12-01T00:00:00Z", date_to="2024-12-01T00:10:00Z")), + ).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace1") + + # Date is after the capture range + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace2", + team=self.team, + timestamp=datetime(2024, 12, 1, 0, 11), + ) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace2", + team=self.team, + timestamp=datetime(2024, 12, 1, 0, 12), + ) + response = TracesQueryRunner( + team=self.team, + query=TracesQuery(dateRange=DateRange(date_from="2024-12-01T00:00:00Z", date_to="2024-12-01T00:10:00Z")), + ).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace1") + + # Date is before the capture range + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace3", + team=self.team, + timestamp=datetime(2024, 11, 30, 23, 59), + ) + _create_ai_generation_event( + distinct_id="person1", + trace_id="trace3", + team=self.team, + timestamp=datetime(2024, 12, 1, 0, 0), + ) + response = TracesQueryRunner( + team=self.team, + query=TracesQuery(dateRange=DateRange(date_from="2024-12-01T00:00:00Z", date_to="2024-12-01T00:10:00Z")), + ).calculate() + self.assertEqual(len(response.results), 1) + self.assertEqual(response.results[0].id, "trace1") diff --git a/posthog/hogql_queries/ai/traces_query_runner.py b/posthog/hogql_queries/ai/traces_query_runner.py new file mode 100644 index 0000000000000..df9a706029688 --- /dev/null +++ b/posthog/hogql_queries/ai/traces_query_runner.py @@ -0,0 +1,239 @@ +from datetime import datetime, timedelta +from functools import cached_property +from typing import cast +from uuid import UUID + +import orjson +import structlog + +from posthog.hogql import ast +from posthog.hogql.constants import LimitContext +from posthog.hogql.parser import parse_select +from posthog.hogql_queries.insights.paginators import HogQLHasMorePaginator +from posthog.hogql_queries.query_runner import QueryRunner +from posthog.hogql_queries.utils.query_date_range import QueryDateRange +from posthog.schema import ( + CachedTracesQueryResponse, + IntervalType, + LLMGeneration, + LLMTrace, + LLMTracePerson, + NodeKind, + TracesQuery, + TracesQueryResponse, +) + +logger = structlog.get_logger(__name__) + + +class TracesQueryDateRange(QueryDateRange): + """ + Extends the QueryDateRange to include a capture range of 10 minutes before and after the date range. + It's a naive assumption that a trace finishes generating within 10 minutes of the first event so we can apply the date filters. + """ + + CAPTURE_RANGE_MINUTES = 10 + + def date_from_for_filtering(self) -> datetime: + return super().date_from() + + def date_to_for_filtering(self) -> datetime: + return super().date_to() + + def date_from(self) -> datetime: + return super().date_from() - timedelta(minutes=self.CAPTURE_RANGE_MINUTES) + + def date_to(self) -> datetime: + return super().date_to() + timedelta(minutes=self.CAPTURE_RANGE_MINUTES) + + +class TracesQueryRunner(QueryRunner): + query: TracesQuery + response: TracesQueryResponse + cached_response: CachedTracesQueryResponse + paginator: HogQLHasMorePaginator + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.paginator = HogQLHasMorePaginator.from_limit_context( + limit_context=LimitContext.QUERY, + limit=self.query.limit if self.query.limit else None, + offset=self.query.offset, + ) + + def to_query(self) -> ast.SelectQuery: + query = self._get_event_query() + query.where = self._get_where_clause() + return query + + def calculate(self): + with self.timings.measure("traces_query_hogql_execute"): + query_result = self.paginator.execute_hogql_query( + query=self.to_query(), + team=self.team, + query_type=NodeKind.TRACES_QUERY, + timings=self.timings, + modifiers=self.modifiers, + limit_context=self.limit_context, + ) + + columns: list[str] = query_result.columns or [] + results = self._map_results(columns, query_result.results) + + return TracesQueryResponse( + columns=columns, + results=results, + timings=query_result.timings, + hogql=query_result.hogql, + modifiers=self.modifiers, + **self.paginator.response_params(), + ) + + @cached_property + def _date_range(self): + # Minute-level precision for 10m capture range + return TracesQueryDateRange(self.query.dateRange, self.team, IntervalType.MINUTE, datetime.now()) + + def _map_results(self, columns: list[str], query_results: list): + TRACE_FIELDS_MAPPING = { + "id": "id", + "created_at": "createdAt", + "person": "person", + "total_latency": "totalLatency", + "input_tokens": "inputTokens", + "output_tokens": "outputTokens", + "input_cost": "inputCost", + "output_cost": "outputCost", + "total_cost": "totalCost", + "events": "events", + } + mapped_results = [dict(zip(columns, value)) for value in query_results] + traces = [] + + for result in mapped_results: + # Exclude traces that are outside of the capture range. + timestamp_dt = cast(datetime, result["trace_timestamp"]) + if ( + timestamp_dt < self._date_range.date_from_for_filtering() + or timestamp_dt > self._date_range.date_to_for_filtering() + ): + continue + + generations = [] + for uuid, timestamp, properties in result["events"]: + generations.append(self._map_generation(uuid, timestamp, properties)) + + trace_dict = { + **result, + "created_at": timestamp_dt.isoformat(), + "person": self._map_person(result["person"]), + "events": generations, + } + # Remap keys from snake case to camel case + trace = LLMTrace.model_validate( + {TRACE_FIELDS_MAPPING[key]: value for key, value in trace_dict.items() if key in TRACE_FIELDS_MAPPING} + ) + traces.append(trace) + + return traces + + def _map_generation(self, event_uuid: UUID, event_timestamp: datetime, event_properties: str) -> LLMGeneration: + properties: dict = orjson.loads(event_properties) + + GENERATION_MAPPING = { + "$ai_input": "input", + "$ai_latency": "latency", + "$ai_output": "output", + "$ai_provider": "provider", + "$ai_model": "model", + "$ai_input_tokens": "inputTokens", + "$ai_output_tokens": "outputTokens", + "$ai_input_cost_usd": "inputCost", + "$ai_output_cost_usd": "outputCost", + "$ai_total_cost_usd": "totalCost", + "$ai_http_status": "httpStatus", + "$ai_base_url": "baseUrl", + } + GENERATION_JSON_FIELDS = {"$ai_input", "$ai_output"} + + generation = { + "id": str(event_uuid), + "createdAt": event_timestamp.isoformat(), + } + + for event_prop, model_prop in GENERATION_MAPPING.items(): + if event_prop in properties: + if event_prop in GENERATION_JSON_FIELDS: + generation[model_prop] = orjson.loads(properties[event_prop]) + else: + generation[model_prop] = properties[event_prop] + + return LLMGeneration.model_validate(generation) + + def _map_person(self, person: tuple[UUID, UUID, datetime, str]) -> LLMTracePerson: + uuid, distinct_id, created_at, properties = person + return LLMTracePerson( + uuid=str(uuid), + distinctId=str(distinct_id), + createdAt=created_at.isoformat(), + properties=orjson.loads(properties) if properties else {}, + ) + + def _get_event_query(self) -> ast.SelectQuery: + query = parse_select( + """ + SELECT + properties.$ai_trace_id as id, + min(timestamp) as trace_timestamp, + tuple(max(person.id), max(distinct_id), max(person.created_at), max(person.properties)) as person, + sum(properties.$ai_latency) as total_latency, + sum(properties.$ai_input_tokens) as input_tokens, + sum(properties.$ai_output_tokens) as output_tokens, + sum(properties.$ai_input_cost_usd) as input_cost, + sum(properties.$ai_output_cost_usd) as output_cost, + sum(properties.$ai_total_cost_usd) as total_cost, + arraySort(x -> x.2, groupArray(tuple(uuid, timestamp, properties))) as events + FROM + events + GROUP BY + id + ORDER BY + trace_timestamp DESC + """ + ) + return cast(ast.SelectQuery, query) + + def _get_where_clause(self): + timestamp_field = ast.Field(chain=["events", "timestamp"]) + + exprs: list[ast.Expr] = [ + ast.CompareOperation( + left=ast.Field(chain=["event"]), + op=ast.CompareOperationOp.Eq, + right=ast.Constant(value="$ai_generation"), + ), + ast.CompareOperation( + op=ast.CompareOperationOp.GtEq, + left=timestamp_field, + right=self._date_range.date_from_as_hogql(), + ), + ast.CompareOperation( + op=ast.CompareOperationOp.LtEq, + left=timestamp_field, + right=self._date_range.date_to_as_hogql(), + ), + ] + + if self.query.traceId is not None: + exprs.append( + ast.CompareOperation( + left=ast.Field(chain=["id"]), + op=ast.CompareOperationOp.Eq, + right=ast.Constant(value=self.query.traceId), + ), + ) + + return ast.And(exprs=exprs) + + def _get_order_by_clause(self): + return [ast.OrderExpr(expr=ast.Field(chain=["trace_timestamp"]), order="DESC")] diff --git a/posthog/hogql_queries/query_runner.py b/posthog/hogql_queries/query_runner.py index abcaddde2d3cf..608acc282968f 100644 --- a/posthog/hogql_queries/query_runner.py +++ b/posthog/hogql_queries/query_runner.py @@ -54,6 +54,7 @@ StickinessQuery, SuggestedQuestionsQuery, TeamTaxonomyQuery, + TracesQuery, TrendsQuery, WebGoalsQuery, WebOverviewQuery, @@ -421,6 +422,16 @@ def get_query_runner( limit_context=limit_context, modifiers=modifiers, ) + if kind == "TracesQuery": + from .ai.traces_query_runner import TracesQueryRunner + + return TracesQueryRunner( + query=cast(TracesQuery | dict[str, Any], query), + team=team, + timings=timings, + limit_context=limit_context, + modifiers=modifiers, + ) raise ValueError(f"Can't get a runner for an unknown query kind: {kind}") diff --git a/posthog/schema.py b/posthog/schema.py index 53da7ed1a0b9b..ce9dc633f47a1 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -966,6 +966,36 @@ class IntervalType(StrEnum): MONTH = "month" +class LLMGeneration(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + baseUrl: Optional[str] = None + createdAt: str + httpStatus: Optional[float] = None + id: str + input: list + inputCost: Optional[float] = None + inputTokens: Optional[float] = None + latency: float + model: Optional[str] = None + output: Optional[Any] = None + outputCost: Optional[float] = None + outputTokens: Optional[float] = None + provider: Optional[str] = None + totalCost: Optional[float] = None + + +class LLMTracePerson(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + createdAt: str + distinctId: str + properties: dict[str, Any] + uuid: str + + class LifecycleToggle(StrEnum): NEW = "new" RESURRECTING = "resurrecting" @@ -1029,6 +1059,7 @@ class NodeKind(StrEnum): TEAM_TAXONOMY_QUERY = "TeamTaxonomyQuery" EVENT_TAXONOMY_QUERY = "EventTaxonomyQuery" ACTORS_PROPERTY_TAXONOMY_QUERY = "ActorsPropertyTaxonomyQuery" + TRACES_QUERY = "TracesQuery" class PathCleaningFilter(BaseModel): @@ -2184,6 +2215,22 @@ class InsightThreshold(BaseModel): type: InsightThresholdType +class LLMTrace(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + createdAt: str + events: list[LLMGeneration] + id: str + inputCost: float + inputTokens: float + outputCost: float + outputTokens: float + person: LLMTracePerson + totalCost: float + totalLatency: float + + class LifecycleFilter(BaseModel): model_config = ConfigDict( extra="forbid", @@ -2647,6 +2694,31 @@ class TestCachedBasicQueryResponse(BaseModel): ) +class TracesQueryResponse(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + columns: Optional[list[str]] = None + error: Optional[str] = Field( + default=None, + description="Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + ) + hasMore: Optional[bool] = None + hogql: Optional[str] = Field(default=None, description="Generated HogQL query.") + limit: Optional[int] = None + modifiers: Optional[HogQLQueryModifiers] = Field( + default=None, description="Modifiers used when performing the query" + ) + offset: Optional[int] = None + query_status: Optional[QueryStatus] = Field( + default=None, description="Query status indicates whether next to the provided data, a query is still running." + ) + results: list[LLMTrace] + timings: Optional[list[QueryTiming]] = Field( + default=None, description="Measured timings for different parts of the query generation process" + ) + + class TrendsAlertConfig(BaseModel): model_config = ConfigDict( extra="forbid", @@ -3636,6 +3708,40 @@ class CachedTeamTaxonomyQueryResponse(BaseModel): ) +class CachedTracesQueryResponse(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + cache_key: str + cache_target_age: Optional[AwareDatetime] = None + calculation_trigger: Optional[str] = Field( + default=None, description="What triggered the calculation of the query, leave empty if user/immediate" + ) + columns: Optional[list[str]] = None + error: Optional[str] = Field( + default=None, + description="Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + ) + hasMore: Optional[bool] = None + hogql: Optional[str] = Field(default=None, description="Generated HogQL query.") + is_cached: bool + last_refresh: AwareDatetime + limit: Optional[int] = None + modifiers: Optional[HogQLQueryModifiers] = Field( + default=None, description="Modifiers used when performing the query" + ) + next_allowed_client_refresh: AwareDatetime + offset: Optional[int] = None + query_status: Optional[QueryStatus] = Field( + default=None, description="Query status indicates whether next to the provided data, a query is still running." + ) + results: list[LLMTrace] + timezone: str + timings: Optional[list[QueryTiming]] = Field( + default=None, description="Measured timings for different parts of the query generation process" + ) + + class CachedTrendsQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", @@ -5264,6 +5370,31 @@ class QueryResponseAlternative39(BaseModel): ) +class QueryResponseAlternative40(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + columns: Optional[list[str]] = None + error: Optional[str] = Field( + default=None, + description="Query error. Returned only if 'explain' or `modifiers.debug` is true. Throws an error otherwise.", + ) + hasMore: Optional[bool] = None + hogql: Optional[str] = Field(default=None, description="Generated HogQL query.") + limit: Optional[int] = None + modifiers: Optional[HogQLQueryModifiers] = Field( + default=None, description="Modifiers used when performing the query" + ) + offset: Optional[int] = None + query_status: Optional[QueryStatus] = Field( + default=None, description="Query status indicates whether next to the provided data, a query is still running." + ) + results: list[LLMTrace] + timings: Optional[list[QueryTiming]] = Field( + default=None, description="Measured timings for different parts of the query generation process" + ) + + class RecordingsQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", @@ -5335,6 +5466,21 @@ class TeamTaxonomyQueryResponse(BaseModel): ) +class TracesQuery(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + dateRange: Optional[DateRange] = None + kind: Literal["TracesQuery"] = "TracesQuery" + limit: Optional[int] = None + modifiers: Optional[HogQLQueryModifiers] = Field( + default=None, description="Modifiers used when performing the query" + ) + offset: Optional[int] = None + response: Optional[TracesQueryResponse] = None + traceId: Optional[str] = None + + class VisualizationMessage(BaseModel): model_config = ConfigDict( extra="forbid", @@ -6782,6 +6928,7 @@ class QueryResponseAlternative( QueryResponseAlternative37, QueryResponseAlternative38, QueryResponseAlternative39, + QueryResponseAlternative40, ] ] ): @@ -6821,6 +6968,7 @@ class QueryResponseAlternative( QueryResponseAlternative37, QueryResponseAlternative38, QueryResponseAlternative39, + QueryResponseAlternative40, ] @@ -7220,6 +7368,7 @@ class QueryRequest(BaseModel): TeamTaxonomyQuery, EventTaxonomyQuery, ActorsPropertyTaxonomyQuery, + TracesQuery, ] = Field( ..., description=( @@ -7287,6 +7436,7 @@ class QuerySchemaRoot( TeamTaxonomyQuery, EventTaxonomyQuery, ActorsPropertyTaxonomyQuery, + TracesQuery, ] ] ): @@ -7328,6 +7478,7 @@ class QuerySchemaRoot( TeamTaxonomyQuery, EventTaxonomyQuery, ActorsPropertyTaxonomyQuery, + TracesQuery, ] = Field(..., discriminator="kind")