From d1f349ebcad46e904811698adba664ea352e5fa8 Mon Sep 17 00:00:00 2001 From: Michael Langowski Date: Thu, 1 Aug 2024 20:31:57 +0200 Subject: [PATCH] Evolog Modules: Add parsing for list collection aggregate --- .../api/programs/atoms/AggregateAtom.java | 8 +++- .../ac/tuwien/kr/alpha/core/antlr/ASPCore2.g4 | 8 +++- .../ac/tuwien/kr/alpha/core/antlr/ASPLexer.g4 | 1 + .../alpha/core/parser/ParseTreeVisitor.java | 31 +++++++++++- .../kr/alpha/core/parser/ParserTest.java | 48 +++++++++++++++++++ 5 files changed, 93 insertions(+), 3 deletions(-) diff --git a/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/programs/atoms/AggregateAtom.java b/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/programs/atoms/AggregateAtom.java index bd5127cf2..f5de58e46 100644 --- a/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/programs/atoms/AggregateAtom.java +++ b/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/programs/atoms/AggregateAtom.java @@ -24,7 +24,8 @@ enum AggregateFunctionSymbol { COUNT, MAX, MIN, - SUM + SUM, + LIST } ComparisonOperator getLowerBoundOperator(); @@ -44,6 +45,11 @@ enum AggregateFunctionSymbol { @Override AggregateLiteral toLiteral(boolean positive); + @Override + default AggregateLiteral toLiteral() { + return toLiteral(true); + } + interface AggregateElement { List getElementTerms(); diff --git a/alpha-core/src/main/antlr/at/ac/tuwien/kr/alpha/core/antlr/ASPCore2.g4 b/alpha-core/src/main/antlr/at/ac/tuwien/kr/alpha/core/antlr/ASPCore2.g4 index 876c4e190..9eca2e92b 100644 --- a/alpha-core/src/main/antlr/at/ac/tuwien/kr/alpha/core/antlr/ASPCore2.g4 +++ b/alpha-core/src/main/antlr/at/ac/tuwien/kr/alpha/core/antlr/ASPCore2.g4 @@ -34,7 +34,13 @@ choice_elements : choice_element (SEMICOLON choice_elements)?; choice_element : classical_literal (COLON naf_literals?)?; -aggregate : NAF? (lt=term lop=binop)? aggregate_function CURLY_OPEN aggregate_elements CURLY_CLOSE (uop=binop ut=term)?; +aggregate : (classic_aggregate | list_aggregate); + +list_aggregate: term EQUAL AGGREGATE_LIST CURLY_OPEN list_comprehension CURLY_CLOSE; + +list_comprehension: term COLON naf_literals; // Note: Term is expected to be a function term or basic_term + +classic_aggregate: NAF? (lt=term lop=binop)? aggregate_function CURLY_OPEN aggregate_elements CURLY_CLOSE (uop=binop ut=term)?; aggregate_elements : aggregate_element (SEMICOLON aggregate_elements)?; diff --git a/alpha-core/src/main/antlr/at/ac/tuwien/kr/alpha/core/antlr/ASPLexer.g4 b/alpha-core/src/main/antlr/at/ac/tuwien/kr/alpha/core/antlr/ASPLexer.g4 index b0641ab1d..1825d3607 100644 --- a/alpha-core/src/main/antlr/at/ac/tuwien/kr/alpha/core/antlr/ASPLexer.g4 +++ b/alpha-core/src/main/antlr/at/ac/tuwien/kr/alpha/core/antlr/ASPLexer.g4 @@ -40,6 +40,7 @@ AGGREGATE_COUNT : '#count'; AGGREGATE_MAX : '#max'; AGGREGATE_MIN : '#min'; AGGREGATE_SUM : '#sum'; +AGGREGATE_LIST : '#list'; DIRECTIVE_ENUM : 'enumeration_predicate_is'; diff --git a/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/parser/ParseTreeVisitor.java b/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/parser/ParseTreeVisitor.java index 16226a168..a28df0d9f 100644 --- a/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/parser/ParseTreeVisitor.java +++ b/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/parser/ParseTreeVisitor.java @@ -388,7 +388,19 @@ public Set visitBody(ASPCore2Parser.BodyContext ctx) { @Override public AggregateLiteral visitAggregate(ASPCore2Parser.AggregateContext ctx) { - // aggregate : NAF? (lt=term lop=binop)? aggregate_function CURLY_OPEN aggregate_elements CURLY_CLOSE (uop=binop ut=term)?; + // aggregate : (classic_aggregate | list_aggregate); + if (ctx.classic_aggregate() != null) { + return visitClassic_aggregate(ctx.classic_aggregate()); + } else if (ctx.list_aggregate() != null) { + return visitList_aggregate(ctx.list_aggregate()); + } else { + throw notSupported(ctx); + } + } + + @Override + public AggregateLiteral visitClassic_aggregate(ASPCore2Parser.Classic_aggregateContext ctx) { + // classic_aggregate: NAF? (lt=term lop=binop)? aggregate_function CURLY_OPEN aggregate_elements CURLY_CLOSE (uop=binop ut=term)?; boolean isPositive = ctx.NAF() == null; Term lt = null; ComparisonOperator lop = null; @@ -407,6 +419,23 @@ public AggregateLiteral visitAggregate(ASPCore2Parser.AggregateContext ctx) { return Atoms.newAggregateAtom(lop, lt, uop, ut, aggregateFunction, aggregateElements).toLiteral(isPositive); } + @Override + public AggregateLiteral visitList_aggregate(ASPCore2Parser.List_aggregateContext ctx) { + // list_aggregate: term EQUAL AGGREGATE_LIST CURLY_OPEN list_comprehension CURLY_CLOSE; + Term listResultTerm = (Term) visit(ctx.term()); + ImmutablePair> listComprehension = visitList_comprehension(ctx.list_comprehension()); + return Atoms.newAggregateAtom(ComparisonOperators.EQ, listResultTerm, AggregateAtom.AggregateFunctionSymbol.LIST, + List.of(Atoms.newAggregateElement(List.of(listComprehension.left), listComprehension.right))).toLiteral(); + } + + @Override + public ImmutablePair> visitList_comprehension(ASPCore2Parser.List_comprehensionContext ctx) { + // list_comprehension: term COLON naf_literals; + Term elementTerm = (Term) visit(ctx.term()); + List elementSelectors = visitNaf_literals(ctx.naf_literals()); + return ImmutablePair.of(elementTerm, elementSelectors); + } + @Override public List visitAggregate_elements(ASPCore2Parser.Aggregate_elementsContext ctx) { // aggregate_elements : aggregate_element (SEMICOLON aggregate_elements)?; diff --git a/alpha-core/src/test/java/at/ac/tuwien/kr/alpha/core/parser/ParserTest.java b/alpha-core/src/test/java/at/ac/tuwien/kr/alpha/core/parser/ParserTest.java index 624658a64..676ce858b 100644 --- a/alpha-core/src/test/java/at/ac/tuwien/kr/alpha/core/parser/ParserTest.java +++ b/alpha-core/src/test/java/at/ac/tuwien/kr/alpha/core/parser/ParserTest.java @@ -106,6 +106,10 @@ public class ParserTest { private static final String MODULE_LITERAL_NO_OUTPUT_WITH_NUM_ANSWER_SETS = "a(X) :- #something{4}[X]."; + private static final String LIST_AGGREGATE = "stuff_list(LST) :- LST = #list{X : stuff(X)}."; + + private static final String LIST_AGGREGATE_TUPLE = "stuff_list(LST) :- LST = #list{stuff_tuple(X,Y) : stuff(X,Y)}."; + private final ProgramParserImpl parser = new ProgramParserImpl(); @Test @@ -499,4 +503,48 @@ public void moduleLiteralNoOutputWithNumAnswerSets() { assertEquals(4, moduleLiteral.getAtom().getInstantiationMode().requestedAnswerSets().get()); } + @Test + public void listAggregate() { + InputProgram prog = parser.parse(LIST_AGGREGATE); + assertEquals(1, prog.getRules().size()); + Rule rule = prog.getRules().get(0); + assertEquals(1, rule.getBody().size()); + assertEquals(1, rule.getBody().stream().filter(lit -> lit instanceof AggregateLiteral).count()); + AggregateLiteral aggregateLiteral = (AggregateLiteral) rule.getBody().stream().filter(lit -> lit instanceof AggregateLiteral).findFirst().get(); + AggregateAtom aggregateAtom = aggregateLiteral.getAtom(); + assertEquals(ComparisonOperators.EQ, aggregateAtom.getLowerBoundOperator()); + assertEquals(Terms.newVariable("LST"), aggregateAtom.getLowerBoundTerm()); + assertEquals(AggregateAtom.AggregateFunctionSymbol.LIST, aggregateAtom.getAggregateFunction()); + assertEquals(1, aggregateAtom.getAggregateElements().size()); + AggregateAtom.AggregateElement aggregateElement = aggregateAtom.getAggregateElements().get(0); + assertEquals(1, aggregateElement.getElementTerms().size()); + Term elementTerm = aggregateElement.getElementTerms().get(0); + assertEquals(Terms.newVariable("X"), elementTerm); + assertEquals(1, aggregateElement.getElementLiterals().size()); + Literal elementLiteral = aggregateElement.getElementLiterals().get(0); + assertEquals(Atoms.newBasicAtom(Predicates.getPredicate("stuff", 1), Terms.newVariable("X")).toLiteral(), elementLiteral); + } + + @Test + public void listAggregateWithTuples() { + InputProgram prog = parser.parse(LIST_AGGREGATE_TUPLE); + assertEquals(1, prog.getRules().size()); + Rule rule = prog.getRules().get(0); + assertEquals(1, rule.getBody().size()); + assertEquals(1, rule.getBody().stream().filter(lit -> lit instanceof AggregateLiteral).count()); + AggregateLiteral aggregateLiteral = (AggregateLiteral) rule.getBody().stream().filter(lit -> lit instanceof AggregateLiteral).findFirst().get(); + AggregateAtom aggregateAtom = aggregateLiteral.getAtom(); + assertEquals(ComparisonOperators.EQ, aggregateAtom.getLowerBoundOperator()); + assertEquals(Terms.newVariable("LST"), aggregateAtom.getLowerBoundTerm()); + assertEquals(AggregateAtom.AggregateFunctionSymbol.LIST, aggregateAtom.getAggregateFunction()); + assertEquals(1, aggregateAtom.getAggregateElements().size()); + AggregateAtom.AggregateElement aggregateElement = aggregateAtom.getAggregateElements().get(0); + assertEquals(1, aggregateElement.getElementTerms().size()); + Term elementTerm = aggregateElement.getElementTerms().get(0); + assertEquals(Terms.newFunctionTerm("stuff_tuple", Terms.newVariable("X"), Terms.newVariable("Y")), elementTerm); + assertEquals(1, aggregateElement.getElementLiterals().size()); + Literal elementLiteral = aggregateElement.getElementLiterals().get(0); + assertEquals(Atoms.newBasicAtom(Predicates.getPredicate("stuff", 2), Terms.newVariable("X"), Terms.newVariable("Y")).toLiteral(), elementLiteral); + } + }