Skip to content

Commit

Permalink
fix(experiments): Fix a couple of issues in the ASOF LEFT JOIN (#26886)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbachhuber authored Dec 13, 2024
1 parent 91f5309 commit d4cc10e
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,69 @@ def create_data_warehouse_table_with_payments(self):
)
return table_name

def create_data_warehouse_table_with_usage(self):
if not OBJECT_STORAGE_ACCESS_KEY_ID or not OBJECT_STORAGE_SECRET_ACCESS_KEY:
raise Exception("Missing vars")

fs = s3fs.S3FileSystem(
client_kwargs={
"region_name": "us-east-1",
"endpoint_url": OBJECT_STORAGE_ENDPOINT,
"aws_access_key_id": OBJECT_STORAGE_ACCESS_KEY_ID,
"aws_secret_access_key": OBJECT_STORAGE_SECRET_ACCESS_KEY,
},
)

path_to_s3_object = "s3://" + OBJECT_STORAGE_BUCKET + f"/{TEST_BUCKET}"

id = pa.array(["1", "2", "3", "4", "5"])
date = pa.array(["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-06", "2023-01-07"])
user_id = pa.array(["user_control_0", "user_test_1", "user_test_2", "user_test_3", "user_extra"])
usage = pa.array([1000, 500, 750, 800, 900])
names = ["id", "ds", "userid", "usage"]

pq.write_to_dataset(
pa.Table.from_arrays([id, date, user_id, usage], names=names),
path_to_s3_object,
filesystem=fs,
use_dictionary=True,
compression="snappy",
version="2.0",
)

table_name = "usage"

credential = DataWarehouseCredential.objects.create(
access_key=OBJECT_STORAGE_ACCESS_KEY_ID,
access_secret=OBJECT_STORAGE_SECRET_ACCESS_KEY,
team=self.team,
)

DataWarehouseTable.objects.create(
name=table_name,
url_pattern=f"http://host.docker.internal:19000/{OBJECT_STORAGE_BUCKET}/{TEST_BUCKET}/*.parquet",
format=DataWarehouseTable.TableFormat.Parquet,
team=self.team,
columns={
"id": "String",
"ds": "Date",
"userid": "String",
"usage": "Int64",
},
credential=credential,
)

DataWarehouseJoin.objects.create(
team=self.team,
source_table_name=table_name,
source_table_key="userid",
joining_table_name="events",
joining_table_key="properties.$user_id",
field_name="events",
configuration={"experiments_optimized": True, "experiments_timestamp_key": "ds"},
)
return table_name

@freeze_time("2020-01-01T12:00:00Z")
def test_query_runner(self):
feature_flag = self.create_feature_flag()
Expand Down Expand Up @@ -694,6 +757,128 @@ def test_query_runner_with_data_warehouse_series_avg_amount(self):
[0.0, 50.0, 125.0, 125.0, 125.0, 205.0, 205.0, 205.0, 205.0, 205.0],
)

def test_query_runner_with_data_warehouse_series_no_end_date_and_nested_id(self):
table_name = self.create_data_warehouse_table_with_usage()

feature_flag = self.create_feature_flag()
experiment = self.create_experiment(
feature_flag=feature_flag,
start_date=datetime(2023, 1, 1),
)

feature_flag_property = f"$feature/{feature_flag.key}"

count_query = TrendsQuery(
series=[
DataWarehouseNode(
id=table_name,
distinct_id_field="userid",
id_field="id",
table_name=table_name,
timestamp_field="ds",
math="avg",
math_property="usage",
math_property_type="data_warehouse_properties",
)
]
)
exposure_query = TrendsQuery(series=[EventsNode(event="$feature_flag_called")])

experiment_query = ExperimentTrendsQuery(
experiment_id=experiment.id,
kind="ExperimentTrendsQuery",
count_query=count_query,
exposure_query=exposure_query,
)

experiment.metrics = [{"type": "primary", "query": experiment_query.model_dump()}]
experiment.save()

# Populate exposure events
for variant, count in [("control", 7), ("test", 9)]:
for i in range(count):
_create_event(
team=self.team,
event="$feature_flag_called",
distinct_id=f"distinct_{variant}_{i}",
properties={feature_flag_property: variant, "$user_id": f"user_{variant}_{i}"},
timestamp=datetime(2023, 1, i + 1),
)

# "user_test_3" first exposure (feature_flag_property="control") is on 2023-01-03
# "user_test_3" relevant exposure (feature_flag_property="test") is on 2023-01-04
# "user_test_3" other event (feature_flag_property="control" is on 2023-01-05
# "user_test_3" purchase is on 2023-01-06
# "user_test_3" second exposure (feature_flag_property="control") is on 2023-01-09
# "user_test_3" should fall into the "test" variant, not the "control" variant
_create_event(
team=self.team,
event="$feature_flag_called",
distinct_id="distinct_test_3",
properties={feature_flag_property: "control", "$user_id": "user_test_3"},
timestamp=datetime(2023, 1, 3),
)
_create_event(
team=self.team,
event="Some other event",
distinct_id="distinct_test_3",
properties={feature_flag_property: "control", "$user_id": "user_test_3"},
timestamp=datetime(2023, 1, 5),
)
_create_event(
team=self.team,
event="$feature_flag_called",
distinct_id="distinct_test_3",
properties={feature_flag_property: "control", "$user_id": "user_test_3"},
timestamp=datetime(2023, 1, 9),
)

flush_persons_and_events()

query_runner = ExperimentTrendsQueryRunner(
query=ExperimentTrendsQuery(**experiment.metrics[0]["query"]), team=self.team
)
with freeze_time("2023-01-07"):
# Build and execute the query to get the ClickHouse SQL
queries = query_runner.count_query_runner.to_queries()
response = execute_hogql_query(
query_type="TrendsQuery",
query=queries[0],
team=query_runner.count_query_runner.team,
modifiers=query_runner.count_query_runner.modifiers,
limit_context=query_runner.count_query_runner.limit_context,
)

# Assert the expected join condition in the clickhouse SQL
expected_join_condition = f"and(equals(events.team_id, {query_runner.count_query_runner.team.id}), equals(event, %(hogql_val_8)s), greaterOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_9)s, 6, %(hogql_val_10)s))), lessOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_11)s, 6, %(hogql_val_12)s))))) AS e__events ON"
self.assertIn(expected_join_condition, str(response.clickhouse))

result = query_runner.calculate()

trend_result = cast(ExperimentTrendsQueryResponse, result)

self.assertEqual(len(result.variants), 2)

control_result = next(variant for variant in trend_result.variants if variant.key == "control")
test_result = next(variant for variant in trend_result.variants if variant.key == "test")

control_insight = next(variant for variant in trend_result.insight if variant["breakdown_value"] == "control")
test_insight = next(variant for variant in trend_result.insight if variant["breakdown_value"] == "test")

self.assertEqual(control_result.count, 1000)
self.assertEqual(test_result.count, 2050)
self.assertEqual(control_result.absolute_exposure, 1)
self.assertEqual(test_result.absolute_exposure, 3)

self.assertEqual(
control_insight["data"][:10],
[1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0],
)
self.assertEqual(
test_insight["data"][:10],
[0.0, 500.0, 1250.0, 1250.0, 1250.0, 2050.0, 2050.0, 2050.0, 2050.0, 2050.0],
)

def test_query_runner_with_data_warehouse_series_expected_query(self):
table_name = self.create_data_warehouse_table_with_payments()

Expand Down
5 changes: 3 additions & 2 deletions posthog/warehouse/models/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def _join_function_for_experiments(
for expr in node.where.exprs:
if isinstance(expr, ast.CompareOperation):
if expr.op == ast.CompareOperationOp.GtEq or expr.op == ast.CompareOperationOp.LtEq:
if isinstance(expr.left, ast.Alias) and expr.left.expr.to_hogql() == timestamp_key:
# Match within hogql string because it could be 'toDateTime(timestamp)'
if isinstance(expr.left, ast.Alias) and timestamp_key in expr.left.expr.to_hogql():
whereExpr.append(
ast.CompareOperation(
op=expr.op, left=ast.Field(chain=["timestamp"]), right=expr.right
Expand Down Expand Up @@ -183,7 +184,7 @@ def _join_function_for_experiments(
]
),
op=ast.CompareOperationOp.Eq,
right=ast.Field(chain=[join_to_add.to_table, "distinct_id"]),
right=ast.Field(chain=[join_to_add.to_table, *self.joining_table_key.split(".")]),
),
ast.CompareOperation(
left=ast.Field(
Expand Down

0 comments on commit d4cc10e

Please sign in to comment.