diff --git a/pyecoreocl/compiler.py b/pyecoreocl/compiler.py index b5f9edd..c02ccd5 100644 --- a/pyecoreocl/compiler.py +++ b/pyecoreocl/compiler.py @@ -62,7 +62,7 @@ def visitComparisonBinaryOperation(self, ctx): def visitCollectionCall(self, ctx): # We currently don't support implicity collection wrapping operation = ctx.attname.text - rule = rules.get(operation, default_collection_call) + rule = RuleSet.rules_collections.get(operation, default_collection_call) rule(self, ctx) def visitBooleanBinaryOperation(self, ctx): @@ -280,17 +280,34 @@ def visitUnreservedName(self, ctx): return self.visitChildren(ctx) -rules = {} +class RuleSet: + rules_collections = {} + rules_primitives = {} -def call_rule(fun): - names = fun.__name__[5:].split("_") - function_name = "".join([names[0], *(f.capitalize() for f in names[1:])]) - rules[function_name] = fun +def register_rule_set(decorator): + register_name = decorator.__name__.split("_")[0] + + def inner(fun): + names = fun.__name__[5:].split("_") + function_name = "".join([names[0], *(f.capitalize() for f in names[1:])]) + getattr(RuleSet, f"rules_{register_name}s")[function_name] = fun + return fun + + return inner + + +@register_rule_set +def collection_rule(fun): return fun -@call_rule +@register_rule_set +def primitive_rule(fun): + return fun + + +@collection_rule def rule_collect_nested(emitter, ctx): emitter.inline("(") emitter.visit(ctx.argExp().body) @@ -305,7 +322,7 @@ def rule_collect_nested(emitter, ctx): emitter.inline(")") -@call_rule +@collection_rule def rule_collect(emitter, ctx): emitter.inline("ocl.flatten(") emitter.visit(ctx.argExp().body) @@ -320,7 +337,7 @@ def rule_collect(emitter, ctx): emitter.inline(")") -@call_rule +@collection_rule def rule_for_all(emitter, ctx): emitter.inline(f"all(") emitter.visit(ctx.argExp().body) @@ -335,7 +352,7 @@ def rule_for_all(emitter, ctx): emitter.inline(")") -@call_rule +@collection_rule def rule_exists(emitter, ctx): emitter.inline(f"any(") emitter.visit(ctx.argExp().body) @@ -350,14 +367,14 @@ def rule_exists(emitter, ctx): emitter.inline(")") -@call_rule +@collection_rule def rule_one(emitter, ctx): emitter.inline("(len(list(") rule_select(emitter, ctx) emitter.inline(")) == 1)") -@call_rule +@collection_rule def rule_select(emitter, ctx): variables = [arg.text for arg in ctx.argExp().varnames] varnames = ", ".join(variables) @@ -374,7 +391,7 @@ def rule_select(emitter, ctx): emitter.inline(")") -@call_rule +@collection_rule def rule_reject(emitter, ctx): variables = [arg.text for arg in ctx.argExp().varnames] varnames = ", ".join(variables) @@ -391,35 +408,35 @@ def rule_reject(emitter, ctx): emitter.inline("))") -@call_rule +@collection_rule def rule_includes(emitter, ctx): emitter.visit(ctx.argExp()) emitter.inline(" in ") emitter.visit(ctx.expression) -@call_rule +@collection_rule def rule_excludes(emitter, ctx): emitter.visit(ctx.argExp()) emitter.inline(" not in ") emitter.visit(ctx.expression) -@call_rule +@collection_rule def rule_not_empty(emitter, ctx): emitter.inline("len(") emitter.visit(ctx.expression) emitter.inline(") > 0") -@call_rule +@collection_rule def rule_is_empty(emitter, ctx): emitter.inline("len(") emitter.visit(ctx.expression) emitter.inline(") == 0") -@call_rule +@collection_rule def rule_at(emitter, ctx): emitter.visit(ctx.expression) emitter.inline("[") @@ -427,42 +444,42 @@ def rule_at(emitter, ctx): emitter.inline("]") -@call_rule +@collection_rule def rule_size(emitter, ctx): emitter.inline("len(") emitter.visit(ctx.expression) emitter.inline(")") -@call_rule +@collection_rule def rule_is_unique(emitter, ctx): emitter.inline("ocl.is_unique(") emitter.visit(ctx.expression) emitter.inline(")") -@call_rule +@collection_rule def rule_sum(emitter, ctx): emitter.inline("sum(") emitter.visit(ctx.expression) emitter.inline(")") -@call_rule +@collection_rule def rule_min(emitter, ctx): emitter.inline("min(") emitter.visit(ctx.expression) emitter.inline(")") -@call_rule +@collection_rule def rule_max(emitter, ctx): emitter.inline("max(") emitter.visit(ctx.expression) emitter.inline(")") -@call_rule +@collection_rule def rule_count(emitter, ctx): emitter.inline("list(") emitter.visit(ctx.expression) @@ -471,7 +488,7 @@ def rule_count(emitter, ctx): emitter.inline(")") -@call_rule +@collection_rule def rule_includes_all(emitter, ctx): emitter.inline("all(e in ") emitter.visit(ctx.expression) @@ -480,7 +497,7 @@ def rule_includes_all(emitter, ctx): emitter.inline(")") -@call_rule +@collection_rule def rule_excludes_all(emitter, ctx): emitter.inline("all(e not in ") emitter.visit(ctx.expression) @@ -489,35 +506,35 @@ def rule_excludes_all(emitter, ctx): emitter.inline(")") -@call_rule +@collection_rule def rule_as_sequence(emitter, ctx): emitter.inline("list(") emitter.visit(ctx.expression) emitter.inline(")") -@call_rule +@collection_rule def rule_as_set(emitter, ctx): emitter.inline("set(") emitter.visit(ctx.expression) emitter.inline(")") -@call_rule +@collection_rule def rule_as_bag(emitter, ctx): emitter.inline("list(") emitter.visit(ctx.expression) emitter.inline(")") -@call_rule +@collection_rule def rule_any(emitter, ctx): emitter.inline("next(iter(") emitter.visit(ctx.expression) emitter.inline("), None)") -@call_rule +@collection_rule def rule_sorted_by(emitter, ctx): emitter.inline("sorted(") emitter.visit(ctx.expression) @@ -526,7 +543,7 @@ def rule_sorted_by(emitter, ctx): emitter.inline(")") -@call_rule +@collection_rule def rule_iterate(emitter, ctx): emitter.inline("ocl.flatten(") emitter.visit(ctx.argExp()) @@ -537,7 +554,7 @@ def rule_iterate(emitter, ctx): emitter.inline(")") -@call_rule +@collection_rule def rule_including(emitter, ctx): emitter.inline("itertools.chain(") emitter.visit(ctx.expression) @@ -546,12 +563,12 @@ def rule_including(emitter, ctx): emitter.inline(",))") -@call_rule +@collection_rule def rule_append(emitter, ctx): rule_including(emitter, ctx) -@call_rule +@collection_rule def rule_prepend(emitter, ctx): emitter.inline("itertools.chain((") emitter.visit(ctx.argExp()) @@ -560,7 +577,7 @@ def rule_prepend(emitter, ctx): emitter.inline(")") -@call_rule +@collection_rule def rule_excluding(emitter, ctx): emitter.inline("_e for _e in ") emitter.visit(ctx.expression) @@ -568,7 +585,7 @@ def rule_excluding(emitter, ctx): emitter.visit(ctx.argExp()) -@call_rule +@collection_rule def rule_select_by_kind(emitter, ctx): emitter.inline("_e for _e in ") emitter.visit(ctx.expression) @@ -577,28 +594,29 @@ def rule_select_by_kind(emitter, ctx): emitter.inline(")") -@call_rule +@collection_rule def rule_select_by_type(emitter, ctx): emitter.inline("_e for _e in ") emitter.visit(ctx.expression) emitter.inline(" if type(_e) == ") emitter.visit(ctx.argExp()) -@call_rule + +@collection_rule def rule_first(emitter, ctx): emitter.inline("next(iter(") emitter.visit(ctx.expression) emitter.inline("))") -@call_rule +@collection_rule def rule_last(emitter, ctx): emitter.inline("list(") emitter.visit(ctx.expression) emitter.inline(")[-1]") -@call_rule +@collection_rule def rule_index_of(emitter, ctx): emitter.inline("next(_i for _i, _e in enumerate(") emitter.visit(ctx.expression) @@ -606,7 +624,8 @@ def rule_index_of(emitter, ctx): emitter.visit(ctx.argExp().body[0]) emitter.inline(")") -@call_rule + +@collection_rule def rule_closure(emitter, ctx): emitter.inline("ocl.closure(") emitter.visit(ctx.expression) diff --git a/tests/test_collection_lib.py b/tests/test_collection_lib.py index c894193..811dddb 100644 --- a/tests/test_collection_lib.py +++ b/tests/test_collection_lib.py @@ -203,5 +203,4 @@ class A: x1, x2 = A(children=[]), A(children=None) a = A(children=[x1, x2]) - print(!Sequence{a}->closure(e | e.children)->asSequence()!) assert !Sequence{a}->closure(e | e.children)->asSequence()! == [x1, x2] \ No newline at end of file