diff --git a/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/terms/Terms.java b/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/terms/Terms.java index 0a8e8f11b..5e689ca4e 100644 --- a/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/terms/Terms.java +++ b/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/terms/Terms.java @@ -17,7 +17,7 @@ public final class Terms { public static final String LIST_TERM_SYMBOL = "lst"; - public static final ConstantTerm EMPTY_LIST = Terms.newSymbolicConstant("emptyList"); + public static final ConstantTerm EMPTY_LIST = Terms.newSymbolicConstant("lst_empty"); /** * Since this is purely a utility class, it may not be instantiated. diff --git a/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/programs/transformation/aggregates/AggregateRewriting.java b/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/programs/transformation/aggregates/AggregateRewriting.java index 4b97238d1..265ca735a 100644 --- a/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/programs/transformation/aggregates/AggregateRewriting.java +++ b/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/programs/transformation/aggregates/AggregateRewriting.java @@ -33,6 +33,7 @@ public class AggregateRewriting extends ProgramTransformation K. " + + "$id$_element_not_successor(ARGS, N, K) :- $id$_element_greater(ARGS, N, I), $id$_element_greater(ARGS, I, K). " + + "$id$_element_successor(ARGS, N, K) :- $id$_element_greater(ARGS, N, K), not $id$_element_not_successor(ARGS, N, K). " + + "$id$_element_has_successor(ARGS, N) :- $id$_element_successor(ARGS, _, N). " + + // Now build the list as a recursively nested function term + "$id$_lst_element(ARGS, IDX, lst(N, lst_empty)) :- $id$_element(ARGS, N), not $id$_element_has_successor(ARGS, N), IDX = 0. " + + "$id$_lst_element(ARGS, IDX, lst(N, lst(K, TAIL))) :- $id$_element(ARGS, N), $id$_element_successor(ARGS, K, N), $id$_lst_element(ARGS, PREV_IDX, lst(K, TAIL)), IDX = PREV_IDX + 1. " + + "has_next_$id$_element(ARGS, IDX) :- $id$_lst_element(ARGS, IDX, _), NEXT_IDX = IDX + 1, $id$_lst_element(ARGS, NEXT_IDX, _). " + + "$aggregate_result$(ARGS, LIST) :- $id$_lst_element(ARGS, IDX, LIST), not has_next_$id$_element(ARGS, IDX)."); + + private final ProgramParser parser; + + protected ListEncoder(ProgramParser parser) { + super(AggregateAtom.AggregateFunctionSymbol.LIST, Set.of(ComparisonOperators.EQ)); + this.parser = parser; + } + + @Override + protected InputProgram encodeAggregateResult(AggregateRewritingContext.AggregateInfo aggregateToEncode) { + ST encodingTemplate = new ST(LIST_AGGREGATION); + encodingTemplate.add("id", aggregateToEncode.getId()); + encodingTemplate.add("aggregate_result", aggregateToEncode.getOutputAtom().getPredicate().getName()); + return parser.parse(encodingTemplate.render()); + } + + @Override + protected BasicAtom buildElementRuleHead(String aggregateId, AggregateAtom.AggregateElement element, Term aggregateArguments) { + Predicate headPredicate = Predicates.getPredicate(this.getElementTuplePredicateSymbol(aggregateId), 2); + if (element.getElementTerms().size() != 1) { + throw new IllegalArgumentException("List elements may only consist of one term."); + } + Term value = element.getElementTerms().get(0); + return Atoms.newBasicAtom(headPredicate, aggregateArguments, value); + } + + @Override + protected String getElementTuplePredicateSymbol(String aggregateId) { + return aggregateId + "_element"; + } +} diff --git a/alpha-solver/src/test/java/at/ac/tuwien/kr/alpha/AggregateRewritingTest.java b/alpha-solver/src/test/java/at/ac/tuwien/kr/alpha/AggregateRewritingTest.java index 245304cc5..6fa33a5ba 100644 --- a/alpha-solver/src/test/java/at/ac/tuwien/kr/alpha/AggregateRewritingTest.java +++ b/alpha-solver/src/test/java/at/ac/tuwien/kr/alpha/AggregateRewritingTest.java @@ -5,6 +5,8 @@ import at.ac.tuwien.kr.alpha.api.impl.AlphaFactory; import at.ac.tuwien.kr.alpha.api.programs.InputProgram; import at.ac.tuwien.kr.alpha.api.programs.Predicate; +import at.ac.tuwien.kr.alpha.api.programs.atoms.Atom; +import at.ac.tuwien.kr.alpha.api.programs.terms.Term; import at.ac.tuwien.kr.alpha.commons.Predicates; import at.ac.tuwien.kr.alpha.commons.programs.atoms.Atoms; import at.ac.tuwien.kr.alpha.commons.programs.terms.Terms; @@ -12,6 +14,7 @@ import org.junit.jupiter.api.Test; import java.util.List; +import java.util.SortedSet; import java.util.function.Function; import java.util.stream.Collectors; @@ -74,6 +77,10 @@ public class AggregateRewritingTest { "p(1..10)." + "q :- X = #count { Y : p( Y ) }, X = #count { Z : p( Z ) }," + " Y = #count { X : p( X ) }, 1 <= #count { X : p( X ) }, Z = #max { W : p( W ) }."; + + private static final String LIST_COLLECT = + "p(1). p(2). p(3)." + + " q(X) :- X = #list{ Y : p(Y) }."; //@formatter:on // Use an alpha instance with default config for all test cases @@ -233,4 +240,18 @@ public void setComplexEqualityWithGlobals() { assertTrue(answerSet.getPredicateInstances(q).contains(Atoms.newBasicAtom(q))); } + @Test + public void listCollect() { + List answerSets = solve.apply(LIST_COLLECT); + assertEquals(1, answerSets.size()); + AnswerSet answerSet = answerSets.get(0); + Predicate q = Predicates.getPredicate("q", 1); + SortedSet instances = answerSet.getPredicateInstances(q); + assertEquals(1, instances.size()); + Atom instance = instances.first(); + assertEquals(1, instance.getTerms().size()); + Term term = instance.getTerms().get(0); + assertEquals(Terms.asListTerm(List.of(Terms.newConstant(1), Terms.newConstant(2), Terms.newConstant(3))), term); + } + } diff --git a/alpha-solver/src/test/java/at/ac/tuwien/kr/alpha/e2etests/End2EndTests.java b/alpha-solver/src/test/java/at/ac/tuwien/kr/alpha/e2etests/End2EndTests.java index a275ddfdb..15d8d1b98 100644 --- a/alpha-solver/src/test/java/at/ac/tuwien/kr/alpha/e2etests/End2EndTests.java +++ b/alpha-solver/src/test/java/at/ac/tuwien/kr/alpha/e2etests/End2EndTests.java @@ -64,7 +64,8 @@ private static DynamicTest alphaEnd2EndTest(String testName, String... fileset) Stream alphaEnd2EndTests() { return Stream.of( alphaEnd2EndTest("3-Coloring", E2E_TESTS_DIR + "3col.asp"), - alphaEnd2EndTest("modules-basic", E2E_TESTS_DIR + "modules-basic.evl") + alphaEnd2EndTest("modules-basic", E2E_TESTS_DIR + "modules-basic.evl"), + alphaEnd2EndTest("neighboring-vertices-list", E2E_TESTS_DIR + "neighboring-vertices-list.evl") ); } diff --git a/alpha-solver/src/test/resources/e2e-tests/neighboring-vertices-list.evl b/alpha-solver/src/test/resources/e2e-tests/neighboring-vertices-list.evl new file mode 100644 index 000000000..149a02a3f --- /dev/null +++ b/alpha-solver/src/test/resources/e2e-tests/neighboring-vertices-list.evl @@ -0,0 +1,45 @@ +% Graph is undirected +edge(Y, X) :- edge(X, Y). + +%% Generate list of neighboring vertices for each vertex in the graph +neighbor(V, N) :- vertex(V), vertex(N), edge(V, N). +neighbors(V, LST) :- vertex(V), LST = #list{ N : neighbor(V, N)}. + +#test lineGraph(expect: 1) { + given { + vertex(1). vertex(2). edge(1, 2). + } + assertForAll { + :- not neighbors(1, lst(2, lst_empty)). + :- not neighbors(2, lst(1, lst_empty)). + } +} + +#test network(expect: 1) { + given { + vertex(1..8). + + edge(1, 2). + edge(1, 3). + edge(1, 4). + edge(2, 5). + edge(1, 3). + edge(1, 4). + edge(2, 5). + edge(3, 6). + edge(4, 7). + edge(5, 8). + edge(6, 8). + edge(7, 8). + } + assertForAll { + :- not neighbors(1, lst(2, lst(3, lst(4, lst_empty)))). + :- not neighbors(2, lst(1, lst(5, lst_empty))). + :- not neighbors(3, lst(1, lst(6, lst_empty))). + :- not neighbors(4, lst(1, lst(7, lst_empty))). + :- not neighbors(5, lst(2, lst(8, lst_empty))). + :- not neighbors(6, lst(3, lst(8, lst_empty))). + :- not neighbors(7, lst(4, lst(8, lst_empty))). + :- not neighbors(8, lst(5, lst(6, lst(7, lst_empty)))). + } +} \ No newline at end of file