diff --git a/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/AnswerSet.java b/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/AnswerSet.java index 8b80fc9ef..09ea75203 100644 --- a/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/AnswerSet.java +++ b/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/AnswerSet.java @@ -3,6 +3,7 @@ import at.ac.tuwien.kr.alpha.api.programs.Predicate; import at.ac.tuwien.kr.alpha.api.programs.atoms.Atom; import at.ac.tuwien.kr.alpha.api.programs.atoms.AtomQuery; +import at.ac.tuwien.kr.alpha.api.programs.terms.Term; import java.util.List; import java.util.Set; diff --git a/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/atoms/Atoms.java b/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/atoms/Atoms.java index d11e77e49..92bbbbb66 100644 --- a/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/atoms/Atoms.java +++ b/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/atoms/Atoms.java @@ -11,8 +11,10 @@ import at.ac.tuwien.kr.alpha.api.programs.atoms.AggregateAtom.AggregateElement; import at.ac.tuwien.kr.alpha.api.programs.atoms.AggregateAtom.AggregateFunctionSymbol; import at.ac.tuwien.kr.alpha.api.programs.literals.Literal; +import at.ac.tuwien.kr.alpha.api.programs.terms.FunctionTerm; import at.ac.tuwien.kr.alpha.api.programs.terms.Term; import at.ac.tuwien.kr.alpha.commons.programs.atoms.AggregateAtomImpl.AggregateElementImpl; +import at.ac.tuwien.kr.alpha.commons.programs.terms.Terms; public final class Atoms { @@ -71,4 +73,8 @@ public static AtomQuery query(String predicateName, int predicateArity) { return AtomQueryImpl.forPredicate(predicateName, predicateArity); } + public static FunctionTerm toFunctionTerm(Atom atom) { + return Terms.newFunctionTerm(atom.getPredicate().getName(), atom.getTerms()); + } + } diff --git a/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/terms/Terms.java b/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/terms/Terms.java index 7d0ac9971..0a8e8f11b 100644 --- a/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/terms/Terms.java +++ b/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/terms/Terms.java @@ -1,8 +1,6 @@ package at.ac.tuwien.kr.alpha.commons.programs.terms; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; +import java.util.*; import at.ac.tuwien.kr.alpha.api.grounder.Substitution; import at.ac.tuwien.kr.alpha.api.programs.terms.*; @@ -18,6 +16,9 @@ */ public final class Terms { + public static final String LIST_TERM_SYMBOL = "lst"; + public static final ConstantTerm EMPTY_LIST = Terms.newSymbolicConstant("emptyList"); + /** * Since this is purely a utility class, it may not be instantiated. * @@ -82,6 +83,20 @@ public static > List> asTermList(T... va return retVal; } + /** + * Constructs a single list term from a list of terms. + */ + public static Term asListTerm(Collection terms) { + List reversedTerms = new ArrayList<>(terms); + Collections.reverse(reversedTerms); + // iterate over the list in reverse order to build the list term from the back. + Term tail = EMPTY_LIST; + for (Term t : reversedTerms) { + tail = Terms.newFunctionTerm(LIST_TERM_SYMBOL, t, tail); + } + return tail; + } + public static List renameTerms(List terms, String prefix, int counterStartingValue) { List renamedTerms = new ArrayList<>(terms.size()); AbstractTerm.RenameCounterImpl renameCounter = new AbstractTerm.RenameCounterImpl(counterStartingValue); diff --git a/alpha-commons/src/test/java/at/ac/tuwien/kr/alpha/commons/programs/terms/TermsTest.java b/alpha-commons/src/test/java/at/ac/tuwien/kr/alpha/commons/programs/terms/TermsTest.java index 50e5afb51..3c3d8aba7 100644 --- a/alpha-commons/src/test/java/at/ac/tuwien/kr/alpha/commons/programs/terms/TermsTest.java +++ b/alpha-commons/src/test/java/at/ac/tuwien/kr/alpha/commons/programs/terms/TermsTest.java @@ -4,6 +4,7 @@ import java.util.List; +import at.ac.tuwien.kr.alpha.api.programs.terms.Term; import org.junit.jupiter.api.Test; import at.ac.tuwien.kr.alpha.api.programs.terms.ConstantTerm; @@ -42,4 +43,26 @@ public void functionTermVsActionSuccessTermHash() { assertEquals(funcTerm.hashCode(), actionSuccessTerm.hashCode()); } + @Test + public void asListTermSingleElement() { + List terms = List.of(Terms.newConstant(1)); + Term lstTerm = Terms.asListTerm(terms); + assertEquals(Terms.newFunctionTerm(Terms.LIST_TERM_SYMBOL, Terms.newConstant(1), Terms.EMPTY_LIST), lstTerm); + } + + @Test + public void asListTermEmptyList() { + assertEquals(Terms.asListTerm(List.of()), Terms.EMPTY_LIST); + } + + @Test + public void asListTermMultipleElements() { + List terms = List.of(Terms.newConstant(1), Terms.newConstant(2), Terms.newConstant(3)); + Term lstTerm = Terms.asListTerm(terms); + assertEquals(Terms.newFunctionTerm(Terms.LIST_TERM_SYMBOL, Terms.newConstant(1), + Terms.newFunctionTerm(Terms.LIST_TERM_SYMBOL, Terms.newConstant(2), + Terms.newFunctionTerm(Terms.LIST_TERM_SYMBOL, Terms.newConstant(3), + Terms.EMPTY_LIST))), lstTerm); + } + } diff --git a/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/programs/transformation/ModuleLinker.java b/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/programs/transformation/ModuleLinker.java index 85653f48f..a5196feb0 100644 --- a/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/programs/transformation/ModuleLinker.java +++ b/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/programs/transformation/ModuleLinker.java @@ -15,17 +15,16 @@ import at.ac.tuwien.kr.alpha.api.programs.rules.NormalRule; import at.ac.tuwien.kr.alpha.api.programs.rules.Rule; import at.ac.tuwien.kr.alpha.api.programs.rules.heads.NormalHead; +import at.ac.tuwien.kr.alpha.api.programs.terms.FunctionTerm; import at.ac.tuwien.kr.alpha.api.programs.terms.Term; import at.ac.tuwien.kr.alpha.commons.programs.Programs; import at.ac.tuwien.kr.alpha.commons.programs.atoms.Atoms; import at.ac.tuwien.kr.alpha.commons.programs.rules.Rules; +import at.ac.tuwien.kr.alpha.commons.programs.terms.Terms; import org.apache.commons.collections4.ListUtils; import org.apache.commons.collections4.SetUtils; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -83,13 +82,13 @@ private ExternalAtom translateModuleAtom(ModuleAtom atom, Map mo NormalProgram normalizedImplementation = moduleRunner.normalizeProgram(definition.getImplementation()); // verify outputs Set outputSpec = definition.getOutputSpec(); - int expectedOutputTerms; + Set expectedOutputPredicates; if (outputSpec.isEmpty()) { - expectedOutputTerms = calculateOutputPredicates(normalizedImplementation).size(); + expectedOutputPredicates = calculateOutputPredicates(normalizedImplementation); } else { - expectedOutputTerms = outputSpec.size(); + expectedOutputPredicates = outputSpec; } - if (atom.getOutput().size() != expectedOutputTerms) { + if (atom.getOutput().size() != expectedOutputPredicates.size()) { throw new IllegalArgumentException("Module " + atom.getModuleName() + " expects " + outputSpec.size() + " outputs, but " + atom.getOutput().size() + " were given."); } // create the actual interpretation @@ -102,7 +101,7 @@ private ExternalAtom translateModuleAtom(ModuleAtom atom, Map mo if (atom.getInstantiationMode().requestedAnswerSets().isPresent()) { answerSets = answerSets.limit(atom.getInstantiationMode().requestedAnswerSets().get()); } - return answerSets.map(ModuleLinker::answerSetToTerms).collect(Collectors.toSet()); + return answerSets.map(as -> answerSetToTerms(as, expectedOutputPredicates)).collect(Collectors.toSet()); }; return Atoms.newExternalAtom(atom.getPredicate(), interpretation, atom.getInput(), atom.getOutput()); } @@ -119,8 +118,18 @@ private static Set calculateOutputPredicates(NormalProgram program) { .collect(Collectors.toSet())); } - private static List answerSetToTerms(AnswerSet answerSet) { - return Collections.emptyList(); // TODO + private static List answerSetToTerms(AnswerSet answerSet, Set moduleOutputSpec) { + List terms = new ArrayList<>(); + for (Predicate predicate : moduleOutputSpec) { + if (!answerSet.getPredicates().contains(predicate)) { + terms.add(Terms.EMPTY_LIST); + } else { + terms.add(Terms.asListTerm(answerSet.getPredicateInstances(predicate).stream() + .map(Atoms::toFunctionTerm).collect(Collectors.toList()))); + } + } + return terms; } + }