Skip to content

Commit

Permalink
make it a bit more secure
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusandra committed Feb 3, 2025
1 parent ad7a6cd commit 76a66c2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
12 changes: 12 additions & 0 deletions posthog/hogql/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,15 @@ def test_deserialize_hx_ast_error(self):
}
)
self.assertEqual(str(e.exception), "Invalid or missing '__hx_ast' kind: Invalid")

with self.assertRaises(ValueError) as e:
deserialize_hx_ast(
{
"__hx_ast": "Call",
"name": "hello",
"args": ["invalid type"],
}
)
self.assertEqual(
str(e.exception), "Invalid type for field 'args' in AST node 'Call'. Expected 'Expr', got 'str'"
)
32 changes: 25 additions & 7 deletions posthog/hogql/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import fields
from typing import Any, Union, get_args, get_origin

from posthog.hogql.ast import AST_CLASSES, AST, Constant
from posthog.hogql.ast import AST_CLASSES, AST


def unwrap_optional(t):
Expand All @@ -21,6 +21,16 @@ def is_ast_subclass(t):
return isinstance(t, type) and issubclass(t, AST)


def is_simple_value(value: Any) -> bool:
if isinstance(value, int) or isinstance(value, float) or isinstance(value, str) or isinstance(value, bool):
return True
if isinstance(value, list):
return all(is_simple_value(item) for item in value)
if isinstance(value, dict):
return all(isinstance(key, str) and is_simple_value(val) for key, val in value.items())
return False


def deserialize_hx_ast(hog_ast: dict) -> AST:
kind = hog_ast.get("__hx_ast", None)
if kind is None or kind not in AST_CLASSES:
Expand All @@ -37,21 +47,29 @@ def deserialize_hx_ast(hog_ast: dict) -> AST:
if isinstance(value, dict) and "__hx_ast" in value:
init_args[key] = deserialize_hx_ast(value)
elif isinstance(value, list):
field_type = unwrap_list(cls_fields[key])
init_args[key] = []
for item in value:
if isinstance(item, dict) and "__hx_ast" in item:
init_args[key].append(deserialize_hx_ast(item))
elif is_ast_subclass(field_type):
init_args[key].append(Constant(value=item))
else:
elif is_simple_value(item):
field_type = unwrap_list(cls_fields[key])
if is_ast_subclass(field_type):
raise ValueError(
f"Invalid type for field '{key}' in AST node '{kind}'. Expected '{field_type.__name__}', got '{type(item).__name__}'"
)
init_args[key].append(item)
else:
raise ValueError(f"Unexpected value for field '{key}' in AST node '{kind}'")
else:
field_type = unwrap_optional(cls_fields[key])
if is_ast_subclass(field_type):
init_args[key] = Constant(value=value)
else:
raise ValueError(
f"Invalid type for field '{key}' in AST node '{kind}'. Expected {field_type}, got {type(item)}"
)
elif is_simple_value(value):
init_args[key] = value
else:
raise ValueError(f"Unexpected value for field '{key}' in AST node '{kind}'")
else:
raise ValueError(f"Unexpected field '{key}' for AST node '{kind}'")

Expand Down

0 comments on commit 76a66c2

Please sign in to comment.