From 50eed8ee7679a9625da7a29240656bebd78c8ef9 Mon Sep 17 00:00:00 2001 From: Eric Vin <8935814+Eric-Vin@users.noreply.github.com> Date: Wed, 14 Aug 2024 15:04:06 -0700 Subject: [PATCH] Requirement Parsing Fixes (#299) * Added deep boolean operator tests. * Moved requirement atomic checks to separate transformer. --- src/scenic/syntax/compiler.py | 30 ++++++++++++++---------------- tests/syntax/test_requirements.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/src/scenic/syntax/compiler.py b/src/scenic/syntax/compiler.py index bc15f1fee..5328c0c6d 100644 --- a/src/scenic/syntax/compiler.py +++ b/src/scenic/syntax/compiler.py @@ -236,11 +236,20 @@ def makeSyntaxError(self, msg, node: ast.AST) -> ScenicParseError: } +class AtomicCheckTransformer(Transformer): + def visit_Call(self, node: ast.Call): + func = node.func + if isinstance(func, ast.Name) and func.id in TEMPORAL_PREFIX_OPS: + self.makeSyntaxError( + f'malformed use of the "{func.id}" temporal operator', node + ) + return self.generic_visit(node) + + class PropositionTransformer(Transformer): def __init__(self, filename="") -> None: super().__init__(filename) self.nextSyntaxId = 0 - self.inAtomic = False def transform( self, node: ast.AST, nextSyntaxId=0 @@ -262,12 +271,9 @@ def transform( return newNode, self.nextSyntaxId def generic_visit(self, node): - # Don't recurse inside atomics. - old_inAtomic = self.inAtomic - self.inAtomic = True - super_val = super().generic_visit(node) - self.inAtomic = old_inAtomic - return super_val + acv = AtomicCheckTransformer(self.filename) + acv.visit(node) + return node def _register_requirement_syntax(self, syntax): """register requirement syntax for later use @@ -346,7 +352,7 @@ def visit_BoolOp(self, node: ast.BoolOp) -> ast.AST: def visit_UnaryOp(self, node): # rewrite `not` in requirements into a proposition factory - if not isinstance(node.op, ast.Not) or self.inAtomic: + if not isinstance(node.op, ast.Not): return self.generic_visit(node) lineNum = ast.Constant(node.lineno) @@ -367,14 +373,6 @@ def visit_UnaryOp(self, node): ) return ast.copy_location(newNode, node) - def visit_Call(self, node: ast.Call): - func = node.func - if isinstance(func, ast.Name) and func.id in TEMPORAL_PREFIX_OPS: - self.makeSyntaxError( - f'malformed use of the "{func.id}" temporal operator', node - ) - return self.generic_visit(node) - def visit_Always(self, node: s.Always): value = self.visit(node.value) if not self.is_proposition_factory(value): diff --git a/tests/syntax/test_requirements.py b/tests/syntax/test_requirements.py index 0c06dd832..0c7699fe0 100644 --- a/tests/syntax/test_requirements.py +++ b/tests/syntax/test_requirements.py @@ -508,3 +508,33 @@ def test_deep_not(): require all(not o.x > 0 for o in objs) """ ) + + +def test_deep_and(): + with pytest.raises(RejectionException): + sampleSceneFrom( + """ + objs = [new Object at 10@10, new Object at 20@20] + require all(o.x > 0 and o.x < 0 for o in objs) + """ + ) + + +def test_deep_or(): + with pytest.raises(RejectionException): + sampleSceneFrom( + """ + objs = [new Object at 10@10, new Object at 20@20] + require all(o.x < 0 or o.x < -1 for o in objs) + """ + ) + + +def test_temporal_in_atomic(): + with pytest.raises(ScenicSyntaxError): + sampleSceneFrom( + """ + objs = [new Object at 10@10, new Object at 20@20] + require all(eventually(o.x > 0) for o in objs) + """ + )