Skip to content

Commit

Permalink
Evolog Modules: Permit external atom interpretations to return Set<Li…
Browse files Browse the repository at this point in the history
…st<Term>> rather than just Set<List<ConstantTerm>> so we can have module interpretations returning function terms
  • Loading branch information
madmike200590 committed Jul 25, 2024
1 parent 60b819f commit f338eeb
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@

@FunctionalInterface
public interface PredicateInterpretation {
Set<List<ConstantTerm<?>>> TRUE = singleton(emptyList());
Set<List<ConstantTerm<?>>> FALSE = emptySet();

Set<List<Term>> TRUE = singleton(emptyList());
Set<List<Term>> FALSE = emptySet();

String EVALUATE_RETURN_TYPE_NAME_PREFIX = Set.class.getName() + "<" + List.class.getName() + "<" + ConstantTerm.class.getName();

Set<List<ConstantTerm<?>>> evaluate(List<Term> terms);
Set<List<Term>> evaluate(List<Term> terms);
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public BindingMethodPredicateInterpretation(Method method) {

@Override
@SuppressWarnings("unchecked")
public Set<List<ConstantTerm<?>>> evaluate(List<Term> terms) {
public Set<List<Term>> evaluate(List<Term> terms) {
if (terms.size() != method.getParameterCount()) {
throw new IllegalArgumentException(
"Parameter count mismatch when calling " + method.getName() + ". " +
Expand Down Expand Up @@ -90,7 +90,7 @@ public Set<List<ConstantTerm<?>>> evaluate(List<Term> terms) {
}

try {
return (Set<List<ConstantTerm<?>>>) method.invoke(null, arguments);
return (Set<List<Term>>) method.invoke(null, arguments);
} catch (IllegalAccessException | InvocationTargetException ex) {
throw new RuntimeException("Error invoking method " + method + "with args [" + StringUtils.join(arguments) + "], expection is: " + ex.getMessage());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import at.ac.tuwien.kr.alpha.api.externals.Predicate;
import at.ac.tuwien.kr.alpha.api.programs.atoms.Atom;
import at.ac.tuwien.kr.alpha.api.programs.terms.ConstantTerm;
import at.ac.tuwien.kr.alpha.api.programs.terms.Term;
import at.ac.tuwien.kr.alpha.commons.Predicates;
import at.ac.tuwien.kr.alpha.commons.programs.atoms.Atoms;
import at.ac.tuwien.kr.alpha.commons.programs.terms.Terms;
Expand Down Expand Up @@ -114,7 +115,7 @@ public static <T, U> PredicateInterpretation processPredicate(java.util.function
return new BinaryPredicateInterpretation<>(predicate);
}

public static PredicateInterpretation processPredicate(java.util.function.Supplier<Set<List<ConstantTerm<?>>>> supplier) {
public static PredicateInterpretation processPredicate(java.util.function.Supplier<Set<List<Term>>> supplier) {
return new SuppliedPredicateInterpretation(supplier);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public NonBindingPredicateInterpretation() {
}

@Override
public Set<List<ConstantTerm<?>>> evaluate(List<Term> terms) {
public Set<List<Term>> evaluate(List<Term> terms) {
if (terms.size() != arity) {
throw new IllegalArgumentException("Exactly " + arity + " term(s) required.");
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,21 @@
package at.ac.tuwien.kr.alpha.commons.externals;

import at.ac.tuwien.kr.alpha.api.common.fixedinterpretations.BindingPredicateInterpretation;
import at.ac.tuwien.kr.alpha.api.programs.terms.ConstantTerm;
import at.ac.tuwien.kr.alpha.api.programs.terms.Term;

import java.util.List;
import java.util.Set;
import java.util.function.Supplier;

public class SuppliedPredicateInterpretation implements BindingPredicateInterpretation {
private final Supplier<Set<List<ConstantTerm<?>>>> supplier;
private final Supplier<Set<List<Term>>> supplier;

public SuppliedPredicateInterpretation(Supplier<Set<List<ConstantTerm<?>>>> supplier) {
public SuppliedPredicateInterpretation(Supplier<Set<List<Term>>> supplier) {
this.supplier = supplier;
}

@Override
public Set<List<ConstantTerm<?>>> evaluate(List<Term> terms) {
public Set<List<Term>> evaluate(List<Term> terms) {
if (!terms.isEmpty()) {
throw new IllegalArgumentException("Can only be used without any arguments.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ public List<Substitution> getSatisfyingSubstitutions(Substitution partialSubstit
for (Term t : input) {
substitutes.add(t.substitute(partialSubstitution));
}
Set<List<ConstantTerm<?>>> results = getAtom().getInterpretation().evaluate(substitutes);
Set<List<Term>> results = getAtom().getInterpretation().evaluate(substitutes);
// TODO verify all results are ground
if (results == null) {
throw new NullPointerException("Predicate " + getPredicate().getName() + " returned null. It must return a Set.");
}
Expand Down Expand Up @@ -160,10 +161,10 @@ public List<Substitution> getSatisfyingSubstitutions(Substitution partialSubstit
* @return true iff no list in externalMethodResult equals the external atom's output term
* list as substituted by the grounder, false otherwise
*/
private boolean isNegatedLiteralSatisfied(Set<List<ConstantTerm<?>>> externalMethodResult) {
private boolean isNegatedLiteralSatisfied(Set<List<Term>> externalMethodResult) {
List<Term> externalAtomOutTerms = this.getAtom().getOutput();
boolean outputMatches;
for (List<ConstantTerm<?>> resultTerms : externalMethodResult) {
for (List<Term> resultTerms : externalMethodResult) {
outputMatches = true;
for (int i = 0; i < externalAtomOutTerms.size(); i++) {
if (!resultTerms.get(i).equals(externalAtomOutTerms.get(i))) {
Expand All @@ -182,10 +183,10 @@ private boolean isNegatedLiteralSatisfied(Set<List<ConstantTerm<?>>> externalMet
return true;
}

private List<Substitution> buildSubstitutionsForOutputs(Substitution partialSubstitution, Set<List<ConstantTerm<?>>> outputs) {
private List<Substitution> buildSubstitutionsForOutputs(Substitution partialSubstitution, Set<List<Term>> outputs) {
List<Substitution> retVal = new ArrayList<>();
List<Term> externalAtomOutputTerms = this.getAtom().getOutput();
for (List<ConstantTerm<?>> bindings : outputs) {
for (List<Term> bindings : outputs) {
if (bindings.size() < externalAtomOutputTerms.size()) {
throw new RuntimeException(
"Predicate " + getPredicate().getName() + " returned " + bindings.size() + " terms when at least " + externalAtomOutputTerms.size()
Expand Down

0 comments on commit f338eeb

Please sign in to comment.