Skip to content

Commit

Permalink
Evolog Modules: Add parsing for list collection aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
madmike200590 committed Aug 1, 2024
1 parent 8bf77b3 commit d1f349e
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ enum AggregateFunctionSymbol {
COUNT,
MAX,
MIN,
SUM
SUM,
LIST
}

ComparisonOperator getLowerBoundOperator();
Expand All @@ -44,6 +45,11 @@ enum AggregateFunctionSymbol {
@Override
AggregateLiteral toLiteral(boolean positive);

@Override
default AggregateLiteral toLiteral() {
return toLiteral(true);
}

interface AggregateElement {

List<Term> getElementTerms();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ choice_elements : choice_element (SEMICOLON choice_elements)?;

choice_element : classical_literal (COLON naf_literals?)?;

aggregate : NAF? (lt=term lop=binop)? aggregate_function CURLY_OPEN aggregate_elements CURLY_CLOSE (uop=binop ut=term)?;
aggregate : (classic_aggregate | list_aggregate);

list_aggregate: term EQUAL AGGREGATE_LIST CURLY_OPEN list_comprehension CURLY_CLOSE;

list_comprehension: term COLON naf_literals; // Note: Term is expected to be a function term or basic_term

classic_aggregate: NAF? (lt=term lop=binop)? aggregate_function CURLY_OPEN aggregate_elements CURLY_CLOSE (uop=binop ut=term)?;

aggregate_elements : aggregate_element (SEMICOLON aggregate_elements)?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ AGGREGATE_COUNT : '#count';
AGGREGATE_MAX : '#max';
AGGREGATE_MIN : '#min';
AGGREGATE_SUM : '#sum';
AGGREGATE_LIST : '#list';

DIRECTIVE_ENUM : 'enumeration_predicate_is';

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,19 @@ public Set<Literal> visitBody(ASPCore2Parser.BodyContext ctx) {

@Override
public AggregateLiteral visitAggregate(ASPCore2Parser.AggregateContext ctx) {
// aggregate : NAF? (lt=term lop=binop)? aggregate_function CURLY_OPEN aggregate_elements CURLY_CLOSE (uop=binop ut=term)?;
// aggregate : (classic_aggregate | list_aggregate);
if (ctx.classic_aggregate() != null) {
return visitClassic_aggregate(ctx.classic_aggregate());
} else if (ctx.list_aggregate() != null) {
return visitList_aggregate(ctx.list_aggregate());
} else {
throw notSupported(ctx);
}
}

@Override
public AggregateLiteral visitClassic_aggregate(ASPCore2Parser.Classic_aggregateContext ctx) {
// classic_aggregate: NAF? (lt=term lop=binop)? aggregate_function CURLY_OPEN aggregate_elements CURLY_CLOSE (uop=binop ut=term)?;
boolean isPositive = ctx.NAF() == null;
Term lt = null;
ComparisonOperator lop = null;
Expand All @@ -407,6 +419,23 @@ public AggregateLiteral visitAggregate(ASPCore2Parser.AggregateContext ctx) {
return Atoms.newAggregateAtom(lop, lt, uop, ut, aggregateFunction, aggregateElements).toLiteral(isPositive);
}

@Override
public AggregateLiteral visitList_aggregate(ASPCore2Parser.List_aggregateContext ctx) {
// list_aggregate: term EQUAL AGGREGATE_LIST CURLY_OPEN list_comprehension CURLY_CLOSE;
Term listResultTerm = (Term) visit(ctx.term());
ImmutablePair<Term, List<Literal>> listComprehension = visitList_comprehension(ctx.list_comprehension());
return Atoms.newAggregateAtom(ComparisonOperators.EQ, listResultTerm, AggregateAtom.AggregateFunctionSymbol.LIST,
List.of(Atoms.newAggregateElement(List.of(listComprehension.left), listComprehension.right))).toLiteral();
}

@Override
public ImmutablePair<Term, List<Literal>> visitList_comprehension(ASPCore2Parser.List_comprehensionContext ctx) {
// list_comprehension: term COLON naf_literals;
Term elementTerm = (Term) visit(ctx.term());
List<Literal> elementSelectors = visitNaf_literals(ctx.naf_literals());
return ImmutablePair.of(elementTerm, elementSelectors);
}

@Override
public List<AggregateAtom.AggregateElement> visitAggregate_elements(ASPCore2Parser.Aggregate_elementsContext ctx) {
// aggregate_elements : aggregate_element (SEMICOLON aggregate_elements)?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ public class ParserTest {

private static final String MODULE_LITERAL_NO_OUTPUT_WITH_NUM_ANSWER_SETS = "a(X) :- #something{4}[X].";

private static final String LIST_AGGREGATE = "stuff_list(LST) :- LST = #list{X : stuff(X)}.";

private static final String LIST_AGGREGATE_TUPLE = "stuff_list(LST) :- LST = #list{stuff_tuple(X,Y) : stuff(X,Y)}.";

private final ProgramParserImpl parser = new ProgramParserImpl();

@Test
Expand Down Expand Up @@ -499,4 +503,48 @@ public void moduleLiteralNoOutputWithNumAnswerSets() {
assertEquals(4, moduleLiteral.getAtom().getInstantiationMode().requestedAnswerSets().get());
}

@Test
public void listAggregate() {
InputProgram prog = parser.parse(LIST_AGGREGATE);
assertEquals(1, prog.getRules().size());
Rule<?> rule = prog.getRules().get(0);
assertEquals(1, rule.getBody().size());
assertEquals(1, rule.getBody().stream().filter(lit -> lit instanceof AggregateLiteral).count());
AggregateLiteral aggregateLiteral = (AggregateLiteral) rule.getBody().stream().filter(lit -> lit instanceof AggregateLiteral).findFirst().get();
AggregateAtom aggregateAtom = aggregateLiteral.getAtom();
assertEquals(ComparisonOperators.EQ, aggregateAtom.getLowerBoundOperator());
assertEquals(Terms.newVariable("LST"), aggregateAtom.getLowerBoundTerm());
assertEquals(AggregateAtom.AggregateFunctionSymbol.LIST, aggregateAtom.getAggregateFunction());
assertEquals(1, aggregateAtom.getAggregateElements().size());
AggregateAtom.AggregateElement aggregateElement = aggregateAtom.getAggregateElements().get(0);
assertEquals(1, aggregateElement.getElementTerms().size());
Term elementTerm = aggregateElement.getElementTerms().get(0);
assertEquals(Terms.newVariable("X"), elementTerm);
assertEquals(1, aggregateElement.getElementLiterals().size());
Literal elementLiteral = aggregateElement.getElementLiterals().get(0);
assertEquals(Atoms.newBasicAtom(Predicates.getPredicate("stuff", 1), Terms.newVariable("X")).toLiteral(), elementLiteral);
}

@Test
public void listAggregateWithTuples() {
InputProgram prog = parser.parse(LIST_AGGREGATE_TUPLE);
assertEquals(1, prog.getRules().size());
Rule<?> rule = prog.getRules().get(0);
assertEquals(1, rule.getBody().size());
assertEquals(1, rule.getBody().stream().filter(lit -> lit instanceof AggregateLiteral).count());
AggregateLiteral aggregateLiteral = (AggregateLiteral) rule.getBody().stream().filter(lit -> lit instanceof AggregateLiteral).findFirst().get();
AggregateAtom aggregateAtom = aggregateLiteral.getAtom();
assertEquals(ComparisonOperators.EQ, aggregateAtom.getLowerBoundOperator());
assertEquals(Terms.newVariable("LST"), aggregateAtom.getLowerBoundTerm());
assertEquals(AggregateAtom.AggregateFunctionSymbol.LIST, aggregateAtom.getAggregateFunction());
assertEquals(1, aggregateAtom.getAggregateElements().size());
AggregateAtom.AggregateElement aggregateElement = aggregateAtom.getAggregateElements().get(0);
assertEquals(1, aggregateElement.getElementTerms().size());
Term elementTerm = aggregateElement.getElementTerms().get(0);
assertEquals(Terms.newFunctionTerm("stuff_tuple", Terms.newVariable("X"), Terms.newVariable("Y")), elementTerm);
assertEquals(1, aggregateElement.getElementLiterals().size());
Literal elementLiteral = aggregateElement.getElementLiterals().get(0);
assertEquals(Atoms.newBasicAtom(Predicates.getPredicate("stuff", 2), Terms.newVariable("X"), Terms.newVariable("Y")).toLiteral(), elementLiteral);
}

}

0 comments on commit d1f349e

Please sign in to comment.