From a9374ed3407158a3a56bdc92db1eee0f6abe9511 Mon Sep 17 00:00:00 2001 From: sMezaOrellana Date: Thu, 31 Aug 2023 14:37:30 +0200 Subject: [PATCH] Rewrite expression parser to support more complex expressions (#37) Co-authored-by: Schamper <1254028+Schamper@users.noreply.github.com> --- dissect/cstruct/exceptions.py | 8 + dissect/cstruct/expression.py | 329 ++++++++++++++++++++++++++++------ tests/test_expression.py | 39 +++- tests/test_struct.py | 2 +- 4 files changed, 318 insertions(+), 60 deletions(-) diff --git a/dissect/cstruct/exceptions.py b/dissect/cstruct/exceptions.py index fb5208c..e899d7a 100644 --- a/dissect/cstruct/exceptions.py +++ b/dissect/cstruct/exceptions.py @@ -16,3 +16,11 @@ class NullPointerDereference(Error): class ArraySizeError(Error): pass + + +class ExpressionParserError(Error): + pass + + +class ExpressionTokenizerError(Error): + pass diff --git a/dissect/cstruct/expression.py b/dissect/cstruct/expression.py index 1cb03f5..53e1cf2 100644 --- a/dissect/cstruct/expression.py +++ b/dissect/cstruct/expression.py @@ -1,84 +1,301 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict +import string +from typing import TYPE_CHECKING, Callable, Optional, Union + +from dissect.cstruct.exceptions import ExpressionParserError, ExpressionTokenizerError if TYPE_CHECKING: from dissect.cstruct import cstruct +HEXBIN_SUFFIX = {"x", "X", "b", "B"} + + +class ExpressionTokenizer: + def __init__(self, expression: str): + self.expression = expression + self.pos = 0 + self.tokens = [] + + def equal(self, token: str, expected: Union[str, str[str]]) -> bool: + if isinstance(expected, set): + return token in expected + else: + return token == expected + + def alnum(self, token: str) -> bool: + return token.isalnum() + + def alpha(self, token: str) -> bool: + return token.isalpha() + + def digit(self, token: str) -> bool: + return token.isdigit() + + def hexdigit(self, token: str) -> bool: + return token in string.hexdigits + + def operator(self, token: str) -> bool: + return token in {"*", "/", "+", "-", "%", "&", "^", "|", "(", ")", "~"} + + def match( + self, + func: Optional[Callable[[str], bool]] = None, + expected: Optional[str] = None, + consume: bool = True, + append: bool = True, + ) -> bool: + if self.eol(): + return False + + token = self.get_token() + + if expected and self.equal(token, expected): + if append: + self.tokens.append(token) + if consume: + self.consume() + return True + + if func and func(token): + if append: + self.tokens.append(token) + if consume: + self.consume() + return True + + return False + + def consume(self) -> None: + self.pos += 1 + + def eol(self) -> bool: + return self.pos >= len(self.expression) + + def get_token(self) -> str: + if self.eol(): + raise ExpressionTokenizerError(f"Out of bounds index: {self.pos}, length: {len(self.expression)}") + return self.expression[self.pos] + + def tokenize(self) -> list[str]: + token = "" + + # Loop over expression runs in linear time + while not self.eol(): + # If token is a single character operand add it to tokens + if self.match(self.operator): + continue + + # If token is a single digit, keep looping over expression and build the number + elif self.match(self.digit, consume=False, append=False): + token += self.get_token() + self.consume() + + # Support for binary and hexadecimal notation + if self.match(expected=HEXBIN_SUFFIX, consume=False, append=False): + token += self.get_token() + self.consume() + + while self.match(self.hexdigit, consume=False, append=False): + token += self.get_token() + self.consume() + if self.eol(): + break + + # Checks for suffixes in numbers + if self.match(expected={"u", "U"}, consume=False, append=False): + self.consume() + self.match(expected={"l", "L"}, append=False) + self.match(expected={"l", "L"}, append=False) + + elif self.match(expected={"l", "L"}, append=False): + self.match(expected={"l", "L"}, append=False) + self.match(expected={"u", "U"}, append=False) + else: + pass + + # Number cannot end on x or b in the case of binary or hexadecimal notation + if len(token) == 2 and token[-1] in HEXBIN_SUFFIX: + raise ExpressionTokenizerError("Invalid binary or hex notation") + + if len(token) > 1 and token[0] == "0" and token[1] not in HEXBIN_SUFFIX: + token = token[:1] + "o" + token[1:] + self.tokens.append(token) + token = "" + + # If token is alpha or underscore we need to build the identifier + elif self.match(self.alpha, consume=False, append=False) or self.match( + expected="_", consume=False, append=False + ): + while self.match(self.alnum, consume=False, append=False) or self.match( + expected="_", consume=False, append=False + ): + token += self.get_token() + self.consume() + if self.eol(): + break + self.tokens.append(token) + token = "" + # If token is length 2 operand make sure next character is part of length 2 operand append to tokens + elif self.match(expected=">", append=False) and self.match(expected=">", append=False): + self.tokens.append(">>") + elif self.match(expected="<", append=False) and self.match(expected="<", append=False): + self.tokens.append("<<") + elif self.match(expected=" ", append=False): + continue + else: + raise ExpressionTokenizerError( + f"Tokenizer does not recognize following token '{self.expression[self.pos]}'" + ) + return self.tokens + + class Expression: - """Expression parser for simple calculations in definitions.""" - - operators = [ - ("*", lambda a, b: a * b), - ("/", lambda a, b: a // b), - ("%", lambda a, b: a % b), - ("+", lambda a, b: a + b), - ("-", lambda a, b: a - b), - (">>", lambda a, b: a >> b), - ("<<", lambda a, b: a << b), - ("&", lambda a, b: a & b), - ("^", lambda a, b: a ^ b), - ("|", lambda a, b: a | b), - ] + """Expression parser for calculations in definitions.""" + + operators = { + "|": lambda a, b: a | b, + "^": lambda a, b: a ^ b, + "&": lambda a, b: a & b, + "<<": lambda a, b: a << b, + ">>": lambda a, b: a >> b, + "+": lambda a, b: a + b, + "-": lambda a, b: a - b, + "*": lambda a, b: a * b, + "/": lambda a, b: a // b, + "%": lambda a, b: a % b, + "u": lambda a: -a, + "~": lambda a: ~a, + } + + precedence_levels = { + "|": 0, + "^": 1, + "&": 2, + "<<": 3, + ">>": 3, + "+": 4, + "-": 4, + "*": 5, + "/": 5, + "%": 5, + "u": 6, + "~": 6, + "sizeof": 6, + } def __init__(self, cstruct: cstruct, expression: str): self.cstruct = cstruct self.expression = expression + self.tokens = ExpressionTokenizer(expression).tokenize() + self.stack = [] + self.queue = [] def __repr__(self) -> str: return self.expression - def evaluate(self, context: Dict[str, int] = None) -> int: - context = context or {} - levels = [] - buf = "" + def precedence(self, o1: str, o2: str) -> bool: + return self.precedence_levels[o1] >= self.precedence_levels[o2] - for i in range(len(self.expression)): - if self.expression[i] == "(": - levels.append(buf) - buf = "" - continue + def evaluate_exp(self) -> None: + operator = self.stack.pop(-1) + res = 0 - if self.expression[i] == ")": - if levels[-1] == "sizeof": - value = len(self.cstruct.resolve(buf)) - levels[-1] = "" - else: - value = self.evaluate_part(buf, context) - buf = levels.pop() - buf += str(value) - continue + if len(self.queue) < 1: + raise ExpressionParserError("Invalid expression: not enough operands") + + right = self.queue.pop(-1) + if operator in ("u", "~"): + res = self.operators[operator](right) + else: + if len(self.queue) < 1: + raise ExpressionParserError("Invalid expression: not enough operands") + + left = self.queue.pop(-1) + res = self.operators[operator](left, right) - buf += self.expression[i] + self.queue.append(res) + + def is_number(self, token: str) -> bool: + return token.isnumeric() or (len(token) > 2 and token[0] == "0" and token[1] in ("x", "X", "b", "B", "o", "O")) + + def evaluate(self, context: Optional[dict[str, int]] = None) -> int: + """Evaluates an expression using a Shunting-Yard implementation.""" + + self.stack = [] + self.queue = [] + operators = set(self.operators.keys()) + + context = context or {} + tmp_expression = self.tokens - return self.evaluate_part(buf, context) + # Unary minus tokens; we change the semantic of '-' depending on the previous token + for i in range(len(self.tokens)): + if self.tokens[i] == "-": + if i == 0: + self.tokens[i] = "u" + continue + if self.tokens[i - 1] in operators or self.tokens[i - 1] == "u" or self.tokens[i - 1] == "(": + self.tokens[i] = "u" + continue - def evaluate_part(self, buf: str, context: Dict[str, int]) -> int: - buf = buf.strip() + i = 0 + while i < len(tmp_expression): + current_token = tmp_expression[i] + if self.is_number(current_token): + self.queue.append(int(current_token, 0)) + elif current_token in context: + self.queue.append(int(context[current_token])) + elif current_token in self.cstruct.consts: + self.queue.append(int(self.cstruct.consts[current_token])) + elif current_token == "u": + self.stack.append(current_token) + elif current_token == "~": + self.stack.append(current_token) + elif current_token == "sizeof": + if len(tmp_expression) < i + 3 or (tmp_expression[i + 1] != "(" or tmp_expression[i + 3] != ")"): + raise ExpressionParserError("Invalid sizeof operation") + self.queue.append(len(self.cstruct.resolve(tmp_expression[i + 2]))) + i += 3 + elif current_token in operators: + while ( + len(self.stack) != 0 and self.stack[-1] != "(" and (self.precedence(self.stack[-1], current_token)) + ): + self.evaluate_exp() + self.stack.append(current_token) + elif current_token == "(": + if i > 0: + previous_token = tmp_expression[i - 1] + if self.is_number(previous_token): + raise ExpressionParserError( + f"Parser expected sizeof or an arethmethic operator instead got: '{previous_token}'" + ) - # Very simple way to support an expression(part) that is a single, - # negative value. To use negative values in more complex expressions, - # they must be wrapped in brackets, e.g.: 2 * (-5). - # - # To have full support for the negation operator a proper expression - # parser must be build. - if buf.startswith("-") and buf[1:].isnumeric(): - return int(buf) + self.stack.append(current_token) + elif current_token == ")": + if i > 0: + previous_token = tmp_expression[i - 1] + if previous_token == "(": + raise ExpressionParserError( + f"Parser expected an expression, instead received empty parenthesis. Index: {i}" + ) - for operator in self.operators: - if operator[0] in buf: - a, b = buf.rsplit(operator[0], 1) + if len(self.stack) == 0: + raise ExpressionParserError("Invalid expression") - return operator[1](self.evaluate_part(a, context), self.evaluate_part(b, context)) + while self.stack[-1] != "(": + self.evaluate_exp() - if buf in context: - return context[buf] + self.stack.pop(-1) + else: + raise ExpressionParserError(f"Unmatched token: '{current_token}'") + i += 1 - if buf.startswith("0x"): - return int(buf, 16) + while len(self.stack) != 0: + if self.stack[-1] == "(": + raise ExpressionParserError("Invalid expression") - if buf in self.cstruct.consts: - return int(self.cstruct.consts[buf]) + self.evaluate_exp() - return int(buf) + return self.queue[0] diff --git a/tests/test_expression.py b/tests/test_expression.py index 549c8c3..c1d0fc0 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -1,6 +1,7 @@ import pytest from dissect import cstruct +from dissect.cstruct.exceptions import ExpressionParserError, ExpressionTokenizerError from dissect.cstruct.expression import Expression testdata = [ @@ -44,13 +45,24 @@ ("0 | 1", 1), ("1 | 1", 1), ("1 | 2", 3), - # This type of expression is not supported by the parser and will fail. - # ('4 * 1 + 1', 5), + ("1 | 2 | 4", 7), + ("1 & 1 * 4", 0), + ("(1 & 1) * 4", 4), + ("4 * 1 + 1", 5), ("-42", -42), ("42 + (-42)", 0), ("A + 5", 13), ("21 - B", 8), ("A + B", 21), + ("~1", -2), + ("~(A + 5)", ~13), + ("10l", 10), + ("10ll", 10), + ("10ull", 10), + ("010ULL", 8), + ("0Xf0 >> 4", 0xF), + ("0x1B", 0x1B), + ("0x1b", 0x1B), ] @@ -67,11 +79,32 @@ def id_fn(val): @pytest.mark.parametrize("expression, answer", testdata, ids=id_fn) -def test_expression(expression, answer): +def test_expression(expression: str, answer: int) -> None: parser = Expression(Consts(), expression) assert parser.evaluate() == answer +@pytest.mark.parametrize( + "expression, exception, message", + [ + ("0b", ExpressionTokenizerError, "Invalid binary or hex notation"), + ("0x", ExpressionTokenizerError, "Invalid binary or hex notation"), + ("$", ExpressionTokenizerError, "Tokenizer does not recognize following token '\\$'"), + ("-", ExpressionParserError, "Invalid expression: not enough operands"), + ("(", ExpressionParserError, "Invalid expression"), + (")", ExpressionParserError, "Invalid expression"), + ("()", ExpressionParserError, "Parser expected an expression, instead received empty parenthesis. Index: 1"), + ("0()", ExpressionParserError, "Parser expected sizeof or an arethmethic operator instead got: '0'"), + ("sizeof)", ExpressionParserError, "Invalid sizeof operation"), + ("sizeof(0 +)", ExpressionParserError, "Invalid sizeof operation"), + ], +) +def test_expression_failure(expression: str, exception: type, message: str) -> None: + with pytest.raises(exception, match=message): + parser = Expression(Consts(), expression) + parser.evaluate() + + def test_sizeof(): d = """ struct test { diff --git a/tests/test_struct.py b/tests/test_struct.py index c94dae1..f08c0f8 100644 --- a/tests/test_struct.py +++ b/tests/test_struct.py @@ -114,7 +114,7 @@ def test_struct_expressions(compiled): #define const 1 struct test { uint8 flag; - uint8 data_1[flag & 1 * 4]; + uint8 data_1[(flag & 1) * 4]; uint8 data_2[flag & (1 << 2)]; uint8 data_3[const]; };