Skip to content

Commit

Permalink
Add new mechanism to register rules
Browse files Browse the repository at this point in the history
  • Loading branch information
aranega committed Sep 18, 2024
1 parent f374770 commit 992f3c6
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 43 deletions.
103 changes: 61 additions & 42 deletions pyecoreocl/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def visitComparisonBinaryOperation(self, ctx):
def visitCollectionCall(self, ctx):
# We currently don't support implicity collection wrapping
operation = ctx.attname.text
rule = rules.get(operation, default_collection_call)
rule = RuleSet.rules_collections.get(operation, default_collection_call)
rule(self, ctx)

def visitBooleanBinaryOperation(self, ctx):
Expand Down Expand Up @@ -280,17 +280,34 @@ def visitUnreservedName(self, ctx):
return self.visitChildren(ctx)


rules = {}
class RuleSet:
rules_collections = {}
rules_primitives = {}


def call_rule(fun):
names = fun.__name__[5:].split("_")
function_name = "".join([names[0], *(f.capitalize() for f in names[1:])])
rules[function_name] = fun
def register_rule_set(decorator):
register_name = decorator.__name__.split("_")[0]

def inner(fun):
names = fun.__name__[5:].split("_")
function_name = "".join([names[0], *(f.capitalize() for f in names[1:])])
getattr(RuleSet, f"rules_{register_name}s")[function_name] = fun
return fun

return inner


@register_rule_set
def collection_rule(fun):
return fun


@call_rule
@register_rule_set
def primitive_rule(fun):
return fun


@collection_rule
def rule_collect_nested(emitter, ctx):
emitter.inline("(")
emitter.visit(ctx.argExp().body)
Expand All @@ -305,7 +322,7 @@ def rule_collect_nested(emitter, ctx):
emitter.inline(")")


@call_rule
@collection_rule
def rule_collect(emitter, ctx):
emitter.inline("ocl.flatten(")
emitter.visit(ctx.argExp().body)
Expand All @@ -320,7 +337,7 @@ def rule_collect(emitter, ctx):
emitter.inline(")")


@call_rule
@collection_rule
def rule_for_all(emitter, ctx):
emitter.inline(f"all(")
emitter.visit(ctx.argExp().body)
Expand All @@ -335,7 +352,7 @@ def rule_for_all(emitter, ctx):
emitter.inline(")")


@call_rule
@collection_rule
def rule_exists(emitter, ctx):
emitter.inline(f"any(")
emitter.visit(ctx.argExp().body)
Expand All @@ -350,14 +367,14 @@ def rule_exists(emitter, ctx):
emitter.inline(")")


@call_rule
@collection_rule
def rule_one(emitter, ctx):
emitter.inline("(len(list(")
rule_select(emitter, ctx)
emitter.inline(")) == 1)")


@call_rule
@collection_rule
def rule_select(emitter, ctx):
variables = [arg.text for arg in ctx.argExp().varnames]
varnames = ", ".join(variables)
Expand All @@ -374,7 +391,7 @@ def rule_select(emitter, ctx):
emitter.inline(")")


@call_rule
@collection_rule
def rule_reject(emitter, ctx):
variables = [arg.text for arg in ctx.argExp().varnames]
varnames = ", ".join(variables)
Expand All @@ -391,78 +408,78 @@ def rule_reject(emitter, ctx):
emitter.inline("))")


@call_rule
@collection_rule
def rule_includes(emitter, ctx):
emitter.visit(ctx.argExp())
emitter.inline(" in ")
emitter.visit(ctx.expression)


@call_rule
@collection_rule
def rule_excludes(emitter, ctx):
emitter.visit(ctx.argExp())
emitter.inline(" not in ")
emitter.visit(ctx.expression)


@call_rule
@collection_rule
def rule_not_empty(emitter, ctx):
emitter.inline("len(")
emitter.visit(ctx.expression)
emitter.inline(") > 0")


@call_rule
@collection_rule
def rule_is_empty(emitter, ctx):
emitter.inline("len(")
emitter.visit(ctx.expression)
emitter.inline(") == 0")


@call_rule
@collection_rule
def rule_at(emitter, ctx):
emitter.visit(ctx.expression)
emitter.inline("[")
emitter.visit(ctx.argExp().body[0])
emitter.inline("]")


@call_rule
@collection_rule
def rule_size(emitter, ctx):
emitter.inline("len(")
emitter.visit(ctx.expression)
emitter.inline(")")


@call_rule
@collection_rule
def rule_is_unique(emitter, ctx):
emitter.inline("ocl.is_unique(")
emitter.visit(ctx.expression)
emitter.inline(")")


@call_rule
@collection_rule
def rule_sum(emitter, ctx):
emitter.inline("sum(")
emitter.visit(ctx.expression)
emitter.inline(")")


@call_rule
@collection_rule
def rule_min(emitter, ctx):
emitter.inline("min(")
emitter.visit(ctx.expression)
emitter.inline(")")


@call_rule
@collection_rule
def rule_max(emitter, ctx):
emitter.inline("max(")
emitter.visit(ctx.expression)
emitter.inline(")")


@call_rule
@collection_rule
def rule_count(emitter, ctx):
emitter.inline("list(")
emitter.visit(ctx.expression)
Expand All @@ -471,7 +488,7 @@ def rule_count(emitter, ctx):
emitter.inline(")")


@call_rule
@collection_rule
def rule_includes_all(emitter, ctx):
emitter.inline("all(e in ")
emitter.visit(ctx.expression)
Expand All @@ -480,7 +497,7 @@ def rule_includes_all(emitter, ctx):
emitter.inline(")")


@call_rule
@collection_rule
def rule_excludes_all(emitter, ctx):
emitter.inline("all(e not in ")
emitter.visit(ctx.expression)
Expand All @@ -489,35 +506,35 @@ def rule_excludes_all(emitter, ctx):
emitter.inline(")")


@call_rule
@collection_rule
def rule_as_sequence(emitter, ctx):
emitter.inline("list(")
emitter.visit(ctx.expression)
emitter.inline(")")


@call_rule
@collection_rule
def rule_as_set(emitter, ctx):
emitter.inline("set(")
emitter.visit(ctx.expression)
emitter.inline(")")


@call_rule
@collection_rule
def rule_as_bag(emitter, ctx):
emitter.inline("list(")
emitter.visit(ctx.expression)
emitter.inline(")")


@call_rule
@collection_rule
def rule_any(emitter, ctx):
emitter.inline("next(iter(")
emitter.visit(ctx.expression)
emitter.inline("), None)")


@call_rule
@collection_rule
def rule_sorted_by(emitter, ctx):
emitter.inline("sorted(")
emitter.visit(ctx.expression)
Expand All @@ -526,7 +543,7 @@ def rule_sorted_by(emitter, ctx):
emitter.inline(")")


@call_rule
@collection_rule
def rule_iterate(emitter, ctx):
emitter.inline("ocl.flatten(")
emitter.visit(ctx.argExp())
Expand All @@ -537,7 +554,7 @@ def rule_iterate(emitter, ctx):
emitter.inline(")")


@call_rule
@collection_rule
def rule_including(emitter, ctx):
emitter.inline("itertools.chain(")
emitter.visit(ctx.expression)
Expand All @@ -546,12 +563,12 @@ def rule_including(emitter, ctx):
emitter.inline(",))")


@call_rule
@collection_rule
def rule_append(emitter, ctx):
rule_including(emitter, ctx)


@call_rule
@collection_rule
def rule_prepend(emitter, ctx):
emitter.inline("itertools.chain((")
emitter.visit(ctx.argExp())
Expand All @@ -560,15 +577,15 @@ def rule_prepend(emitter, ctx):
emitter.inline(")")


@call_rule
@collection_rule
def rule_excluding(emitter, ctx):
emitter.inline("_e for _e in ")
emitter.visit(ctx.expression)
emitter.inline(" if _e != ")
emitter.visit(ctx.argExp())


@call_rule
@collection_rule
def rule_select_by_kind(emitter, ctx):
emitter.inline("_e for _e in ")
emitter.visit(ctx.expression)
Expand All @@ -577,36 +594,38 @@ def rule_select_by_kind(emitter, ctx):
emitter.inline(")")


@call_rule
@collection_rule
def rule_select_by_type(emitter, ctx):
emitter.inline("_e for _e in ")
emitter.visit(ctx.expression)
emitter.inline(" if type(_e) == ")
emitter.visit(ctx.argExp())

@call_rule

@collection_rule
def rule_first(emitter, ctx):
emitter.inline("next(iter(")
emitter.visit(ctx.expression)
emitter.inline("))")


@call_rule
@collection_rule
def rule_last(emitter, ctx):
emitter.inline("list(")
emitter.visit(ctx.expression)
emitter.inline(")[-1]")


@call_rule
@collection_rule
def rule_index_of(emitter, ctx):
emitter.inline("next(_i for _i, _e in enumerate(")
emitter.visit(ctx.expression)
emitter.inline(") if _e == ")
emitter.visit(ctx.argExp().body[0])
emitter.inline(")")

@call_rule

@collection_rule
def rule_closure(emitter, ctx):
emitter.inline("ocl.closure(")
emitter.visit(ctx.expression)
Expand Down
1 change: 0 additions & 1 deletion tests/test_collection_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,5 +203,4 @@ class A:
x1, x2 = A(children=[]), A(children=None)
a = A(children=[x1, x2])

print(!Sequence{a}->closure(e | e.children)->asSequence()!)
assert !Sequence{a}->closure(e | e.children)->asSequence()! == [x1, x2]

0 comments on commit 992f3c6

Please sign in to comment.