Skip to content

Commit

Permalink
Evolog Modules: Fix bugs where module definitions do not get passed o…
Browse files Browse the repository at this point in the history
…n. Add very simple end2end test with module-based 3-coloring implementation.
  • Loading branch information
madmike200590 committed Jul 29, 2024
1 parent 13010e0 commit 8bf77b3
Show file tree
Hide file tree
Showing 14 changed files with 106 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import at.ac.tuwien.kr.alpha.api.programs.Predicate;
import at.ac.tuwien.kr.alpha.api.programs.atoms.Atom;
import at.ac.tuwien.kr.alpha.api.programs.atoms.AtomQuery;
import at.ac.tuwien.kr.alpha.api.programs.terms.Term;

import java.util.List;
import java.util.Set;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ public static InputProgramBuilder builder(InputProgram program) {
return new InputProgramBuilder(program);
}

public static NormalProgram newNormalProgram(List<NormalRule> rules, List<Atom> facts, InlineDirectives inlineDirectives) {
return new NormalProgramImpl(rules, facts, inlineDirectives, Collections.emptyList());
}
// public static NormalProgram newNormalProgram(List<NormalRule> rules, List<Atom> facts, InlineDirectives inlineDirectives) {
// return new NormalProgramImpl(rules, facts, inlineDirectives, Collections.emptyList());
// }

public static NormalProgram newNormalProgram(List<NormalRule> rules, List<Atom> facts, InlineDirectives inlineDirectives, List<Module> modules) {
return new NormalProgramImpl(rules, facts, inlineDirectives, modules);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public NormalProgram toNormalProgram() {
for (CompiledRule rule : getRules()) {
normalRules.add(Rules.newNormalRule(rule.getHead(), new LinkedHashSet<>(rule.getBody())));
}
return Programs.newNormalProgram(normalRules, getFacts(), getInlineDirectives());
return Programs.newNormalProgram(normalRules, getFacts(), getInlineDirectives(), Collections.emptyList());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public NormalProgram apply(NormalProgram inputProgram) {
return inputProgram;
}
// Create new program with rewritten rules.
return Programs.newNormalProgram(rewrittenRules, inputProgram.getFacts(), inputProgram.getInlineDirectives());
return Programs.newNormalProgram(rewrittenRules, inputProgram.getFacts(), inputProgram.getInlineDirectives(), inputProgram.getModules());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public InputProgram apply(InputProgram inputProgram) {
}
}
return programBuilder.addRules(srcRules).addRules(additionalRules).addFacts(inputProgram.getFacts())
.addInlineDirectives(inputProgram.getInlineDirectives()).build();
.addInlineDirectives(inputProgram.getInlineDirectives()).addModules(inputProgram.getModules()).build();
}

private static boolean containsIntervalTerms(Atom atom) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public InputProgram apply(InputProgram inputProgram) {
programBuilder.addFacts(inputProgram.getFacts());

List<Rule<Head>> srcRules = new ArrayList<>(inputProgram.getRules());
programBuilder.addRules(rewriteRules(srcRules, enumPredicate));
programBuilder.addRules(rewriteRules(srcRules, enumPredicate)).addModules(inputProgram.getModules());
return programBuilder.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ public class IntervalTermToIntervalAtom extends ProgramTransformation<NormalProg

/**
* Rewrites intervals into a new variable and special IntervalAtom.
*
* @return true if some interval occurs in the rule.
*/
private static NormalRule rewriteIntervalSpecifications(NormalRule rule) {
// Collect all intervals and replace them with variables.
Expand Down Expand Up @@ -180,6 +178,6 @@ public NormalProgram apply(NormalProgram inputProgram) {
if (!didChange) {
return inputProgram;
}
return Programs.newNormalProgram(rewrittenRules, inputProgram.getFacts(), inputProgram.getInlineDirectives());
return Programs.newNormalProgram(rewrittenRules, inputProgram.getFacts(), inputProgram.getInlineDirectives(), inputProgram.getModules());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import at.ac.tuwien.kr.alpha.api.Alpha;
import at.ac.tuwien.kr.alpha.api.AnswerSet;
import at.ac.tuwien.kr.alpha.api.common.fixedinterpretations.PredicateInterpretation;
import at.ac.tuwien.kr.alpha.api.programs.Predicate;
import at.ac.tuwien.kr.alpha.api.programs.NormalProgram;
import at.ac.tuwien.kr.alpha.api.programs.Predicate;
import at.ac.tuwien.kr.alpha.api.programs.atoms.Atom;
import at.ac.tuwien.kr.alpha.api.programs.atoms.BasicAtom;
import at.ac.tuwien.kr.alpha.api.programs.atoms.ExternalAtom;
Expand All @@ -15,7 +15,6 @@
import at.ac.tuwien.kr.alpha.api.programs.rules.NormalRule;
import at.ac.tuwien.kr.alpha.api.programs.rules.Rule;
import at.ac.tuwien.kr.alpha.api.programs.rules.heads.NormalHead;
import at.ac.tuwien.kr.alpha.api.programs.terms.FunctionTerm;
import at.ac.tuwien.kr.alpha.api.programs.terms.Term;
import at.ac.tuwien.kr.alpha.commons.programs.Programs;
import at.ac.tuwien.kr.alpha.commons.programs.atoms.Atoms;
Expand All @@ -37,7 +36,6 @@
public class ModuleLinker extends ProgramTransformation<NormalProgram, NormalProgram> {

// Note: References to a standard library of modules that are always available for linking should be member variables of a linker.

private final Alpha moduleRunner;

public ModuleLinker(Alpha moduleRunner) {
Expand All @@ -51,11 +49,10 @@ public NormalProgram apply(NormalProgram inputProgram) {
List<NormalRule> transformedRules = inputProgram.getRules().stream()
.map(rule -> containsModuleAtom(rule) ? linkModuleAtoms(rule, moduleTable) : rule)
.collect(Collectors.toList());
return null;
return Programs.newNormalProgram(transformedRules, inputProgram.getFacts(), inputProgram.getInlineDirectives(), Collections.emptyList());
}

private NormalRule linkModuleAtoms(NormalRule rule, Map<String, Module> moduleTable) {
NormalHead newHead = rule.getHead();
Set<Literal> newBody = rule.getBody().stream()
.map(literal -> {
if (literal instanceof ModuleLiteral) {
Expand All @@ -66,7 +63,7 @@ private NormalRule linkModuleAtoms(NormalRule rule, Map<String, Module> moduleTa
}
})
.collect(Collectors.toSet());
return Rules.newNormalRule(newHead, newBody);
return Rules.newNormalRule(rule.getHead(), newBody);
}

private ExternalAtom translateModuleAtom(ModuleAtom atom, Map<String, Module> moduleTable) {
Expand Down Expand Up @@ -95,7 +92,7 @@ private ExternalAtom translateModuleAtom(ModuleAtom atom, Map<String, Module> mo
PredicateInterpretation interpretation = terms -> {
BasicAtom inputAtom = Atoms.newBasicAtom(inputSpec, terms);
NormalProgram program = Programs.newNormalProgram(normalizedImplementation.getRules(),
ListUtils.union(List.of(inputAtom), normalizedImplementation.getFacts()), normalizedImplementation.getInlineDirectives());
ListUtils.union(List.of(inputAtom), normalizedImplementation.getFacts()), normalizedImplementation.getInlineDirectives(), Collections.emptyList());
java.util.function.Predicate<Predicate> filter = outputSpec.isEmpty() ? p -> true : outputSpec::contains;
Stream<AnswerSet> answerSets = moduleRunner.solve(program, filter);
if (atom.getInstantiationMode().requestedAnswerSets().isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public InputProgram apply(InputProgram inputProgram) {
for (Rule<Head> rule : inputProgram.getRules()) {
rewrittenRules.add(findAndReplaceVariableEquality(rule));
}
return Programs.newInputProgram(rewrittenRules, inputProgram.getFacts(), inputProgram.getInlineDirectives());
return Programs.newInputProgram(rewrittenRules, inputProgram.getFacts(), inputProgram.getInlineDirectives(), Collections.emptyList(), inputProgram.getModules());
}

private Rule<Head> findAndReplaceVariableEquality(Rule<Head> rule) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public InputProgram apply(InputProgram inputProgram) {
}
// Substitute AggregateLiterals with generated result literals.
outputRules.addAll(rewriteRulesWithAggregates(ctx));
InputProgramBuilder resultBuilder = Programs.builder().addRules(outputRules).addFacts(inputProgram.getFacts())
InputProgramBuilder resultBuilder = Programs.builder().addRules(outputRules).addFacts(inputProgram.getFacts()).addModules(inputProgram.getModules())
.addInlineDirectives(inputProgram.getInlineDirectives());
// Add sub-programs deriving respective aggregate literals.
for (Map.Entry<ImmutablePair<AggregateFunctionSymbol, ComparisonOperator>, Set<AggregateInfo>> aggToRewrite : ctx.getAggregateFunctionsToRewrite()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import at.ac.tuwien.kr.alpha.core.programs.AnalyzedProgram;
import at.ac.tuwien.kr.alpha.core.programs.CompiledProgram;
import at.ac.tuwien.kr.alpha.core.programs.InternalProgram;
import at.ac.tuwien.kr.alpha.core.programs.transformation.ModuleLinker;
import at.ac.tuwien.kr.alpha.core.programs.transformation.ProgramTransformation;
import at.ac.tuwien.kr.alpha.core.programs.transformation.StratifiedEvaluation;
import at.ac.tuwien.kr.alpha.core.solver.SolverFactory;
Expand Down Expand Up @@ -168,11 +169,9 @@ public NormalProgram normalizeProgram(InputProgram program) {
@VisibleForTesting
InternalProgram performProgramPreprocessing(NormalProgram program) {
LOGGER.debug("Preprocessing InternalProgram!");
LOGGER.debug("Preprocessing InternalProgram!");
InternalProgram retVal = InternalProgram.fromNormalProgram(program);
AnalyzedProgram analyzed = new AnalyzedProgram(retVal.getRules(), retVal.getFacts());
retVal = stratifiedEvaluationFactory.get().apply(analyzed);
return retVal;
NormalProgram linkedProgram = new ModuleLinker(this).apply(program);
AnalyzedProgram analyzed = AnalyzedProgram.analyzeNormalProgram(linkedProgram);
return stratifiedEvaluationFactory.get().apply(analyzed);
}

/**
Expand Down Expand Up @@ -255,7 +254,8 @@ public DebugSolvingContext prepareDebugSolve(final InputProgram program, java.ut
public DebugSolvingContext prepareDebugSolve(final NormalProgram program, java.util.function.Predicate<Predicate> filter) {
final DependencyGraph depGraph;
final ComponentGraph compGraph;
final AnalyzedProgram analyzed = AnalyzedProgram.analyzeNormalProgram(program);
NormalProgram linkedProgram = new ModuleLinker(this).apply(program);
final AnalyzedProgram analyzed = AnalyzedProgram.analyzeNormalProgram(linkedProgram);
final NormalProgram preprocessed;
preprocessed = stratifiedEvaluationFactory.get().apply(analyzed).toNormalProgram();
depGraph = analyzed.getDependencyGraph();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ private TestResult.TestCaseResult runTestCase(NormalProgram programUnderTest, Te
LOGGER.info("Running test case " + testCase.getName());
List<Atom> facts = new ArrayList<>(programUnderTest.getFacts());
facts.addAll(testCase.getInput());
NormalProgram prog = Programs.newNormalProgram(programUnderTest.getRules(), facts, programUnderTest.getInlineDirectives());
NormalProgram prog = Programs.newNormalProgram(programUnderTest.getRules(), facts, programUnderTest.getInlineDirectives(), programUnderTest.getModules());
Set<AnswerSet> answerSets;
try {
answerSets = alpha.solve(prog).collect(Collectors.toSet());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ private static DynamicTest alphaEnd2EndTest(String testName, String... fileset)
@TestFactory
Stream<DynamicTest> alphaEnd2EndTests() {
return Stream.of(
alphaEnd2EndTest("3-Coloring", E2E_TESTS_DIR + "3col.asp")
alphaEnd2EndTest("3-Coloring", E2E_TESTS_DIR + "3col.asp"),
alphaEnd2EndTest("modules-basic", E2E_TESTS_DIR + "modules-basic.evl")
);
}

Expand Down
83 changes: 83 additions & 0 deletions alpha-solver/src/test/resources/e2e-tests/modules-basic.evl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
%%% Module 3col %%%
%
% Calculates 3-colorings of a given graph.
% Graph is expected to be represented as a function term with following
% structure: graph(list(V, TAIL), list(E, TAIL)), where list(V, TAIL)
% and list(E, TAIL) are vertex- and edge-lists.
%
%%%
#module threecol(graph/2 => {col/2}) {
% Unwrap input
vertex_element(V, TAIL) :- graph(list(V, TAIL), _).
vertex_element(V, TAIL) :- vertex_element(_, list(V, TAIL)).
vertex(V) :- vertex_element(V, _).
edge_element(E, TAIL) :- graph(_, list(E, TAIL)).
edge_element(E, TAIL) :- edge_element(_, list(E, TAIL)).
edge(V1, V2) :- edge_element(edge(V1, V2), _).

% Make sure edges are symmetric
edge(V2, V1) :- edge(V1, V2).

% Guess colors
red(V) :- vertex(V), not green(V), not blue(V).
green(V) :- vertex(V), not red(V), not blue(V).
blue(V) :- vertex(V), not red(V), not green(V).

% Filter invalid guesses
:- vertex(V1), vertex(V2), edge(V1, V2), red(V1), red(V2).
:- vertex(V1), vertex(V2), edge(V1, V2), green(V1), green(V2).
:- vertex(V1), vertex(V2), edge(V1, V2), blue(V1), blue(V2).

col(V, red) :- red(V).
col(V, blue) :- blue(V).
col(V, green) :- green(V).
}

%%% Main program %%%
%
% This program uses the module "3col" to determine 3-colorability and actual colorings of a graph
%
%%%

%% pack vertices into a vertex list
vertex_element(E) :- vertex(E).
% First, establish ordering of elements (which we need to establish the order within the list)
vertex_element_less(N, K) :- vertex_element(N), vertex_element(K), N < K.
vertex_element_not_predecessor(N, K) :- vertex_element_less(N, I), vertex_element_less(I, K).
vertex_element_predecessor(N, K) :- vertex_element_less(N, K), not vertex_element_not_predecessor(N, K).
vertex_element_has_predecessor(N) :- vertex_element_predecessor(_, N).
% Now build the list as a recursively nested function term
vertex_lst_element(IDX, list(N, list_empty)) :- vertex_element(N), not vertex_element_has_predecessor(N), IDX = 0.
vertex_lst_element(IDX, list(N, list(K, TAIL))) :- vertex_element(N), vertex_element_predecessor(K, N), vertex_lst_element(PREV_IDX, list(K, TAIL)), IDX = PREV_IDX + 1.
has_next_vertex_element(IDX) :- vertex_lst_element(IDX, _), NEXT_IDX = IDX + 1, vertex_lst_element(NEXT_IDX, _).
vertex_lst(LIST) :- vertex_lst_element(IDX, LIST), not has_next_vertex_element(IDX).

%% pack edges into an edge list
edge_element(edge(V1, V2)) :- edge(V1, V2).
% First, establish ordering of elements (which we need to establish the order within the list)
edge_element_less(N, K) :- edge_element(N), edge_element(K), N < K.
edge_element_not_predecessor(N, K) :- edge_element_less(N, I), edge_element_less(I, K).
edge_element_predecessor(N, K) :- edge_element_less(N, K), not edge_element_not_predecessor(N, K).
edge_element_has_predecessor(N) :- edge_element_predecessor(_, N).
% Now build the list as a recursively nested function term
edge_lst_element(IDX, list(N, list_empty)) :- edge_element(N), not edge_element_has_predecessor(N), IDX = 0.
edge_lst_element(IDX, list(N, list(K, TAIL))) :- edge_element(N), edge_element_predecessor(K, N), edge_lst_element(PREV_IDX, list(K, TAIL)), IDX = PREV_IDX + 1.
has_next_edge_element(IDX) :- edge_lst_element(IDX, _), NEXT_IDX = IDX + 1, edge_lst_element(NEXT_IDX, _).
edge_lst(LIST) :- edge_lst_element(IDX, LIST), not has_next_edge_element(IDX).

coloring(COL) :- vertex_lst(VERTEX_LST), edge_lst(EDGE_LST), #threecol[VERTEX_LST, EDGE_LST](COL).

#test smokeTest(expect: >0) {
given {
vertex(a).
vertex(b).
vertex(c).
edge(a, b).
edge(b, c).
edge(c, a).
}
assertForAll {
coloring_found :- coloring(_).
:- not coloring_found.
}
}

0 comments on commit 8bf77b3

Please sign in to comment.