diff --git a/build.gradle.kts b/build.gradle.kts index e5be05ea9b..3b60d44bbc 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -29,7 +29,7 @@ buildscript { allprojects { group = "hu.bme.mit.theta" - version = "6.7.0" + version = "6.8.0" apply(from = rootDir.resolve("gradle/shared-with-buildSrc/mirrors.gradle.kts")) } diff --git a/lib/hu.bme.mit.delta-0.0.1-all.jar b/lib/hu.bme.mit.delta-0.0.1-all.jar index 4d58b7344a..9f764ba328 100644 Binary files a/lib/hu.bme.mit.delta-0.0.1-all.jar and b/lib/hu.bme.mit.delta-0.0.1-all.jar differ diff --git a/subprojects/cfa/cfa-analysis/src/main/kotlin/hu/bme/mit/theta/cfa/analysis/CfaToMonolithicExpr.kt b/subprojects/cfa/cfa-analysis/src/main/kotlin/hu/bme/mit/theta/cfa/analysis/CfaToMonolithicExpr.kt index 0ba48342d9..37529a6bfc 100644 --- a/subprojects/cfa/cfa-analysis/src/main/kotlin/hu/bme/mit/theta/cfa/analysis/CfaToMonolithicExpr.kt +++ b/subprojects/cfa/cfa-analysis/src/main/kotlin/hu/bme/mit/theta/cfa/analysis/CfaToMonolithicExpr.kt @@ -22,11 +22,22 @@ import hu.bme.mit.theta.cfa.CFA import hu.bme.mit.theta.core.decl.Decls import hu.bme.mit.theta.core.model.Valuation import hu.bme.mit.theta.core.stmt.* +import hu.bme.mit.theta.core.type.Expr +import hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Eq +import hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Neq import hu.bme.mit.theta.core.type.booltype.BoolExprs.And -import hu.bme.mit.theta.core.type.inttype.IntExprs.* +import hu.bme.mit.theta.core.type.booltype.BoolExprs.Bool +import hu.bme.mit.theta.core.type.booltype.BoolType +import hu.bme.mit.theta.core.type.bvtype.BvType +import hu.bme.mit.theta.core.type.fptype.FpExprs.* +import hu.bme.mit.theta.core.type.fptype.FpType +import hu.bme.mit.theta.core.type.inttype.IntExprs.Int import hu.bme.mit.theta.core.type.inttype.IntLitExpr +import hu.bme.mit.theta.core.type.inttype.IntType +import hu.bme.mit.theta.core.utils.BvUtils import hu.bme.mit.theta.core.utils.StmtUtils import hu.bme.mit.theta.core.utils.indexings.VarIndexingFactory +import java.math.BigInteger import java.util.* fun CFA.toMonolithicExpr(): MonolithicExpr { @@ -36,7 +47,11 @@ fun CFA.toMonolithicExpr(): MonolithicExpr { for ((i, x) in this.locs.withIndex()) { map[x] = i } - val locVar = Decls.Var("__loc__", Int()) + val locVar = + Decls.Var( + "__loc__", + Int(), + ) // TODO: add edge var as well, to avoid parallel edges causing problems val tranList = this.edges .map { e -> @@ -49,15 +64,40 @@ fun CFA.toMonolithicExpr(): MonolithicExpr { ) } .toList() + + val defaultValues = + this.vars + .map { + when (it.type) { + is IntType -> Eq(it.ref, Int(0)) + is BoolType -> Eq(it.ref, Bool(false)) + is BvType -> + Eq( + it.ref, + BvUtils.bigIntegerToNeutralBvLitExpr(BigInteger.ZERO, (it.type as BvType).size), + ) + is FpType -> FpAssign(it.ref as Expr, NaN(it.type as FpType)) + else -> throw IllegalArgumentException("Unsupported type") + } + } + .toList() + .let { And(it) } + val trans = NonDetStmt.of(tranList) val transUnfold = StmtUtils.toExpr(trans, VarIndexingFactory.indexing(0)) val transExpr = And(transUnfold.exprs) - val initExpr = Eq(locVar.ref, Int(map[this.initLoc]!!)) + val initExpr = And(Eq(locVar.ref, Int(map[this.initLoc]!!)), defaultValues) val propExpr = Neq(locVar.ref, Int(map[this.errorLoc.orElseThrow()]!!)) val offsetIndex = transUnfold.indexing - return MonolithicExpr(initExpr, transExpr, propExpr, offsetIndex) + return MonolithicExpr( + initExpr, + transExpr, + propExpr, + offsetIndex, + vars = this.vars.toList() + listOf(locVar), + ) } fun CFA.valToAction(val1: Valuation, val2: Valuation): CfaAction { diff --git a/subprojects/cfa/cfa-analysis/src/test/java/hu/bme/mit/theta/cfa/analysis/CfaMddCheckerTest.java b/subprojects/cfa/cfa-analysis/src/test/java/hu/bme/mit/theta/cfa/analysis/CfaMddCheckerTest.java new file mode 100644 index 0000000000..ca6abc822d --- /dev/null +++ b/subprojects/cfa/cfa-analysis/src/test/java/hu/bme/mit/theta/cfa/analysis/CfaMddCheckerTest.java @@ -0,0 +1,127 @@ +/* + * Copyright 2024 Budapest University of Technology and Economics + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hu.bme.mit.theta.cfa.analysis; + +import static hu.bme.mit.theta.cfa.analysis.config.CfaConfigBuilder.Domain.*; +import static hu.bme.mit.theta.cfa.analysis.config.CfaConfigBuilder.Refinement.*; + +import hu.bme.mit.theta.analysis.algorithm.SafetyResult; +import hu.bme.mit.theta.analysis.algorithm.mdd.MddCex; +import hu.bme.mit.theta.analysis.algorithm.mdd.MddChecker; +import hu.bme.mit.theta.analysis.algorithm.mdd.MddProof; +import hu.bme.mit.theta.analysis.expr.ExprAction; +import hu.bme.mit.theta.cfa.CFA; +import hu.bme.mit.theta.cfa.dsl.CfaDslManager; +import hu.bme.mit.theta.common.OsHelper; +import hu.bme.mit.theta.common.logging.ConsoleLogger; +import hu.bme.mit.theta.common.logging.Logger; +import hu.bme.mit.theta.common.logging.NullLogger; +import hu.bme.mit.theta.core.type.Expr; +import hu.bme.mit.theta.core.type.booltype.BoolType; +import hu.bme.mit.theta.core.utils.indexings.VarIndexing; +import hu.bme.mit.theta.core.utils.indexings.VarIndexingFactory; +import hu.bme.mit.theta.solver.SolverFactory; +import hu.bme.mit.theta.solver.SolverManager; +import hu.bme.mit.theta.solver.SolverPool; +import hu.bme.mit.theta.solver.smtlib.SmtLibSolverManager; +import hu.bme.mit.theta.solver.z3legacy.Z3SolverManager; +import java.io.FileInputStream; +import java.util.Arrays; +import java.util.Collection; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) +public class CfaMddCheckerTest { + + @Parameterized.Parameter(value = 0) + public String filePath; + + @Parameterized.Parameter(value = 1) + public boolean isSafe; + + @Parameterized.Parameters(name = "{index}: {0}, {1}") + public static Collection data() { + return Arrays.asList( + new Object[][] { + {"src/test/resources/arithmetic-bool00.cfa", false}, + {"src/test/resources/arithmetic-bool01.cfa", false}, + {"src/test/resources/arithmetic-bool10.cfa", false}, + {"src/test/resources/arithmetic-bool11.cfa", false}, + {"src/test/resources/arithmetic-mod.cfa", true}, + {"src/test/resources/counter5_true.cfa", true}, + {"src/test/resources/counter_bv_true.cfa", true}, + {"src/test/resources/counter_bv_false.cfa", false}, + {"src/test/resources/ifelse.cfa", false}, + }); + } + + @Test + public void test() throws Exception { + final Logger logger = new ConsoleLogger(Logger.Level.SUBSTEP); + + SolverManager.registerSolverManager(Z3SolverManager.create()); + if (OsHelper.getOs().equals(OsHelper.OperatingSystem.LINUX)) { + SolverManager.registerSolverManager( + SmtLibSolverManager.create(SmtLibSolverManager.HOME, NullLogger.getInstance())); + } + + final SolverFactory solverFactory; + try { + solverFactory = SolverManager.resolveSolverFactory("Z3"); + } catch (Exception e) { + Assume.assumeNoException(e); + return; + } + + try { + CFA cfa = CfaDslManager.createCfa(new FileInputStream(filePath)); + var monolithicExpr = CfaToMonolithicExprKt.toMonolithicExpr(cfa); + + final SafetyResult status; + try (var solverPool = new SolverPool(solverFactory)) { + final MddChecker checker = + MddChecker.create( + monolithicExpr.getInitExpr(), + VarIndexingFactory.indexing(0), + new ExprAction() { + @Override + public Expr toExpr() { + return monolithicExpr.getTransExpr(); + } + + @Override + public VarIndexing nextIndexing() { + return VarIndexingFactory.indexing(1); + } + }, + monolithicExpr.getPropExpr(), + monolithicExpr.getVars(), + solverPool, + logger, + MddChecker.IterationStrategy.GSAT); + status = checker.check(null); + } + + Assert.assertEquals(isSafe, status.isSafe()); + } finally { + SolverManager.closeAll(); + } + } +} diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/AbstractMonolithicExpr.kt b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/AbstractMonolithicExpr.kt new file mode 100644 index 0000000000..c3a7bcd2ce --- /dev/null +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/AbstractMonolithicExpr.kt @@ -0,0 +1,88 @@ +/* + * Copyright 2024 Budapest University of Technology and Economics + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hu.bme.mit.theta.analysis.algorithm.bounded + +import hu.bme.mit.theta.analysis.pred.PredPrec +import hu.bme.mit.theta.analysis.pred.PredState +import hu.bme.mit.theta.core.decl.Decl +import hu.bme.mit.theta.core.decl.Decls +import hu.bme.mit.theta.core.decl.VarDecl +import hu.bme.mit.theta.core.model.Valuation +import hu.bme.mit.theta.core.type.Expr +import hu.bme.mit.theta.core.type.anytype.Exprs +import hu.bme.mit.theta.core.type.booltype.BoolExprs +import hu.bme.mit.theta.core.type.booltype.BoolLitExpr +import hu.bme.mit.theta.core.type.booltype.BoolType +import hu.bme.mit.theta.core.type.booltype.IffExpr +import hu.bme.mit.theta.core.type.booltype.SmartBoolExprs.And +import hu.bme.mit.theta.core.type.booltype.SmartBoolExprs.Not +import hu.bme.mit.theta.core.utils.ExprUtils +import hu.bme.mit.theta.core.utils.indexings.VarIndexingFactory +import java.util.HashMap + +fun MonolithicExpr.createAbstract(prec: PredPrec): MonolithicExpr { + // TODO: handle initOffsetIndex in abstract initExpr + val lambdaList = ArrayList() + val lambdaPrimeList = ArrayList() + val activationLiterals = ArrayList>() + val literalToPred = HashMap, Expr>() + + prec.preds.forEachIndexed { index, expr -> + run { + val v = Decls.Var("v$index", BoolType.getInstance()) + activationLiterals.add(v) + literalToPred[v] = expr + lambdaList.add(IffExpr.of(v.ref, expr)) + lambdaPrimeList.add( + BoolExprs.Iff(Exprs.Prime(v.ref), ExprUtils.applyPrimes(expr, this.transOffsetIndex)) + ) + } + } + + var indexingBuilder = VarIndexingFactory.indexingBuilder(1) + this.vars + .filter { it !in ctrlVars } + .forEach { decl -> + repeat(transOffsetIndex.get(decl)) { indexingBuilder = indexingBuilder.inc(decl) } + } + + return MonolithicExpr( + initExpr = And(And(lambdaList), initExpr), + transExpr = And(And(lambdaList), And(lambdaPrimeList), transExpr), + propExpr = Not(And(And(lambdaList), Not(propExpr))), + transOffsetIndex = indexingBuilder.build(), + initOffsetIndex = VarIndexingFactory.indexing(0), + vars = activationLiterals + ctrlVars, + valToState = { valuation: Valuation -> + PredState.of( + valuation + .toMap() + .entries + .stream() + .filter { it.key !in ctrlVars } + .map { + when ((it.value as BoolLitExpr).value) { + true -> literalToPred[it.key] + false -> Not(literalToPred[it.key]) + } + } + .toList() + ) + }, + biValToAction = this.biValToAction, + ctrlVars = ctrlVars, + ) +} diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/BoundedChecker.kt b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/BoundedChecker.kt index 2c1273f87f..ed9405aaa4 100644 --- a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/BoundedChecker.kt +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/BoundedChecker.kt @@ -76,7 +76,7 @@ constructor( private val logger: Logger, ) : SafetyChecker, UnitPrec> { - private val vars = monolithicExpr.vars() + private val vars = monolithicExpr.vars private val unfoldedInitExpr = PathUtils.unfold(monolithicExpr.initExpr, VarIndexingFactory.indexing(0)) private val unfoldedPropExpr = { i: VarIndexing -> PathUtils.unfold(monolithicExpr.propExpr, i) } diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/MonolithicExpr.kt b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/MonolithicExpr.kt index 65df7dfb65..f1f9146cd2 100644 --- a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/MonolithicExpr.kt +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/MonolithicExpr.kt @@ -13,25 +13,36 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package hu.bme.mit.theta.analysis.algorithm.bounded +import hu.bme.mit.theta.analysis.expl.ExplState +import hu.bme.mit.theta.analysis.expr.ExprAction +import hu.bme.mit.theta.analysis.expr.ExprState import hu.bme.mit.theta.core.decl.VarDecl +import hu.bme.mit.theta.core.model.Valuation import hu.bme.mit.theta.core.type.Expr import hu.bme.mit.theta.core.type.booltype.BoolType import hu.bme.mit.theta.core.utils.ExprUtils.getVars import hu.bme.mit.theta.core.utils.indexings.VarIndexing import hu.bme.mit.theta.core.utils.indexings.VarIndexingFactory -data class MonolithicExpr( - val initExpr: Expr, - val transExpr: Expr, - val propExpr: Expr, - val transOffsetIndex: VarIndexing = VarIndexingFactory.indexing(1), - val initOffsetIndex: VarIndexing = VarIndexingFactory.indexing(0) -) { +data class MonolithicExpr +@JvmOverloads +constructor( + val initExpr: Expr, + val transExpr: Expr, + val propExpr: Expr, + val transOffsetIndex: VarIndexing = VarIndexingFactory.indexing(1), + val initOffsetIndex: VarIndexing = VarIndexingFactory.indexing(0), + val vars: List> = + (getVars(initExpr) union getVars(transExpr) union getVars(propExpr)).toList(), + val valToState: (Valuation) -> ExprState = ExplState::of, + val biValToAction: (Valuation, Valuation) -> ExprAction = { _: Valuation, _: Valuation -> + object : ExprAction { + override fun toExpr(): Expr = transExpr - fun vars(): Collection> { - return getVars(initExpr) union getVars(transExpr) union getVars(propExpr) + override fun nextIndexing(): VarIndexing = transOffsetIndex } -} \ No newline at end of file + }, + val ctrlVars: Collection> = listOf(), +) diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/MonolithicExprCegarChecker.java b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/MonolithicExprCegarChecker.java new file mode 100644 index 0000000000..2457715f91 --- /dev/null +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/MonolithicExprCegarChecker.java @@ -0,0 +1,116 @@ +/* + * Copyright 2024 Budapest University of Technology and Economics + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hu.bme.mit.theta.analysis.algorithm.bounded; + +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.True; + +import com.google.common.base.Preconditions; +import hu.bme.mit.theta.analysis.*; +import hu.bme.mit.theta.analysis.algorithm.Proof; +import hu.bme.mit.theta.analysis.algorithm.SafetyChecker; +import hu.bme.mit.theta.analysis.algorithm.SafetyResult; +import hu.bme.mit.theta.analysis.expr.ExprAction; +import hu.bme.mit.theta.analysis.expr.ExprState; +import hu.bme.mit.theta.analysis.expr.refinement.ExprTraceChecker; +import hu.bme.mit.theta.analysis.expr.refinement.ExprTraceFwBinItpChecker; +import hu.bme.mit.theta.analysis.expr.refinement.ExprTraceStatus; +import hu.bme.mit.theta.analysis.expr.refinement.ItpRefutation; +import hu.bme.mit.theta.analysis.pred.PredPrec; +import hu.bme.mit.theta.analysis.unit.UnitPrec; +import hu.bme.mit.theta.common.logging.Logger; +import hu.bme.mit.theta.solver.SolverFactory; +import java.util.List; +import java.util.function.Function; + +public class MonolithicExprCegarChecker + implements SafetyChecker, PredPrec> { + private final MonolithicExpr model; + private final Function< + MonolithicExpr, + SafetyChecker< + W, + ? extends Trace, + UnitPrec>> + checkerFactory; + + private final SolverFactory solverFactory; + + private final Logger logger; + + public MonolithicExprCegarChecker( + MonolithicExpr model, + Function< + MonolithicExpr, + SafetyChecker< + W, + ? extends Trace, + UnitPrec>> + checkerFactory, + Logger logger, + SolverFactory solverFactory) { + this.model = model; + this.checkerFactory = checkerFactory; + this.logger = logger; + this.solverFactory = solverFactory; + } + + public SafetyResult> check( + PredPrec initPrec) { + var predPrec = + initPrec == null + ? PredPrec.of(List.of(model.getInitExpr(), model.getPropExpr())) + : initPrec; + + while (true) { + logger.write(Logger.Level.SUBSTEP, "Current prec: %s\n", predPrec); + + final var abstractMonolithicExpr = + AbstractMonolithicExprKt.createAbstract(model, predPrec); + final var checker = checkerFactory.apply(abstractMonolithicExpr); + + final var result = checker.check(); + if (result.isSafe()) { + logger.write(Logger.Level.MAINSTEP, "Model is safe, stopping CEGAR"); + return SafetyResult.safe(result.getProof()); + } else { + Preconditions.checkState(result.isUnsafe()); + final Trace trace = + result.asUnsafe().getCex(); + + final ExprTraceChecker exprTraceFwBinItpChecker = + ExprTraceFwBinItpChecker.create( + True(), True(), solverFactory.createItpSolver()); + + if (trace != null) { + logger.write(Logger.Level.VERBOSE, "\tFound trace: %s\n", trace); + final ExprTraceStatus concretizationResult = + exprTraceFwBinItpChecker.check(trace); + if (concretizationResult.isFeasible()) { + logger.write(Logger.Level.MAINSTEP, "Model is unsafe, stopping CEGAR\n"); + + return SafetyResult.unsafe(trace, result.getProof()); + } else { + final var ref = concretizationResult.asInfeasible().getRefutation(); + final var newPred = ref.get(ref.getPruneIndex()); + final var newPrec = PredPrec.of(newPred); + predPrec = predPrec.join(newPrec); + logger.write(Logger.Level.INFO, "Added new predicate " + newPrec + "\n"); + } + } + } + } + } +} diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/ReversedMonolithicExpr.kt b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/ReversedMonolithicExpr.kt new file mode 100644 index 0000000000..6e8481e55e --- /dev/null +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/bounded/ReversedMonolithicExpr.kt @@ -0,0 +1,31 @@ +/* + * Copyright 2024 Budapest University of Technology and Economics + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hu.bme.mit.theta.analysis.algorithm.bounded + +import hu.bme.mit.theta.core.type.booltype.BoolExprs.Not +import hu.bme.mit.theta.core.utils.ExprUtils +import hu.bme.mit.theta.core.utils.indexings.VarIndexingFactory + +fun MonolithicExpr.createReversed(): MonolithicExpr { + return MonolithicExpr( + initExpr = Not(propExpr), + transExpr = ExprUtils.reverse(transExpr, transOffsetIndex), + propExpr = Not(initExpr), + transOffsetIndex = transOffsetIndex, + initOffsetIndex = VarIndexingFactory.indexing(0), + vars = vars, + ) +} diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddChecker.java b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddChecker.java index cd81d5a697..de0cae2016 100644 --- a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddChecker.java +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddChecker.java @@ -15,8 +15,12 @@ */ package hu.bme.mit.theta.analysis.algorithm.mdd; +import static com.google.common.base.Preconditions.checkArgument; +import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Eq; +import static hu.bme.mit.theta.core.type.booltype.SmartBoolExprs.And; import static hu.bme.mit.theta.core.type.booltype.SmartBoolExprs.Not; +import com.google.common.collect.Lists; import hu.bme.mit.delta.java.mdd.JavaMddFactory; import hu.bme.mit.delta.java.mdd.MddGraph; import hu.bme.mit.delta.java.mdd.MddHandle; @@ -35,19 +39,20 @@ import hu.bme.mit.theta.analysis.algorithm.mdd.fixedpoint.SimpleSaturationProvider; import hu.bme.mit.theta.analysis.algorithm.mdd.fixedpoint.StateSpaceEnumerationProvider; import hu.bme.mit.theta.analysis.expr.ExprAction; +import hu.bme.mit.theta.common.container.Containers; import hu.bme.mit.theta.common.logging.Logger; import hu.bme.mit.theta.common.logging.Logger.Level; import hu.bme.mit.theta.core.decl.Decl; import hu.bme.mit.theta.core.decl.VarDecl; import hu.bme.mit.theta.core.type.Expr; import hu.bme.mit.theta.core.type.booltype.BoolType; -import hu.bme.mit.theta.core.utils.ExprUtils; import hu.bme.mit.theta.core.utils.PathUtils; import hu.bme.mit.theta.core.utils.indexings.VarIndexing; import hu.bme.mit.theta.core.utils.indexings.VarIndexingFactory; import hu.bme.mit.theta.solver.SolverPool; + +import java.util.ArrayList; import java.util.List; -import java.util.Set; public class MddChecker implements SafetyChecker { @@ -55,6 +60,7 @@ public class MddChecker implements SafetyChecker safetyProperty; + private final List> variableOrdering; private final SolverPool solverPool; private final Logger logger; private IterationStrategy iterationStrategy = IterationStrategy.GSAT; @@ -70,12 +76,14 @@ private MddChecker( VarIndexing initIndexing, A transRel, Expr safetyProperty, + List> variableOrdering, SolverPool solverPool, Logger logger, IterationStrategy iterationStrategy) { this.initRel = initRel; this.initIndexing = initIndexing; this.transRel = transRel; + this.variableOrdering = variableOrdering; this.safetyProperty = safetyProperty; this.solverPool = solverPool; this.logger = logger; @@ -87,6 +95,7 @@ public static MddChecker create( VarIndexing initIndexing, A transRel, Expr safetyProperty, + List> variableOrdering, SolverPool solverPool, Logger logger) { return new MddChecker( @@ -94,6 +103,7 @@ public static MddChecker create( initIndexing, transRel, safetyProperty, + variableOrdering, solverPool, logger, IterationStrategy.GSAT); @@ -104,6 +114,7 @@ public static MddChecker create( VarIndexing initIndexing, A transRel, Expr safetyProperty, + List> variableOrdering, SolverPool solverPool, Logger logger, IterationStrategy iterationStrategy) { @@ -112,6 +123,7 @@ public static MddChecker create( initIndexing, transRel, safetyProperty, + variableOrdering, solverPool, logger, iterationStrategy); @@ -128,17 +140,31 @@ public SafetyResult check(Void input) { final MddVariableOrder transOrder = JavaMddFactory.getDefault().createMddVariableOrder(mddGraph); - final Set> vars = - ExprUtils.getVars(List.of(initRel, transRel.toExpr(), safetyProperty)); - for (var v : vars) { - final var domainSize = v.getType() instanceof BoolType ? 2 : 0; + checkArgument(variableOrdering.size() == Containers.createSet(variableOrdering).size(), "Variable ordering contains duplicates"); + final var identityExprs = new ArrayList>(); + for (var v : Lists.reverse(variableOrdering)) { + var domainSize = Math.max(v.getType().getDomainSize().getFiniteSize().intValue(), 0); + + if (domainSize > 100) { + domainSize = 0; + } stateOrder.createOnTop( MddVariableDescriptor.create(v.getConstDecl(initIndexing.get(v)), domainSize)); - transOrder.createOnTop( - MddVariableDescriptor.create( - v.getConstDecl(transRel.nextIndexing().get(v)), domainSize)); + final var index = transRel.nextIndexing().get(v); + if(index > 0) { + transOrder.createOnTop( + MddVariableDescriptor.create( + v.getConstDecl(transRel.nextIndexing().get(v)), domainSize)); + } else { + transOrder.createOnTop( + MddVariableDescriptor.create( + v.getConstDecl(1), domainSize)); + identityExprs.add(Eq(v.getConstDecl(0).getRef(), v.getConstDecl(1).getRef())); + } + + transOrder.createOnTop(MddVariableDescriptor.create(v.getConstDecl(0), domainSize)); } @@ -150,10 +176,10 @@ public SafetyResult check(Void input) { stateSig.getTopVariableHandle() .checkInNode(MddExpressionTemplate.of(initExpr, o -> (Decl) o, solverPool)); - logger.write(Level.INFO, "Created initial node"); + logger.write(Level.INFO, "Created initial node\n"); final Expr transExpr = - PathUtils.unfold(transRel.toExpr(), VarIndexingFactory.indexing(0)); + And(PathUtils.unfold(transRel.toExpr(), VarIndexingFactory.indexing(0)), And(identityExprs)); final MddHandle transitionNode = transSig.getTopVariableHandle() .checkInNode( @@ -209,6 +235,8 @@ public SafetyResult check(Void input) { stateSpaceProvider.getQueryCount(), stateSpaceProvider.getCacheSize()); + logger.write(Level.MAINSTEP, "%s\n", statistics); + final SafetyResult result; if (violatingSize != 0) { result = @@ -217,7 +245,7 @@ public SafetyResult check(Void input) { } else { result = SafetyResult.safe(MddProof.of(stateSpace), statistics); } - logger.write(Level.RESULT, "%s%n", result); + logger.write(Level.RESULT, "%s\n", result); return result; } } diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/ansd/AbstractNextStateDescriptor.java b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/ansd/AbstractNextStateDescriptor.java index 9268baf638..0428edeeed 100644 --- a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/ansd/AbstractNextStateDescriptor.java +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/ansd/AbstractNextStateDescriptor.java @@ -64,23 +64,27 @@ default IntObjMapView> getOffDiagonal return IntObjMapView.empty(getValuations(localStateSpace)); } + final class TerminalEmpty implements AbstractNextStateDescriptor.Postcondition { + @Override + public IntObjMapView getValuations(StateSpaceInfo localStateSpace) { + return IntObjMapView.empty(terminalEmpty()); + } + + @Override + public Optional> split() { + return Optional.empty(); + } + + @Override + public boolean evaluate() { + return false; + } + } + + TerminalEmpty TERMINAL_EMPTY = new TerminalEmpty(); + static AbstractNextStateDescriptor.Postcondition terminalEmpty() { - return new AbstractNextStateDescriptor.Postcondition() { - @Override - public IntObjMapView getValuations(StateSpaceInfo localStateSpace) { - return IntObjMapView.empty(terminalEmpty()); - } - - @Override - public Optional> split() { - return Optional.empty(); - } - - @Override - public boolean evaluate() { - return false; - } - }; + return TERMINAL_EMPTY; } } diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/ansd/impl/MddNodeInitializer.java b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/ansd/impl/MddNodeInitializer.java index c2cd39380d..2aad9748c8 100644 --- a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/ansd/impl/MddNodeInitializer.java +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/ansd/impl/MddNodeInitializer.java @@ -24,6 +24,8 @@ import hu.bme.mit.theta.analysis.algorithm.mdd.ansd.AbstractNextStateDescriptor; import hu.bme.mit.theta.analysis.algorithm.mdd.ansd.StateSpaceInfo; +import java.util.Objects; + public class MddNodeInitializer implements AbstractNextStateDescriptor.Postcondition { private final MddNode node; @@ -58,4 +60,18 @@ public boolean evaluate() { public IntObjMapView getValuations(StateSpaceInfo localStateSpace) { return new IntObjMapViews.Transforming<>(node, n -> of(n, variableHandle.getLower().orElseThrow())); } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MddNodeInitializer that = (MddNodeInitializer) o; + return Objects.equals(node, that.node) + && Objects.equals(variableHandle, that.variableHandle); + } + + @Override + public int hashCode() { + return Objects.hash(node, variableHandle); + } } diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/ansd/impl/MddNodeNextStateDescriptor.java b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/ansd/impl/MddNodeNextStateDescriptor.java index 83378ca593..543dec5082 100644 --- a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/ansd/impl/MddNodeNextStateDescriptor.java +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/ansd/impl/MddNodeNextStateDescriptor.java @@ -25,7 +25,6 @@ import hu.bme.mit.delta.java.mdd.MddVariableHandle; import hu.bme.mit.theta.analysis.algorithm.mdd.ansd.AbstractNextStateDescriptor; import hu.bme.mit.theta.analysis.algorithm.mdd.ansd.StateSpaceInfo; - import java.util.List; import java.util.Objects; import java.util.Optional; @@ -41,7 +40,8 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; MddNodeNextStateDescriptor that = (MddNodeNextStateDescriptor) o; - return Objects.equals(node, that.node) && Objects.equals(variableHandle, that.variableHandle); + return Objects.equals(node, that.node) + && Objects.equals(variableHandle, that.variableHandle); } @Override @@ -49,14 +49,23 @@ public int hashCode() { return Objects.hash(node, variableHandle); } + @Override + public String toString() { + return node + ", " + variableHandle; + } + private MddNodeNextStateDescriptor(MddNode node, MddVariableHandle variableHandle) { this.node = Preconditions.checkNotNull(node); this.variableHandle = Preconditions.checkNotNull(variableHandle); - Preconditions.checkArgument((variableHandle.isTerminal() && node.isTerminal()) || node.isOn(variableHandle.getVariable().orElseThrow())); + Preconditions.checkArgument( + (variableHandle.isTerminal() && node.isTerminal()) + || node.isOn(variableHandle.getVariable().orElseThrow())); } private static AbstractNextStateDescriptor of(MddNode node, MddVariableHandle variableHandle) { - return (node == null || node == variableHandle.getMddGraph().getTerminalZeroNode()) ? AbstractNextStateDescriptor.terminalEmpty() : new MddNodeNextStateDescriptor(node, variableHandle); + return (node == null || node == variableHandle.getMddGraph().getTerminalZeroNode()) + ? AbstractNextStateDescriptor.terminalEmpty() + : new MddNodeNextStateDescriptor(node, variableHandle); } public static AbstractNextStateDescriptor of(MddHandle handle) { @@ -71,18 +80,42 @@ public boolean evaluate() { @Override public IntObjMapView getDiagonal(StateSpaceInfo localStateSpace) { final MddNode constraint = localStateSpace.toStructuralRepresentation(); - return new ConstrainedIntObjMapView<>(new IntObjMapViews.Transforming<>(node, (n, key) -> { - if (key == null) return AbstractNextStateDescriptor.terminalEmpty(); - else - return MddNodeNextStateDescriptor.of(n.get(key), variableHandle.getLower().orElseThrow().getLower().orElseThrow()); - }), constraint); + return new ConstrainedIntObjMapView<>( + new IntObjMapViews.Transforming<>( + node, + (n, key) -> { + if (key == null) return AbstractNextStateDescriptor.terminalEmpty(); + else + return MddNodeNextStateDescriptor.of( + n.get(key), + variableHandle + .getLower() + .orElseThrow() + .getLower() + .orElseThrow()); + }), + constraint); } @Override - public IntObjMapView> getOffDiagonal(StateSpaceInfo localStateSpace) { + public IntObjMapView> getOffDiagonal( + StateSpaceInfo localStateSpace) { final MddNode constraint = localStateSpace.toStructuralRepresentation(); - return new IntObjMapViews.Transforming<>(node, - outerNode -> new ConstrainedIntObjMapView<>(new IntObjMapViews.Transforming<>(outerNode, mddNode -> MddNodeNextStateDescriptor.of(mddNode, variableHandle.getLower().orElseThrow().getLower().orElseThrow())), constraint)); + return new IntObjMapViews.Transforming<>( + node, + outerNode -> + new ConstrainedIntObjMapView<>( + new IntObjMapViews.Transforming<>( + outerNode, + mddNode -> + MddNodeNextStateDescriptor.of( + mddNode, + variableHandle + .getLower() + .orElseThrow() + .getLower() + .orElseThrow())), + constraint)); } @Override @@ -97,15 +130,19 @@ public class Cursor implements AbstractNextStateDescriptor.Cursor { private final Runnable closer; - private Cursor(RecursiveIntObjCursor wrapped, MddVariableHandle variableHandle, Runnable closer) { + private Cursor( + RecursiveIntObjCursor wrapped, + MddVariableHandle variableHandle, + Runnable closer) { this.wrapped = wrapped; this.variableHandle = variableHandle; this.closer = closer; } - private Cursor(RecursiveIntObjCursor wrapped, MddVariableHandle variableHandle) { - this(wrapped, variableHandle, () -> { - }); + private Cursor( + RecursiveIntObjCursor wrapped, + MddVariableHandle variableHandle) { + this(wrapped, variableHandle, () -> {}); } @Override @@ -128,21 +165,32 @@ public boolean moveTo(int key) { return wrapped.moveTo(key); } -// @Override -// public AbstractNextStateDescriptor.Cursor valueCursor(int from) { -// var fromCursor = wrapped.valueCursor(); -// if (fromCursor.moveTo(from)) { -// return new Cursor(fromCursor.valueCursor(), Cursor.this.variableHandle.getLower().orElseThrow().getLower().orElseThrow(), () -> fromCursor.close()); -// } else return new EmptyCursor(() -> fromCursor.close()); -// } + // @Override + // public AbstractNextStateDescriptor.Cursor valueCursor(int from) { + // var fromCursor = wrapped.valueCursor(); + // if (fromCursor.moveTo(from)) { + // return new Cursor(fromCursor.valueCursor(), + // Cursor.this.variableHandle.getLower().orElseThrow().getLower().orElseThrow(), () -> + // fromCursor.close()); + // } else return new EmptyCursor(() -> fromCursor.close()); + // } @Override - public AbstractNextStateDescriptor.Cursor valueCursor(int from, StateSpaceInfo localStateSpace) { + public AbstractNextStateDescriptor.Cursor valueCursor( + int from, StateSpaceInfo localStateSpace) { final MddNode constraint = localStateSpace.toStructuralRepresentation(); // TODO the valueCursor call of the wrapped cursor has to propagate the constraint var fromCursor = wrapped.valueCursor(); if (fromCursor.moveTo(from)) { - return new Cursor(fromCursor.valueCursor(), Cursor.this.variableHandle.getLower().orElseThrow().getLower().orElseThrow(), () -> fromCursor.close()); + return new Cursor( + fromCursor.valueCursor(), + Cursor.this + .variableHandle + .getLower() + .orElseThrow() + .getLower() + .orElseThrow(), + () -> fromCursor.close()); } else return new EmptyCursor(() -> fromCursor.close()); } @@ -158,70 +206,73 @@ public Optional> split() { } } -// public class Cursor extends CursorBase { -// -// private final RecursiveIntObjCursor fromCursor; -// -// private Cursor(RecursiveIntObjCursor wrapped, RecursiveIntObjCursor fromCursor){ -// super(wrapped); -// this.fromCursor = fromCursor; -// } -// -// @Override -// public void close() { -// super.close(); -// fromCursor.close(); -// } -// } -// -// public class RootCursor extends CursorBase { -// -// private final MddNodeNextStateDescriptor descriptor; -// -// private int currentPosition; -// -// public RootCursor(MddNodeNextStateDescriptor descriptor) { -// super(descriptor.node.cursor()); -// this.descriptor = descriptor; -// this.currentPosition = -1; -// } -// -// @Override -// public int key() { -// throw new UnsupportedOperationException("This operation is not supported on the root cursor"); -// } -// -// @Override -// public AbstractNextStateDescriptor value() { -// return descriptor; -// } -// -// @Override -// public boolean moveNext() { -// currentPosition++; -// return currentPosition == 0; -// } -// -// @Override -// public boolean moveTo(int key) { -// throw new UnsupportedOperationException("This operation is not supported on the root cursor"); -// } -// -// @Override -// public AbstractNextStateDescriptor.Cursor valueCursor(int from) { -// var fromCursor = descriptor.node.cursor(); -// if(fromCursor.moveTo(from)) { -// return new Cursor(fromCursor.valueCursor(), fromCursor); -// } else { -// return EmptyCursor.INSTANCE; -// } -// -// } -// -// @Override -// public void close() {} -// -// } + // public class Cursor extends CursorBase { + // + // private final RecursiveIntObjCursor fromCursor; + // + // private Cursor(RecursiveIntObjCursor wrapped, + // RecursiveIntObjCursor fromCursor){ + // super(wrapped); + // this.fromCursor = fromCursor; + // } + // + // @Override + // public void close() { + // super.close(); + // fromCursor.close(); + // } + // } + // + // public class RootCursor extends CursorBase { + // + // private final MddNodeNextStateDescriptor descriptor; + // + // private int currentPosition; + // + // public RootCursor(MddNodeNextStateDescriptor descriptor) { + // super(descriptor.node.cursor()); + // this.descriptor = descriptor; + // this.currentPosition = -1; + // } + // + // @Override + // public int key() { + // throw new UnsupportedOperationException("This operation is not supported on the + // root cursor"); + // } + // + // @Override + // public AbstractNextStateDescriptor value() { + // return descriptor; + // } + // + // @Override + // public boolean moveNext() { + // currentPosition++; + // return currentPosition == 0; + // } + // + // @Override + // public boolean moveTo(int key) { + // throw new UnsupportedOperationException("This operation is not supported on the + // root cursor"); + // } + // + // @Override + // public AbstractNextStateDescriptor.Cursor valueCursor(int from) { + // var fromCursor = descriptor.node.cursor(); + // if(fromCursor.moveTo(from)) { + // return new Cursor(fromCursor.valueCursor(), fromCursor); + // } else { + // return EmptyCursor.INSTANCE; + // } + // + // } + // + // @Override + // public void close() {} + // + // } public static class EmptyCursor implements AbstractNextStateDescriptor.Cursor { @@ -231,15 +282,16 @@ public EmptyCursor(Runnable closer) { this.closer = closer; } - @Override public int key() { - throw new UnsupportedOperationException("This operation is not supported on the root cursor"); + throw new UnsupportedOperationException( + "This operation is not supported on the root cursor"); } @Override public AbstractNextStateDescriptor value() { - throw new UnsupportedOperationException("This operation is not supported on the root cursor"); + throw new UnsupportedOperationException( + "This operation is not supported on the root cursor"); } @Override @@ -253,8 +305,10 @@ public boolean moveTo(int key) { } @Override - public AbstractNextStateDescriptor.Cursor valueCursor(int from, StateSpaceInfo localStateSpace) { - throw new UnsupportedOperationException("This operation is not supported on the root cursor"); + public AbstractNextStateDescriptor.Cursor valueCursor( + int from, StateSpaceInfo localStateSpace) { + throw new UnsupportedOperationException( + "This operation is not supported on the root cursor"); } @Override @@ -266,15 +320,16 @@ public void close() { public Optional> split() { return Optional.of(List.of(this)); } - } - private class ConstrainedIntObjMapView extends IntObjMapViews.ForwardingBase implements IntObjMapView { + private class ConstrainedIntObjMapView extends IntObjMapViews.ForwardingBase + implements IntObjMapView { private final IntObjMapView target; private final IntObjMapView constraint; - public ConstrainedIntObjMapView(IntObjMapView target, IntObjMapView constraint) { + public ConstrainedIntObjMapView( + IntObjMapView target, IntObjMapView constraint) { this.target = target; this.constraint = constraint; } diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/expressionnode/LitExprConverter.java b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/expressionnode/LitExprConverter.java index 3be2042cf3..34d32df077 100644 --- a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/expressionnode/LitExprConverter.java +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/expressionnode/LitExprConverter.java @@ -15,6 +15,8 @@ */ package hu.bme.mit.theta.analysis.algorithm.mdd.expressionnode; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Bool; + import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; import hu.bme.mit.theta.core.type.LitExpr; @@ -24,22 +26,22 @@ import hu.bme.mit.theta.core.type.arraytype.ArrayType; import hu.bme.mit.theta.core.type.booltype.BoolLitExpr; import hu.bme.mit.theta.core.type.booltype.BoolType; +import hu.bme.mit.theta.core.type.bvtype.BvLitExpr; +import hu.bme.mit.theta.core.type.bvtype.BvType; import hu.bme.mit.theta.core.type.enumtype.EnumLitExpr; import hu.bme.mit.theta.core.type.enumtype.EnumType; +import hu.bme.mit.theta.core.type.fptype.FpLitExpr; +import hu.bme.mit.theta.core.type.fptype.FpType; import hu.bme.mit.theta.core.type.inttype.IntLitExpr; import hu.bme.mit.theta.core.type.inttype.IntType; - +import hu.bme.mit.theta.core.utils.BvUtils; import java.math.BigInteger; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Bool; - -/** - * Util class for converting between integers and {@link LitExpr} - */ +/** Util class for converting between integers and {@link LitExpr} */ public class LitExprConverter { private static int cnt = 0; - private final static BiMap objToInt = HashBiMap.create(); + private static final BiMap objToInt = HashBiMap.create(); public static int toInt(LitExpr litExpr) { if (litExpr instanceof IntLitExpr) { @@ -48,7 +50,11 @@ public static int toInt(LitExpr litExpr) { if (litExpr instanceof BoolLitExpr) { return ((BoolLitExpr) litExpr).getValue() ? 1 : 0; } - if (litExpr instanceof ArrayLitExpr) { + if (litExpr instanceof BvLitExpr bvLitExpr) { + var ret = BvUtils.neutralBvLitExprToBigInteger(bvLitExpr).intValue(); + return ret; + } + if (litExpr instanceof ArrayLitExpr || litExpr instanceof FpLitExpr) { if (objToInt.get(litExpr) != null) { return objToInt.get(litExpr); } @@ -72,7 +78,11 @@ public static LitExpr toLitExpr(int integer, Type type) { } return BoolLitExpr.of(integer != 0); } - if (type instanceof ArrayType) { + if (type instanceof BvType) { + return BvUtils.bigIntegerToNeutralBvLitExpr( + BigInteger.valueOf(integer), ((BvType) type).getSize()); + } + if (type instanceof ArrayType || type instanceof FpType) { return (LitExpr) objToInt.inverse().get(integer); } if (type instanceof EnumType) { @@ -80,5 +90,4 @@ public static LitExpr toLitExpr(int integer, Type type) { } throw new UnsupportedOperationException("Unsupported type"); } - } diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/expressionnode/MddExpressionRepresentation.java b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/expressionnode/MddExpressionRepresentation.java index 9816236b4d..e60c1b7a5f 100644 --- a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/expressionnode/MddExpressionRepresentation.java +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/expressionnode/MddExpressionRepresentation.java @@ -15,6 +15,10 @@ */ package hu.bme.mit.theta.analysis.algorithm.mdd.expressionnode; +import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.*; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.True; +import static hu.bme.mit.theta.core.type.booltype.SmartBoolExprs.And; + import com.google.common.base.Preconditions; import com.koloboke.collect.map.hash.HashIntObjMap; import com.koloboke.collect.map.hash.HashIntObjMaps; @@ -22,7 +26,6 @@ import hu.bme.mit.delta.java.mdd.MddGraph; import hu.bme.mit.delta.java.mdd.MddNode; import hu.bme.mit.delta.java.mdd.MddVariable; -import hu.bme.mit.theta.solver.SolverPool; import hu.bme.mit.theta.common.GrowingIntArray; import hu.bme.mit.theta.core.decl.Decl; import hu.bme.mit.theta.core.model.ImmutableValuation; @@ -32,19 +35,14 @@ import hu.bme.mit.theta.core.type.LitExpr; import hu.bme.mit.theta.core.type.booltype.BoolType; import hu.bme.mit.theta.core.type.booltype.FalseExpr; -import hu.bme.mit.theta.core.type.enumtype.EnumType; import hu.bme.mit.theta.core.utils.ExprUtils; import hu.bme.mit.theta.solver.Solver; +import hu.bme.mit.theta.solver.SolverPool; import hu.bme.mit.theta.solver.SolverStatus; import hu.bme.mit.theta.solver.utils.WithPushPop; - import java.io.Closeable; import java.util.*; -import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.*; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.True; -import static hu.bme.mit.theta.core.type.booltype.SmartBoolExprs.And; - public class MddExpressionRepresentation implements RecursiveIntObjMapView { private final Expr expr; @@ -56,7 +54,11 @@ public class MddExpressionRepresentation implements RecursiveIntObjMapView expr, final Decl decl, final MddVariable mddVariable, final SolverPool solverPool) { + private MddExpressionRepresentation( + final Expr expr, + final Decl decl, + final MddVariable mddVariable, + final SolverPool solverPool) { this.expr = expr; this.decl = decl; this.mddVariable = mddVariable; @@ -64,17 +66,27 @@ private MddExpressionRepresentation(final Expr expr, final Decl dec this.explicitRepresentation = new ExplicitRepresentation(); } - //TODO only for debugging + // TODO only for debugging public ExplicitRepresentation getExplicitRepresentation() { return explicitRepresentation; } - public static MddExpressionRepresentation of(final Expr expr, final Decl decl, final MddVariable mddVariable, final SolverPool solverPool) { + public static MddExpressionRepresentation of( + final Expr expr, + final Decl decl, + final MddVariable mddVariable, + final SolverPool solverPool) { return new MddExpressionRepresentation(expr, decl, mddVariable, solverPool); } - public static MddExpressionRepresentation ofDefault(final Expr expr, final Decl decl, final MddVariable mddVariable, final SolverPool solverPool, final MddNode defaultValue) { - final MddExpressionRepresentation representation = new MddExpressionRepresentation(expr, decl, mddVariable, solverPool); + public static MddExpressionRepresentation ofDefault( + final Expr expr, + final Decl decl, + final MddVariable mddVariable, + final SolverPool solverPool, + final MddNode defaultValue) { + final MddExpressionRepresentation representation = + new MddExpressionRepresentation(expr, decl, mddVariable, solverPool); representation.explicitRepresentation.cacheDefault(defaultValue); representation.explicitRepresentation.setComplete(); return representation; @@ -99,7 +111,7 @@ private Traverser getLazyTraverser() { @Override public boolean isEmpty() { -// return false; + // return false; return explicitRepresentation.isComplete() && size() == 0; } @@ -132,10 +144,12 @@ public MddNode get(int key) { final MddNode childNode; if (mddVariable.getLower().isPresent()) { - final MddExpressionTemplate template = MddExpressionTemplate.of(simplifiedExpr, o -> (Decl) o, solverPool); + final MddExpressionTemplate template = + MddExpressionTemplate.of(simplifiedExpr, o -> (Decl) o, solverPool); childNode = mddVariable.getLower().get().checkInNode(template); } else { - final Expr canonizedExpr = ExprUtils.canonize(ExprUtils.simplify(simplifiedExpr)); + final Expr canonizedExpr = + ExprUtils.canonize(ExprUtils.simplify(simplifiedExpr)); MddGraph mddGraph = (MddGraph) mddVariable.getMddGraph(); if (canonizedExpr instanceof FalseExpr) { @@ -172,37 +186,48 @@ public RecursiveIntObjCursor cursor(RecursiveIntObjMapView final MddNode mddNodeConstraint = (MddNode) constraint; final List> exprs = new ArrayList<>(); - if (mddVariable.getLower().isPresent()) { + if (mddVariable.getLower().isPresent() && !mddNodeConstraint.isTerminal()) { MddVariable variable = mddVariable.getLower().get(); - MddNode mddNode = mddNodeConstraint.get(mddNodeConstraint.statistics().lowestValue()); + MddNode mddNode = + mddNodeConstraint.defaultValue() == null + ? mddNodeConstraint.get(mddNodeConstraint.statistics().lowestValue()) + : mddNodeConstraint.defaultValue(); while (true) { - // This is needed because the constraint node might contain level-skips: of the domain is bounded, then default values are detected + // This is needed because the constraint node might contain level-skips: if the + // domain is bounded, then default values are detected if (mddNode.isTerminal()) break; final IntStatistics statistics = mddNode.statistics(); final Decl decl = variable.getTraceInfo(Decl.class); - final LitExpr lowerBound = LitExprConverter.toLitExpr(statistics.lowestValue(), decl.getType()); - final LitExpr upperBound = LitExprConverter.toLitExpr(statistics.highestValue(), decl.getType()); - if (!decl.getType().equals(BoolType.getInstance()) && !(decl.getType() instanceof EnumType)) { // TODO delete + final LitExpr lowerBound = + LitExprConverter.toLitExpr(statistics.lowestValue(), decl.getType()); + final LitExpr upperBound = + LitExprConverter.toLitExpr(statistics.highestValue(), decl.getType()); + if (decl.getType().getDomainSize().isInfinite()) { // TODO delete if (lowerBound.equals(upperBound)) { exprs.add(Eq(decl.getRef(), lowerBound)); } else { - exprs.add(And(Geq(decl.getRef(), lowerBound), Leq(decl.getRef(), upperBound))); + exprs.add( + And( + Geq(decl.getRef(), lowerBound), + Leq(decl.getRef(), upperBound))); } } - - if (variable.getLower().isEmpty() || variable.getLower().get().getLower().isEmpty()) { + if (variable.getLower().isEmpty() + || variable.getLower().get().getLower().isEmpty()) { break; } else { variable = variable.getLower().get().getLower().get(); - mddNode = mddNode.get(statistics.lowestValue()); // TODO we assume here that all edges point to the same node + mddNode = + mddNode.get( + statistics.lowestValue()); // TODO we assume here that all edges + // point to the same node } - } } -// System.out.println(exprs); + // System.out.println(exprs); return new Cursor(null, Traverser.createConstrained(this, And(exprs), solverPool)); } @@ -222,9 +247,9 @@ public String toString() { public boolean equals(Object that) { if (this == that) return true; if (that instanceof MddExpressionRepresentation) { - return Objects.equals(expr, ((MddExpressionRepresentation) that).expr) && - Objects.equals(decl, ((MddExpressionRepresentation) that).decl) && - Objects.equals(mddVariable, ((MddExpressionRepresentation) that).mddVariable); + return Objects.equals(mddVariable, ((MddExpressionRepresentation) that).mddVariable) + && Objects.equals(decl, ((MddExpressionRepresentation) that).decl) + && (Objects.equals(expr, ((MddExpressionRepresentation) that).expr) || semanticEquals(that)); } if (that instanceof MddNode) { return this.equals(((MddNode) that).getRepresentation()); @@ -232,10 +257,23 @@ public boolean equals(Object that) { return false; } + private boolean semanticEquals(Object that) { + + if(that instanceof MddExpressionRepresentation thatRepresentation) { + if(this.explicitRepresentation.complete && thatRepresentation.explicitRepresentation.complete) { + return IntObjMapView.equals(this.explicitRepresentation.getCacheView(), thatRepresentation.explicitRepresentation.getCacheView()); + } + } else if (that instanceof IntObjMapView thatView) { + if(this.explicitRepresentation.complete) { + return IntObjMapView.equals(thatView, this.explicitRepresentation.getCacheView()); + } + } + return false; + } + @Override public int hashCode() { return Objects.hash(expr, decl, mddVariable); - } public static class ExplicitRepresentation { @@ -296,7 +334,11 @@ private static class Traverser implements Closeable { private final boolean constrained; - private Traverser(MddExpressionRepresentation rootRepresentation, Expr constraint, SolverPool solverPool, boolean constrained) { + private Traverser( + MddExpressionRepresentation rootRepresentation, + Expr constraint, + SolverPool solverPool, + boolean constrained) { this.solverPool = solverPool; this.solver = solverPool.requestSolver(); this.stack = new Stack<>(); @@ -308,11 +350,15 @@ private Traverser(MddExpressionRepresentation rootRepresentation, Expr setCurrentRepresentation(Preconditions.checkNotNull(rootRepresentation)); } - public static Traverser createConstrained(MddExpressionRepresentation rootRepresentation, Expr constraint, SolverPool solverPool) { + public static Traverser createConstrained( + MddExpressionRepresentation rootRepresentation, + Expr constraint, + SolverPool solverPool) { return new Traverser(rootRepresentation, constraint, solverPool, true); } - public static Traverser create(MddExpressionRepresentation rootRepresentation, SolverPool solverPool) { + public static Traverser create( + MddExpressionRepresentation rootRepresentation, SolverPool solverPool) { return new Traverser(rootRepresentation, True(), solverPool, false); } @@ -325,12 +371,19 @@ public MddExpressionRepresentation moveUp() { } public boolean queryEdge(int assignment) { - if (currentRepresentation.explicitRepresentation.getCacheView().keySet().contains(assignment) || currentRepresentation.explicitRepresentation.getCacheView().defaultValue() != null) - return true; + if (currentRepresentation + .explicitRepresentation + .getCacheView() + .keySet() + .contains(assignment) + || currentRepresentation.explicitRepresentation.getCacheView().defaultValue() + != null) return true; else if (!currentRepresentation.explicitRepresentation.isComplete()) { final SolverStatus status; final Valuation model; - final LitExpr litExpr = LitExprConverter.toLitExpr(assignment, currentRepresentation.decl.getType()); + final LitExpr litExpr = + LitExprConverter.toLitExpr( + assignment, currentRepresentation.decl.getType()); try (WithPushPop wpp = new WithPushPop(solver)) { solver.add(Eq(currentRepresentation.decl.getRef(), litExpr)); solver.check(); @@ -355,7 +408,12 @@ public MddNode peekDown(int assignment) { public QueryResult queryEdge() { if (!currentRepresentation.explicitRepresentation.isComplete()) { - if (pushedNegatedAssignments != currentRepresentation.explicitRepresentation.getCacheView().keySet().size()) { + if (pushedNegatedAssignments + != currentRepresentation + .explicitRepresentation + .getCacheView() + .keySet() + .size()) { popNegatedAssignments(); pushNegatedAssignments(); } @@ -373,8 +431,15 @@ public QueryResult queryEdge() { } else { final int newValue; if (currentRepresentation.mddVariable.isBounded()) { - final IntSetView domain = IntSetView.range(0, currentRepresentation.mddVariable.getDomainSize()); - final IntSetView remaining = domain.minus(currentRepresentation.explicitRepresentation.getCacheView().keySet()); + final IntSetView domain = + IntSetView.range( + 0, currentRepresentation.mddVariable.getDomainSize()); + final IntSetView remaining = + domain.minus( + currentRepresentation + .explicitRepresentation + .getCacheView() + .keySet()); if (remaining.isEmpty()) { currentRepresentation.explicitRepresentation.setComplete(); return QueryResult.failed(); @@ -384,9 +449,17 @@ public QueryResult queryEdge() { newValue = cur.elem(); } } else { - // only visited once per node, because of the negated assignment that is pushed to the solver - final IntSetView cachedKeys = currentRepresentation.explicitRepresentation.getCacheView().keySet(); - newValue = cachedKeys.isEmpty() ? 0 : cachedKeys.statistics().highestValue() + 1; + // only visited once per node, because of the negated assignment that is + // pushed to the solver + final IntSetView cachedKeys = + currentRepresentation + .explicitRepresentation + .getCacheView() + .keySet(); + newValue = + cachedKeys.isEmpty() + ? 0 + : cachedKeys.statistics().highestValue() + 1; } literal = LitExprConverter.toLitExpr(newValue, decl.getType()); final var extendedModel = MutableValuation.copyOf(model); @@ -414,11 +487,18 @@ public MddNode moveDown(int assignment) { if (queryEdge(assignment)) { popNegatedAssignments(); solver.push(); - solver.add(Eq(currentRepresentation.decl.getRef(), LitExprConverter.toLitExpr(assignment, currentRepresentation.decl.getType()))); + solver.add( + Eq( + currentRepresentation.decl.getRef(), + LitExprConverter.toLitExpr( + assignment, currentRepresentation.decl.getType()))); stack.push(currentRepresentation); - final MddNode childNode = currentRepresentation.explicitRepresentation.getCacheView().get(assignment); - Preconditions.checkArgument(childNode.getRepresentation() instanceof MddExpressionRepresentation); - setCurrentRepresentation((MddExpressionRepresentation) childNode.getRepresentation()); + final MddNode childNode = + currentRepresentation.explicitRepresentation.getCacheView().get(assignment); + Preconditions.checkArgument( + childNode.getRepresentation() instanceof MddExpressionRepresentation); + setCurrentRepresentation( + (MddExpressionRepresentation) childNode.getRepresentation()); return childNode; } else return null; } @@ -426,8 +506,13 @@ public MddNode moveDown(int assignment) { private void pushNegatedAssignments() { solver.push(); final var negatedAssignments = new ArrayList>(); - for (var cur = currentRepresentation.explicitRepresentation.getCacheView().cursor(); cur.moveNext(); ) { - negatedAssignments.add(Neq(currentRepresentation.decl.getRef(), LitExprConverter.toLitExpr(cur.key(), currentRepresentation.decl.getType()))); + for (var cur = currentRepresentation.explicitRepresentation.getCacheView().cursor(); + cur.moveNext(); ) { + negatedAssignments.add( + Neq( + currentRepresentation.decl.getRef(), + LitExprConverter.toLitExpr( + cur.key(), currentRepresentation.decl.getType()))); pushedNegatedAssignments++; } solver.add(And(negatedAssignments)); @@ -450,73 +535,69 @@ private void cacheModel(Valuation valuation) { } else { - final Optional lower = representation.mddVariable.getLower(); - final LitExpr literalToCache = determineLiteralToCache(representation, valuation); - - if (representation.explicitRepresentation.getCacheView().containsKey(LitExprConverter.toInt(literalToCache))) { - - childNode = representation.explicitRepresentation.getCacheView().get(LitExprConverter.toInt(literalToCache)); - assert lower.isEmpty() || childNode.isOn(lower.get()); - + // Substitute literal if available + final Optional> literal = + valuation.eval(representation.getDecl()); + final Expr substitutedExpr; + if (literal.isPresent()) { + substitutedExpr = + ExprUtils.simplify( + representation.expr, + ImmutableValuation.builder() + .put(representation.getDecl(), literal.get()) + .build()); } else { + substitutedExpr = representation.expr; + } - final Expr substitutedExpr = ExprUtils.simplify(representation.expr, ImmutableValuation.builder().put(representation.getDecl(), literalToCache).build()); + if (literal.isPresent() + && representation + .explicitRepresentation + .getCacheView() + .containsKey(LitExprConverter.toInt(literal.get()))) { + // Return cached if available + childNode = + representation + .explicitRepresentation + .getCacheView() + .get(LitExprConverter.toInt(literal.get())); + } else { + final Optional lower = + representation.mddVariable.getLower(); if (lower.isPresent()) { - final MddExpressionTemplate template = MddExpressionTemplate.of(substitutedExpr, o -> (Decl) o, representation.solverPool); + final MddExpressionTemplate template = + MddExpressionTemplate.of( + substitutedExpr, + o -> (Decl) o, + representation.solverPool); childNode = lower.get().checkInNode(template); } else { - final Expr canonizedExpr = ExprUtils.canonize(substitutedExpr); - MddGraph mddGraph = (MddGraph) representation.mddVariable.getMddGraph(); + final Expr canonizedExpr = + ExprUtils.canonize(substitutedExpr); + MddGraph mddGraph = + (MddGraph) representation.mddVariable.getMddGraph(); assert !(canonizedExpr instanceof FalseExpr); childNode = mddGraph.getNodeFor(canonizedExpr); } - assert !representation.mddVariable.isNullOrZero(childNode) : "This would mean the model returned by the solver is incorrect"; - representation.explicitRepresentation.cacheNode(LitExprConverter.toInt(literalToCache), childNode); + assert !representation.mddVariable.isNullOrZero(childNode) + : "This would mean the model returned by the solver is incorrect"; + if (literal.isPresent()) + representation.explicitRepresentation.cacheNode( + LitExprConverter.toInt(literal.get()), childNode); // TODO update domainSize } } if (childNode.isTerminal()) break; - //Preconditions.checkArgument(childNode.getRepresentation() instanceof MddExpressionRepresentation); + // Preconditions.checkArgument(childNode.getRepresentation() instanceof + // MddExpressionRepresentation); // TODO assert representation = (MddExpressionRepresentation) childNode.getRepresentation(); } } - private static LitExpr determineLiteralToCache(MddExpressionRepresentation representation, Valuation valuation) { - final Decl decl = representation.getDecl(); - final Optional> literal = valuation.eval(decl); - - if (literal.isPresent()) { - return literal.get(); - } else { - return LitExprConverter.toLitExpr(generateMissingLiteral(representation), decl.getType()); - } - } - - private static int generateMissingLiteral(MddExpressionRepresentation representation) { - final int newValue; - if (representation.mddVariable.isBounded()) { - final IntSetView domain = IntSetView.range(0, representation.mddVariable.getDomainSize()); - final IntSetView remaining = domain.minus(representation.explicitRepresentation.getCacheView().keySet()); - if (remaining.isEmpty()) { - representation.explicitRepresentation.setComplete(); - // Return the first element - newValue = representation.explicitRepresentation.getCacheView().keySet().statistics().lowestValue(); - } else { - final var cur = remaining.cursor(); - Preconditions.checkState(cur.moveNext()); - newValue = cur.elem(); - } - } else { - final IntSetView cachedKeys = representation.explicitRepresentation.getCacheView().keySet(); - newValue = cachedKeys.isEmpty() ? 0 : cachedKeys.statistics().highestValue() + 1; - } - return newValue; - } - private void setCurrentRepresentation(MddExpressionRepresentation representation) { this.currentRepresentation = representation; pushNegatedAssignments(); @@ -565,15 +646,15 @@ public QueryResult.Status getStatus() { } /** - * The status of the result. - * FAILED: no further edges are possible - * SINGLE_EDGE: a single edge was found - * DEFAULT_EDGE: the node is a level-skip and has a default value + * The status of the result. FAILED: no further edges are possible SINGLE_EDGE: a single + * edge was found DEFAULT_EDGE: the node is a level-skip and has a default value */ public enum Status { - FAILED, SINGLE_EDGE, DEFAULT_EDGE, CONSTRAINED_FAILED + FAILED, + SINGLE_EDGE, + DEFAULT_EDGE, + CONSTRAINED_FAILED } - } } @@ -582,13 +663,11 @@ private static class Cursor implements RecursiveIntObjCursor { // Fields for node enumeration private final Traverser traverser; - // Fields for the recursive cursor structure private final Cursor parent; private boolean blocked = false; private boolean closed = false; - // Common cursor fields private int index; private int key; @@ -644,18 +723,20 @@ public boolean moveNext() { @Override public boolean moveTo(int key) { - Preconditions.checkState(!blocked, "Cursor can't be moved until its children are disposed of"); + Preconditions.checkState( + !blocked, "Cursor can't be moved until its children are disposed of"); Preconditions.checkState(!closed, "Cursor can't be moved if it was closed"); var currentRepresentation = traverser.currentRepresentation; - if (currentRepresentation.explicitRepresentation.getCacheView().containsKey(key) || !currentRepresentation.explicitRepresentation.isComplete() && traverser.queryEdge(key)) { + if (currentRepresentation.explicitRepresentation.getCacheView().containsKey(key) + || !currentRepresentation.explicitRepresentation.isComplete() + && traverser.queryEdge(key)) { this.key = key; this.value = currentRepresentation.get(key); this.initialized = true; return true; } return false; - } @Override @@ -687,7 +768,8 @@ public MddNode value() { @Override public boolean moveNext() { - Preconditions.checkState(!blocked, "Cursor can't be moved until its children are not closed"); + Preconditions.checkState( + !blocked, "Cursor can't be moved until its children are not closed"); Preconditions.checkState(!closed, "Cursor can't be moved if it was closed"); var currentRepresentation = traverser.currentRepresentation; @@ -697,15 +779,19 @@ public boolean moveNext() { value = currentRepresentation.explicitRepresentation.getCacheView().get(key); initialized = true; return true; - } else if (!currentRepresentation.explicitRepresentation.isComplete() && !constrainedFailed) { - final MddExpressionRepresentation.Traverser.QueryResult queryResult = traverser.queryEdge(); - if (queryResult.getStatus() == MddExpressionRepresentation.Traverser.QueryResult.Status.SINGLE_EDGE) { + } else if (!currentRepresentation.explicitRepresentation.isComplete() + && !constrainedFailed) { + final MddExpressionRepresentation.Traverser.QueryResult queryResult = + traverser.queryEdge(); + if (queryResult.getStatus() + == MddExpressionRepresentation.Traverser.QueryResult.Status.SINGLE_EDGE) { index++; key = queryResult.getKey(); value = currentRepresentation.explicitRepresentation.getCacheView().get(key); initialized = true; return true; - } else if (queryResult.getStatus() == Traverser.QueryResult.Status.CONSTRAINED_FAILED) { + } else if (queryResult.getStatus() + == Traverser.QueryResult.Status.CONSTRAINED_FAILED) { this.constrainedFailed = true; } } diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/GeneralizedSaturationProvider.java b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/GeneralizedSaturationProvider.java index c42aae07c9..b9e65d7c44 100644 --- a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/GeneralizedSaturationProvider.java +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/GeneralizedSaturationProvider.java @@ -15,13 +15,14 @@ */ package hu.bme.mit.theta.analysis.algorithm.mdd.fixedpoint; +import com.google.common.base.Preconditions; import com.koloboke.collect.set.hash.HashObjSets; import hu.bme.mit.delta.collections.IntObjCursor; import hu.bme.mit.delta.collections.IntObjMapView; +import hu.bme.mit.delta.collections.RecursiveIntObjMapView; import hu.bme.mit.delta.java.mdd.*; import hu.bme.mit.delta.java.mdd.impl.MddStructuralTemplate; import hu.bme.mit.theta.analysis.algorithm.mdd.ansd.AbstractNextStateDescriptor; - import java.util.Optional; import java.util.Set; import java.util.function.Consumer; @@ -32,7 +33,8 @@ public final class GeneralizedSaturationProvider implements StateSpaceEnumeratio private MddVariableOrder variableOrder; private RelationalProductProvider relProdProvider; - private final CacheManager cacheManager = new CacheManager<>(v -> new SaturationCache()); + private final CacheManager cacheManager = + new CacheManager<>(v -> new SaturationCache()); private MddNode terminalZeroNode; public GeneralizedSaturationProvider(final MddVariableOrder variableOrder) { @@ -40,8 +42,7 @@ public GeneralizedSaturationProvider(final MddVariableOrder variableOrder) { } public GeneralizedSaturationProvider( - final MddVariableOrder variableOrder, final RelationalProductProvider relProdProvider - ) { + final MddVariableOrder variableOrder, final RelationalProductProvider relProdProvider) { this.variableOrder = variableOrder; this.relProdProvider = relProdProvider; this.variableOrder.getMddGraph().registerCleanupListener(this); @@ -52,9 +53,12 @@ public GeneralizedSaturationProvider( public MddHandle compute( AbstractNextStateDescriptor.Postcondition initializer, AbstractNextStateDescriptor nextStateRelation, - MddVariableHandle highestAffectedVariable - ) { - final MddHandle initialStates = relProdProvider.compute(variableOrder.getMddGraph().getHandleForTop(), initializer, highestAffectedVariable); + MddVariableHandle highestAffectedVariable) { + final MddHandle initialStates = + relProdProvider.compute( + variableOrder.getMddGraph().getHandleForTop(), + initializer, + highestAffectedVariable); MddNode result; @@ -62,10 +66,11 @@ public MddHandle compute( final MddVariable variable = highestAffectedVariable.getVariable().get(); result = this.compute(initialStates.getNode(), nextStateRelation, variable); } else { - result = this.computeTerminal(initialStates.getNode(), - nextStateRelation, - highestAffectedVariable.getMddGraph() - ); + result = + this.computeTerminal( + initialStates.getNode(), + nextStateRelation, + highestAffectedVariable.getMddGraph()); } return highestAffectedVariable.getHandleFor(result); @@ -75,8 +80,9 @@ private MddNode recurse( final MddNode mddNode, final AbstractNextStateDescriptor nextState, MddVariable currentVariable, - final CacheManager>.CacheHolder cache - ) { + final CacheManager> + .CacheHolder + cache) { if (currentVariable.getLower().isPresent()) { return compute(mddNode, nextState, currentVariable.getLower().get()); } else { @@ -85,8 +91,7 @@ private MddNode recurse( } private MddNode unionChildren( - final MddNode lhs, final MddNode rhs, MddVariable currentVariable - ) { + final MddNode lhs, final MddNode rhs, MddVariable currentVariable) { if (currentVariable.getLower().isPresent()) { return currentVariable.getLower().get().union(lhs, rhs); } else { @@ -96,8 +101,9 @@ private MddNode unionChildren( @Override public MddNode compute( - final MddNode mddNode, final AbstractNextStateDescriptor nextState, final MddVariable mddVariable - ) { + final MddNode mddNode, + final AbstractNextStateDescriptor nextState, + final MddVariable mddVariable) { return saturate(mddNode, nextState, mddVariable, cacheManager.getCacheFor(mddVariable)); } @@ -105,11 +111,10 @@ private MddNode saturate( final MddNode n, AbstractNextStateDescriptor d, MddVariable variable, - CacheManager.CacheHolder cache - ) { - if (n.isTerminal() || - d == AbstractNextStateDescriptor.terminalIdentity() || - d == AbstractNextStateDescriptor.terminalEmpty()) { + CacheManager.CacheHolder cache) { + if (n.isTerminal() + || d == AbstractNextStateDescriptor.terminalIdentity() + || d == AbstractNextStateDescriptor.terminalEmpty()) { // TODO this does not handle level skips return n; } @@ -122,41 +127,43 @@ private MddNode saturate( if (verbose) { printIndent(); System.out.println("Saturating on level " + variable.getTraceInfo() + " with " + d); - } // indent++; final MddStateSpaceInfo stateSpaceInfo = new MddStateSpaceInfo(variable, n); -// -// IntObjMapView satTemplate = new IntObjMapViews.Transforming(n, -// (node, key) -> node == null ? null : terminalZeroToNull(saturate(node, -// d.getDiagonal(stateSpaceInfo).get(key), -// variable.getLower().orElse(null), -// cache.getLower() -// )) -// ); -// -// MddNode nsat = variable.checkInNode(MddStructuralTemplate.of(satTemplate)); - + // + // IntObjMapView satTemplate = new IntObjMapViews.Transforming(n, + // (node, key) -> node == null ? null : terminalZeroToNull(saturate(node, + // d.getDiagonal(stateSpaceInfo).get(key), + // variable.getLower().orElse(null), + // cache.getLower() + // )) + // ); + // + // MddNode nsat = variable.checkInNode(MddStructuralTemplate.of(satTemplate)); - MddUnsafeTemplateBuilder templateBuilder = JavaMddFactory.getDefault().createUnsafeTemplateBuilder(); + MddUnsafeTemplateBuilder templateBuilder = + JavaMddFactory.getDefault().createUnsafeTemplateBuilder(); for (IntObjCursor cFrom = n.cursor(); cFrom.moveNext(); ) { - MddNode s = saturate(cFrom.value(), - d.getDiagonal(stateSpaceInfo).get(cFrom.key()), - variable.getLower().orElse(null), - cache.getLower() - ); - - templateBuilder.set(cFrom.key(), - terminalZeroToNull(unionChildren(templateBuilder.get(cFrom.key()), s, variable)) - ); - + MddNode s = + saturate( + cFrom.value(), + d.getDiagonal(stateSpaceInfo).get(cFrom.key()), + variable.getLower().orElse(null), + cache.getLower()); + + templateBuilder.set( + cFrom.key(), + terminalZeroToNull( + unionChildren(templateBuilder.get(cFrom.key()), s, variable))); } - MddNode nsat = variable.checkInNode(MddStructuralTemplate.of(templateBuilder.buildAndReset())); + MddNode nsat = + variable.checkInNode(MddStructuralTemplate.of(templateBuilder.buildAndReset())); boolean changed; @@ -166,7 +173,7 @@ private MddNode saturate( final Optional> splitNS = d.split(); if (splitNS.isPresent()) { for (AbstractNextStateDescriptor dfire : splitNS.get()) { - //System.out.println("Applying transition: " + dfire); + // System.out.println("Applying transition: " + dfire); if (dfire.isLocallyIdentity(stateSpaceInfo)) { continue; } @@ -179,7 +186,7 @@ private MddNode saturate( } } } else if (!d.isLocallyIdentity(stateSpaceInfo)) { - //System.out.println("Applying transition: " + d); + // System.out.println("Applying transition: " + d); MddNode nfire = satFire(nsat, d, d, variable, cache); nfire = variable.union(nsat, nfire); @@ -195,12 +202,17 @@ private MddNode saturate( if (verbose) { indent--; printIndent(); - System.out.println("Done Saturating on level " + variable.getTraceInfo() + " resulting in " + nsat); + System.out.println( + "Done Saturating on level " + + variable.getTraceInfo() + + " resulting in " + + nsat); } // indent--; // printIndent(); - // System.out.println("Saturated level " + variable.getTraceInfo() + ", domain size is " + variable.getDomainSize()); + // System.out.println("Saturated level " + variable.getTraceInfo() + ", domain size is " + + // variable.getDomainSize()); // return nsat; } @@ -210,8 +222,7 @@ private MddNode satFire( AbstractNextStateDescriptor dsat, AbstractNextStateDescriptor dfire, MddVariable variable, - CacheManager.CacheHolder cache - ) { + CacheManager.CacheHolder cache) { if (n == terminalZeroNode || dfire == AbstractNextStateDescriptor.terminalEmpty()) { return terminalZeroNode; } @@ -220,31 +231,48 @@ private MddNode satFire( return n; } + boolean lhsSkipped = !n.isOn(variable); + if (verbose) { printIndent(); - System.out.println("SatFire on level " + - variable.getTraceInfo() + - " with dsat=" + - dsat + - "; dfire=" + - dfire); + System.out.println( + "SatFire on level " + + variable.getTraceInfo() + + " with dsat=" + + dsat + + "; dfire=" + + dfire); indent++; } - MddUnsafeTemplateBuilder templateBuilder = JavaMddFactory.getDefault().createUnsafeTemplateBuilder(); + MddUnsafeTemplateBuilder templateBuilder = + JavaMddFactory.getDefault().createUnsafeTemplateBuilder(); -// final IntObjMapView diagonal = dfire.getDiagonal( -// stateSpaceInfo); -// var c = diagonal.cursor(); + // final IntObjMapView diagonal = dfire.getDiagonal( + // stateSpaceInfo); + // var c = diagonal.cursor(); final var stateSpaceInfo = new MddStateSpaceInfo(variable, n); - final IntObjMapView> offDiagonal = dfire.getOffDiagonal( - stateSpaceInfo); + final IntObjMapView> offDiagonal = + dfire.getOffDiagonal(stateSpaceInfo); - for (IntObjCursor cFrom = n.cursor(); cFrom.moveNext(); ) { - for (IntObjCursor cTo = offDiagonal.get( - cFrom.key()).cursor(); cTo.moveNext(); ) { + final RecursiveIntObjMapView lhsInterpreter; + if ((lhsSkipped || (n.defaultValue() != null && n.isEmpty())) && !variable.isBounded()) { + final MddNode childCandidate = lhsSkipped ? n : n.defaultValue(); + // We use the keyset of the ANSD to trim + lhsInterpreter = + RecursiveIntObjMapView.of( + IntObjMapView.empty(childCandidate).trim(offDiagonal.keySet())); + } else { + lhsInterpreter = + variable.getNodeInterpreter( + n); // using the interpreter might cause a performance overhead + } + for (IntObjCursor cFrom = lhsInterpreter.cursor(); cFrom.moveNext(); ) { + for (IntObjCursor cTo = + offDiagonal.get(cFrom.key()).cursor(); + cTo.moveNext(); ) { if (cFrom.key() == cTo.key()) { continue; } @@ -256,29 +284,36 @@ private MddNode satFire( assert cFrom.value() != terminalZeroNode; assert cTo.value() != AbstractNextStateDescriptor.terminalEmpty(); - MddNode s = relProd(cFrom.value(), - dsat.getDiagonal(stateSpaceInfo).get(cTo.key()), - cTo.value(), - variable.getLower().orElse(null), - cache.getLower() - ); + MddNode s = + relProd( + cFrom.value(), + dsat.getDiagonal(stateSpaceInfo).get(cTo.key()), + cTo.value(), + variable.getLower().orElse(null), + cache.getLower()); if (s != terminalZeroNode) { confirm(variable, cTo.key()); - templateBuilder.set(cTo.key(), - terminalZeroToNull(unionChildren(templateBuilder.get(cTo.key()), s, variable)) - ); + templateBuilder.set( + cTo.key(), + terminalZeroToNull( + unionChildren(templateBuilder.get(cTo.key()), s, variable))); } } } - MddNode ret = variable.checkInNode(MddStructuralTemplate.of(templateBuilder.buildAndReset())); + final var template = templateBuilder.buildAndReset(); + if (!template.isEmpty()) + Preconditions.checkArgument( + n.defaultValue() == null, "Default value is not supported with explicit edges"); + MddNode ret = variable.checkInNode(MddStructuralTemplate.of(template)); if (verbose) { indent--; printIndent(); - System.out.println("Done SatFire on level " + variable.getTraceInfo() + " resulting in " + ret); + System.out.println( + "Done SatFire on level " + variable.getTraceInfo() + " resulting in " + ret); } return ret; @@ -289,8 +324,7 @@ private MddNode relProd( AbstractNextStateDescriptor dsat, AbstractNextStateDescriptor dfire, MddVariable variable, - CacheManager.CacheHolder cache - ) { + CacheManager.CacheHolder cache) { if (n == terminalZeroNode || dfire == AbstractNextStateDescriptor.terminalEmpty()) { return terminalZeroNode; } @@ -303,6 +337,8 @@ private MddNode relProd( return n; } + boolean lhsSkipped = !n.isOn(variable); + final MddStateSpaceInfo stateSpaceInfo = new MddStateSpaceInfo(variable, n); MddNode ret = cache.getCache().getRelProdCache().getOrNull(n, dsat, dfire); @@ -312,25 +348,40 @@ private MddNode relProd( if (verbose) { printIndent(); - System.out.println("SatRelProd on level " + - variable.getTraceInfo() + - ", node=" + - n + - ", with dsat=" + - dsat + - "; dfire" + - "=" + - dfire); + System.out.println( + "SatRelProd on level " + + variable.getTraceInfo() + + ", node=" + + n + + ", with dsat=" + + dsat + + "; dfire" + + "=" + + dfire); indent++; } - MddUnsafeTemplateBuilder templateBuilder = JavaMddFactory.getDefault().createUnsafeTemplateBuilder(); - - final IntObjMapView diagonal = dfire.getDiagonal(stateSpaceInfo); - final IntObjMapView> offDiagonal = dfire.getOffDiagonal( - stateSpaceInfo); - - for (IntObjCursor cFrom = n.cursor(); cFrom.moveNext(); ) { + MddUnsafeTemplateBuilder templateBuilder = + JavaMddFactory.getDefault().createUnsafeTemplateBuilder(); + + final IntObjMapView diagonal = + dfire.getDiagonal(stateSpaceInfo); + final IntObjMapView> offDiagonal = + dfire.getOffDiagonal(stateSpaceInfo); + + final RecursiveIntObjMapView lhsInterpreter; + if ((lhsSkipped || (n.defaultValue() != null && n.isEmpty())) && !variable.isBounded()) { + final MddNode childCandidate = lhsSkipped ? n : n.defaultValue(); + // We use the keyset of the ANSD to trim + lhsInterpreter = + RecursiveIntObjMapView.of( + IntObjMapView.empty(childCandidate).trim(offDiagonal.keySet())); + } else { + lhsInterpreter = + variable.getNodeInterpreter( + n); // using the interpreter might cause a performance overhead + } + for (IntObjCursor cFrom = lhsInterpreter.cursor(); cFrom.moveNext(); ) { // Identity step final AbstractNextStateDescriptor diagonalContinuation = diagonal.get(cFrom.key()); if (!AbstractNextStateDescriptor.isNullOrEmpty(diagonalContinuation)) { @@ -339,24 +390,27 @@ private MddNode relProd( System.out.println("Potential step: " + cFrom.key() + "->" + cFrom.key()); } - MddNode s = relProd(cFrom.value(), - dsat.getDiagonal(stateSpaceInfo).get(cFrom.key()), - diagonalContinuation, - variable.getLower().orElse(null), - cache.getLower() - ); + MddNode s = + relProd( + cFrom.value(), + dsat.getDiagonal(stateSpaceInfo).get(cFrom.key()), + diagonalContinuation, + variable.getLower().orElse(null), + cache.getLower()); if (s != terminalZeroNode) { // confirm(variable, cFrom.key()); - templateBuilder.set(cFrom.key(), - terminalZeroToNull(unionChildren(templateBuilder.get(cFrom.key()), s, variable)) - ); + templateBuilder.set( + cFrom.key(), + terminalZeroToNull( + unionChildren(templateBuilder.get(cFrom.key()), s, variable))); } } - for (IntObjCursor cTo = offDiagonal.get(cFrom.key()).cursor(); - cTo.moveNext(); ) { + for (IntObjCursor cTo = + offDiagonal.get(cFrom.key()).cursor(); + cTo.moveNext(); ) { if (cFrom.key() == cTo.key()) { continue; } @@ -368,24 +422,30 @@ private MddNode relProd( assert cFrom.value() != terminalZeroNode; assert cTo.value() != AbstractNextStateDescriptor.terminalEmpty(); - MddNode s = relProd(cFrom.value(), - dsat.getDiagonal(stateSpaceInfo).get(cTo.key()), - cTo.value(), - variable.getLower().orElse(null), - cache.getLower() - ); + MddNode s = + relProd( + cFrom.value(), + dsat.getDiagonal(stateSpaceInfo).get(cTo.key()), + cTo.value(), + variable.getLower().orElse(null), + cache.getLower()); if (s != terminalZeroNode) { confirm(variable, cTo.key()); - templateBuilder.set(cTo.key(), - terminalZeroToNull(unionChildren(templateBuilder.get(cTo.key()), s, variable)) - ); + templateBuilder.set( + cTo.key(), + terminalZeroToNull( + unionChildren(templateBuilder.get(cTo.key()), s, variable))); } } } - ret = variable.checkInNode(MddStructuralTemplate.of(templateBuilder.buildAndReset())); + final var template = templateBuilder.buildAndReset(); + if (!template.isEmpty()) + Preconditions.checkArgument( + n.defaultValue() == null, "Default value is not supported with explicit edges"); + ret = variable.checkInNode(MddStructuralTemplate.of(template)); ret = saturate(ret, dsat, variable, cache); @@ -394,20 +454,20 @@ private MddNode relProd( if (verbose) { indent--; printIndent(); - System.out.println("Done SatRelProd on level " + variable.getTraceInfo() + " resulting in " + ret); + System.out.println( + "Done SatRelProd on level " + variable.getTraceInfo() + " resulting in " + ret); } return ret; } - private void confirm(final MddVariable variable, final int key) { - - } + private void confirm(final MddVariable variable, final int key) {} @Override public MddNode computeTerminal( - final MddNode mddNode, final AbstractNextStateDescriptor nextState, final MddGraph mddGraph - ) { + final MddNode mddNode, + final AbstractNextStateDescriptor nextState, + final MddGraph mddGraph) { return mddNode; } @@ -435,12 +495,19 @@ public void clear() { @Override public void cleanup() { - this.cacheManager.forEachCache((cache) -> { - cache.getSaturateCache().clearSelectively((source, ns1, result) -> source.getReferenceCount() == 0 || - result.getReferenceCount() == 0); - cache.getRelProdCache().clearSelectively((source, ns1, ns2, result) -> source.getReferenceCount() == 0 || - result.getReferenceCount() == 0); - }); + this.cacheManager.forEachCache( + (cache) -> { + cache.getSaturateCache() + .clearSelectively( + (source, ns1, result) -> + source.getReferenceCount() == 0 + || result.getReferenceCount() == 0); + cache.getRelProdCache() + .clearSelectively( + (source, ns1, ns2, result) -> + source.getReferenceCount() == 0 + || result.getReferenceCount() == 0); + }); } private class Aggregator implements Consumer { @@ -456,7 +523,6 @@ private Aggregator(final ToLongFunction extractor) { public void accept(final SaturationCache cache) { result += extractor.applyAsLong(cache); } - } public Cache getSaturateCache() { @@ -501,10 +567,14 @@ public long getHitCount() { public Set getSaturatedNodes() { final Set ret = HashObjSets.newUpdatableSet(); - cacheManager.forEachCache((c) -> c.getSaturateCache().clearSelectively((source, ns, result) -> { - ret.add(result); - return false; - })); + cacheManager.forEachCache( + (c) -> + c.getSaturateCache() + .clearSelectively( + (source, ns, result) -> { + ret.add(result); + return false; + })); return ret; } diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/LegacyRelationalProductProvider.java b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/LegacyRelationalProductProvider.java index 807360e81b..c849b410f5 100644 --- a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/LegacyRelationalProductProvider.java +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/LegacyRelationalProductProvider.java @@ -15,19 +15,20 @@ */ package hu.bme.mit.theta.analysis.algorithm.mdd.fixedpoint; +import com.google.common.base.Preconditions; import hu.bme.mit.delta.collections.IntObjCursor; import hu.bme.mit.delta.collections.IntObjMapView; +import hu.bme.mit.delta.collections.RecursiveIntObjMapView; import hu.bme.mit.delta.collections.impl.IntObjMapViews; import hu.bme.mit.delta.java.mdd.*; import hu.bme.mit.delta.java.mdd.impl.MddStructuralTemplate; import hu.bme.mit.theta.analysis.algorithm.mdd.ansd.AbstractNextStateDescriptor; - import java.util.function.Consumer; import java.util.function.ToLongFunction; public final class LegacyRelationalProductProvider implements RelationalProductProvider { - private final CacheManager> cacheManager = new CacheManager<>( - v -> new BinaryOperationCache<>()); + private final CacheManager> + cacheManager = new CacheManager<>(v -> new BinaryOperationCache<>()); private final MddVariableOrder variableOrder; public LegacyRelationalProductProvider(final MddVariableOrder variableOrder) { @@ -35,18 +36,23 @@ public LegacyRelationalProductProvider(final MddVariableOrder variableOrder) { this.variableOrder.getMddGraph().registerCleanupListener(this); } - private MddNode recurse(final MddNode mddNode, final AbstractNextStateDescriptor nextState, - MddVariable currentVariable, final CacheManager>.CacheHolder currentCache) { + private MddNode recurse( + final MddNode mddNode, + final AbstractNextStateDescriptor nextState, + MddVariable currentVariable, + final CacheManager> + .CacheHolder + currentCache) { if (currentVariable.getLower().isPresent()) { - return doCompute(mddNode, nextState, currentVariable.getLower().get(), currentCache.getLower()); + return doCompute( + mddNode, nextState, currentVariable.getLower().get(), currentCache.getLower()); } else { return computeTerminal(mddNode, nextState, currentVariable.getMddGraph()); } } - private MddNode unionChildren(final MddNode lhs, final MddNode rhs, - MddVariable currentVariable) { + private MddNode unionChildren( + final MddNode lhs, final MddNode rhs, MddVariable currentVariable) { if (currentVariable.getLower().isPresent()) { return currentVariable.getLower().get().union(lhs, rhs); } else { @@ -58,19 +64,26 @@ private MddNode unionChildren(final MddNode lhs, final MddNode rhs, public MddNode compute( final MddNode mddNode, final AbstractNextStateDescriptor abstractNextStateDescriptor, - final MddVariable mddVariable - ) { - return doCompute(mddNode, abstractNextStateDescriptor, mddVariable, cacheManager.getCacheFor(mddVariable)); + final MddVariable mddVariable) { + return doCompute( + mddNode, + abstractNextStateDescriptor, + mddVariable, + cacheManager.getCacheFor(mddVariable)); } private MddNode doCompute( final MddNode lhs, final AbstractNextStateDescriptor nextState, final MddVariable variable, - final CacheManager>.CacheHolder cache - ) { - assert cache != null : "Invalid behavior for CacheManager: should have assigned a cache to every variable."; - if (variable.isNullOrZero(lhs) || nextState == AbstractNextStateDescriptor.terminalIdentity()) { + final CacheManager> + .CacheHolder + cache) { + assert cache != null + : "Invalid behavior for CacheManager: should have assigned a cache to every" + + " variable."; + if (variable.isNullOrZero(lhs) + || nextState == AbstractNextStateDescriptor.terminalIdentity()) { return lhs; } if (nextState == null || nextState == AbstractNextStateDescriptor.terminalEmpty()) { @@ -79,11 +92,6 @@ private MddNode doCompute( boolean lhsSkipped = !lhs.isOn(variable); - if ((lhsSkipped || !variable.isNullOrZero(lhs.defaultValue())) && - !(lhs.isTerminal() && nextState instanceof AbstractNextStateDescriptor.Postcondition)) { - throw new UnsupportedOperationException("Default values are not yet supported in relational product."); - } - MddNode ret = cache.getCache().getOrNull(lhs, nextState); if (ret != null) { return ret; @@ -91,20 +99,29 @@ private MddNode doCompute( final MddStateSpaceInfo stateSpaceInfo = new MddStateSpaceInfo(variable, lhs); - final IntObjMapView diagonal = nextState.getDiagonal(stateSpaceInfo); - final IntObjMapView> offDiagonal = nextState.getOffDiagonal( - stateSpaceInfo); + final IntObjMapView diagonal = + nextState.getDiagonal(stateSpaceInfo); + final IntObjMapView> offDiagonal = + nextState.getOffDiagonal(stateSpaceInfo); IntObjMapView template; // Patch to enable initializers if (lhs.isTerminal() && nextState instanceof AbstractNextStateDescriptor.Postcondition) { - template = new IntObjMapViews.Transforming(nextState.getDiagonal( - stateSpaceInfo), ns -> ns == null ? null : terminalZeroToNull(recurse(lhs, ns, variable, cache), - variable.getMddGraph().getTerminalZeroNode())); - // } else if (diagonal.isEmpty() && offDiagonal.isEmpty() && AbstractNextStateDescriptor.isNullOrEmpty( + template = + new IntObjMapViews.Transforming( + nextState.getDiagonal(stateSpaceInfo), + ns -> + ns == null + ? null + : terminalZeroToNull( + recurse(lhs, ns, variable, cache), + variable.getMddGraph().getTerminalZeroNode())); + // } else if (diagonal.isEmpty() && offDiagonal.isEmpty() && + // AbstractNextStateDescriptor.isNullOrEmpty( // offDiagonal.defaultValue())) { - // // Either the ANSD does not affect this level or it is not fireable - will be evaluated in the next call + // // Either the ANSD does not affect this level or it is not fireable - will be + // evaluated in the next call // // TODO: THIS IS GONNA BE TERRIBLY SLOW // template = new IntObjMapViews.Transforming(lhs, // (child) -> child == null ? null : terminalZeroToNull( @@ -112,26 +129,47 @@ private MddNode doCompute( // variable.getMddGraph().getTerminalZeroNode() // )); } else { - MddUnsafeTemplateBuilder templateBuilder = JavaMddFactory.getDefault().createUnsafeTemplateBuilder(); - for (IntObjCursor c = lhs.cursor(); c.moveNext(); ) { + MddUnsafeTemplateBuilder templateBuilder = + JavaMddFactory.getDefault().createUnsafeTemplateBuilder(); + RecursiveIntObjMapView lhsInterpreter; + if ((lhsSkipped || (lhs.defaultValue() != null && lhs.isEmpty())) + && !variable.isBounded()) { + final MddNode childCandidate = lhsSkipped ? lhs : lhs.defaultValue(); + // We use the keyset of the ANSD to trim + lhsInterpreter = + RecursiveIntObjMapView.of( + IntObjMapView.empty(childCandidate).trim(offDiagonal.keySet())); + } else { + lhsInterpreter = + variable.getNodeInterpreter( + lhs); // using the interpreter might cause a performance overhead + } + for (IntObjCursor c = lhsInterpreter.cursor(); c.moveNext(); ) { final MddNode res = recurse(c.value(), diagonal.get(c.key()), variable, cache); final MddNode unioned = unionChildren(res, templateBuilder.get(c.key()), variable); - templateBuilder.set(c.key(), - terminalZeroToNull(unioned, variable.getMddGraph().getTerminalZeroNode()) - ); + templateBuilder.set( + c.key(), + terminalZeroToNull(unioned, variable.getMddGraph().getTerminalZeroNode())); - for (IntObjCursor next = offDiagonal.get(c.key()).cursor(); - next.moveNext(); ) { + for (IntObjCursor next = + offDiagonal.get(c.key()).cursor(); + next.moveNext(); ) { final MddNode res1 = recurse(c.value(), next.value(), variable, cache); - final MddNode unioned1 = unionChildren(res1, templateBuilder.get(next.key()), variable); + final MddNode unioned1 = + unionChildren(res1, templateBuilder.get(next.key()), variable); - templateBuilder.set(next.key(), - terminalZeroToNull(unioned1, variable.getMddGraph().getTerminalZeroNode()) - ); + templateBuilder.set( + next.key(), + terminalZeroToNull( + unioned1, variable.getMddGraph().getTerminalZeroNode())); } } template = templateBuilder.buildAndReset(); + if (!template.isEmpty()) + Preconditions.checkArgument( + lhs.defaultValue() == null, + "Default value is not supported with explicit edges"); } ret = variable.checkInNode(MddStructuralTemplate.of(template)); @@ -143,8 +181,9 @@ private MddNode doCompute( @Override public MddNode computeTerminal( - final MddNode mddNode, final AbstractNextStateDescriptor abstractNextStateDescriptor, final MddGraph mddGraph - ) { + final MddNode mddNode, + final AbstractNextStateDescriptor abstractNextStateDescriptor, + final MddGraph mddGraph) { if (mddNode == mddGraph.getTerminalZeroNode() || !abstractNextStateDescriptor.evaluate()) { return mddGraph.getTerminalZeroNode(); } @@ -167,31 +206,47 @@ public void clear() { @Override public void cleanup() { - this.cacheManager.forEachCache((cache) -> cache.clearSelectively((source, ns, result) -> source.getReferenceCount() == - 0 || - result.getReferenceCount() == - 0)); + this.cacheManager.forEachCache( + (cache) -> + cache.clearSelectively( + (source, ns, result) -> + source.getReferenceCount() == 0 + || result.getReferenceCount() == 0)); } - private class Aggregator implements Consumer> { + private class Aggregator + implements Consumer< + BinaryOperationCache> { public long result = 0; - private final ToLongFunction> extractor; - - private Aggregator(final ToLongFunction> extractor) { + private final ToLongFunction< + BinaryOperationCache> + extractor; + + private Aggregator( + final ToLongFunction< + BinaryOperationCache> + extractor) { this.extractor = extractor; } @Override - public void accept(final BinaryOperationCache cache) { + public void accept( + final BinaryOperationCache cache) { result += extractor.applyAsLong(cache); } } public Cache getRelProdCache() { class RelProdCache implements Cache { - private final CacheManager> cacheManager; - - RelProdCache(final CacheManager> cacheManager) { + private final CacheManager< + BinaryOperationCache> + cacheManager; + + RelProdCache( + final CacheManager< + BinaryOperationCache< + MddNode, AbstractNextStateDescriptor, MddNode>> + cacheManager) { this.cacheManager = cacheManager; } diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/MddStateSpaceInfo.java b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/MddStateSpaceInfo.java index c41c4d5d47..a3dce0b2c1 100644 --- a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/MddStateSpaceInfo.java +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/MddStateSpaceInfo.java @@ -15,11 +15,14 @@ */ package hu.bme.mit.theta.analysis.algorithm.mdd.fixedpoint; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.True; + import com.google.common.base.Preconditions; import com.koloboke.collect.map.ObjIntMap; import com.koloboke.collect.map.hash.HashObjIntMaps; import com.koloboke.collect.set.ObjSet; import com.koloboke.collect.set.hash.HashObjSets; +import hu.bme.mit.delta.Pair; import hu.bme.mit.delta.collections.IntObjMapView; import hu.bme.mit.delta.collections.IntSetView; import hu.bme.mit.delta.collections.IntStatistics; @@ -30,17 +33,12 @@ import hu.bme.mit.delta.java.mdd.impl.MddStructuralTemplate; import hu.bme.mit.theta.analysis.algorithm.mdd.ansd.StateSpaceInfo; import hu.bme.mit.theta.common.container.Containers; -import hu.bme.mit.delta.Pair; import hu.bme.mit.theta.core.type.Expr; import hu.bme.mit.theta.core.type.booltype.BoolType; - import java.util.Objects; import java.util.Optional; - import java.util.Set; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.True; - public final class MddStateSpaceInfo implements StateSpaceInfo { private final MddVariable variable; private final MddNode mddNode; @@ -51,8 +49,7 @@ public MddStateSpaceInfo(final MddVariable variable, final MddNode mddNode) { this.variable = variable; this.mddNode = mddNode; - for (var c = mddNode.cursor(); c.moveNext(); ) { - } // TODO delete later + for (var c = mddNode.cursor(); c.moveNext(); ) {} // TODO delete later } @Override @@ -109,7 +106,7 @@ public IntSetView getLocalStateSpace() { public StateSpaceInfo getLocalStateSpace(final Object someLowerComponent) { // TODO: Auto-generated method stub. throw new UnsupportedOperationException("Not (yet) implemented."); - //return null; + // return null; } @Override @@ -119,7 +116,6 @@ public MddNode toStructuralRepresentation() { structuralRepresentation = representBounds(variable, boundsCollector); } return structuralRepresentation; - } private MddNode representBounds(MddVariable variable, BoundsCollector boundsCollector) { @@ -127,7 +123,8 @@ private MddNode representBounds(MddVariable variable, BoundsCollector boundsColl if (variable.getLower().isPresent()) { continuation = representBounds(variable.getLower().get(), boundsCollector); } else { - final MddGraph> mddGraph = (MddGraph>) variable.getMddGraph(); + final MddGraph> mddGraph = + (MddGraph>) variable.getMddGraph(); continuation = mddGraph.getNodeFor(True()); } final var bounds = boundsCollector.getBoundsFor(variable); @@ -137,28 +134,28 @@ private MddNode representBounds(MddVariable variable, BoundsCollector boundsColl template = IntObjMapView.singleton(bounds.get().first, continuation); } else { // TODO: canonization of trimmed intobjmapviews could be improved - template = new IntObjMapViews.Trimmed<>( - IntObjMapView.empty(continuation), - IntSetView.range(bounds.get().first, bounds.get().second + 1) - ); + template = + new IntObjMapViews.Trimmed<>( + IntObjMapView.empty(continuation), + IntSetView.range(bounds.get().first, bounds.get().second + 1)); } } else { template = IntObjMapView.empty(continuation); } return variable.checkInNode(MddStructuralTemplate.of(template)); - } -// private MddNode collapseEdges(MddNode parent) { -// -// IntSetView setView = IntSetView.empty(); -// for (var c = parent.cursor(); c.moveNext(); ) { -// setView = setView.union(c.value().keySet()); -// } -// -// } - private class BoundsCollector { + // private MddNode collapseEdges(MddNode parent) { + // + // IntSetView setView = IntSetView.empty(); + // for (var c = parent.cursor(); c.moveNext(); ) { + // setView = setView.union(c.value().keySet()); + // } + // + // } + + private static class BoundsCollector { private final ObjIntMap lowerBounds; private final ObjIntMap upperBounds; @@ -174,29 +171,36 @@ public BoundsCollector(MddNode rootNode, MddVariable variable) { traverse(rootNode, variable, traversed); } - private void traverse(final MddNode node, final MddVariable variable, - final Set traversed) { + private void traverse( + final MddNode node, final MddVariable variable, final Set traversed) { if (traversed.contains(node) || node.isTerminal()) { return; } else { traversed.add(node); } + Preconditions.checkNotNull(variable); - for (var c = node.cursor(); c.moveNext(); ) { - } // TODO delete later + for (var c = node.cursor(); c.moveNext(); ) {} // TODO delete later - if (node.defaultValue() != null) { - final MddNode defaultValue = node.defaultValue(); + final var nodeInterpreter = variable.getNodeInterpreter(node); + if (nodeInterpreter.defaultValue() != null) { + final MddNode defaultValue = nodeInterpreter.defaultValue(); traverse(defaultValue, variable.getLower().orElse(null), traversed); hasDefaultValue.add(variable); } else { - final IntStatistics statistics = node.statistics(); - if (variable != null) { - lowerBounds.put(variable, Math.min(lowerBounds.getOrDefault(variable, Integer.MAX_VALUE), statistics.lowestValue())); - upperBounds.put(variable, Math.max(upperBounds.getOrDefault(variable, Integer.MIN_VALUE), statistics.highestValue())); - } - - for (var cur = node.cursor(); cur.moveNext(); ) { + final IntStatistics statistics = nodeInterpreter.statistics(); + lowerBounds.put( + variable, + Math.min( + lowerBounds.getOrDefault(variable, Integer.MAX_VALUE), + statistics.lowestValue())); + upperBounds.put( + variable, + Math.max( + upperBounds.getOrDefault(variable, Integer.MIN_VALUE), + statistics.highestValue())); + + for (var cur = nodeInterpreter.cursor(); cur.moveNext(); ) { if (cur.value() != null) { traverse(cur.value(), variable.getLower().orElse(null), traversed); } @@ -204,13 +208,12 @@ private void traverse(final MddNode node, final MddVariable variable, } } - public Optional> getBoundsFor(MddVariable variable) { if (hasDefaultValue.contains(variable)) return Optional.empty(); if (!lowerBounds.containsKey(variable) || !upperBounds.containsKey(variable)) return Optional.empty(); - return Optional.of(new Pair<>(lowerBounds.getInt(variable), upperBounds.getInt(variable))); + return Optional.of( + new Pair<>(lowerBounds.getInt(variable), upperBounds.getInt(variable))); } } - } diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/SimpleSaturationProvider.java b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/SimpleSaturationProvider.java index f9ad7278f3..582972ca21 100644 --- a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/SimpleSaturationProvider.java +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/fixedpoint/SimpleSaturationProvider.java @@ -15,14 +15,15 @@ */ package hu.bme.mit.theta.analysis.algorithm.mdd.fixedpoint; +import com.google.common.base.Preconditions; import com.koloboke.collect.set.hash.HashObjSets; import hu.bme.mit.delta.collections.IntObjCursor; import hu.bme.mit.delta.collections.IntObjMapView; +import hu.bme.mit.delta.collections.RecursiveIntObjMapView; import hu.bme.mit.delta.collections.impl.IntObjMapViews; import hu.bme.mit.delta.java.mdd.*; import hu.bme.mit.delta.java.mdd.impl.MddStructuralTemplate; import hu.bme.mit.theta.analysis.algorithm.mdd.ansd.AbstractNextStateDescriptor; - import java.util.Optional; import java.util.Set; import java.util.function.Consumer; @@ -33,7 +34,8 @@ public final class SimpleSaturationProvider implements StateSpaceEnumerationProv private MddVariableOrder variableOrder; private RelationalProductProvider relProdProvider; - private final CacheManager cacheManager = new CacheManager<>(v -> new SaturationCache()); + private final CacheManager cacheManager = + new CacheManager<>(v -> new SaturationCache()); private MddNode terminalZeroNode; public SimpleSaturationProvider(final MddVariableOrder variableOrder) { @@ -41,8 +43,7 @@ public SimpleSaturationProvider(final MddVariableOrder variableOrder) { } public SimpleSaturationProvider( - final MddVariableOrder variableOrder, final RelationalProductProvider relProdProvider - ) { + final MddVariableOrder variableOrder, final RelationalProductProvider relProdProvider) { this.variableOrder = variableOrder; this.relProdProvider = relProdProvider; this.variableOrder.getMddGraph().registerCleanupListener(this); @@ -52,9 +53,12 @@ public SimpleSaturationProvider( public MddHandle compute( AbstractNextStateDescriptor.Postcondition initializer, AbstractNextStateDescriptor nextStateRelation, - MddVariableHandle highestAffectedVariable - ) { - final MddHandle initialStates = relProdProvider.compute(variableOrder.getMddGraph().getHandleForTop(), initializer, highestAffectedVariable); + MddVariableHandle highestAffectedVariable) { + final MddHandle initialStates = + relProdProvider.compute( + variableOrder.getMddGraph().getHandleForTop(), + initializer, + highestAffectedVariable); MddNode result; @@ -62,10 +66,11 @@ public MddHandle compute( final MddVariable variable = highestAffectedVariable.getVariable().get(); result = this.compute(initialStates.getNode(), nextStateRelation, variable); } else { - result = this.computeTerminal(initialStates.getNode(), - nextStateRelation, - highestAffectedVariable.getMddGraph() - ); + result = + this.computeTerminal( + initialStates.getNode(), + nextStateRelation, + highestAffectedVariable.getMddGraph()); } return highestAffectedVariable.getHandleFor(result); @@ -75,8 +80,9 @@ private MddNode recurse( final MddNode mddNode, final AbstractNextStateDescriptor nextState, MddVariable currentVariable, - final CacheManager>.CacheHolder cache - ) { + final CacheManager> + .CacheHolder + cache) { if (currentVariable.getLower().isPresent()) { return compute(mddNode, nextState, currentVariable.getLower().get()); } else { @@ -85,8 +91,7 @@ private MddNode recurse( } private MddNode unionChildren( - final MddNode lhs, final MddNode rhs, MddVariable currentVariable - ) { + final MddNode lhs, final MddNode rhs, MddVariable currentVariable) { if (currentVariable.getLower().isPresent()) { return currentVariable.getLower().get().union(lhs, rhs); } else { @@ -96,8 +101,9 @@ private MddNode unionChildren( @Override public MddNode compute( - final MddNode mddNode, final AbstractNextStateDescriptor nextState, final MddVariable mddVariable - ) { + final MddNode mddNode, + final AbstractNextStateDescriptor nextState, + final MddVariable mddVariable) { return saturate(mddNode, nextState, mddVariable, cacheManager.getCacheFor(mddVariable)); } @@ -105,11 +111,10 @@ private MddNode saturate( final MddNode n, AbstractNextStateDescriptor d, MddVariable variable, - CacheManager.CacheHolder cache - ) { - if (n.isTerminal() || - d == AbstractNextStateDescriptor.terminalIdentity() || - d == AbstractNextStateDescriptor.terminalEmpty()) { + CacheManager.CacheHolder cache) { + if (n.isTerminal() + || d == AbstractNextStateDescriptor.terminalIdentity() + || d == AbstractNextStateDescriptor.terminalEmpty()) { // TODO this does not handle level skips return n; } @@ -127,13 +132,18 @@ private MddNode saturate( final MddStateSpaceInfo stateSpaceInfo = new MddStateSpaceInfo(variable, n); - IntObjMapView satTemplate = new IntObjMapViews.Transforming(n, - (node, key) -> node == null ? null : terminalZeroToNull(saturate(node, - d.getDiagonal(stateSpaceInfo).get(key), - variable.getLower().orElse(null), - cache.getLower() - )) - ); + IntObjMapView satTemplate = + new IntObjMapViews.Transforming( + n, + (node, key) -> + node == null + ? null + : terminalZeroToNull( + saturate( + node, + d.getDiagonal(stateSpaceInfo).get(key), + variable.getLower().orElse(null), + cache.getLower()))); MddNode nsat = variable.checkInNode(MddStructuralTemplate.of(satTemplate)); @@ -145,11 +155,11 @@ private MddNode saturate( final Optional> splitNS = d.split(); if (splitNS.isPresent()) { for (AbstractNextStateDescriptor dfire : splitNS.get()) { - //System.out.println("Applying transition: " + dfire); + // System.out.println("Applying transition: " + dfire); if (dfire.isLocallyIdentity(stateSpaceInfo)) { continue; } - MddNode nfire = satFire(nsat, d, dfire, variable, cache, stateSpaceInfo); + MddNode nfire = satFire(nsat, d, dfire, variable, cache); nfire = variable.union(nsat, nfire); @@ -159,8 +169,8 @@ private MddNode saturate( } } } else if (!d.isLocallyIdentity(stateSpaceInfo)) { - //System.out.println("Applying transition: " + d); - MddNode nfire = satFire(nsat, d, d, variable, cache, stateSpaceInfo); + // System.out.println("Applying transition: " + d); + MddNode nfire = satFire(nsat, d, d, variable, cache); nfire = variable.union(nsat, nfire); @@ -176,12 +186,17 @@ private MddNode saturate( if (verbose) { indent--; printIndent(); - System.out.println("Done Saturating on level " + variable.getTraceInfo() + " resulting in " + nsat); + System.out.println( + "Done Saturating on level " + + variable.getTraceInfo() + + " resulting in " + + nsat); } // indent--; // printIndent(); - // System.out.println("Saturated level " + variable.getTraceInfo() + ", domain size is " + variable.getDomainSize()); + // System.out.println("Saturated level " + variable.getTraceInfo() + ", domain size is " + + // variable.getDomainSize()); // return nsat; } @@ -191,9 +206,7 @@ private MddNode satFire( AbstractNextStateDescriptor dsat, AbstractNextStateDescriptor dfire, MddVariable variable, - CacheManager.CacheHolder cache, - final MddStateSpaceInfo stateSpaceInfo - ) { + CacheManager.CacheHolder cache) { if (n == terminalZeroNode || dfire == AbstractNextStateDescriptor.terminalEmpty()) { return terminalZeroNode; } @@ -202,24 +215,43 @@ private MddNode satFire( return n; } + boolean lhsSkipped = !n.isOn(variable); + if (verbose) { printIndent(); - System.out.println("SatFire on level " + - variable.getTraceInfo() + - " with dsat=" + - dsat + - "; dfire=" + - dfire); + System.out.println( + "SatFire on level " + + variable.getTraceInfo() + + " with dsat=" + + dsat + + "; dfire=" + + dfire); indent++; } - MddUnsafeTemplateBuilder templateBuilder = JavaMddFactory.getDefault().createUnsafeTemplateBuilder(); + MddUnsafeTemplateBuilder templateBuilder = + JavaMddFactory.getDefault().createUnsafeTemplateBuilder(); - final IntObjMapView diagonal = dfire.getDiagonal(stateSpaceInfo); - final IntObjMapView> offDiagonal = dfire.getOffDiagonal( - stateSpaceInfo); + final MddStateSpaceInfo stateSpaceInfo = new MddStateSpaceInfo(variable, n); - for (IntObjCursor cFrom = n.cursor(); cFrom.moveNext(); ) { + final IntObjMapView diagonal = + dfire.getDiagonal(stateSpaceInfo); + final IntObjMapView> offDiagonal = + dfire.getOffDiagonal(stateSpaceInfo); + + final RecursiveIntObjMapView lhsInterpreter; + if ((lhsSkipped || (n.defaultValue() != null && n.isEmpty())) && !variable.isBounded()) { + final MddNode childCandidate = lhsSkipped ? n : n.defaultValue(); + // We use the keyset of the ANSD to trim + lhsInterpreter = + RecursiveIntObjMapView.of( + IntObjMapView.empty(childCandidate).trim(offDiagonal.keySet())); + } else { + lhsInterpreter = + variable.getNodeInterpreter( + n); // using the interpreter might cause a performance overhead + } + for (IntObjCursor cFrom = lhsInterpreter.cursor(); cFrom.moveNext(); ) { // Identity step final AbstractNextStateDescriptor diagonalContinuation = diagonal.get(cFrom.key()); @@ -229,24 +261,27 @@ private MddNode satFire( System.out.println("Potential step: " + cFrom.key() + "->" + cFrom.key()); } - MddNode s = relProd(cFrom.value(), - dsat.getDiagonal(stateSpaceInfo).defaultValue(), - diagonalContinuation, - variable.getLower().orElse(null), - cache.getLower() - ); + MddNode s = + relProd( + cFrom.value(), + dsat.getDiagonal(stateSpaceInfo).defaultValue(), + diagonalContinuation, + variable.getLower().orElse(null), + cache.getLower()); if (s != terminalZeroNode) { // confirm(variable, cFrom.key()); - templateBuilder.set(cFrom.key(), - terminalZeroToNull(unionChildren(templateBuilder.get(cFrom.key()), s, variable)) - ); + templateBuilder.set( + cFrom.key(), + terminalZeroToNull( + unionChildren(templateBuilder.get(cFrom.key()), s, variable))); } } - for (IntObjCursor cTo = offDiagonal.get( - cFrom.key()).cursor(); cTo.moveNext(); ) { + for (IntObjCursor cTo = + offDiagonal.get(cFrom.key()).cursor(); + cTo.moveNext(); ) { if (cFrom.key() == cTo.key()) { continue; } @@ -258,30 +293,37 @@ private MddNode satFire( assert cFrom.value() != terminalZeroNode; assert cTo.value() != AbstractNextStateDescriptor.terminalEmpty(); - MddNode s = relProd(cFrom.value(), - // Level skip will be encoded as default value - dsat.getDiagonal(stateSpaceInfo).defaultValue(), - cTo.value(), - variable.getLower().orElse(null), - cache.getLower() - ); + MddNode s = + relProd( + cFrom.value(), + // Level skip will be encoded as default value + dsat.getDiagonal(stateSpaceInfo).defaultValue(), + cTo.value(), + variable.getLower().orElse(null), + cache.getLower()); if (s != terminalZeroNode) { confirm(variable, cTo.key()); - templateBuilder.set(cTo.key(), - terminalZeroToNull(unionChildren(templateBuilder.get(cTo.key()), s, variable)) - ); + templateBuilder.set( + cTo.key(), + terminalZeroToNull( + unionChildren(templateBuilder.get(cTo.key()), s, variable))); } } } - MddNode ret = variable.checkInNode(MddStructuralTemplate.of(templateBuilder.buildAndReset())); + final var template = templateBuilder.buildAndReset(); + if (!template.isEmpty()) + Preconditions.checkArgument( + n.defaultValue() == null, "Default value is not supported with explicit edges"); + MddNode ret = variable.checkInNode(MddStructuralTemplate.of(template)); if (verbose) { indent--; printIndent(); - System.out.println("Done SatFire on level " + variable.getTraceInfo() + " resulting in " + ret); + System.out.println( + "Done SatFire on level " + variable.getTraceInfo() + " resulting in " + ret); } return ret; @@ -292,8 +334,7 @@ private MddNode relProd( AbstractNextStateDescriptor dsat, AbstractNextStateDescriptor dfire, MddVariable variable, - CacheManager.CacheHolder cache - ) { + CacheManager.CacheHolder cache) { if (n == terminalZeroNode || dfire == AbstractNextStateDescriptor.terminalEmpty()) { return terminalZeroNode; } @@ -306,6 +347,8 @@ private MddNode relProd( return n; } + boolean lhsSkipped = !n.isOn(variable); + final MddStateSpaceInfo stateSpaceInfo = new MddStateSpaceInfo(variable, n); MddNode ret = cache.getCache().getRelProdCache().getOrNull(n, dsat, dfire); @@ -315,25 +358,40 @@ private MddNode relProd( if (verbose) { printIndent(); - System.out.println("SatRelProd on level " + - variable.getTraceInfo() + - ", node=" + - n + - ", with dsat=" + - dsat + - "; dfire" + - "=" + - dfire); + System.out.println( + "SatRelProd on level " + + variable.getTraceInfo() + + ", node=" + + n + + ", with dsat=" + + dsat + + "; dfire" + + "=" + + dfire); indent++; } - MddUnsafeTemplateBuilder templateBuilder = JavaMddFactory.getDefault().createUnsafeTemplateBuilder(); - - final IntObjMapView diagonal = dfire.getDiagonal(stateSpaceInfo); - final IntObjMapView> offDiagonal = dfire.getOffDiagonal( - stateSpaceInfo); - - for (IntObjCursor cFrom = n.cursor(); cFrom.moveNext(); ) { + MddUnsafeTemplateBuilder templateBuilder = + JavaMddFactory.getDefault().createUnsafeTemplateBuilder(); + + final IntObjMapView diagonal = + dfire.getDiagonal(stateSpaceInfo); + final IntObjMapView> offDiagonal = + dfire.getOffDiagonal(stateSpaceInfo); + + final RecursiveIntObjMapView lhsInterpreter; + if ((lhsSkipped || (n.defaultValue() != null && n.isEmpty())) && !variable.isBounded()) { + final MddNode childCandidate = lhsSkipped ? n : n.defaultValue(); + // We use the keyset of the ANSD to trim + lhsInterpreter = + RecursiveIntObjMapView.of( + IntObjMapView.empty(childCandidate).trim(offDiagonal.keySet())); + } else { + lhsInterpreter = + variable.getNodeInterpreter( + n); // using the interpreter might cause a performance overhead + } + for (IntObjCursor cFrom = lhsInterpreter.cursor(); cFrom.moveNext(); ) { // Identity step final AbstractNextStateDescriptor diagonalContinuation = diagonal.get(cFrom.key()); if (!AbstractNextStateDescriptor.isNullOrEmpty(diagonalContinuation)) { @@ -342,24 +400,27 @@ private MddNode relProd( System.out.println("Potential step: " + cFrom.key() + "->" + cFrom.key()); } - MddNode s = relProd(cFrom.value(), - dsat.getDiagonal(stateSpaceInfo).defaultValue(), - diagonalContinuation, - variable.getLower().orElse(null), - cache.getLower() - ); + MddNode s = + relProd( + cFrom.value(), + dsat.getDiagonal(stateSpaceInfo).defaultValue(), + diagonalContinuation, + variable.getLower().orElse(null), + cache.getLower()); if (s != terminalZeroNode) { // confirm(variable, cFrom.key()); - templateBuilder.set(cFrom.key(), - terminalZeroToNull(unionChildren(templateBuilder.get(cFrom.key()), s, variable)) - ); + templateBuilder.set( + cFrom.key(), + terminalZeroToNull( + unionChildren(templateBuilder.get(cFrom.key()), s, variable))); } } - for (IntObjCursor cTo = offDiagonal.get(cFrom.key()).cursor(); - cTo.moveNext(); ) { + for (IntObjCursor cTo = + offDiagonal.get(cFrom.key()).cursor(); + cTo.moveNext(); ) { if (cFrom.key() == cTo.key()) { continue; } @@ -371,24 +432,30 @@ private MddNode relProd( assert cFrom.value() != terminalZeroNode; assert cTo.value() != AbstractNextStateDescriptor.terminalEmpty(); - MddNode s = relProd(cFrom.value(), - dsat.getDiagonal(stateSpaceInfo).defaultValue(), - cTo.value(), - variable.getLower().orElse(null), - cache.getLower() - ); + MddNode s = + relProd( + cFrom.value(), + dsat.getDiagonal(stateSpaceInfo).defaultValue(), + cTo.value(), + variable.getLower().orElse(null), + cache.getLower()); if (s != terminalZeroNode) { confirm(variable, cTo.key()); - templateBuilder.set(cTo.key(), - terminalZeroToNull(unionChildren(templateBuilder.get(cTo.key()), s, variable)) - ); + templateBuilder.set( + cTo.key(), + terminalZeroToNull( + unionChildren(templateBuilder.get(cTo.key()), s, variable))); } } } - ret = variable.checkInNode(MddStructuralTemplate.of(templateBuilder.buildAndReset())); + final var template = templateBuilder.buildAndReset(); + if (!template.isEmpty()) + Preconditions.checkArgument( + n.defaultValue() == null, "Default value is not supported with explicit edges"); + ret = variable.checkInNode(MddStructuralTemplate.of(template)); ret = saturate(ret, dsat, variable, cache); @@ -397,20 +464,20 @@ private MddNode relProd( if (verbose) { indent--; printIndent(); - System.out.println("Done SatRelProd on level " + variable.getTraceInfo() + " resulting in " + ret); + System.out.println( + "Done SatRelProd on level " + variable.getTraceInfo() + " resulting in " + ret); } return ret; } - private void confirm(final MddVariable variable, final int key) { - - } + private void confirm(final MddVariable variable, final int key) {} @Override public MddNode computeTerminal( - final MddNode mddNode, final AbstractNextStateDescriptor nextState, final MddGraph mddGraph - ) { + final MddNode mddNode, + final AbstractNextStateDescriptor nextState, + final MddGraph mddGraph) { return mddNode; } @@ -438,12 +505,19 @@ public void clear() { @Override public void cleanup() { - this.cacheManager.forEachCache((cache) -> { - cache.getSaturateCache().clearSelectively((source, ns1, result) -> source.getReferenceCount() == 0 || - result.getReferenceCount() == 0); - cache.getRelProdCache().clearSelectively((source, ns1, ns2, result) -> source.getReferenceCount() == 0 || - result.getReferenceCount() == 0); - }); + this.cacheManager.forEachCache( + (cache) -> { + cache.getSaturateCache() + .clearSelectively( + (source, ns1, result) -> + source.getReferenceCount() == 0 + || result.getReferenceCount() == 0); + cache.getRelProdCache() + .clearSelectively( + (source, ns1, ns2, result) -> + source.getReferenceCount() == 0 + || result.getReferenceCount() == 0); + }); } private class Aggregator implements Consumer { @@ -498,14 +572,17 @@ public long getHitCount() { return new SaturateCache(cacheManager); } - // TODO: HAXXXX DON'T DO THIS EVER AGAIN public Set getSaturatedNodes() { final Set ret = HashObjSets.newUpdatableSet(); - cacheManager.forEachCache((c) -> c.getSaturateCache().clearSelectively((source, ns, result) -> { - ret.add(result); - return false; - })); + cacheManager.forEachCache( + (c) -> + c.getSaturateCache() + .clearSelectively( + (source, ns, result) -> { + ret.add(result); + return false; + })); return ret; } diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/varordering/ForceVarOrdering.kt b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/varordering/ForceVarOrdering.kt new file mode 100644 index 0000000000..9a1cdfcd30 --- /dev/null +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/algorithm/mdd/varordering/ForceVarOrdering.kt @@ -0,0 +1,85 @@ +/* + * Copyright 2024 Budapest University of Technology and Economics + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package hu.bme.mit.theta.analysis.algorithm.mdd.varordering + +import hu.bme.mit.theta.core.decl.VarDecl +import hu.bme.mit.theta.core.stmt.Stmt +import hu.bme.mit.theta.core.utils.StmtUtils +import kotlin.random.Random + +/** + * Variable ordering based on the 'FORCE' variable ordering heuristic. + * https://doi.org/10.1145/764808.764839 + */ +fun orderVarsFromRandomStartingPoints(vars: List>, events: Set, numStartingPoints: Int = 5): List> { + val random = Random(0) + val startingPoints = (0 until numStartingPoints).map { vars.shuffled(random) } + val orderings = startingPoints.map { orderVars(it, events) } + return orderings.minBy { eventSpans(it, events) } +} + +fun orderVars(vars: List>, events: Set): List> { + + val affectedVars = events.associateWith { event -> + StmtUtils.getVars(event) + } + + val affectingEvents = vars.associateWith { varDecl -> + events.filter { varDecl in affectedVars[it]!! }.toSet() + } + + var currentVarOrdering = vars.toList() + var currentEventSpans = eventSpans(currentVarOrdering, events) + + while (true) { + val cogs = events.associateWith { + affectedVars[it]!!.map { varDecl -> currentVarOrdering.indexOf(varDecl) }.fold(0, Int::plus) + .toDouble() / affectedVars[it]!!.size.toDouble() + } + val newLocations = vars.associateWith { varDecl -> + affectingEvents[varDecl]!!.map { cogs[it]!! }.fold(0.0, Double::plus) / affectingEvents[varDecl]!!.size.toDouble() + } + + val newVarOrdering = currentVarOrdering.sortedBy { newLocations[it]!! } + val newEventSpans = eventSpans(newVarOrdering, events) + + if (newEventSpans >= currentEventSpans) { + break + } else { + currentVarOrdering = newVarOrdering + currentEventSpans = newEventSpans + } + } + + return currentVarOrdering + +} + +private fun eventSpans(vars: List>, events: Set) = events.map { event -> + StmtUtils.getVars(event).let { + when(it.isEmpty()) { + true -> 0 + else -> { + val firstVar = it.minOf { vars.indexOf(it) } + val lastVar = it.maxOf { vars.indexOf(it) } + lastVar - firstVar + } + } + } + }.fold(0, Int::plus) + + diff --git a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/utils/MddNodeVisualizer.java b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/utils/MddNodeVisualizer.java index dcaed9e3f9..bf348b1e32 100644 --- a/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/utils/MddNodeVisualizer.java +++ b/subprojects/common/analysis/src/main/java/hu/bme/mit/theta/analysis/utils/MddNodeVisualizer.java @@ -15,6 +15,9 @@ */ package hu.bme.mit.theta.analysis.utils; +import static hu.bme.mit.theta.common.visualization.Alignment.LEFT; +import static hu.bme.mit.theta.common.visualization.Shape.RECTANGLE; + import hu.bme.mit.delta.collections.RecursiveIntObjCursor; import hu.bme.mit.delta.collections.impl.RecursiveIntObjMapViews; import hu.bme.mit.delta.java.mdd.MddNode; @@ -23,16 +26,12 @@ import hu.bme.mit.theta.common.visualization.Graph; import hu.bme.mit.theta.common.visualization.LineStyle; import hu.bme.mit.theta.common.visualization.NodeAttributes; - import java.awt.*; import java.util.IdentityHashMap; import java.util.Map; import java.util.Set; import java.util.function.Function; -import static hu.bme.mit.theta.common.visualization.Alignment.LEFT; -import static hu.bme.mit.theta.common.visualization.Shape.RECTANGLE; - public class MddNodeVisualizer { private static final LineStyle CHILD_EDGE_STYLE = LineStyle.NORMAL; @@ -51,25 +50,23 @@ public class MddNodeVisualizer { public static long idFor(MddNode n) { Long l = registry.get(n); - if (l == null) - registry.put(n, l = nextId++); + if (l == null) registry.put(n, l = nextId++); return l; } private static class LazyHolderDefault { - static final MddNodeVisualizer INSTANCE = new MddNodeVisualizer(n -> n.toString()); + static final MddNodeVisualizer INSTANCE = create(); } private static class LazyHolderStructureOnly { - static final MddNodeVisualizer INSTANCE = new MddNodeVisualizer(n -> ""); + static final MddNodeVisualizer INSTANCE = create(n -> ""); } private MddNodeVisualizer(final Function nodeToString) { this.nodeToString = nodeToString; } - public static MddNodeVisualizer create( - final Function nodeToString) { + public static MddNodeVisualizer create(final Function nodeToString) { return new MddNodeVisualizer(nodeToString); } @@ -95,8 +92,11 @@ public Graph visualize(final MddNode rootNode) { return graph; } - private void traverse(final Graph graph, final MddNode node, RecursiveIntObjCursor cursor, - final Set traversed) { + private void traverse( + final Graph graph, + final MddNode node, + RecursiveIntObjCursor cursor, + final Set traversed) { if (traversed.contains(node)) { return; } else { @@ -106,11 +106,19 @@ private void traverse(final Graph graph, final MddNode node, RecursiveIntObjCurs final LineStyle lineStyle = CHILD_EDGE_STYLE; final int peripheries = 1; -// final int peripheries = node.isComplete()?2:1; - - final NodeAttributes nAttributes = NodeAttributes.builder().label(nodeToString.apply(node)) - .alignment(LEFT).shape(RECTANGLE).font(FONT).fillColor(FILL_COLOR).lineColor(LINE_COLOR) - .peripheries(peripheries).lineStyle(lineStyle).build(); + // final int peripheries = node.isComplete()?2:1; + + final NodeAttributes nAttributes = + NodeAttributes.builder() + .label(nodeToString.apply(node)) + .alignment(LEFT) + .shape(RECTANGLE) + .font(FONT) + .fillColor(FILL_COLOR) + .lineColor(LINE_COLOR) + .peripheries(peripheries) + .lineStyle(lineStyle) + .build(); graph.addNode(nodeId, nAttributes); @@ -121,8 +129,13 @@ private void traverse(final Graph graph, final MddNode node, RecursiveIntObjCurs } final String sourceId = NODE_ID_PREFIX + idFor(node); final String targetId = NODE_ID_PREFIX + idFor(defaultValue); - final EdgeAttributes eAttributes = EdgeAttributes.builder() - .alignment(LEFT).font(FONT).color(LINE_COLOR).lineStyle(DEFAULT_EDGE_STYLE).build(); + final EdgeAttributes eAttributes = + EdgeAttributes.builder() + .alignment(LEFT) + .font(FONT) + .color(LINE_COLOR) + .lineStyle(DEFAULT_EDGE_STYLE) + .build(); graph.addEdge(sourceId, targetId, eAttributes); } else { while (cursor.moveNext()) { @@ -132,20 +145,25 @@ private void traverse(final Graph graph, final MddNode node, RecursiveIntObjCurs } final String sourceId = NODE_ID_PREFIX + idFor(node); final String targetId = NODE_ID_PREFIX + idFor(cursor.value()); - final EdgeAttributes eAttributes = EdgeAttributes.builder().label(cursor.key() + "") - .alignment(LEFT).font(FONT).color(LINE_COLOR).lineStyle(CHILD_EDGE_STYLE).build(); + final EdgeAttributes eAttributes = + EdgeAttributes.builder() + .label(cursor.key() + "") + .alignment(LEFT) + .font(FONT) + .color(LINE_COLOR) + .lineStyle(CHILD_EDGE_STYLE) + .build(); graph.addEdge(sourceId, targetId, eAttributes); } - } } - } private static String nodeToString(MddNode node) { if (node.getRepresentation() instanceof RecursiveIntObjMapViews.OfIntObjMapView) return ""; - return node instanceof MddNode.Terminal ? ((MddNode.Terminal) node).getTerminalData().toString() : node.getRepresentation().toString(); + return node instanceof MddNode.Terminal + ? ((MddNode.Terminal) node).getTerminalData().toString() + : node.getRepresentation().toString(); } - } diff --git a/subprojects/common/analysis/src/test/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddCheckerTest.java b/subprojects/common/analysis/src/test/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddCheckerTest.java index 0f9f524354..d4f5f02b8d 100644 --- a/subprojects/common/analysis/src/test/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddCheckerTest.java +++ b/subprojects/common/analysis/src/test/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddCheckerTest.java @@ -33,12 +33,14 @@ import hu.bme.mit.theta.core.type.booltype.BoolType; import hu.bme.mit.theta.core.type.inttype.IntExprs; import hu.bme.mit.theta.core.type.inttype.IntType; +import hu.bme.mit.theta.core.utils.ExprUtils; import hu.bme.mit.theta.core.utils.indexings.VarIndexing; import hu.bme.mit.theta.core.utils.indexings.VarIndexingFactory; import hu.bme.mit.theta.solver.SolverPool; import hu.bme.mit.theta.solver.z3legacy.Z3LegacySolverFactory; import java.util.Arrays; import java.util.Collection; +import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -187,6 +189,7 @@ public VarIndexing nextIndexing() { } }, propExpr, + List.copyOf(ExprUtils.getVars(List.of(initExpr, tranExpr, propExpr))), solverPool, logger, iterationStrategy); diff --git a/subprojects/common/analysis/src/test/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddConstrainedCursorTest.java b/subprojects/common/analysis/src/test/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddConstrainedCursorTest.java new file mode 100644 index 0000000000..26070d5dd4 --- /dev/null +++ b/subprojects/common/analysis/src/test/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddConstrainedCursorTest.java @@ -0,0 +1,191 @@ +/* + * Copyright 2024 Budapest University of Technology and Economics + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hu.bme.mit.theta.analysis.algorithm.mdd; + +import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Add; +import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Eq; +import static hu.bme.mit.theta.core.type.anytype.Exprs.Prime; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Or; +import static hu.bme.mit.theta.core.type.inttype.IntExprs.Int; +import static org.junit.Assert.assertEquals; + +import hu.bme.mit.delta.java.mdd.JavaMddFactory; +import hu.bme.mit.delta.java.mdd.MddGraph; +import hu.bme.mit.delta.java.mdd.MddHandle; +import hu.bme.mit.delta.java.mdd.MddVariableOrder; +import hu.bme.mit.delta.mdd.MddVariableDescriptor; +import hu.bme.mit.theta.analysis.algorithm.mdd.expressionnode.ExprLatticeDefinition; +import hu.bme.mit.theta.analysis.algorithm.mdd.expressionnode.MddExpressionTemplate; +import hu.bme.mit.theta.analysis.algorithm.mdd.fixedpoint.MddStateSpaceInfo; +import hu.bme.mit.theta.core.decl.Decl; +import hu.bme.mit.theta.core.decl.Decls; +import hu.bme.mit.theta.core.decl.VarDecl; +import hu.bme.mit.theta.core.type.Expr; +import hu.bme.mit.theta.core.type.LitExpr; +import hu.bme.mit.theta.core.type.booltype.BoolExprs; +import hu.bme.mit.theta.core.type.booltype.BoolType; +import hu.bme.mit.theta.core.type.enumtype.EnumType; +import hu.bme.mit.theta.core.type.inttype.IntType; +import hu.bme.mit.theta.core.utils.PathUtils; +import hu.bme.mit.theta.solver.SolverPool; +import hu.bme.mit.theta.solver.z3legacy.Z3LegacySolverFactory; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) +public class MddConstrainedCursorTest { + + private static final VarDecl X = Decls.Var("x", IntType.getInstance()); + private static final VarDecl Y = Decls.Var("y", IntType.getInstance()); + + private static final VarDecl A = Decls.Var("a", BoolType.getInstance()); + private static final VarDecl B = Decls.Var("b", BoolType.getInstance()); + + private static final EnumType colorType = EnumType.of("color", List.of("red", "green", "blue")); + private static final VarDecl C = Decls.Var("c", colorType); + private static final LitExpr RED = colorType.litFromShortName("red"); + private static final LitExpr GREEN = colorType.litFromShortName("green"); + private static final LitExpr BLUE = colorType.litFromShortName("blue"); + + @Parameterized.Parameter(value = 0) + public List> varOrder; + + @Parameterized.Parameter(value = 1) + public Expr constraintExpr; + + @Parameterized.Parameter(value = 2) + public Expr transExpr; + + @Parameterized.Parameter(value = 3) + public Integer topLevelCursorExpectedSize; + + @Parameterized.Parameters(name = "{index}: {0}, {1}, {2}, {3}") + public static Collection data() { + return Arrays.asList( + new Object[][] { + { + List.of(X, Y), + BoolExprs.And( + Eq(X.getRef(), Int(0)), Eq(Y.getRef(), Int(0))), // x = 0, y = 0 + BoolExprs.And( + Eq(Prime(X.getRef()), X.getRef()), + Eq(Prime(Y.getRef()), Y.getRef())), // x'=x, y'=y + 1 + }, + { + List.of(X, Y), + BoolExprs.And( + Eq(X.getRef(), Int(0)), Eq(Y.getRef(), Int(0))), // x = 0, y = 0 + BoolExprs.And( + Eq(Prime(X.getRef()), Add(X.getRef(), Int(1))), + Eq(Prime(Y.getRef()), Y.getRef())), // x'=x + 1, y'=y + 1 + }, + { + List.of(X, Y), + Or( + BoolExprs.And(Eq(X.getRef(), Int(0)), Eq(Y.getRef(), Int(0))), + BoolExprs.And( + Eq(X.getRef(), Int(1)), + Eq(Y.getRef(), Int(1)))), // x = 0, y = 0 or x = 1, y = 1 + BoolExprs.And( + Eq(Prime(X.getRef()), X.getRef()), + Eq(Prime(Y.getRef()), Y.getRef())), // x'=x, y'=y + 2 + }, + { + List.of(X, Y), + BoolExprs.And( + Eq(X.getRef(), Int(0)), Eq(Y.getRef(), Int(0))), // x = 0, y = 0 + BoolExprs.And( + Eq(Prime(X.getRef()), Y.getRef()), + Eq(Prime(Y.getRef()), Y.getRef())), // x'=y, y'=y + 1 + }, + { + List.of(X, Y), + BoolExprs.And( + Eq(X.getRef(), Int(0)), Eq(Y.getRef(), Int(0))), // x = 0, y = 0 + BoolExprs.And( + Eq(Prime(X.getRef()), Add(Y.getRef(), Int(1))), + Eq(Prime(Y.getRef()), Y.getRef())), // x'=y + 1, y'=y + 1 + }, + }); + } + + @Test + public void test() throws Exception { + + try (final SolverPool solverPool = new SolverPool(Z3LegacySolverFactory.getInstance())) { + final MddGraph mddGraph = + JavaMddFactory.getDefault().createMddGraph(ExprLatticeDefinition.forExpr()); + + final MddVariableOrder stateOrder = + JavaMddFactory.getDefault().createMddVariableOrder(mddGraph); + final MddVariableOrder transOrder = + JavaMddFactory.getDefault().createMddVariableOrder(mddGraph); + + varOrder.forEach( + v -> { + final var domainSize = + Math.max(v.getType().getDomainSize().getFiniteSize().intValue(), 0); + stateOrder.createOnTop( + MddVariableDescriptor.create(v.getConstDecl(0), domainSize)); + transOrder.createOnTop( + MddVariableDescriptor.create(v.getConstDecl(1), domainSize)); + transOrder.createOnTop( + MddVariableDescriptor.create(v.getConstDecl(0), domainSize)); + }); + + final var stateSig = stateOrder.getDefaultSetSignature(); + final var transSig = transOrder.getDefaultSetSignature(); + + final var constraintUnfolded = PathUtils.unfold(constraintExpr, 0); + final var transUnfolded = PathUtils.unfold(transExpr, 0); + + final MddHandle constraintHandle = + stateSig.getTopVariableHandle() + .checkInNode( + MddExpressionTemplate.of( + constraintUnfolded, o -> (Decl) o, solverPool)); + final MddHandle transHandle = + transSig.getTopVariableHandle() + .checkInNode( + MddExpressionTemplate.of( + transUnfolded, o -> (Decl) o, solverPool)); + + final var stateSpaceInfo = + new MddStateSpaceInfo( + stateSig.getTopVariableHandle().getVariable().orElseThrow(), + constraintHandle.getNode()); + final var structuralRepresentation = stateSpaceInfo.toStructuralRepresentation(); + // final var structuralHandle = + // stateSig.getTopVariableHandle().getHandleFor(structuralRepresentation); + + Integer size = 0; + for (var cursor = transHandle.cursor(structuralRepresentation); cursor.moveNext(); ) { + size++; + } + + assertEquals(topLevelCursorExpectedSize, size); + } + } +} diff --git a/subprojects/common/analysis/src/test/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddRelProdTest.java b/subprojects/common/analysis/src/test/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddRelProdTest.java new file mode 100644 index 0000000000..ffa47d709d --- /dev/null +++ b/subprojects/common/analysis/src/test/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddRelProdTest.java @@ -0,0 +1,331 @@ +/* + * Copyright 2024 Budapest University of Technology and Economics + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hu.bme.mit.theta.analysis.algorithm.mdd; + +import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.*; +import static hu.bme.mit.theta.core.type.anytype.Exprs.Prime; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.*; +import static hu.bme.mit.theta.core.type.inttype.IntExprs.Int; +import static org.junit.Assert.assertEquals; + +import hu.bme.mit.delta.java.mdd.*; +import hu.bme.mit.delta.mdd.MddInterpreter; +import hu.bme.mit.delta.mdd.MddVariableDescriptor; +import hu.bme.mit.theta.analysis.algorithm.mdd.ansd.AbstractNextStateDescriptor; +import hu.bme.mit.theta.analysis.algorithm.mdd.ansd.impl.MddNodeNextStateDescriptor; +import hu.bme.mit.theta.analysis.algorithm.mdd.expressionnode.ExprLatticeDefinition; +import hu.bme.mit.theta.analysis.algorithm.mdd.expressionnode.MddExpressionTemplate; +import hu.bme.mit.theta.analysis.algorithm.mdd.fixedpoint.LegacyRelationalProductProvider; +import hu.bme.mit.theta.core.decl.Decl; +import hu.bme.mit.theta.core.decl.Decls; +import hu.bme.mit.theta.core.decl.VarDecl; +import hu.bme.mit.theta.core.type.Expr; +import hu.bme.mit.theta.core.type.LitExpr; +import hu.bme.mit.theta.core.type.booltype.BoolType; +import hu.bme.mit.theta.core.type.enumtype.EnumType; +import hu.bme.mit.theta.core.type.inttype.IntType; +import hu.bme.mit.theta.core.utils.PathUtils; +import hu.bme.mit.theta.solver.SolverPool; +import hu.bme.mit.theta.solver.z3legacy.Z3LegacySolverFactory; +import java.util.*; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) +public class MddRelProdTest { + + private static final VarDecl X = Decls.Var("x", IntType.getInstance()); + private static final VarDecl Y = Decls.Var("y", IntType.getInstance()); + + private static final VarDecl A = Decls.Var("a", BoolType.getInstance()); + private static final VarDecl B = Decls.Var("b", BoolType.getInstance()); + + private static final EnumType colorType = EnumType.of("color", List.of("red", "green", "blue")); + private static final VarDecl C = Decls.Var("c", colorType); + private static final LitExpr RED = colorType.litFromShortName("red"); + private static final LitExpr GREEN = colorType.litFromShortName("green"); + private static final LitExpr BLUE = colorType.litFromShortName("blue"); + + @Parameterized.Parameter(value = 0) + public List> varOrder; + + @Parameterized.Parameter(value = 1) + public Expr stateExpr; + + @Parameterized.Parameter(value = 2) + public Expr transExpr; + + @Parameterized.Parameter(value = 3) + public Long expectedSize; + + @Parameterized.Parameters(name = "{index}: {0}, {1}, {2}, {3}") + public static Collection data() { + return Arrays.asList( + new Object[][] { + { + List.of(X, Y), + And(Eq(X.getRef(), Int(0)), Eq(Y.getRef(), Int(0))), // x = 0, y = 0 + And( + Eq(Prime(X.getRef()), X.getRef()), + Eq(Prime(Y.getRef()), Y.getRef())), // x'=x, y'=y + 1L + }, + { + List.of(X, Y), + And(Eq(X.getRef(), Int(0)), Eq(Y.getRef(), Int(0))), // x = 0, y = 0 + And( + Eq(Prime(X.getRef()), Add(X.getRef(), Int(1))), + Eq(Prime(Y.getRef()), Y.getRef())), // x'=x + 1, y'=y + 1L + }, + { + List.of(X, Y), + Or( + And(Eq(X.getRef(), Int(0)), Eq(Y.getRef(), Int(0))), + And( + Eq(X.getRef(), Int(1)), + Eq(Y.getRef(), Int(1)))), // x = 0, y = 0 or x = 1, y = 1 + And( + Eq(Prime(X.getRef()), X.getRef()), + Eq(Prime(Y.getRef()), Y.getRef())), // x'=x, y'=y + 2L + }, + { + List.of(X, Y), + Or( + And(Eq(X.getRef(), Int(0)), Eq(Y.getRef(), Int(0))), + And( + Eq(X.getRef(), Int(1)), + Eq(Y.getRef(), Int(1)))), // x = 0, y = 0 or x = 1, y = 1 + And( + Eq(Prime(X.getRef()), Add(X.getRef(), Int(1))), + Eq(Prime(Y.getRef()), Y.getRef())), // x'=x + 1, y'=y + 2L + }, + { + List.of(X, Y), + And(Eq(X.getRef(), Int(0)), Eq(Y.getRef(), Int(0))), // x = 0, y = 0 + And( + Or( + Eq(Prime(X.getRef()), X.getRef()), + Eq(Prime(X.getRef()), Add(X.getRef(), Int(1)))), + Or( + Eq(Prime(Y.getRef()), Y.getRef()), + Eq( + Prime(Y.getRef()), + Add( + Y.getRef(), + Int(1))))), // (x'=x or x'=x+1), (y'=y + // or y'=y+1) + 4L + }, + + // These won't ever be supported + // {List.of(X, Y), + // Eq(X.getRef(), Int(0)), // x = 0 + // Eq(Prime(X.getRef()), X.getRef()), // x'=x + // 1L}, + // + // {List.of(X, Y), + // Eq(X.getRef(), Int(0)), // x = 0, y = 0 + // Eq(Prime(X.getRef()), Add(X.getRef(), Int(1))), // + // x'=x + 1, y'=y + // 1L}, + // + // {List.of(X, Y), + // And(Eq(X.getRef(), Int(0)), Eq(Y.getRef(), Int(0))), + // // x = 0, y = 0 + // True(), // true + // 0L}, + + { + List.of(A, B), + And(A.getRef(), B.getRef()), + And( + Eq(A.getRef(), Prime(A.getRef())), + Eq(B.getRef(), Prime(B.getRef()))), // a'=a, b'=b + 1L + }, + { + List.of(A, B), + And(A.getRef(), B.getRef()), + And( + Eq(A.getRef(), Prime(A.getRef())), + Eq(A.getRef(), Prime(B.getRef()))), // a'=a, b'=a + 1L + }, + { + List.of(A, B), + And(A.getRef(), B.getRef()), + And( + Eq(B.getRef(), Prime(A.getRef())), + Eq(B.getRef(), Prime(B.getRef()))), // a'=b, b'=b + 1L + }, + { + List.of(A, B), + And(A.getRef(), B.getRef()), + Eq(A.getRef(), Prime(A.getRef())), // a'=a + 2L + }, + { + List.of(A, B), + And(A.getRef(), B.getRef()), + Eq(B.getRef(), Prime(B.getRef())), // b'=b + 2L + }, + { + List.of(A, B), + And(A.getRef(), B.getRef()), + True(), // true + 4L + }, + { + List.of(A, B), + True(), + And( + Eq(A.getRef(), Prime(A.getRef())), + Eq(B.getRef(), Prime(B.getRef()))), // a'=a, b'=b + 4L + }, + { + List.of(A, B), + True(), + True(), // true + 4L + }, + { + List.of(A, B), + True(), + And(Prime(A.getRef()), Prime(B.getRef())), // a', b' + 1L + }, + { + List.of(A, B), + True(), + Prime(A.getRef()), // a' + 2L + }, + { + List.of(A, B), + True(), + Prime(B.getRef()), // b' + 2L + }, + { + List.of(A, C), + And(A.getRef(), Eq(C.getRef(), RED)), + And( + Eq(A.getRef(), Prime(A.getRef())), + Eq(C.getRef(), Prime(C.getRef()))), // a'=a, c'=c + 1L + }, + { + List.of(A, C), + And(A.getRef(), Eq(C.getRef(), RED)), + Eq(A.getRef(), Prime(A.getRef())), // a'=a + 3L + }, + { + List.of(A, C), + And(A.getRef(), Eq(C.getRef(), RED)), + And( + Eq(A.getRef(), Prime(A.getRef())), + Or( + Eq(Prime(C.getRef()), RED), + Eq(Prime(C.getRef()), GREEN), + Eq(Prime(C.getRef()), BLUE))), // a'=a + 3L + }, + { + List.of(A, C), + True(), + And( + Eq(A.getRef(), Prime(A.getRef())), + Eq(C.getRef(), Prime(C.getRef()))), // a'=a, c'=c + 6L + }, + { + List.of(A, C), + And(A.getRef(), Eq(C.getRef(), RED)), + True(), // true + 6L + }, + { + List.of(A, C), + True(), + True(), // true + 6L + }, + }); + } + + @Test + public void test() throws Exception { + + try (final SolverPool solverPool = new SolverPool(Z3LegacySolverFactory.getInstance())) { + final MddGraph mddGraph = + JavaMddFactory.getDefault().createMddGraph(ExprLatticeDefinition.forExpr()); + + final MddVariableOrder stateOrder = + JavaMddFactory.getDefault().createMddVariableOrder(mddGraph); + final MddVariableOrder transOrder = + JavaMddFactory.getDefault().createMddVariableOrder(mddGraph); + + varOrder.forEach( + v -> { + final var domainSize = + Math.max(v.getType().getDomainSize().getFiniteSize().intValue(), 0); + stateOrder.createOnTop( + MddVariableDescriptor.create(v.getConstDecl(0), domainSize)); + transOrder.createOnTop( + MddVariableDescriptor.create(v.getConstDecl(1), domainSize)); + transOrder.createOnTop( + MddVariableDescriptor.create(v.getConstDecl(0), domainSize)); + }); + + final var stateSig = stateOrder.getDefaultSetSignature(); + final var transSig = transOrder.getDefaultSetSignature(); + + final var stateUnfolded = PathUtils.unfold(stateExpr, 0); + final var transUnfolded = PathUtils.unfold(transExpr, 0); + + final MddHandle stateHandle = + stateSig.getTopVariableHandle() + .checkInNode( + MddExpressionTemplate.of( + stateUnfolded, o -> (Decl) o, solverPool)); + final MddHandle transHandle = + transSig.getTopVariableHandle() + .checkInNode( + MddExpressionTemplate.of( + transUnfolded, o -> (Decl) o, solverPool)); + + final AbstractNextStateDescriptor nextStateDescriptor = + MddNodeNextStateDescriptor.of(transHandle); + + final var provider = new LegacyRelationalProductProvider(stateSig.getVariableOrder()); + final var result = + provider.compute( + stateHandle, nextStateDescriptor, stateSig.getTopVariableHandle()); + + final Long resultSize = MddInterpreter.calculateNonzeroCount(result); + + assertEquals(expectedSize, resultSize); + } + } +} diff --git a/subprojects/common/analysis/src/test/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddStateSpaceInfoTest.java b/subprojects/common/analysis/src/test/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddStateSpaceInfoTest.java new file mode 100644 index 0000000000..c8086fecd4 --- /dev/null +++ b/subprojects/common/analysis/src/test/java/hu/bme/mit/theta/analysis/algorithm/mdd/MddStateSpaceInfoTest.java @@ -0,0 +1,166 @@ +/* + * Copyright 2024 Budapest University of Technology and Economics + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hu.bme.mit.theta.analysis.algorithm.mdd; + +import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Eq; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.*; +import static hu.bme.mit.theta.core.type.inttype.IntExprs.Int; +import static org.junit.Assert.assertEquals; + +import hu.bme.mit.delta.java.mdd.JavaMddFactory; +import hu.bme.mit.delta.java.mdd.MddGraph; +import hu.bme.mit.delta.java.mdd.MddHandle; +import hu.bme.mit.delta.java.mdd.MddVariableOrder; +import hu.bme.mit.delta.mdd.MddInterpreter; +import hu.bme.mit.delta.mdd.MddVariableDescriptor; +import hu.bme.mit.theta.analysis.algorithm.mdd.expressionnode.ExprLatticeDefinition; +import hu.bme.mit.theta.analysis.algorithm.mdd.expressionnode.MddExpressionTemplate; +import hu.bme.mit.theta.analysis.algorithm.mdd.fixedpoint.MddStateSpaceInfo; +import hu.bme.mit.theta.core.decl.Decl; +import hu.bme.mit.theta.core.decl.Decls; +import hu.bme.mit.theta.core.decl.VarDecl; +import hu.bme.mit.theta.core.type.Expr; +import hu.bme.mit.theta.core.type.LitExpr; +import hu.bme.mit.theta.core.type.booltype.BoolType; +import hu.bme.mit.theta.core.type.enumtype.EnumType; +import hu.bme.mit.theta.core.type.inttype.IntType; +import hu.bme.mit.theta.core.utils.PathUtils; +import hu.bme.mit.theta.solver.SolverPool; +import hu.bme.mit.theta.solver.z3legacy.Z3LegacySolverFactory; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) +public class MddStateSpaceInfoTest { + + private static final VarDecl X = Decls.Var("x", IntType.getInstance()); + private static final VarDecl Y = Decls.Var("y", IntType.getInstance()); + + private static final VarDecl A = Decls.Var("a", BoolType.getInstance()); + private static final VarDecl B = Decls.Var("b", BoolType.getInstance()); + + private static final EnumType colorType = EnumType.of("color", List.of("red", "green", "blue")); + private static final VarDecl C = Decls.Var("c", colorType); + private static final LitExpr RED = colorType.litFromShortName("red"); + private static final LitExpr GREEN = colorType.litFromShortName("green"); + private static final LitExpr BLUE = colorType.litFromShortName("blue"); + + @Parameterized.Parameter(value = 0) + public List> varOrder; + + @Parameterized.Parameter(value = 1) + public Expr stateSpaceExpr; + + @Parameterized.Parameter(value = 2) + public Long expectedSize; + + @Parameterized.Parameters(name = "{index}: {0}, {1}, {2}") + public static Collection data() { + return Arrays.asList( + new Object[][] { + { + List.of(X, Y), + And(Eq(X.getRef(), Int(0)), Eq(Y.getRef(), Int(0))), // x = 0, y = 0 + 1L + }, + { + List.of(A, B), + Eq(A.getRef(), False()), // a = 0 + 2L + }, + { + List.of(A, B), + Eq(B.getRef(), False()), // y = 0 + 2L + }, + { + List.of(A, B), + True(), // true + 4L + }, + { + List.of(X, Y), + Or( + And(Eq(X.getRef(), Int(0)), Eq(Y.getRef(), Int(0))), + And( + Eq(X.getRef(), Int(1)), + Eq(Y.getRef(), Int(1)))), // x = 0, y = 0 or x = 1, y = 1 + 4L + }, + { + List.of(X, Y), + Or( + And(Eq(X.getRef(), Int(0)), Eq(Y.getRef(), Int(0))), + And(Eq(X.getRef(), Int(1)), Eq(Y.getRef(), Int(1))), + And( + Eq(X.getRef(), Int(2)), + Eq( + Y.getRef(), + Int(2)))), // x = 0, y = 0 or x = 1, y = 1 or x + // = 2, y = 3 + 9L + }, + {List.of(A, C), And(A.getRef(), Eq(C.getRef(), RED)), 1L}, + {List.of(A, C), A.getRef(), 3L}, + {List.of(A, C), True(), 6L}, + {List.of(C, A), True(), 6L}, + }); + } + + @Test + public void test() throws Exception { + + try (final SolverPool solverPool = new SolverPool(Z3LegacySolverFactory.getInstance())) { + final MddGraph mddGraph = + JavaMddFactory.getDefault().createMddGraph(ExprLatticeDefinition.forExpr()); + + final MddVariableOrder variableOrder = + JavaMddFactory.getDefault().createMddVariableOrder(mddGraph); + varOrder.forEach( + v -> { + final var domainSize = + Math.max(v.getType().getDomainSize().getFiniteSize().intValue(), 0); + variableOrder.createOnTop( + MddVariableDescriptor.create(v.getConstDecl(0), domainSize)); + }); + final var signature = variableOrder.getDefaultSetSignature(); + + final var stateUnfolded = PathUtils.unfold(stateSpaceExpr, 0); + final MddHandle stateHandle = + signature + .getTopVariableHandle() + .checkInNode( + MddExpressionTemplate.of( + stateUnfolded, o -> (Decl) o, solverPool)); + + final var stateSpaceInfo = + new MddStateSpaceInfo( + signature.getTopVariableHandle().getVariable().orElseThrow(), + stateHandle.getNode()); + final var structuralRepresentation = stateSpaceInfo.toStructuralRepresentation(); + final var structuralHandle = + signature.getTopVariableHandle().getHandleFor(structuralRepresentation); + + final Long resultSize = MddInterpreter.calculateNonzeroCount(structuralHandle); + + assertEquals(expectedSize, resultSize); + } + } +} diff --git a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/type/anytype/Exprs.java b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/type/anytype/Exprs.java index 4a55648279..04df53164b 100644 --- a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/type/anytype/Exprs.java +++ b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/type/anytype/Exprs.java @@ -15,25 +15,23 @@ */ package hu.bme.mit.theta.core.type.anytype; +import static com.google.common.base.Preconditions.checkArgument; + import hu.bme.mit.theta.core.decl.Decl; import hu.bme.mit.theta.core.type.Expr; import hu.bme.mit.theta.core.type.Type; import hu.bme.mit.theta.core.type.booltype.BoolType; -import static com.google.common.base.Preconditions.checkArgument; - public final class Exprs { - private Exprs() { - } + private Exprs() {} public static RefExpr Ref(final Decl decl) { return RefExpr.of(decl); } - public static IteExpr Ite(final Expr cond, - final Expr then, - final Expr elze) { + public static IteExpr Ite( + final Expr cond, final Expr then, final Expr elze) { return IteExpr.of(cond, then, elze); } @@ -42,12 +40,13 @@ public static PrimeExpr Prime(final Expr - Dereference Dereference(final Expr arr, final Expr offset, final ExprType type) { + Dereference Dereference( + final Expr arr, final Expr offset, final ExprType type) { return Dereference.of(arr, offset, type); } public static - Reference Reference(final Expr expr, final ArrType type) { + Reference Reference(final Expr expr, final ArrType type) { return Reference.of(expr, type); } @@ -55,14 +54,15 @@ Reference Reference(final Expr expr, final ArrType * Convenience methods */ - public static PrimeExpr Prime(final Expr op, - final int i) { - checkArgument(i > 0); - if (i == 1) { + public static Expr Prime( + final Expr op, final int i) { + checkArgument(i >= 0); + if (i == 0) { + return op; + } else if (i == 1) { return Prime(op); } else { return Prime(Prime(op, i - 1)); } } - } diff --git a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/type/enumtype/EnumType.java b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/type/enumtype/EnumType.java index b8c8ebb5c0..b8367c430c 100644 --- a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/type/enumtype/EnumType.java +++ b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/type/enumtype/EnumType.java @@ -15,6 +15,8 @@ */ package hu.bme.mit.theta.core.type.enumtype; +import static com.google.common.base.Preconditions.checkArgument; + import hu.bme.mit.theta.core.type.DomainSize; import hu.bme.mit.theta.core.type.Expr; import hu.bme.mit.theta.core.type.LitExpr; @@ -23,7 +25,6 @@ import hu.bme.mit.theta.core.type.abstracttype.Equational; import hu.bme.mit.theta.core.type.abstracttype.NeqExpr; import hu.bme.mit.theta.core.type.anytype.InvalidLitExpr; - import java.util.Collection; import java.util.LinkedHashMap; import java.util.Map; @@ -31,8 +32,6 @@ import java.util.regex.Pattern; import java.util.stream.Collectors; -import static com.google.common.base.Preconditions.checkArgument; - public final class EnumType implements Equational, Type { public static final String FULLY_QUALIFIED_NAME_SEPARATOR = "."; @@ -59,9 +58,10 @@ public static String makeLongName(EnumType type, String literal) { } public static String getShortName(String longName) { - if (!longName.contains(FULLY_QUALIFIED_NAME_SEPARATOR)) - return longName; - return longName.substring(longName.indexOf(FULLY_QUALIFIED_NAME_SEPARATOR) + FULLY_QUALIFIED_NAME_SEPARATOR.length()); + if (!longName.contains(FULLY_QUALIFIED_NAME_SEPARATOR)) return longName; + return longName.substring( + longName.indexOf(FULLY_QUALIFIED_NAME_SEPARATOR) + + FULLY_QUALIFIED_NAME_SEPARATOR.length()); } @Override @@ -84,7 +84,9 @@ public Set getValues() { } public Set getLongValues() { - return literals.keySet().stream().map(val -> makeLongName(this, val)).collect(Collectors.toSet()); + return literals.keySet().stream() + .map(val -> makeLongName(this, val)) + .collect(Collectors.toSet()); } public String getName() { @@ -96,21 +98,29 @@ public int getIntValue(EnumLitExpr literal) { } public int getIntValue(String literal) { - checkArgument(literals.containsKey(literal), String.format("Enum type %s does not contain literal '%s'", name, literal)); + checkArgument( + literals.containsKey(literal), + String.format("Enum type %s does not contain literal '%s'", name, literal)); return literals.get(literal); } + public LitExpr litFromShortName(String shortName) { + try { + return EnumLitExpr.of(this, shortName); + } catch (Exception e) { + throw new RuntimeException( + String.format("%s is not valid for type %s", shortName, name), e); + } + } + public LitExpr litFromLongName(String longName) { if (!longName.contains(FULLY_QUALIFIED_NAME_SEPARATOR)) throw new RuntimeException(String.format("%s is an invalid enum longname")); String[] parts = longName.split(Pattern.quote(FULLY_QUALIFIED_NAME_SEPARATOR)); String type = parts[0]; - checkArgument(name.equals(type), String.format("%s does not belong to type %s", type, name)); - try { - return EnumLitExpr.of(this, parts[1]); - } catch (Exception e) { - throw new RuntimeException(String.format("%s is not valid for type %s", longName, name), e); - } + checkArgument( + name.equals(type), String.format("%s does not belong to type %s", type, name)); + return litFromShortName(parts[1]); } public LitExpr litFromIntValue(int value) { diff --git a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/BvUtils.java b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/BvUtils.java index b11cb7b4a7..cde6781908 100644 --- a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/BvUtils.java +++ b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/BvUtils.java @@ -13,20 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package hu.bme.mit.theta.core.utils; -import hu.bme.mit.theta.core.type.bvtype.BvLitExpr; +import static hu.bme.mit.theta.core.type.bvtype.BvExprs.Bv; +import hu.bme.mit.theta.core.type.bvtype.BvLitExpr; import java.math.BigInteger; -import static hu.bme.mit.theta.core.type.bvtype.BvExprs.Bv; - public final class BvUtils { - private BvUtils() { - - } + private BvUtils() {} public static BigInteger neutralBvLitExprToBigInteger(final BvLitExpr expr) { return unsignedBvLitExprToBigInteger(expr); @@ -84,6 +80,7 @@ public static BigInteger fitBigIntegerIntoNeutralDomain(BigInteger integer, fina return fitBigIntegerIntoUnsignedDomain(integer, size); } + // TODO: is this correct? See modifications below in unsigned public static BigInteger fitBigIntegerIntoSignedDomain(BigInteger integer, final int size) { while (integer.compareTo(BigInteger.TWO.pow(size - 1).negate()) < 0) { integer = integer.add(BigInteger.TWO.pow(size)); @@ -98,11 +95,11 @@ public static BigInteger fitBigIntegerIntoSignedDomain(BigInteger integer, final public static BigInteger fitBigIntegerIntoUnsignedDomain(BigInteger integer, final int size) { while (integer.compareTo(BigInteger.ZERO) < 0) { - integer = integer.add(BigInteger.TWO.pow(size)); + integer = integer.mod(BigInteger.TWO.pow(size)); } while (integer.compareTo(BigInteger.TWO.pow(size)) >= 0) { - integer = integer.subtract(BigInteger.TWO.pow(size)); + integer = integer.mod(BigInteger.TWO.pow(size)); } return integer; diff --git a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprCanonizer.java b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprCanonizer.java index 263c9c938a..1377383849 100644 --- a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprCanonizer.java +++ b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprCanonizer.java @@ -15,6 +15,8 @@ */ package hu.bme.mit.theta.core.utils; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Not; + import hu.bme.mit.theta.common.DispatchTable; import hu.bme.mit.theta.core.type.BinaryExpr; import hu.bme.mit.theta.core.type.Expr; @@ -32,41 +34,6 @@ import hu.bme.mit.theta.core.type.booltype.NotExpr; import hu.bme.mit.theta.core.type.booltype.OrExpr; import hu.bme.mit.theta.core.type.booltype.XorExpr; -import hu.bme.mit.theta.core.type.bvtype.BvAddExpr; -import hu.bme.mit.theta.core.type.bvtype.BvAndExpr; -import hu.bme.mit.theta.core.type.bvtype.BvArithShiftRightExpr; -import hu.bme.mit.theta.core.type.bvtype.BvConcatExpr; -import hu.bme.mit.theta.core.type.bvtype.BvEqExpr; -import hu.bme.mit.theta.core.type.bvtype.BvExtractExpr; -import hu.bme.mit.theta.core.type.bvtype.BvLogicShiftRightExpr; -import hu.bme.mit.theta.core.type.bvtype.BvMulExpr; -import hu.bme.mit.theta.core.type.bvtype.BvNegExpr; -import hu.bme.mit.theta.core.type.bvtype.BvNeqExpr; -import hu.bme.mit.theta.core.type.bvtype.BvNotExpr; -import hu.bme.mit.theta.core.type.bvtype.BvOrExpr; -import hu.bme.mit.theta.core.type.bvtype.BvPosExpr; -import hu.bme.mit.theta.core.type.bvtype.BvRotateLeftExpr; -import hu.bme.mit.theta.core.type.bvtype.BvRotateRightExpr; -import hu.bme.mit.theta.core.type.bvtype.BvSDivExpr; -import hu.bme.mit.theta.core.type.bvtype.BvSExtExpr; -import hu.bme.mit.theta.core.type.bvtype.BvSGeqExpr; -import hu.bme.mit.theta.core.type.bvtype.BvSGtExpr; -import hu.bme.mit.theta.core.type.bvtype.BvSLeqExpr; -import hu.bme.mit.theta.core.type.bvtype.BvSLtExpr; -import hu.bme.mit.theta.core.type.bvtype.BvSModExpr; -import hu.bme.mit.theta.core.type.bvtype.BvSRemExpr; -import hu.bme.mit.theta.core.type.bvtype.BvShiftLeftExpr; -import hu.bme.mit.theta.core.type.bvtype.BvSignChangeExpr; -import hu.bme.mit.theta.core.type.bvtype.BvSubExpr; -import hu.bme.mit.theta.core.type.bvtype.BvType; -import hu.bme.mit.theta.core.type.bvtype.BvUDivExpr; -import hu.bme.mit.theta.core.type.bvtype.BvUGeqExpr; -import hu.bme.mit.theta.core.type.bvtype.BvUGtExpr; -import hu.bme.mit.theta.core.type.bvtype.BvULeqExpr; -import hu.bme.mit.theta.core.type.bvtype.BvULtExpr; -import hu.bme.mit.theta.core.type.bvtype.BvURemExpr; -import hu.bme.mit.theta.core.type.bvtype.BvXorExpr; -import hu.bme.mit.theta.core.type.bvtype.BvZExtExpr; import hu.bme.mit.theta.core.type.inttype.IntAddExpr; import hu.bme.mit.theta.core.type.inttype.IntDivExpr; import hu.bme.mit.theta.core.type.inttype.IntEqExpr; @@ -98,182 +65,77 @@ import hu.bme.mit.theta.core.type.rattype.RatSubExpr; import hu.bme.mit.theta.core.type.rattype.RatToIntExpr; import hu.bme.mit.theta.core.type.rattype.RatType; - import java.util.Comparator; import java.util.List; import java.util.stream.Collectors; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Not; - public final class ExprCanonizer { - private static final DispatchTable> TABLE = DispatchTable.>builder() - - // Boolean - - .addCase(NotExpr.class, ExprCanonizer::canonizeNot) - - .addCase(ImplyExpr.class, ExprCanonizer::canonizeImply) - - .addCase(IffExpr.class, ExprCanonizer::canonizeIff) - - .addCase(XorExpr.class, ExprCanonizer::canonizeXor) - - .addCase(AndExpr.class, ExprCanonizer::canonizeAnd) - - .addCase(OrExpr.class, ExprCanonizer::canonizeOr) - - // Rational - - .addCase(RatAddExpr.class, ExprCanonizer::canonizeRatAdd) - - .addCase(RatSubExpr.class, ExprCanonizer::canonizeRatSub) - - .addCase(RatPosExpr.class, ExprCanonizer::canonizeRatPos) - - .addCase(RatNegExpr.class, ExprCanonizer::canonizeRatNeg) - - .addCase(RatMulExpr.class, ExprCanonizer::canonizeRatMul) - - .addCase(RatDivExpr.class, ExprCanonizer::canonizeRatDiv) - - .addCase(RatEqExpr.class, ExprCanonizer::canonizeRatEq) - - .addCase(RatNeqExpr.class, ExprCanonizer::canonizeRatNeq) - - .addCase(RatGeqExpr.class, ExprCanonizer::canonizeRatGeq) - - .addCase(RatGtExpr.class, ExprCanonizer::canonizeRatGt) - - .addCase(RatLeqExpr.class, ExprCanonizer::canonizeRatLeq) - - .addCase(RatLtExpr.class, ExprCanonizer::canonizeRatLt) - - .addCase(RatToIntExpr.class, ExprCanonizer::canonizeRatToInt) - - // Integer - - .addCase(IntToRatExpr.class, ExprCanonizer::canonizeIntToRat) - - .addCase(IntAddExpr.class, ExprCanonizer::canonizeIntAdd) - - .addCase(IntSubExpr.class, ExprCanonizer::canonizeIntSub) - - .addCase(IntPosExpr.class, ExprCanonizer::canonizeIntPos) - - .addCase(IntNegExpr.class, ExprCanonizer::canonizeIntNeg) - - .addCase(IntMulExpr.class, ExprCanonizer::canonizeIntMul) - - .addCase(IntDivExpr.class, ExprCanonizer::canonizeIntDiv) - - .addCase(IntModExpr.class, ExprCanonizer::canonizeMod) - - .addCase(IntEqExpr.class, ExprCanonizer::canonizeIntEq) - - .addCase(IntNeqExpr.class, ExprCanonizer::canonizeIntNeq) - - .addCase(IntGeqExpr.class, ExprCanonizer::canonizeIntGeq) - - .addCase(IntGtExpr.class, ExprCanonizer::canonizeIntGt) - - .addCase(IntLeqExpr.class, ExprCanonizer::canonizeIntLeq) - - .addCase(IntLtExpr.class, ExprCanonizer::canonizeIntLt) - - // Array - - .addCase(ArrayReadExpr.class, ExprCanonizer::canonizeArrayRead) - - .addCase(ArrayWriteExpr.class, ExprCanonizer::canonizeArrayWrite) - - // Bitvectors - - .addCase(BvConcatExpr.class, ExprCanonizer::canonizeBvConcat) - - .addCase(BvExtractExpr.class, ExprCanonizer::canonizeBvExtract) - - .addCase(BvZExtExpr.class, ExprCanonizer::canonizeBvZExt) - - .addCase(BvSExtExpr.class, ExprCanonizer::canonizeBvSExt) - - .addCase(BvAddExpr.class, ExprCanonizer::canonizeBvAdd) - - .addCase(BvSubExpr.class, ExprCanonizer::canonizeBvSub) - - .addCase(BvPosExpr.class, ExprCanonizer::canonizeBvPos) - - .addCase(BvSignChangeExpr.class, ExprCanonizer::canonizeBvSignChange) - - .addCase(BvNegExpr.class, ExprCanonizer::canonizeBvNeg) - - .addCase(BvMulExpr.class, ExprCanonizer::canonizeBvMul) - - .addCase(BvUDivExpr.class, ExprCanonizer::canonizeBvUDiv) - - .addCase(BvSDivExpr.class, ExprCanonizer::canonizeBvSDiv) - - .addCase(BvSModExpr.class, ExprCanonizer::canonizeBvSMod) - - .addCase(BvURemExpr.class, ExprCanonizer::canonizeBvURem) - - .addCase(BvSRemExpr.class, ExprCanonizer::canonizeBvSRem) - - .addCase(BvAndExpr.class, ExprCanonizer::canonizeBvAnd) - - .addCase(BvOrExpr.class, ExprCanonizer::canonizeBvOr) - - .addCase(BvXorExpr.class, ExprCanonizer::canonizeBvXor) - - .addCase(BvNotExpr.class, ExprCanonizer::canonizeBvNot) - - .addCase(BvShiftLeftExpr.class, ExprCanonizer::canonizeBvShiftLeft) - - .addCase(BvArithShiftRightExpr.class, ExprCanonizer::canonizeBvArithShiftRight) - - .addCase(BvLogicShiftRightExpr.class, ExprCanonizer::canonizeBvLogicShiftRight) - - .addCase(BvRotateLeftExpr.class, ExprCanonizer::canonizeBvRotateLeft) - - .addCase(BvRotateRightExpr.class, ExprCanonizer::canonizeBvRotateRight) - - .addCase(BvEqExpr.class, ExprCanonizer::canonizeBvEq) - - .addCase(BvNeqExpr.class, ExprCanonizer::canonizeBvNeq) - - .addCase(BvUGeqExpr.class, ExprCanonizer::canonizeBvUGeq) - - .addCase(BvUGtExpr.class, ExprCanonizer::canonizeBvUGt) - - .addCase(BvULeqExpr.class, ExprCanonizer::canonizeBvULeq) - - .addCase(BvULtExpr.class, ExprCanonizer::canonizeBvULt) - - .addCase(BvSGeqExpr.class, ExprCanonizer::canonizeBvSGeq) - - .addCase(BvSGtExpr.class, ExprCanonizer::canonizeBvSGt) - - .addCase(BvSLeqExpr.class, ExprCanonizer::canonizeBvSLeq) - - .addCase(BvSLtExpr.class, ExprCanonizer::canonizeBvSLt) - - // General - - .addCase(RefExpr.class, ExprCanonizer::canonizeRef) - - .addCase(IteExpr.class, ExprCanonizer::canonizeIte) - - // Default - - .addDefault((o) -> { - final Expr expr = (Expr) o; - return expr.map(e -> canonize(e)); - }) - - .build(); - - private ExprCanonizer() { - } + private static final DispatchTable> TABLE = + DispatchTable.>builder() + + // Boolean + + .addCase(NotExpr.class, ExprCanonizer::canonizeNot) + .addCase(ImplyExpr.class, ExprCanonizer::canonizeImply) + .addCase(IffExpr.class, ExprCanonizer::canonizeIff) + .addCase(XorExpr.class, ExprCanonizer::canonizeXor) + .addCase(AndExpr.class, ExprCanonizer::canonizeAnd) + .addCase(OrExpr.class, ExprCanonizer::canonizeOr) + + // Rational + + .addCase(RatAddExpr.class, ExprCanonizer::canonizeRatAdd) + .addCase(RatSubExpr.class, ExprCanonizer::canonizeRatSub) + .addCase(RatPosExpr.class, ExprCanonizer::canonizeRatPos) + .addCase(RatNegExpr.class, ExprCanonizer::canonizeRatNeg) + .addCase(RatMulExpr.class, ExprCanonizer::canonizeRatMul) + .addCase(RatDivExpr.class, ExprCanonizer::canonizeRatDiv) + .addCase(RatEqExpr.class, ExprCanonizer::canonizeRatEq) + .addCase(RatNeqExpr.class, ExprCanonizer::canonizeRatNeq) + .addCase(RatGeqExpr.class, ExprCanonizer::canonizeRatGeq) + .addCase(RatGtExpr.class, ExprCanonizer::canonizeRatGt) + .addCase(RatLeqExpr.class, ExprCanonizer::canonizeRatLeq) + .addCase(RatLtExpr.class, ExprCanonizer::canonizeRatLt) + .addCase(RatToIntExpr.class, ExprCanonizer::canonizeRatToInt) + + // Integer + + .addCase(IntToRatExpr.class, ExprCanonizer::canonizeIntToRat) + .addCase(IntAddExpr.class, ExprCanonizer::canonizeIntAdd) + .addCase(IntSubExpr.class, ExprCanonizer::canonizeIntSub) + .addCase(IntPosExpr.class, ExprCanonizer::canonizeIntPos) + .addCase(IntNegExpr.class, ExprCanonizer::canonizeIntNeg) + .addCase(IntMulExpr.class, ExprCanonizer::canonizeIntMul) + .addCase(IntDivExpr.class, ExprCanonizer::canonizeIntDiv) + .addCase(IntModExpr.class, ExprCanonizer::canonizeMod) + .addCase(IntEqExpr.class, ExprCanonizer::canonizeIntEq) + .addCase(IntNeqExpr.class, ExprCanonizer::canonizeIntNeq) + .addCase(IntGeqExpr.class, ExprCanonizer::canonizeIntGeq) + .addCase(IntGtExpr.class, ExprCanonizer::canonizeIntGt) + .addCase(IntLeqExpr.class, ExprCanonizer::canonizeIntLeq) + .addCase(IntLtExpr.class, ExprCanonizer::canonizeIntLt) + + // Array + + .addCase(ArrayReadExpr.class, ExprCanonizer::canonizeArrayRead) + .addCase(ArrayWriteExpr.class, ExprCanonizer::canonizeArrayWrite) + + // General + + .addCase(RefExpr.class, ExprCanonizer::canonizeRef) + .addCase(IteExpr.class, ExprCanonizer::canonizeIte) + + // Default + + .addDefault( + (o) -> { + final Expr expr = (Expr) o; + return expr.map(e -> canonize(e)); + }) + .build(); + + private ExprCanonizer() {} @SuppressWarnings("unchecked") public static Expr canonize(final Expr expr) { @@ -307,8 +169,8 @@ private static Expr canonizeArrayRead(final ArrayReadExpr expr) { return canonizeGenericArrayRead(expr); } - private static Expr - canonizeGenericArrayRead(final ArrayReadExpr expr) { + private static Expr canonizeGenericArrayRead( + final ArrayReadExpr expr) { Expr> arr = canonize(expr.getArray()); Expr index = canonize(expr.getIndex()); @@ -319,8 +181,8 @@ private static Expr canonizeArrayWrite(final ArrayWriteExpr expr) { return canonizeGenericArrayWrite(expr); } - private static Expr> - canonizeGenericArrayWrite(final ArrayWriteExpr expr) { + private static + Expr> canonizeGenericArrayWrite(final ArrayWriteExpr expr) { Expr> arr = canonize(expr.getArray()); Expr index = canonize(expr.getIndex()); Expr elem = canonize(expr.getElem()); @@ -345,8 +207,9 @@ private static Expr canonizeImply(final ImplyExpr expr) { return expr.with(leftOp, rightOp); } - private static Expr - canonizeGenericCommutativeBinaryExpr(final BinaryExpr expr) { + private static + Expr canonizeGenericCommutativeBinaryExpr( + final BinaryExpr expr) { final Expr leftOp = canonize(expr.getLeftOp()); final Expr rightOp = canonize(expr.getRightOp()); @@ -368,12 +231,14 @@ private static Expr canonizeXor(final XorExpr expr) { return canonizeGenericCommutativeBinaryExpr(expr); } - private static Expr - canonizeGenericCommutativeMultiaryExpr(final MultiaryExpr expr) { - final List> orderedCanonizedOps = expr.getOps().stream() - .map(ExprCanonizer::canonize) - .sorted(Comparator.comparingInt(Object::hashCode)) - .collect(Collectors.toList()); + private static + Expr canonizeGenericCommutativeMultiaryExpr( + final MultiaryExpr expr) { + final List> orderedCanonizedOps = + expr.getOps().stream() + .map(ExprCanonizer::canonize) + .sorted(Comparator.comparingInt(Object::hashCode)) + .collect(Collectors.toList()); return expr.withOps(orderedCanonizedOps); } @@ -542,145 +407,4 @@ private static Expr canonizeIntLt(final IntLtExpr expr) { return expr.with(leftOp, rightOp); } - - /* - * Bitvectors - */ - - private static Expr canonizeBvConcat(final BvConcatExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvExtract(final BvExtractExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvZExt(final BvZExtExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvSExt(final BvSExtExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvAdd(final BvAddExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvSub(final BvSubExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvPos(final BvPosExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvSignChange(final BvSignChangeExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvNeg(final BvNegExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvMul(final BvMulExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvUDiv(final BvUDivExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvSDiv(final BvSDivExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvSMod(final BvSModExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvURem(final BvURemExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvSRem(final BvSRemExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvAnd(final BvAndExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvOr(final BvOrExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvXor(final BvXorExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvNot(final BvNotExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvShiftLeft(final BvShiftLeftExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvArithShiftRight(final BvArithShiftRightExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvLogicShiftRight(final BvLogicShiftRightExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvRotateLeft(final BvRotateLeftExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvRotateRight(final BvRotateRightExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvEq(final BvEqExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvNeq(final BvNeqExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvUGeq(final BvUGeqExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvUGt(final BvUGtExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvULeq(final BvULeqExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvULt(final BvULtExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvSGeq(final BvSGeqExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvSGt(final BvSGtExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvSLeq(final BvSLeqExpr expr) { - throw new UnsupportedOperationException(); - } - - private static Expr canonizeBvSLt(final BvSLtExpr expr) { - throw new UnsupportedOperationException(); - } - } diff --git a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprReverser.java b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprReverser.java new file mode 100644 index 0000000000..d69449e845 --- /dev/null +++ b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprReverser.java @@ -0,0 +1,135 @@ +/* + * Copyright 2024 Budapest University of Technology and Economics + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hu.bme.mit.theta.core.utils; + +import static com.google.common.base.Preconditions.checkArgument; +import static hu.bme.mit.theta.core.type.anytype.Exprs.Prime; + +import hu.bme.mit.theta.common.DispatchTable; +import hu.bme.mit.theta.common.DispatchTable2; +import hu.bme.mit.theta.core.decl.VarDecl; +import hu.bme.mit.theta.core.type.Expr; +import hu.bme.mit.theta.core.type.Type; +import hu.bme.mit.theta.core.type.anytype.PrimeExpr; +import hu.bme.mit.theta.core.type.anytype.RefExpr; +import hu.bme.mit.theta.core.utils.indexings.VarIndexing; + +public class ExprReverser { + + private final VarIndexing indexing; + + private final DispatchTable> TABLE = + DispatchTable.>builder() + .addCase(RefExpr.class, this::reverseRef) + .addCase(PrimeExpr.class, this::reversePrime) + + // Default + + .addDefault( + (o) -> { + final Expr expr = (Expr) o; + return expr.map(e -> reverseInner(e)); + }) + .build(); + + public ExprReverser(VarIndexing indexing) { + this.indexing = indexing; + } + + public Expr reverse(final Expr expr) { + final var transformed = PrimeToLeaves.transform(expr); + return (Expr) TABLE.dispatch(transformed); + } + + @SuppressWarnings("unchecked") + private Expr reverseInner(final Expr expr) { + return (Expr) TABLE.dispatch(expr); + } + + /* + * General + */ + + private Expr reverseRef(final RefExpr expr) { + final VarDecl varDecl = extractVarDecl(expr); + return reverse(varDecl, 0); + } + + private Expr reversePrime(final PrimeExpr expr) { + final int primeDepth = primeDepth(expr); + final VarDecl varDecl = extractVarDecl(expr); + return reverse(varDecl, primeDepth); + } + + private Expr reverse(final VarDecl decl, int primeDepth) { + checkArgument(primeDepth >= 0 && primeDepth <= indexing.get(decl)); + return Prime(decl.getRef(), indexing.get(decl) - primeDepth); + } + + private static int primeDepth(final Expr expr) { + if (expr instanceof PrimeExpr) { + return 1 + primeDepth(((PrimeExpr) expr).getOp()); + } else { + return 0; + } + } + + private static VarDecl extractVarDecl(final Expr expr) { + if (expr instanceof RefExpr refExpr) { + checkArgument(refExpr.getDecl() instanceof VarDecl); + return (VarDecl) refExpr.getDecl(); + } else if (expr instanceof PrimeExpr primeExpr) { + return extractVarDecl(primeExpr.getOp()); + } else { + throw new IllegalArgumentException( + "Cannot extract variable declaration from expression: " + expr); + } + } + + private static class PrimeToLeaves { + + private static final DispatchTable2> TABLE = + DispatchTable2.>builder() + .addCase(RefExpr.class, PrimeToLeaves::transformRef) + .addCase(PrimeExpr.class, PrimeToLeaves::transformPrime) + + // Default + + .addDefault( + (o, primeDepth) -> { + final Expr expr = (Expr) o; + return expr.map(e -> transform(e, primeDepth)); + }) + .build(); + + public static Expr transform(final Expr expr) { + return transform(expr, 0); + } + + @SuppressWarnings("unchecked") + private static Expr transform(final Expr expr, int primeDepth) { + return (Expr) TABLE.dispatch(expr, primeDepth); + } + + private static Expr transformRef(final Expr expr, Integer primeDepth) { + return Prime(expr, primeDepth); + } + + private static Expr transformPrime(final Expr expr, Integer primeDepth) { + return transform(((PrimeExpr) expr).getOp(), primeDepth + 1); + } + } +} diff --git a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprSimplifier.java b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprSimplifier.java index f46173b44b..86cbd9659a 100644 --- a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprSimplifier.java +++ b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprSimplifier.java @@ -15,6 +15,12 @@ */ package hu.bme.mit.theta.core.utils; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.*; +import static hu.bme.mit.theta.core.type.bvtype.BvExprs.Bv; +import static hu.bme.mit.theta.core.type.inttype.IntExprs.Int; +import static hu.bme.mit.theta.core.type.rattype.RatExprs.Rat; +import static hu.bme.mit.theta.core.utils.SimplifierLevel.LITERAL_ONLY; + import hu.bme.mit.theta.common.DispatchTable2; import hu.bme.mit.theta.common.Tuple2; import hu.bme.mit.theta.common.Utils; @@ -39,19 +45,12 @@ import hu.bme.mit.theta.core.type.fptype.*; import hu.bme.mit.theta.core.type.inttype.*; import hu.bme.mit.theta.core.type.rattype.*; -import org.kframework.mpfr.BigFloat; - import java.math.BigInteger; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Optional; - -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.*; -import static hu.bme.mit.theta.core.type.bvtype.BvExprs.Bv; -import static hu.bme.mit.theta.core.type.inttype.IntExprs.Int; -import static hu.bme.mit.theta.core.type.rattype.RatExprs.Rat; -import static hu.bme.mit.theta.core.utils.SimplifierLevel.LITERAL_ONLY; +import org.kframework.mpfr.BigFloat; public final class ExprSimplifier { @@ -77,233 +76,149 @@ public Expr simplify(final Expr expr, final Valuation val return (Expr) TABLE.dispatch(expr, valuation); } - private final DispatchTable2> TABLE = DispatchTable2.>builder() - - // Boolean - - .addCase(NotExpr.class, this::simplifyNot) - - .addCase(ImplyExpr.class, this::simplifyImply) - - .addCase(IffExpr.class, this::simplifyIff) - - .addCase(XorExpr.class, this::simplifyXor) - - .addCase(AndExpr.class, this::simplifyAnd) - - .addCase(OrExpr.class, this::simplifyOr) - - // Rational - - .addCase(RatAddExpr.class, this::simplifyRatAdd) - - .addCase(RatSubExpr.class, this::simplifyRatSub) - - .addCase(RatPosExpr.class, this::simplifyRatPos) - - .addCase(RatNegExpr.class, this::simplifyRatNeg) - - .addCase(RatMulExpr.class, this::simplifyRatMul) - - .addCase(RatDivExpr.class, this::simplifyRatDiv) - - .addCase(RatEqExpr.class, this::simplifyRatEq) - - .addCase(RatNeqExpr.class, this::simplifyRatNeq) - - .addCase(RatGeqExpr.class, this::simplifyRatGeq) - - .addCase(RatGtExpr.class, this::simplifyRatGt) - - .addCase(RatLeqExpr.class, this::simplifyRatLeq) - - .addCase(RatLtExpr.class, this::simplifyRatLt) - - .addCase(RatToIntExpr.class, this::simplifyRatToInt) - - // Integer - - .addCase(IntToRatExpr.class, this::simplifyIntToRat) - - .addCase(IntAddExpr.class, this::simplifyIntAdd) - - .addCase(IntSubExpr.class, this::simplifyIntSub) - - .addCase(IntPosExpr.class, this::simplifyIntPos) - - .addCase(IntNegExpr.class, this::simplifyIntNeg) - - .addCase(IntMulExpr.class, this::simplifyIntMul) - - .addCase(IntDivExpr.class, this::simplifyIntDiv) - - .addCase(IntModExpr.class, this::simplifyMod) - - .addCase(IntRemExpr.class, this::simplifyRem) - - .addCase(IntEqExpr.class, this::simplifyIntEq) - - .addCase(IntNeqExpr.class, this::simplifyIntNeq) - - .addCase(IntGeqExpr.class, this::simplifyIntGeq) - - .addCase(IntGtExpr.class, this::simplifyIntGt) - - .addCase(IntLeqExpr.class, this::simplifyIntLeq) - - .addCase(IntLtExpr.class, this::simplifyIntLt) - - // Enum - - .addCase(EnumEqExpr.class, this::simplifyEnumEqExpr) - - .addCase(EnumNeqExpr.class, this::simplifyEnumNeqExpr) - - // Array - - .addCase(ArrayReadExpr.class, this::simplifyArrayRead) - - .addCase(ArrayWriteExpr.class, this::simplifyArrayWrite) - - //.addCase(ArrayInitExpr.class, this::simplifyArrayInit) - .addCase(ArrayInitExpr.class, (arrayInitExpr, valuation) -> this.simplifyArrayInit(arrayInitExpr, valuation)) - - // Bitvectors - - .addCase(BvConcatExpr.class, this::simplifyBvConcat) - - .addCase(BvExtractExpr.class, this::simplifyBvExtract) - - .addCase(BvZExtExpr.class, this::simplifyBvZExt) - - .addCase(BvSExtExpr.class, this::simplifyBvSExt) - - .addCase(BvAddExpr.class, this::simplifyBvAdd) - - .addCase(BvSubExpr.class, this::simplifyBvSub) - - .addCase(BvPosExpr.class, this::simplifyBvPos) - - .addCase(BvSignChangeExpr.class, this::simplifyBvSignChange) - - .addCase(BvNegExpr.class, this::simplifyBvNeg) - - .addCase(BvMulExpr.class, this::simplifyBvMul) - - .addCase(BvUDivExpr.class, this::simplifyBvUDiv) - - .addCase(BvSDivExpr.class, this::simplifyBvSDiv) - - .addCase(BvSModExpr.class, this::simplifyBvSMod) - - .addCase(BvURemExpr.class, this::simplifyBvURem) - - .addCase(BvSRemExpr.class, this::simplifyBvSRem) - - .addCase(BvAndExpr.class, this::simplifyBvAnd) - - .addCase(BvOrExpr.class, this::simplifyBvOr) - - .addCase(BvXorExpr.class, this::simplifyBvXor) - - .addCase(BvNotExpr.class, this::simplifyBvNot) - - .addCase(BvShiftLeftExpr.class, this::simplifyBvShiftLeft) - - .addCase(BvArithShiftRightExpr.class, this::simplifyBvArithShiftRight) - - .addCase(BvLogicShiftRightExpr.class, this::simplifyBvLogicShiftRight) - - .addCase(BvRotateLeftExpr.class, this::simplifyBvRotateLeft) - - .addCase(BvRotateRightExpr.class, this::simplifyBvRotateRight) - - .addCase(BvEqExpr.class, this::simplifyBvEq) - - .addCase(BvNeqExpr.class, this::simplifyBvNeq) - - .addCase(BvUGeqExpr.class, this::simplifyBvUGeq) - - .addCase(BvUGtExpr.class, this::simplifyBvUGt) - - .addCase(BvULeqExpr.class, this::simplifyBvULeq) - - .addCase(BvULtExpr.class, this::simplifyBvULt) - - .addCase(BvSGeqExpr.class, this::simplifyBvSGeq) - - .addCase(BvSGtExpr.class, this::simplifyBvSGt) - - .addCase(BvSLeqExpr.class, this::simplifyBvSLeq) - - .addCase(BvSLtExpr.class, this::simplifyBvSLt) - - // Floating points - - .addCase(FpAddExpr.class, this::simplifyFpAdd) - - .addCase(FpSubExpr.class, this::simplifyFpSub) - - .addCase(FpPosExpr.class, this::simplifyFpPos) - - .addCase(FpNegExpr.class, this::simplifyFpNeg) - - .addCase(FpMulExpr.class, this::simplifyFpMul) - - .addCase(FpDivExpr.class, this::simplifyFpDiv) - - .addCase(FpEqExpr.class, this::simplifyFpEq) - - .addCase(FpAssignExpr.class, this::simplifyFpAssign) - - .addCase(FpGeqExpr.class, this::simplifyFpGeq) - - .addCase(FpLeqExpr.class, this::simplifyFpLeq) - - .addCase(FpGtExpr.class, this::simplifyFpGt) - - .addCase(FpLtExpr.class, this::simplifyFpLt) - - .addCase(FpNeqExpr.class, this::simplifyFpNeq) - - .addCase(FpAbsExpr.class, this::simplifyFpAbs) - - .addCase(FpRoundToIntegralExpr.class, this::simplifyFpRoundToIntegral) - - .addCase(FpMaxExpr.class, this::simplifyFpMax) - - .addCase(FpMinExpr.class, this::simplifyFpMin) - - .addCase(FpSqrtExpr.class, this::simplifyFpSqrt) - - .addCase(FpIsNanExpr.class, this::simplifyFpIsNan) - - .addCase(FpFromBvExpr.class, this::simplifyFpFromBv) - - .addCase(FpToBvExpr.class, this::simplifyFpToBv) - - .addCase(FpToFpExpr.class, this::simplifyFpToFp) - - // General - - .addCase(RefExpr.class, this::simplifyRef) - - .addCase(IteExpr.class, this::simplifyIte) - - // Reference - - .addCase(Dereference.class, this::simplifyDereference) - -// .addCase(Reference.class, this::simplifyReference) - - // Default - - .addDefault((o, val) -> { - final Expr expr = (Expr) o; - return expr.map(e -> simplify(e, val)); - }) - - .build(); + private final DispatchTable2> TABLE = + DispatchTable2.>builder() + + // Boolean + + .addCase(NotExpr.class, this::simplifyNot) + .addCase(ImplyExpr.class, this::simplifyImply) + .addCase(IffExpr.class, this::simplifyIff) + .addCase(XorExpr.class, this::simplifyXor) + .addCase(AndExpr.class, this::simplifyAnd) + .addCase(OrExpr.class, this::simplifyOr) + + // Rational + + .addCase(RatAddExpr.class, this::simplifyRatAdd) + .addCase(RatSubExpr.class, this::simplifyRatSub) + .addCase(RatPosExpr.class, this::simplifyRatPos) + .addCase(RatNegExpr.class, this::simplifyRatNeg) + .addCase(RatMulExpr.class, this::simplifyRatMul) + .addCase(RatDivExpr.class, this::simplifyRatDiv) + .addCase(RatEqExpr.class, this::simplifyRatEq) + .addCase(RatNeqExpr.class, this::simplifyRatNeq) + .addCase(RatGeqExpr.class, this::simplifyRatGeq) + .addCase(RatGtExpr.class, this::simplifyRatGt) + .addCase(RatLeqExpr.class, this::simplifyRatLeq) + .addCase(RatLtExpr.class, this::simplifyRatLt) + .addCase(RatToIntExpr.class, this::simplifyRatToInt) + + // Integer + + .addCase(IntToRatExpr.class, this::simplifyIntToRat) + .addCase(IntAddExpr.class, this::simplifyIntAdd) + .addCase(IntSubExpr.class, this::simplifyIntSub) + .addCase(IntPosExpr.class, this::simplifyIntPos) + .addCase(IntNegExpr.class, this::simplifyIntNeg) + .addCase(IntMulExpr.class, this::simplifyIntMul) + .addCase(IntDivExpr.class, this::simplifyIntDiv) + .addCase(IntModExpr.class, this::simplifyMod) + .addCase(IntRemExpr.class, this::simplifyRem) + .addCase(IntEqExpr.class, this::simplifyIntEq) + .addCase(IntNeqExpr.class, this::simplifyIntNeq) + .addCase(IntGeqExpr.class, this::simplifyIntGeq) + .addCase(IntGtExpr.class, this::simplifyIntGt) + .addCase(IntLeqExpr.class, this::simplifyIntLeq) + .addCase(IntLtExpr.class, this::simplifyIntLt) + + // Enum + + .addCase(EnumEqExpr.class, this::simplifyEnumEqExpr) + .addCase(EnumNeqExpr.class, this::simplifyEnumNeqExpr) + + // Array + + .addCase(ArrayReadExpr.class, this::simplifyArrayRead) + .addCase(ArrayWriteExpr.class, this::simplifyArrayWrite) + + // .addCase(ArrayInitExpr.class, this::simplifyArrayInit) + .addCase( + ArrayInitExpr.class, + (arrayInitExpr, valuation) -> + this.simplifyArrayInit(arrayInitExpr, valuation)) + + // Bitvectors + + .addCase(BvConcatExpr.class, this::simplifyBvConcat) + .addCase(BvExtractExpr.class, this::simplifyBvExtract) + .addCase(BvZExtExpr.class, this::simplifyBvZExt) + .addCase(BvSExtExpr.class, this::simplifyBvSExt) + .addCase(BvAddExpr.class, this::simplifyBvAdd) + .addCase(BvSubExpr.class, this::simplifyBvSub) + .addCase(BvPosExpr.class, this::simplifyBvPos) + .addCase(BvSignChangeExpr.class, this::simplifyBvSignChange) + .addCase(BvNegExpr.class, this::simplifyBvNeg) + .addCase(BvMulExpr.class, this::simplifyBvMul) + .addCase(BvUDivExpr.class, this::simplifyBvUDiv) + .addCase(BvSDivExpr.class, this::simplifyBvSDiv) + .addCase(BvSModExpr.class, this::simplifyBvSMod) + .addCase(BvURemExpr.class, this::simplifyBvURem) + .addCase(BvSRemExpr.class, this::simplifyBvSRem) + .addCase(BvAndExpr.class, this::simplifyBvAnd) + .addCase(BvOrExpr.class, this::simplifyBvOr) + .addCase(BvXorExpr.class, this::simplifyBvXor) + .addCase(BvNotExpr.class, this::simplifyBvNot) + .addCase(BvShiftLeftExpr.class, this::simplifyBvShiftLeft) + .addCase(BvArithShiftRightExpr.class, this::simplifyBvArithShiftRight) + .addCase(BvLogicShiftRightExpr.class, this::simplifyBvLogicShiftRight) + .addCase(BvRotateLeftExpr.class, this::simplifyBvRotateLeft) + .addCase(BvRotateRightExpr.class, this::simplifyBvRotateRight) + .addCase(BvEqExpr.class, this::simplifyBvEq) + .addCase(BvNeqExpr.class, this::simplifyBvNeq) + .addCase(BvUGeqExpr.class, this::simplifyBvUGeq) + .addCase(BvUGtExpr.class, this::simplifyBvUGt) + .addCase(BvULeqExpr.class, this::simplifyBvULeq) + .addCase(BvULtExpr.class, this::simplifyBvULt) + .addCase(BvSGeqExpr.class, this::simplifyBvSGeq) + .addCase(BvSGtExpr.class, this::simplifyBvSGt) + .addCase(BvSLeqExpr.class, this::simplifyBvSLeq) + .addCase(BvSLtExpr.class, this::simplifyBvSLt) + + // Floating points + + .addCase(FpAddExpr.class, this::simplifyFpAdd) + .addCase(FpSubExpr.class, this::simplifyFpSub) + .addCase(FpPosExpr.class, this::simplifyFpPos) + .addCase(FpNegExpr.class, this::simplifyFpNeg) + .addCase(FpMulExpr.class, this::simplifyFpMul) + .addCase(FpDivExpr.class, this::simplifyFpDiv) + .addCase(FpEqExpr.class, this::simplifyFpEq) + .addCase(FpAssignExpr.class, this::simplifyFpAssign) + .addCase(FpGeqExpr.class, this::simplifyFpGeq) + .addCase(FpLeqExpr.class, this::simplifyFpLeq) + .addCase(FpGtExpr.class, this::simplifyFpGt) + .addCase(FpLtExpr.class, this::simplifyFpLt) + .addCase(FpNeqExpr.class, this::simplifyFpNeq) + .addCase(FpAbsExpr.class, this::simplifyFpAbs) + .addCase(FpRoundToIntegralExpr.class, this::simplifyFpRoundToIntegral) + .addCase(FpMaxExpr.class, this::simplifyFpMax) + .addCase(FpMinExpr.class, this::simplifyFpMin) + .addCase(FpSqrtExpr.class, this::simplifyFpSqrt) + .addCase(FpIsNanExpr.class, this::simplifyFpIsNan) + .addCase(FpFromBvExpr.class, this::simplifyFpFromBv) + .addCase(FpToBvExpr.class, this::simplifyFpToBv) + .addCase(FpToFpExpr.class, this::simplifyFpToFp) + + // General + + .addCase(RefExpr.class, this::simplifyRef) + .addCase(IteExpr.class, this::simplifyIte) + + // Reference + + .addCase(Dereference.class, this::simplifyDereference) + + // .addCase(Reference.class, this::simplifyReference) + + // Default + + .addDefault( + (o, val) -> { + final Expr expr = (Expr) o; + return expr.map(e -> simplify(e, val)); + }) + .build(); private Expr simplifyRef(final RefExpr expr, final Valuation val) { return simplifyGenericRef(expr, val); @@ -315,8 +230,8 @@ private Expr simplifyRef(final RefExpr expr, final Valuation val) { // TODO Eliminate helper method once the Java compiler is able to handle // this kind of type inference - private Expr simplifyGenericRef(final RefExpr expr, - final Valuation val) { + private Expr simplifyGenericRef( + final RefExpr expr, final Valuation val) { final Optional> eval = val.eval(expr.getDecl()); if (eval.isPresent()) { return eval.get(); @@ -331,8 +246,8 @@ private Expr simplifyIte(final IteExpr expr, final Valuation val) { // TODO Eliminate helper method once the Java compiler is able to handle // this kind of type inference - private Expr simplifyGenericIte(final IteExpr expr, - final Valuation val) { + private Expr simplifyGenericIte( + final IteExpr expr, final Valuation val) { final Expr cond = simplify(expr.getCond(), val); if (cond instanceof TrueExpr) { @@ -358,11 +273,15 @@ private Expr simplifyArrayRead(final ArrayReadExpr expr, final Valuatio return simplifyGenericArrayRead(expr, val); } - private Expr - simplifyGenericArrayRead(final ArrayReadExpr expr, final Valuation val) { + private Expr simplifyGenericArrayRead( + final ArrayReadExpr expr, final Valuation val) { Expr> arr = simplify(expr.getArray(), val); Expr index = simplify(expr.getIndex(), val); - if (arr instanceof LitExpr && index instanceof LitExpr) { //The index is required to be a literal so that we can use 'equals' to compare it against existing keys in the array + if (arr instanceof LitExpr + && index + instanceof + LitExpr) { // The index is required to be a literal so that we can use + // 'equals' to compare it against existing keys in the array return expr.eval(val); } return expr.with(arr, index); @@ -372,19 +291,21 @@ private Expr simplifyArrayWrite(final ArrayWriteExpr expr, final Valuat return simplifyGenericArrayWrite(expr, val); } - private Expr> - simplifyGenericArrayWrite(final ArrayWriteExpr expr, final Valuation val) { + private Expr> simplifyGenericArrayWrite( + final ArrayWriteExpr expr, final Valuation val) { Expr> arr = simplify(expr.getArray(), val); Expr index = simplify(expr.getIndex(), val); Expr elem = simplify(expr.getElem(), val); - if (arr instanceof LitExpr && index instanceof LitExpr && elem instanceof LitExpr) { + if (arr instanceof LitExpr + && index instanceof LitExpr + && elem instanceof LitExpr) { return expr.eval(val); } return expr.with(arr, index, elem); } - private Expr> - simplifyArrayInit(final ArrayInitExpr t, final Valuation val) { + private Expr> simplifyArrayInit( + final ArrayInitExpr t, final Valuation val) { boolean nonLiteralFound = false; List, Expr>> newElements = new ArrayList<>(); Expr newElseElem = simplify(t.getElseElem(), val); @@ -946,6 +867,9 @@ private Expr simplifyIntDiv(final IntDivExpr expr, final Valuation val) if (leftOp instanceof IntLitExpr && rightOp instanceof IntLitExpr) { final IntLitExpr leftLit = (IntLitExpr) leftOp; final IntLitExpr rightLit = (IntLitExpr) rightOp; + if (rightLit.getValue().compareTo(BigInteger.ZERO) == 0) { + return expr.with(leftOp, rightOp); + } return leftLit.div(rightLit); } @@ -959,8 +883,12 @@ private Expr simplifyMod(final IntModExpr expr, final Valuation val) { if (leftOp instanceof IntLitExpr && rightOp instanceof IntLitExpr) { final IntLitExpr leftLit = (IntLitExpr) leftOp; final IntLitExpr rightLit = (IntLitExpr) rightOp; + if (rightLit.getValue().compareTo(BigInteger.ZERO) == 0) { + return expr.with(leftOp, rightOp); + } return leftLit.mod(rightLit); - } else if (leftOp instanceof IntModExpr && ((IntModExpr) leftOp).getRightOp().equals(rightOp)) { + } else if (leftOp instanceof IntModExpr + && ((IntModExpr) leftOp).getRightOp().equals(rightOp)) { return leftOp; } @@ -975,7 +903,8 @@ private Expr simplifyRem(final IntRemExpr expr, final Valuation val) { final IntLitExpr leftLit = (IntLitExpr) leftOp; final IntLitExpr rightLit = (IntLitExpr) rightOp; return leftLit.rem(rightLit); - } else if (leftOp instanceof IntRemExpr && ((IntRemExpr) leftOp).getRightOp().equals(rightOp)) { + } else if (leftOp instanceof IntRemExpr + && ((IntRemExpr) leftOp).getRightOp().equals(rightOp)) { return simplify(leftOp, val); } @@ -1096,11 +1025,13 @@ private Expr simplifyEnumEqExpr(final EnumEqExpr expr, final Valuation return Bool(leftLit.equals(rightLit)); } - if (leftOp instanceof RefExpr && rightOp instanceof RefExpr && level != LITERAL_ONLY && leftOp.equals(rightOp)) { + if (leftOp instanceof RefExpr + && rightOp instanceof RefExpr + && level != LITERAL_ONLY + && leftOp.equals(rightOp)) { return True(); } - return expr.with(leftOp, rightOp); } @@ -1116,11 +1047,13 @@ private Expr simplifyEnumNeqExpr(final EnumNeqExpr expr, final Valuati return Bool(!leftLit.equals(rightLit)); } - if (leftOp instanceof RefExpr && rightOp instanceof RefExpr && level != LITERAL_ONLY && leftOp.equals(rightOp)) { + if (leftOp instanceof RefExpr + && rightOp instanceof RefExpr + && level != LITERAL_ONLY + && leftOp.equals(rightOp)) { return False(); } - return expr.with(leftOp, rightOp); } @@ -1147,7 +1080,7 @@ private Expr simplifyBvConcat(final BvConcatExpr expr, final Valuation v } else { value = value.concat(litOp); } -// iterator.remove(); + // iterator.remove(); } else { return expr.withOps(ops); } @@ -1553,7 +1486,8 @@ private Expr simplifyBvShiftLeft(final BvShiftLeftExpr expr, final Valua return expr.with(leftOp, rightOp); } - private Expr simplifyBvArithShiftRight(final BvArithShiftRightExpr expr, final Valuation val) { + private Expr simplifyBvArithShiftRight( + final BvArithShiftRightExpr expr, final Valuation val) { final Expr leftOp = simplify(expr.getLeftOp(), val); final Expr rightOp = simplify(expr.getRightOp(), val); @@ -1566,7 +1500,8 @@ private Expr simplifyBvArithShiftRight(final BvArithShiftRightExpr expr, return expr.with(leftOp, rightOp); } - private Expr simplifyBvLogicShiftRight(final BvLogicShiftRightExpr expr, final Valuation val) { + private Expr simplifyBvLogicShiftRight( + final BvLogicShiftRightExpr expr, final Valuation val) { final Expr leftOp = simplify(expr.getLeftOp(), val); final Expr rightOp = simplify(expr.getRightOp(), val); @@ -1799,7 +1734,9 @@ private Expr simplifyFpAdd(final FpAddExpr expr, final Valuation val) { ops.add(opVisited); } } - final FpLitExpr zero = FpUtils.bigFloatToFpLitExpr(BigFloat.zero(expr.getType().getSignificand()), expr.getType()); + final FpLitExpr zero = + FpUtils.bigFloatToFpLitExpr( + BigFloat.zero(expr.getType().getSignificand()), expr.getType()); FpLitExpr value = zero; for (final Iterator> iterator = ops.iterator(); iterator.hasNext(); ) { @@ -1836,7 +1773,8 @@ private Expr simplifyFpSub(final FpSubExpr expr, final Valuation val) { if (leftOp instanceof RefExpr && rightOp instanceof RefExpr) { if (leftOp.equals(rightOp)) { - return FpUtils.bigFloatToFpLitExpr(BigFloat.zero(expr.getType().getSignificand()), expr.getType()); + return FpUtils.bigFloatToFpLitExpr( + BigFloat.zero(expr.getType().getSignificand()), expr.getType()); } } @@ -1886,13 +1824,16 @@ private Expr simplifyFpIsInfinite(final FpIsInfiniteExpr expr, final V final Expr op = simplify(expr.getOp(), val); if (op instanceof FpLitExpr) { - return Bool((((FpLitExpr) op).isNegativeInfinity() || ((FpLitExpr) op).isPositiveInfinity())); + return Bool( + (((FpLitExpr) op).isNegativeInfinity() + || ((FpLitExpr) op).isPositiveInfinity())); } return expr.with(op); } - private Expr simplifyFpRoundToIntegral(final FpRoundToIntegralExpr expr, final Valuation val) { + private Expr simplifyFpRoundToIntegral( + final FpRoundToIntegralExpr expr, final Valuation val) { final Expr op = simplify(expr.getOp(), val); if (op instanceof FpRoundToIntegralExpr) { @@ -1917,8 +1858,15 @@ private Expr simplifyFpMul(final FpMulExpr expr, final Valuation val) { } } - final FpLitExpr ZERO = FpUtils.bigFloatToFpLitExpr(BigFloat.zero(expr.getType().getSignificand()), expr.getType()); - final FpLitExpr ONE = FpUtils.bigFloatToFpLitExpr(new BigFloat(1.0f, FpUtils.getMathContext(expr.getType(), expr.getRoundingMode())), expr.getType()); + final FpLitExpr ZERO = + FpUtils.bigFloatToFpLitExpr( + BigFloat.zero(expr.getType().getSignificand()), expr.getType()); + final FpLitExpr ONE = + FpUtils.bigFloatToFpLitExpr( + new BigFloat( + 1.0f, + FpUtils.getMathContext(expr.getType(), expr.getRoundingMode())), + expr.getType()); FpLitExpr value = ONE; for (final Iterator> iterator = ops.iterator(); iterator.hasNext(); ) { @@ -2103,6 +2051,4 @@ private Expr simplifyFpToFp(final FpToFpExpr expr, final Valuation val) return expr.with(op); } - - } diff --git a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprUtils.java b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprUtils.java index 6873033f60..e429afb05e 100644 --- a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprUtils.java +++ b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/ExprUtils.java @@ -15,6 +15,9 @@ */ package hu.bme.mit.theta.core.utils; +import static com.google.common.base.Preconditions.checkNotNull; +import static hu.bme.mit.theta.core.utils.TypeUtils.cast; + import com.google.common.collect.ImmutableList; import hu.bme.mit.theta.common.Tuple2; import hu.bme.mit.theta.common.container.Containers; @@ -36,7 +39,7 @@ import hu.bme.mit.theta.core.type.functype.FuncAppExpr; import hu.bme.mit.theta.core.utils.IndexedVars.Builder; import hu.bme.mit.theta.core.utils.indexings.VarIndexing; - +import hu.bme.mit.theta.core.utils.indexings.VarIndexingFactory; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -48,26 +51,21 @@ import java.util.Set; import java.util.stream.Collectors; -import static com.google.common.base.Preconditions.checkNotNull; -import static hu.bme.mit.theta.core.utils.TypeUtils.cast; - -/** - * Utility functions related to expressions. - */ +/** Utility functions related to expressions. */ public final class ExprUtils { private static final ExprSimplifier exprSimplifier = ExprSimplifier.create(); - private ExprUtils() { - } + private ExprUtils() {} /** * Collect atoms from a Boolean expression into a given collection. * - * @param expr Expression + * @param expr Expression * @param collectTo Collection where the atoms should be put */ - public static void collectAtoms(final Expr expr, final Collection> collectTo) { + public static void collectAtoms( + final Expr expr, final Collection> collectTo) { ExprAtomCollector.collectAtoms(expr, collectTo); } @@ -114,7 +112,9 @@ public static Collection> getConjuncts(final Expr expr) if (expr instanceof AndExpr) { final AndExpr andExpr = (AndExpr) expr; - return andExpr.getOps().stream().map(ExprUtils::getConjuncts).flatMap(Collection::stream) + return andExpr.getOps().stream() + .map(ExprUtils::getConjuncts) + .flatMap(Collection::stream) .collect(Collectors.toSet()); } else { return Collections.singleton(expr); @@ -124,7 +124,7 @@ public static Collection> getConjuncts(final Expr expr) /** * Collect params of an expression into a given collection. * - * @param expr Expression + * @param expr Expression * @param collectTo Collection where the params should be put */ public static void collectParams(final Expr expr, final Collection> collectTo) { @@ -152,10 +152,11 @@ public static void collectParams(final Expr expr, final Collection> exprs, final Collection> collectTo) { + public static void collectParams( + final Iterable> exprs, final Collection> collectTo) { exprs.forEach(e -> collectParams(e, collectTo)); } @@ -183,11 +184,10 @@ public static Set> getParams(final Iterable> expr return vars; } - /** * Collect variables of an expression into a given collection. * - * @param expr Expression + * @param expr Expression * @param collectTo Collection where the variables should be put */ public static void collectVars(final Expr expr, final Collection> collectTo) { @@ -206,10 +206,11 @@ public static void collectVars(final Expr expr, final Collection> /** * Collect variables from expressions into a given collection. * - * @param exprs Expressions + * @param exprs Expressions * @param collectTo Collection where the variables should be put */ - public static void collectVars(final Iterable> exprs, final Collection> collectTo) { + public static void collectVars( + final Iterable> exprs, final Collection> collectTo) { exprs.forEach(e -> collectVars(e, collectTo)); } @@ -240,10 +241,11 @@ public static Set> getVars(final Iterable> exprs) { /** * Collect indexed constants of an expression into a given collection. * - * @param expr Expression + * @param expr Expression * @param collectTo Collection where the constants should be put */ - public static void collectIndexedConstants(final Expr expr, final Collection> collectTo) { + public static void collectIndexedConstants( + final Expr expr, final Collection> collectTo) { if (expr instanceof RefExpr) { final RefExpr refExpr = (RefExpr) expr; final Decl decl = refExpr.getDecl(); @@ -259,10 +261,12 @@ public static void collectIndexedConstants(final Expr expr, final Collection< /** * Collect indexed constants from expressions into a given collection. * - * @param exprs Expressions + * @param exprs Expressions * @param collectTo Collection where the constants should be put */ - public static void collectIndexedConstants(final Iterable> exprs, final Collection> collectTo) { + public static void collectIndexedConstants( + final Iterable> exprs, + final Collection> collectTo) { exprs.forEach(e -> collectIndexedConstants(e, collectTo)); } @@ -284,7 +288,8 @@ public static Set> getIndexedConstants(final Expr expr) { * @param exprs Expressions * @return Set of constants appearing in the expressions */ - public static Set> getIndexedConstants(final Iterable> exprs) { + public static Set> getIndexedConstants( + final Iterable> exprs) { final Set> consts = new HashSet<>(); collectIndexedConstants(exprs, consts); return consts; @@ -293,10 +298,11 @@ public static Set> getIndexedConstants(final Iterable expr, final Collection> collectTo) { + public static void collectConstants( + final Expr expr, final Collection> collectTo) { if (expr instanceof RefExpr) { final RefExpr refExpr = (RefExpr) expr; final Decl decl = refExpr.getDecl(); @@ -312,10 +318,11 @@ public static void collectConstants(final Expr expr, final Collection> exprs, final Collection> collectTo) { + public static void collectConstants( + final Iterable> exprs, final Collection> collectTo) { exprs.forEach(e -> collectConstants(e, collectTo)); } @@ -368,8 +375,7 @@ public static IndexedVars getVarsIndexed(final Iterable> exprs } /** - * Transform expression into an equivalent new expression without - * if-then-else constructs. + * Transform expression into an equivalent new expression without if-then-else constructs. * * @param expr Original expression * @return Transformed expression @@ -382,10 +388,11 @@ public static Expr eliminateIte(final Expr expr) { * Simplify expression and substitute the valuation. * * @param expr Original expression - * @param val Valuation + * @param val Valuation * @return Simplified expression */ - public static Expr simplify(final Expr expr, final Valuation val) { + public static Expr simplify( + final Expr expr, final Valuation val) { return exprSimplifier.simplify(expr, val); } @@ -439,6 +446,29 @@ public static List> canonizeAll(final List> exprs) { return canonizedArgs; } + /** + * Reverses the given expression (swaps primed variables with unprimed variables and + * vice-versa). Also works if variables can have multiple primes. + * + * @param expr Original expression + * @return Reversed form + */ + public static Expr reverse( + final Expr expr, final VarIndexing indexing) { + return new ExprReverser(indexing).reverse(expr); + } + + /** + * Reverses the given expression (swaps primed variables with unprimed variables and + * vice-versa). + * + * @param expr Original expression + * @return Reversed form + */ + public static Expr reverse(final Expr expr) { + return new ExprReverser(VarIndexingFactory.indexing(1)).reverse(expr); + } + /** * Transform an expression into a ponated one. * @@ -457,29 +487,29 @@ public static Expr ponate(final Expr expr) { /** * Transform an expression by universally quantifying certain variables. * - * @param expr Original expression + * @param expr Original expression * @param mapping Quantifying * @return Transformed expression */ - public static Expr close(final Expr expr, final Map, ParamDecl> mapping) { + public static Expr close( + final Expr expr, final Map, ParamDecl> mapping) { return ExprCloser.close(expr, mapping); } /** - * Transform an expression by applying primes to an expression based on an - * indexing. + * Transform an expression by applying primes to an expression based on an indexing. * - * @param expr Original expression + * @param expr Original expression * @param indexing Indexing * @return Transformed expression */ - public static Expr applyPrimes(final Expr expr, final VarIndexing indexing) { + public static Expr applyPrimes( + final Expr expr, final VarIndexing indexing) { return ExprPrimeApplier.applyPrimes(expr, indexing); } /** - * Get the size of an expression by counting the nodes in its tree - * representation. + * Get the size of an expression by counting the nodes in its tree representation. * * @param expr Expression * @return Node count @@ -491,11 +521,12 @@ public static int nodeCountSize(final Expr expr) { /** * Change fixed subexpressions using a lookup * - * @param expr the expr to change subexpressions in + * @param expr the expr to change subexpressions in * @param lookup the lookup mapping subexpression to replacements * @return the changed expression */ - public static Expr changeSubexpr(Expr expr, Map, Expr> lookup) { + public static Expr changeSubexpr( + Expr expr, Map, Expr> lookup) { if (lookup.containsKey(expr)) { return cast(lookup.get(expr), expr.getType()); } else { @@ -503,13 +534,16 @@ public static Expr changeSubexpr(Expr expr, Map, } } - public static Expr changeDecls(Expr expr, Map, ? extends Decl> lookup) { - return changeSubexpr(expr, lookup.entrySet().stream().map(entry -> Map.entry(entry.getKey().getRef(), entry.getValue().getRef())).collect(Collectors.toMap(Entry::getKey, Entry::getValue))); + public static Expr changeDecls( + Expr expr, Map, ? extends Decl> lookup) { + return changeSubexpr( + expr, + lookup.entrySet().stream() + .map(entry -> Map.entry(entry.getKey().getRef(), entry.getValue().getRef())) + .collect(Collectors.toMap(Entry::getKey, Entry::getValue))); } - /** - * Extracts function and its arguments from a nested expression - */ + /** Extracts function and its arguments from a nested expression */ public static Tuple2, List>> extractFuncAndArgs(final FuncAppExpr expr) { final Expr func = expr.getFunc(); final Expr arg = expr.getParam(); @@ -518,8 +552,8 @@ public static Tuple2, List>> extractFuncAndArgs(final FuncAppExp final Tuple2, List>> funcAndArgs = extractFuncAndArgs(funcApp); final Expr resFunc = funcAndArgs.get1(); final List> args = funcAndArgs.get2(); - final List> resArgs = ImmutableList.>builder().addAll(args).add(arg) - .build(); + final List> resArgs = + ImmutableList.>builder().addAll(args).add(arg).build(); return Tuple2.of(resFunc, resArgs); } else { return Tuple2.of(func, ImmutableList.of(arg)); diff --git a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/StmtToExprTransformer.java b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/StmtToExprTransformer.java index 1345d096f8..f278fa5009 100644 --- a/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/StmtToExprTransformer.java +++ b/subprojects/common/core/src/main/java/hu/bme/mit/theta/core/utils/StmtToExprTransformer.java @@ -15,6 +15,16 @@ */ package hu.bme.mit.theta.core.utils; +import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Eq; +import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Ite; +import static hu.bme.mit.theta.core.type.anytype.Exprs.Prime; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.And; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Bool; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Or; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.True; +import static hu.bme.mit.theta.core.type.fptype.FpExprs.FpAssign; +import static hu.bme.mit.theta.core.utils.TypeUtils.cast; + import com.google.common.collect.ImmutableList; import hu.bme.mit.theta.core.decl.VarDecl; import hu.bme.mit.theta.core.stmt.AssignStmt; @@ -35,29 +45,15 @@ import hu.bme.mit.theta.core.type.booltype.BoolType; import hu.bme.mit.theta.core.type.booltype.SmartBoolExprs; import hu.bme.mit.theta.core.type.fptype.FpType; -import hu.bme.mit.theta.core.type.inttype.IntType; import hu.bme.mit.theta.core.utils.indexings.VarIndexing; - import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Set; -import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Eq; -import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Ite; -import static hu.bme.mit.theta.core.type.anytype.Exprs.Prime; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.And; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Bool; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Or; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.True; -import static hu.bme.mit.theta.core.type.fptype.FpExprs.FpAssign; -import static hu.bme.mit.theta.core.type.inttype.IntExprs.Int; -import static hu.bme.mit.theta.core.utils.TypeUtils.cast; - final class StmtToExprTransformer { - private StmtToExprTransformer() { - } + private StmtToExprTransformer() {} static StmtUnfoldResult toExpr(final Stmt stmt, final VarIndexing indexing) { return stmt.accept(StmtToExprVisitor.INSTANCE, indexing); @@ -82,8 +78,7 @@ private static class StmtToExprVisitor implements StmtVisitor StmtUnfoldResult visit(final HavocStmt stmt, - final VarIndexing indexing) { + public StmtUnfoldResult visit( + final HavocStmt stmt, final VarIndexing indexing) { final VarDecl varDecl = stmt.getVarDecl(); final VarIndexing newIndexing = indexing.inc(varDecl); return StmtUnfoldResult.of(ImmutableList.of(True()), newIndexing); } @Override - public StmtUnfoldResult visit(final AssignStmt stmt, - final VarIndexing indexing) { + public StmtUnfoldResult visit( + final AssignStmt stmt, final VarIndexing indexing) { final VarDecl varDecl = stmt.getVarDecl(); final VarIndexing newIndexing = indexing.inc(varDecl); final Expr rhs = ExprUtils.applyPrimes(stmt.getExpr(), indexing); @@ -115,8 +110,10 @@ public StmtUnfoldResult visit(final AssignStmt final Expr expr; if (varDecl.getType() instanceof FpType) { - expr = FpAssign(TypeUtils.cast(lhs, (FpType) varDecl.getType()), - TypeUtils.cast(rhs, (FpType) varDecl.getType())); + expr = + FpAssign( + TypeUtils.cast(lhs, (FpType) varDecl.getType()), + TypeUtils.cast(rhs, (FpType) varDecl.getType())); } else { expr = Eq(lhs, rhs); } @@ -124,9 +121,14 @@ public StmtUnfoldResult visit(final AssignStmt } @Override - public StmtUnfoldResult visit(MemoryAssignStmt stmt, VarIndexing indexing) { + public + StmtUnfoldResult visit( + MemoryAssignStmt stmt, + VarIndexing indexing) { final Expr rhs = ExprUtils.applyPrimes(stmt.getExpr(), indexing); - final Dereference lhs = (Dereference) ExprUtils.applyPrimes(stmt.getDeref(), indexing); + final Dereference lhs = + (Dereference) + ExprUtils.applyPrimes(stmt.getDeref(), indexing); final var retExpr = Eq(lhs, rhs); return StmtUnfoldResult.of(ImmutableList.of(retExpr), indexing); @@ -135,8 +137,8 @@ public St @Override public StmtUnfoldResult visit(SequenceStmt sequenceStmt, VarIndexing indexing) { final StmtUnfoldResult result = toExpr(sequenceStmt.getStmts(), indexing); - return StmtUnfoldResult.of(ImmutableList.of(And(result.getExprs())), - result.getIndexing()); + return StmtUnfoldResult.of( + ImmutableList.of(And(result.getExprs())), result.getIndexing()); } @Override @@ -146,12 +148,13 @@ public StmtUnfoldResult visit(NonDetStmt nonDetStmt, VarIndexing indexing) { final List indexings = new ArrayList<>(); VarIndexing jointIndexing = indexing; int count = 0; - VarDecl tempVar = VarPoolUtil.requestInt(); + // VarDecl tempVar = VarPoolUtil.requestInt(); for (Stmt stmt : nonDetStmt.getStmts()) { - final Expr tempExpr = Eq( - ExprUtils.applyPrimes(tempVar.getRef(), indexing), Int(count++)); - final StmtUnfoldResult result = toExpr(stmt, indexing.inc(tempVar)); - choices.add(And(tempExpr, And(result.exprs))); + // final Expr tempExpr = Eq( + // ExprUtils.applyPrimes(tempVar.getRef(), indexing), + // Int(count++)); + final StmtUnfoldResult result = toExpr(stmt, indexing /*.inc(tempVar)*/); + choices.add(/*And(tempExpr, */ And(result.exprs) /*)*/); indexings.add(result.indexing); jointIndexing = jointIndexing.join(result.indexing); } @@ -165,8 +168,10 @@ public StmtUnfoldResult visit(NonDetStmt nonDetStmt, VarIndexing indexing) { int jointIndex = jointIndexing.get(decl); if (currentBranchIndex < jointIndex) { if (currentBranchIndex > 0) { - exprs.add(Eq(Prime(decl.getRef(), currentBranchIndex), - Prime(decl.getRef(), jointIndex))); + exprs.add( + Eq( + Prime(decl.getRef(), currentBranchIndex), + Prime(decl.getRef(), jointIndex))); } else { exprs.add(Eq(decl.getRef(), Prime(decl.getRef(), jointIndex))); } @@ -175,7 +180,7 @@ public StmtUnfoldResult visit(NonDetStmt nonDetStmt, VarIndexing indexing) { branchExprs.add(And(exprs)); } final Expr expr = Or(branchExprs); - VarPoolUtil.returnInt(tempVar); + // VarPoolUtil.returnInt(tempVar); return StmtUnfoldResult.of(ImmutableList.of(expr), jointIndexing); } @@ -184,10 +189,10 @@ public StmtUnfoldResult visit(IfStmt ifStmt, VarIndexing indexing) { final Expr cond = ifStmt.getCond(); final Expr condExpr = ExprUtils.applyPrimes(cond, indexing); - final StmtUnfoldResult thenResult = toExpr(ifStmt.getThen(), - indexing.transform().build()); - final StmtUnfoldResult elzeResult = toExpr(ifStmt.getElze(), - indexing.transform().build()); + final StmtUnfoldResult thenResult = + toExpr(ifStmt.getThen(), indexing.transform().build()); + final StmtUnfoldResult elzeResult = + toExpr(ifStmt.getElze(), indexing.transform().build()); final VarIndexing thenIndexing = thenResult.indexing; final VarIndexing elzeIndexing = elzeResult.indexing; @@ -205,14 +210,18 @@ public StmtUnfoldResult visit(IfStmt ifStmt, VarIndexing indexing) { if (thenIndex < elzeIndex) { if (thenIndex > 0) { thenAdditions.add( - Eq(Prime(decl.getRef(), thenIndex), Prime(decl.getRef(), elzeIndex))); + Eq( + Prime(decl.getRef(), thenIndex), + Prime(decl.getRef(), elzeIndex))); } else { thenAdditions.add(Eq(decl.getRef(), Prime(decl.getRef(), elzeIndex))); } } else if (elzeIndex < thenIndex) { if (elzeIndex > 0) { elzeAdditions.add( - Eq(Prime(decl.getRef(), elzeIndex), Prime(decl.getRef(), thenIndex))); + Eq( + Prime(decl.getRef(), elzeIndex), + Prime(decl.getRef(), thenIndex))); } else { elzeAdditions.add(Eq(decl.getRef(), Prime(decl.getRef(), thenIndex))); } @@ -220,14 +229,16 @@ public StmtUnfoldResult visit(IfStmt ifStmt, VarIndexing indexing) { } final Expr thenExprExtended = - thenAdditions.size() > 0 ? SmartBoolExprs.And(thenExpr, And(thenAdditions)) + thenAdditions.size() > 0 + ? SmartBoolExprs.And(thenExpr, And(thenAdditions)) : thenExpr; final Expr elzeExprExtended = - elzeAdditions.size() > 0 ? SmartBoolExprs.And(elzeExpr, And(elzeAdditions)) + elzeAdditions.size() > 0 + ? SmartBoolExprs.And(elzeExpr, And(elzeAdditions)) : elzeExpr; - final Expr ite = cast(Ite(condExpr, thenExprExtended, elzeExprExtended), - Bool()); + final Expr ite = + cast(Ite(condExpr, thenExprExtended, elzeExprExtended), Bool()); return StmtUnfoldResult.of(ImmutableList.of(ite), jointIndexing); } diff --git a/subprojects/common/core/src/test/java/hu/bme/mit/theta/core/utils/StmtToExprTransformerTest.java b/subprojects/common/core/src/test/java/hu/bme/mit/theta/core/utils/StmtToExprTransformerTest.java index 4b01823c64..9b5c00f6ee 100644 --- a/subprojects/common/core/src/test/java/hu/bme/mit/theta/core/utils/StmtToExprTransformerTest.java +++ b/subprojects/common/core/src/test/java/hu/bme/mit/theta/core/utils/StmtToExprTransformerTest.java @@ -15,6 +15,13 @@ */ package hu.bme.mit.theta.core.utils; +import static hu.bme.mit.theta.core.type.anytype.Exprs.Prime; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.And; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.False; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.True; +import static hu.bme.mit.theta.core.type.inttype.IntExprs.Eq; +import static hu.bme.mit.theta.core.type.inttype.IntExprs.Int; + import com.google.common.collect.ImmutableList; import hu.bme.mit.theta.core.decl.Decls; import hu.bme.mit.theta.core.decl.VarDecl; @@ -24,6 +31,8 @@ import hu.bme.mit.theta.core.type.booltype.BoolType; import hu.bme.mit.theta.core.type.inttype.IntType; import hu.bme.mit.theta.core.utils.indexings.VarIndexingFactory; +import java.util.Arrays; +import java.util.Collection; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -31,17 +40,6 @@ import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameters; -import java.util.Arrays; -import java.util.Collection; - -import static hu.bme.mit.theta.core.type.anytype.Exprs.Prime; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.And; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.False; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Or; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.True; -import static hu.bme.mit.theta.core.type.inttype.IntExprs.Eq; -import static hu.bme.mit.theta.core.type.inttype.IntExprs.Int; - @RunWith(Parameterized.class) public class StmtToExprTransformerTest { @@ -56,47 +54,61 @@ public class StmtToExprTransformerTest { @Parameters public static Collection data() { - return Arrays.asList(new Object[][]{ - - {Stmts.Assume(And(True(), False())), ImmutableList.of(And(True(), False()))}, - - {Stmts.Havoc(VX), ImmutableList.of(True())}, - - {Stmts.Assign(VX, Int(2)), ImmutableList.of(Eq(Prime(VX.getRef()), Int(2)))}, - - {Stmts.SequenceStmt(ImmutableList.of(Stmts.Assume(And(True(), False())))), - ImmutableList.of(And(ImmutableList.of(And(True(), False()))))}, - - {Stmts.SequenceStmt( - ImmutableList.of(Stmts.Assign(VX, Int(2)), Stmts.Assign(VX, Int(2)))), + return Arrays.asList( + new Object[][] { + {Stmts.Assume(And(True(), False())), ImmutableList.of(And(True(), False()))}, + {Stmts.Havoc(VX), ImmutableList.of(True())}, + {Stmts.Assign(VX, Int(2)), ImmutableList.of(Eq(Prime(VX.getRef()), Int(2)))}, + { + Stmts.SequenceStmt(ImmutableList.of(Stmts.Assume(And(True(), False())))), + ImmutableList.of(And(ImmutableList.of(And(True(), False())))) + }, + { + Stmts.SequenceStmt( + ImmutableList.of( + Stmts.Assign(VX, Int(2)), Stmts.Assign(VX, Int(2)))), ImmutableList.of( - And(Eq(Prime(VX.getRef()), Int(2)), Eq(Prime(Prime(VX.getRef())), Int(2))))}, - - {Stmts.NonDetStmt(ImmutableList.of(Stmts.Assume(And(True(), False())))), - ImmutableList.of(Or(ImmutableList.of(And(ImmutableList.of( - And(Eq(TEMP0.getRef(), Int(0)), - And(ImmutableList.of(And(True(), False())))))))))}, - - {Stmts.NonDetStmt(ImmutableList.of(Stmts.Assign(VX, Int(2)))), ImmutableList.of( - Or(ImmutableList.of(And(ImmutableList.of(And(Eq(TEMP0.getRef(), Int(0)), - And(ImmutableList.of(Eq(Prime(VX.getRef()), Int(2))))))))))}, - - {Stmts.NonDetStmt(ImmutableList.of(Stmts.Assume(True()), Stmts.Assign(VX, Int(2)))), - ImmutableList.of(Or(ImmutableList.of(And(ImmutableList.of( - And(Eq(TEMP0.getRef(), Int(0)), And(ImmutableList.of(True()))), - Eq(VX.getRef(), Prime(VX.getRef())))), And(ImmutableList.of( - And(Eq(TEMP0.getRef(), Int(1)), - And(ImmutableList.of(Eq(Prime(VX.getRef()), Int(2))))))))))} - - }); + And( + Eq(Prime(VX.getRef()), Int(2)), + Eq(Prime(Prime(VX.getRef())), Int(2)))) + }, + + // {Stmts.NonDetStmt(ImmutableList.of(Stmts.Assume(And(True(), + // False())))), + // + // ImmutableList.of(Or(ImmutableList.of(And(ImmutableList.of( + // And(Eq(TEMP0.getRef(), Int(0)), + // And(ImmutableList.of(And(True(), + // False())))))))))}, + // + // {Stmts.NonDetStmt(ImmutableList.of(Stmts.Assign(VX, Int(2)))), + // ImmutableList.of( + // + // Or(ImmutableList.of(And(ImmutableList.of(And(Eq(TEMP0.getRef(), Int(0)), + // And(ImmutableList.of(Eq(Prime(VX.getRef()), + // Int(2))))))))))}, + // + // {Stmts.NonDetStmt(ImmutableList.of(Stmts.Assume(True()), + // Stmts.Assign(VX, Int(2)))), + // + // ImmutableList.of(Or(ImmutableList.of(And(ImmutableList.of( + // And(Eq(TEMP0.getRef(), Int(0)), + // And(ImmutableList.of(True()))), + // Eq(VX.getRef(), Prime(VX.getRef())))), + // And(ImmutableList.of( + // And(Eq(TEMP0.getRef(), Int(1)), + // + // And(ImmutableList.of(Eq(Prime(VX.getRef()), Int(2))))))))))} + + }); } @Test public void test() { VarPoolUtil.returnInt(TEMP0); - final StmtUnfoldResult unfoldResult = StmtUtils.toExpr(stmt, - VarIndexingFactory.indexing(0)); + final StmtUnfoldResult unfoldResult = + StmtUtils.toExpr(stmt, VarIndexingFactory.indexing(0)); final Collection> actualExprs = unfoldResult.getExprs(); Assert.assertEquals(expectedExprs, actualExprs); } diff --git a/subprojects/frontends/c-frontend/src/main/java/hu/bme/mit/theta/frontend/transformation/grammar/function/FunctionVisitor.java b/subprojects/frontends/c-frontend/src/main/java/hu/bme/mit/theta/frontend/transformation/grammar/function/FunctionVisitor.java index d45fe00e83..c4ae8ea4de 100644 --- a/subprojects/frontends/c-frontend/src/main/java/hu/bme/mit/theta/frontend/transformation/grammar/function/FunctionVisitor.java +++ b/subprojects/frontends/c-frontend/src/main/java/hu/bme/mit/theta/frontend/transformation/grammar/function/FunctionVisitor.java @@ -481,6 +481,10 @@ public CStatement visitBodyDeclaration(CParser.BodyDeclarationContext ctx) { ctx.declaration().declarationSpecifiers(), ctx.declaration().initDeclaratorList()); CCompound compound = new CCompound(parseContext); + final var preCompound = new CCompound(parseContext); + final var postCompound = new CCompound(parseContext); + compound.setPreStatements(preCompound); + compound.setPostStatements(postCompound); for (CDeclaration declaration : declarations) { if (declaration.getInitExpr() != null) { createVars(declaration); @@ -546,16 +550,12 @@ public CStatement visitBodyDeclaration(CParser.BodyDeclarationContext ctx) { recordMetadata(ctx, cAssignment); compound.getcStatementList().add(cAssignment); if (declaration.getInitExpr() instanceof CCompound compoundInitExpr) { - final var preCompound = new CCompound(parseContext); - final var postCompound = new CCompound(parseContext); final var preStatements = collectPreStatements(compoundInitExpr); preCompound.getcStatementList().addAll(preStatements); final var postStatements = collectPostStatements(compoundInitExpr); postCompound.getcStatementList().addAll(postStatements); resetPreStatements(compoundInitExpr); resetPostStatements(compoundInitExpr); - compound.setPreStatements(preCompound); - compound.setPostStatements(postCompound); } } } diff --git a/subprojects/solver/solver-z3-legacy/src/main/java/hu/bme/mit/theta/solver/z3legacy/Z3TermTransformer.java b/subprojects/solver/solver-z3-legacy/src/main/java/hu/bme/mit/theta/solver/z3legacy/Z3TermTransformer.java index a6944812b5..ad0ae49302 100644 --- a/subprojects/solver/solver-z3-legacy/src/main/java/hu/bme/mit/theta/solver/z3legacy/Z3TermTransformer.java +++ b/subprojects/solver/solver-z3-legacy/src/main/java/hu/bme/mit/theta/solver/z3legacy/Z3TermTransformer.java @@ -15,6 +15,20 @@ */ package hu.bme.mit.theta.solver.z3legacy; +import static com.google.common.base.Preconditions.*; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static hu.bme.mit.theta.common.Utils.head; +import static hu.bme.mit.theta.common.Utils.tail; +import static hu.bme.mit.theta.core.decl.Decls.Param; +import static hu.bme.mit.theta.core.type.arraytype.ArrayExprs.Array; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.*; +import static hu.bme.mit.theta.core.type.bvtype.BvExprs.BvType; +import static hu.bme.mit.theta.core.type.functype.FuncExprs.App; +import static hu.bme.mit.theta.core.type.functype.FuncExprs.Func; +import static hu.bme.mit.theta.core.type.inttype.IntExprs.Int; +import static hu.bme.mit.theta.core.type.rattype.RatExprs.Rat; +import static java.lang.String.format; + import com.google.common.collect.ImmutableList; import com.microsoft.z3legacy.*; import com.microsoft.z3legacy.enumerations.Z3_decl_kind; @@ -27,13 +41,14 @@ import hu.bme.mit.theta.core.decl.ParamDecl; import hu.bme.mit.theta.core.type.Expr; import hu.bme.mit.theta.core.type.Type; -import hu.bme.mit.theta.core.type.anytype.Exprs; import hu.bme.mit.theta.core.type.abstracttype.*; +import hu.bme.mit.theta.core.type.anytype.Exprs; import hu.bme.mit.theta.core.type.anytype.IteExpr; import hu.bme.mit.theta.core.type.anytype.PrimeExpr; import hu.bme.mit.theta.core.type.arraytype.ArrayReadExpr; import hu.bme.mit.theta.core.type.arraytype.ArrayType; import hu.bme.mit.theta.core.type.arraytype.ArrayWriteExpr; +import hu.bme.mit.theta.core.type.booltype.*; import hu.bme.mit.theta.core.type.bvtype.BvAddExpr; import hu.bme.mit.theta.core.type.bvtype.BvAndExpr; import hu.bme.mit.theta.core.type.bvtype.BvArithShiftRightExpr; @@ -67,6 +82,8 @@ import hu.bme.mit.theta.core.type.bvtype.BvURemExpr; import hu.bme.mit.theta.core.type.bvtype.BvXorExpr; import hu.bme.mit.theta.core.type.bvtype.BvZExtExpr; +import hu.bme.mit.theta.core.type.enumtype.EnumLitExpr; +import hu.bme.mit.theta.core.type.enumtype.EnumType; import hu.bme.mit.theta.core.type.fptype.FpAbsExpr; import hu.bme.mit.theta.core.type.fptype.FpAddExpr; import hu.bme.mit.theta.core.type.fptype.FpDivExpr; @@ -86,9 +103,6 @@ import hu.bme.mit.theta.core.type.fptype.FpPosExpr; import hu.bme.mit.theta.core.type.fptype.FpRemExpr; import hu.bme.mit.theta.core.type.fptype.FpRoundToIntegralExpr; -import hu.bme.mit.theta.core.type.booltype.*; -import hu.bme.mit.theta.core.type.enumtype.EnumLitExpr; -import hu.bme.mit.theta.core.type.enumtype.EnumType; import hu.bme.mit.theta.core.type.fptype.FpRoundingMode; import hu.bme.mit.theta.core.type.fptype.FpSqrtExpr; import hu.bme.mit.theta.core.type.fptype.FpSubExpr; @@ -107,8 +121,6 @@ import hu.bme.mit.theta.core.utils.BvUtils; import hu.bme.mit.theta.core.utils.FpUtils; import hu.bme.mit.theta.core.utils.TypeUtils; -import org.kframework.mpfr.BigFloat; - import java.math.BigInteger; import java.util.ArrayList; import java.util.Arrays; @@ -123,27 +135,17 @@ import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.Stream; - -import static com.google.common.base.Preconditions.*; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static hu.bme.mit.theta.common.Utils.head; -import static hu.bme.mit.theta.common.Utils.tail; -import static hu.bme.mit.theta.core.decl.Decls.Param; -import static hu.bme.mit.theta.core.type.arraytype.ArrayExprs.Array; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.*; -import static hu.bme.mit.theta.core.type.bvtype.BvExprs.BvType; -import static hu.bme.mit.theta.core.type.functype.FuncExprs.App; -import static hu.bme.mit.theta.core.type.functype.FuncExprs.Func; -import static hu.bme.mit.theta.core.type.inttype.IntExprs.Int; -import static hu.bme.mit.theta.core.type.rattype.RatExprs.Rat; -import static java.lang.String.format; +import org.kframework.mpfr.BigFloat; final class Z3TermTransformer { private static final String PARAM_NAME_FORMAT = "_p%d"; private final Z3SymbolTable symbolTable; - private final Map, TriFunction>, Expr>> environment; + private final Map< + Tuple2, + TriFunction>, Expr>> + environment; public Z3TermTransformer(final Z3SymbolTable symbolTable) { this.symbolTable = symbolTable; @@ -240,92 +242,138 @@ public Z3TermTransformer(final Z3SymbolTable symbolTable) { this.addFunc("write", this.exprTernaryOperator(ArrayWriteExpr::create)); this.addFunc("select", this.exprBinaryOperator(ArrayReadExpr::create)); this.addFunc("store", this.exprTernaryOperator(ArrayWriteExpr::create)); - this.environment.put(Tuple2.of("fp.frombv", 1), (term, model, vars) -> { - FpType type = (FpType) transformSort(term.getSort()); - FpRoundingMode roundingmode = this.getRoundingMode((term.getArgs()[0]).toString()); - Expr op = (Expr) this.transform(term.getArgs()[1], model, vars); - return FpFromBvExpr.of(roundingmode, op, type, true); - }); - this.environment.put(Tuple2.of("fp.to_sbv", 2), (term, model, vars) -> { - BvType type = (BvType) transformSort(term.getSort()); - FpRoundingMode roundingmode = this.getRoundingMode((term.getArgs()[0]).toString()); - Expr op = (Expr) this.transform(term.getArgs()[1], model, vars); - return FpToBvExpr.of(roundingmode, op, type.getSize(), true); - }); - this.environment.put(Tuple2.of("fp.to_ubv", 2), (term, model, vars) -> { - BvType type = (BvType) transformSort(term.getSort()); - FpRoundingMode roundingmode = this.getRoundingMode((term.getArgs()[0]).toString()); - Expr op = (Expr) this.transform(term.getArgs()[1], model, vars); - return FpToBvExpr.of(roundingmode, op, type.getSize(), false); - }); - this.environment.put(Tuple2.of("to_fp", 2), (term, model, vars) -> { - FpType type = (FpType) transformSort(term.getSort()); - FpRoundingMode roundingmode = this.getRoundingMode((term.getArgs()[0]).toString()); - Expr op = this.transform(term.getArgs()[1], model, vars); - if (op.getType() instanceof FpType) { - return FpToFpExpr.of(roundingmode, (Expr) op, type.getExponent(), type.getSignificand()); - } else if (op.getType() instanceof BvType) { - return FpFromBvExpr.of(roundingmode, (Expr) op, FpType.of(type.getExponent(), type.getSignificand()), false); - } else { - throw new Z3Exception("Unsupported:" + op.getType()); - } - }); - this.environment.put(Tuple2.of("to_fp", 1), (term, model, vars) -> { - FpType type = (FpType) transformSort(term.getSort()); - Expr op = (Expr) this.transform(term.getArgs()[0], model, vars); - return FpFromBvExpr.of(FpRoundingMode.getDefaultRoundingMode(), op, FpType.of(type.getExponent(), type.getSignificand()), true); - }); - this.environment.put(Tuple2.of("extract", 1), (term, model, vars) -> { - Pattern pattern = Pattern.compile("extract ([0-9]+) ([0-9]+)"); - String termStr = term.toString(); - Matcher match = pattern.matcher(termStr); - if (match.find()) { - int to = Integer.parseInt(match.group(1)) + 1; - int from = Integer.parseInt(match.group(2)); - Expr op = (Expr) this.transform(term.getArgs()[0], model, vars); - return BvExtractExpr.of(op, IntExprs.Int(from), IntExprs.Int(to)); - } else { - throw new Z3Exception("Not supported: " + term); - } - }); - this.environment.put(Tuple2.of("zero_extend", 1), (term, model, vars) -> { - BvType type = (BvType) transformSort(term.getSort()); - Expr op = (Expr) this.transform(term.getArgs()[0], model, vars); - return BvZExtExpr.of(op, BvType.of(type.getSize())); - }); - this.environment.put(Tuple2.of("sign_extend", 1), (term, model, vars) -> { - BvType type = (BvType) transformSort(term.getSort()); - Expr op = (Expr) this.transform(term.getArgs()[0], model, vars); - return BvSExtExpr.of(op, BvType.of(type.getSize())); - }); - this.environment.put(Tuple2.of("EqZero", 1), (term, model, vars) -> { - Expr op = this.transform(term.getArgs()[0], model, vars); - return AbstractExprs.Eq(op, TypeUtils.getDefaultValue(op.getType())); - }); - this.environment.put(Tuple2.of("fp", 3), (term, model, vars) -> { - Expr op1 = (Expr) this.transform(term.getArgs()[0], model, vars); - Expr op2 = (Expr) this.transform(term.getArgs()[1], model, vars); - Expr op3 = (Expr) this.transform(term.getArgs()[2], model, vars); - return FpLitExpr.of(((BvLitExpr) op1).getValue()[0], (BvLitExpr) op2, (BvLitExpr) op3); - }); - this.environment.put(Tuple2.of("const", 1), (term, model, vars) -> { - return this.transform(term.getArgs()[0], model, vars); - }); - } - - private void addFunc(String name, Tuple2>, Expr>> func) { - checkArgument(!environment.containsKey(Tuple2.of(name, func.get1())), "Duplicate key: " + Tuple2.of(name, func.get1())); + this.environment.put( + Tuple2.of("fp.frombv", 1), + (term, model, vars) -> { + FpType type = (FpType) transformSort(term.getSort()); + FpRoundingMode roundingmode = + this.getRoundingMode((term.getArgs()[0]).toString()); + Expr op = (Expr) this.transform(term.getArgs()[1], model, vars); + return FpFromBvExpr.of(roundingmode, op, type, true); + }); + this.environment.put( + Tuple2.of("fp.to_sbv", 2), + (term, model, vars) -> { + BvType type = (BvType) transformSort(term.getSort()); + FpRoundingMode roundingmode = + this.getRoundingMode((term.getArgs()[0]).toString()); + Expr op = (Expr) this.transform(term.getArgs()[1], model, vars); + return FpToBvExpr.of(roundingmode, op, type.getSize(), true); + }); + this.environment.put( + Tuple2.of("fp.to_ubv", 2), + (term, model, vars) -> { + BvType type = (BvType) transformSort(term.getSort()); + FpRoundingMode roundingmode = + this.getRoundingMode((term.getArgs()[0]).toString()); + Expr op = (Expr) this.transform(term.getArgs()[1], model, vars); + return FpToBvExpr.of(roundingmode, op, type.getSize(), false); + }); + this.environment.put( + Tuple2.of("to_fp", 2), + (term, model, vars) -> { + FpType type = (FpType) transformSort(term.getSort()); + FpRoundingMode roundingmode = + this.getRoundingMode((term.getArgs()[0]).toString()); + Expr op = this.transform(term.getArgs()[1], model, vars); + if (op.getType() instanceof FpType) { + return FpToFpExpr.of( + roundingmode, + (Expr) op, + type.getExponent(), + type.getSignificand()); + } else if (op.getType() instanceof BvType) { + return FpFromBvExpr.of( + roundingmode, + (Expr) op, + FpType.of(type.getExponent(), type.getSignificand()), + false); + } else { + throw new Z3Exception("Unsupported:" + op.getType()); + } + }); + this.environment.put( + Tuple2.of("to_fp", 1), + (term, model, vars) -> { + FpType type = (FpType) transformSort(term.getSort()); + Expr op = (Expr) this.transform(term.getArgs()[0], model, vars); + return FpFromBvExpr.of( + FpRoundingMode.getDefaultRoundingMode(), + op, + FpType.of(type.getExponent(), type.getSignificand()), + true); + }); + this.environment.put( + Tuple2.of("extract", 1), + (term, model, vars) -> { + Pattern pattern = Pattern.compile("extract ([0-9]+) ([0-9]+)"); + String termStr = term.toString(); + Matcher match = pattern.matcher(termStr); + if (match.find()) { + int to = Integer.parseInt(match.group(1)) + 1; + int from = Integer.parseInt(match.group(2)); + Expr op = + (Expr) this.transform(term.getArgs()[0], model, vars); + return BvExtractExpr.of(op, IntExprs.Int(from), IntExprs.Int(to)); + } else { + throw new Z3Exception("Not supported: " + term); + } + }); + this.environment.put( + Tuple2.of("zero_extend", 1), + (term, model, vars) -> { + BvType type = (BvType) transformSort(term.getSort()); + Expr op = (Expr) this.transform(term.getArgs()[0], model, vars); + return BvZExtExpr.of(op, BvType.of(type.getSize())); + }); + this.environment.put( + Tuple2.of("sign_extend", 1), + (term, model, vars) -> { + BvType type = (BvType) transformSort(term.getSort()); + Expr op = (Expr) this.transform(term.getArgs()[0], model, vars); + return BvSExtExpr.of(op, BvType.of(type.getSize())); + }); + this.environment.put( + Tuple2.of("EqZero", 1), + (term, model, vars) -> { + Expr op = this.transform(term.getArgs()[0], model, vars); + return AbstractExprs.Eq(op, TypeUtils.getDefaultValue(op.getType())); + }); + this.environment.put( + Tuple2.of("fp", 3), + (term, model, vars) -> { + Expr op1 = + (Expr) this.transform(term.getArgs()[0], model, vars); + Expr op2 = + (Expr) this.transform(term.getArgs()[1], model, vars); + Expr op3 = + (Expr) this.transform(term.getArgs()[2], model, vars); + return FpLitExpr.of( + ((BvLitExpr) op1).getValue()[0], (BvLitExpr) op2, (BvLitExpr) op3); + }); + this.environment.put( + Tuple2.of("const", 1), + (term, model, vars) -> { + return this.transform(term.getArgs()[0], model, vars); + }); + } + + private void addFunc( + String name, + Tuple2>, Expr>> + func) { + assert !environment.containsKey(Tuple2.of(name, func.get1())); environment.put(Tuple2.of(name, func.get1()), func.get2()); } - public Expr toExpr(final com.microsoft.z3legacy.Expr term) { return transform(term, null, new ArrayList<>()); } - public Expr toFuncLitExpr(final FuncDecl funcDecl, final Model model, - final List> vars) { - checkNotNull(model, + public Expr toFuncLitExpr( + final FuncDecl funcDecl, final Model model, final List> vars) { + checkNotNull( + model, "Unsupported function '" + funcDecl.getName() + "' in Z3 back-transformation."); final com.microsoft.z3legacy.FuncInterp funcInterp = model.getFuncInterp(funcDecl); final List> paramDecls = transformParams(vars, funcDecl.getDomain()); @@ -335,11 +383,11 @@ public Expr toFuncLitExpr(final FuncDecl funcDecl, final Model model, return funcLitExpr; } - public Expr toArrayLitExpr(final FuncDecl funcDecl, final Model model, - final List> vars) { + public Expr toArrayLitExpr( + final FuncDecl funcDecl, final Model model, final List> vars) { final com.microsoft.z3legacy.FuncInterp funcInterp = model.getFuncInterp(funcDecl); - final List>, Expr>> entryExprs = createEntryExprs(funcInterp, model, - vars); + final List>, Expr>> entryExprs = + createEntryExprs(funcInterp, model, vars); final Expr elseExpr = transform(funcInterp.getElse(), model, vars); final ArraySort sort = (ArraySort) funcDecl.getRange(); @@ -347,27 +395,38 @@ public Expr toArrayLitExpr(final FuncDecl funcDecl, final Model model, return createArrayLitExpr(sort, entryExprs, elseExpr); } - private Expr createArrayLitExpr(ArraySort sort, - List>, Expr>> entryExprs, Expr elseExpr) { - return this.createIndexValueArrayLitExpr(transformSort(sort.getDomain()), - transformSort(sort.getRange()), entryExprs, elseExpr); + private Expr createArrayLitExpr( + ArraySort sort, List>, Expr>> entryExprs, Expr elseExpr) { + return this.createIndexValueArrayLitExpr( + transformSort(sort.getDomain()), + transformSort(sort.getRange()), + entryExprs, + elseExpr); } @SuppressWarnings("unchecked") - private Expr createIndexValueArrayLitExpr(I indexType, - E elemType, List>, Expr>> entryExprs, Expr elseExpr) { - return Array(entryExprs.stream().map(entry -> { - checkState(entry.get1().size() == 1); - return Tuple2.of((Expr) entry.get1().get(0), (Expr) entry.get2()); - }).collect(Collectors.toUnmodifiableList()), + private Expr createIndexValueArrayLitExpr( + I indexType, + E elemType, + List>, Expr>> entryExprs, + Expr elseExpr) { + return Array( + entryExprs.stream() + .map( + entry -> { + checkState(entry.get1().size() == 1); + return Tuple2.of( + (Expr) entry.get1().get(0), (Expr) entry.get2()); + }) + .collect(Collectors.toUnmodifiableList()), (Expr) elseExpr, ArrayType.of(indexType, elemType)); } //////// - private Expr transform(final com.microsoft.z3legacy.Expr term, final Model model, - final List> vars) { + private Expr transform( + final com.microsoft.z3legacy.Expr term, final Model model, final List> vars) { if (term.isIntNum()) { return transformIntLit(term); @@ -375,7 +434,8 @@ private Expr transform(final com.microsoft.z3legacy.Expr term, final Model mo return transformRatLit(term); // BitVecNum is not BVNumeral? Potential bug? - } else if (/* term.isBVNumeral() */ term instanceof com.microsoft.z3legacy.BitVecNum) { + } else if ( + /* term.isBVNumeral() */ term instanceof com.microsoft.z3legacy.BitVecNum) { return transformBvLit(term); } else if (term instanceof FPNum) { @@ -388,7 +448,8 @@ private Expr transform(final com.microsoft.z3legacy.Expr term, final Model mo return transformApp(term, model, vars); } else if (term.isQuantifier()) { - final com.microsoft.z3legacy.Quantifier quantifier = (com.microsoft.z3legacy.Quantifier) term; + final com.microsoft.z3legacy.Quantifier quantifier = + (com.microsoft.z3legacy.Quantifier) term; return transformQuantifier(quantifier, model, vars); } else if (term.isVar()) { @@ -419,12 +480,12 @@ private Expr transformRatLit(final com.microsoft.z3legacy.Expr term) { return Rat(num, denom); } - private Expr transformArrLit(final com.microsoft.z3legacy.Expr term, final Model model, - final List> vars) { + private Expr transformArrLit( + final com.microsoft.z3legacy.Expr term, final Model model, final List> vars) { final ArrayExpr arrayExpr = (ArrayExpr) term; final ArraySort sort = (ArraySort) arrayExpr.getSort(); - return createArrayLitExpr(sort, Arrays.asList(), - transform(arrayExpr.getArgs()[0], model, vars)); + return createArrayLitExpr( + sort, Arrays.asList(), transform(arrayExpr.getArgs()[0], model, vars)); } private Expr transformBvLit(final com.microsoft.z3legacy.Expr term) { @@ -440,11 +501,11 @@ private Expr transformFpLit(final com.microsoft.z3legacy.Expr term) { FpType type = FpType.of((fpTerm).getEBits(), (fpTerm).getSBits()); String printed = term.toString(); if (printed.equals("+oo")) { - return FpUtils.bigFloatToFpLitExpr(BigFloat.positiveInfinity(type.getSignificand()), - type); + return FpUtils.bigFloatToFpLitExpr( + BigFloat.positiveInfinity(type.getSignificand()), type); } else if (printed.equals("-oo")) { - return FpUtils.bigFloatToFpLitExpr(BigFloat.negativeInfinity(type.getSignificand()), - type); + return FpUtils.bigFloatToFpLitExpr( + BigFloat.negativeInfinity(type.getSignificand()), type); } else if (printed.equals("NaN")) { return FpUtils.bigFloatToFpLitExpr(BigFloat.NaN(type.getSignificand()), type); } else if (printed.equals("+zero")) { @@ -452,24 +513,31 @@ private Expr transformFpLit(final com.microsoft.z3legacy.Expr term) { } else if (printed.equals("-zero")) { return FpUtils.bigFloatToFpLitExpr(BigFloat.negativeZero(type.getSignificand()), type); } - BigFloat bigFloat = new BigFloat((fpTerm).getSignificand(), - FpUtils.getMathContext(type, FpRoundingMode.RNE)).multiply( - new BigFloat("2", FpUtils.getMathContext(type, FpRoundingMode.RNE)).pow( - new BigFloat((fpTerm).getExponent(), - FpUtils.getMathContext(type, FpRoundingMode.RNE)), - FpUtils.getMathContext(type, FpRoundingMode.RNE)), - FpUtils.getMathContext(type, FpRoundingMode.RNE)); + BigFloat bigFloat = + new BigFloat( + (fpTerm).getSignificand(), + FpUtils.getMathContext(type, FpRoundingMode.RNE)) + .multiply( + new BigFloat("2", FpUtils.getMathContext(type, FpRoundingMode.RNE)) + .pow( + new BigFloat( + (fpTerm).getExponent(), + FpUtils.getMathContext( + type, FpRoundingMode.RNE)), + FpUtils.getMathContext(type, FpRoundingMode.RNE)), + FpUtils.getMathContext(type, FpRoundingMode.RNE)); return FpUtils.bigFloatToFpLitExpr(bigFloat, type); } - private Expr transformEnumLit(final com.microsoft.z3legacy.Expr term, final EnumType enumType) { + private Expr transformEnumLit( + final com.microsoft.z3legacy.Expr term, final EnumType enumType) { String longName = term.getFuncDecl().getName().toString(); String literal = EnumType.getShortName(longName); return EnumLitExpr.of(enumType, literal); } - private Expr transformApp(final com.microsoft.z3legacy.Expr term, final Model model, - final List> vars) { + private Expr transformApp( + final com.microsoft.z3legacy.Expr term, final Model model, final List> vars) { final FuncDecl funcDecl = term.getFuncDecl(); final String symbol = funcDecl.getName().toString(); @@ -491,41 +559,50 @@ private Expr transformApp(final com.microsoft.z3legacy.Expr term, final Model } } - private Expr transformFuncInterp(final com.microsoft.z3legacy.FuncInterp funcInterp, - final Model model, final List> vars) { + private Expr transformFuncInterp( + final com.microsoft.z3legacy.FuncInterp funcInterp, + final Model model, + final List> vars) { checkArgument(funcInterp.getArity() >= 1); final ParamDecl paramDecl = (ParamDecl) vars.get(vars.size() - 1); - final Expr op = createFuncLitExprBody( - vars.subList(vars.size() - funcInterp.getArity(), vars.size()).stream() - .map(decl -> (ParamDecl) decl).collect(Collectors.toList()), funcInterp, model, - vars); + final Expr op = + createFuncLitExprBody( + vars.subList(vars.size() - funcInterp.getArity(), vars.size()).stream() + .map(decl -> (ParamDecl) decl) + .collect(Collectors.toList()), + funcInterp, + model, + vars); return Func(paramDecl, op); } - private Expr createFuncLitExprBody(final List> paramDecl, - final com.microsoft.z3legacy.FuncInterp funcInterp, - final Model model, final List> vars) { - final List>, Expr>> entryExprs = createEntryExprs(funcInterp, model, - vars); + private Expr createFuncLitExprBody( + final List> paramDecl, + final com.microsoft.z3legacy.FuncInterp funcInterp, + final Model model, + final List> vars) { + final List>, Expr>> entryExprs = + createEntryExprs(funcInterp, model, vars); final Expr elseExpr = transform(funcInterp.getElse(), model, vars); return createNestedIteExpr(paramDecl, entryExprs, elseExpr); } - private Expr createNestedIteExpr(final List> paramDecl, - final List>, Expr>> entryExprs, - final Expr elseExpr) { + private Expr createNestedIteExpr( + final List> paramDecl, + final List>, Expr>> entryExprs, + final Expr elseExpr) { if (entryExprs.isEmpty()) { return elseExpr; } else { final Tuple2>, Expr> head = head(entryExprs); - checkState(paramDecl.size() == head.get1().size(), - "Mismatched argument-parameter size!"); + checkState( + paramDecl.size() == head.get1().size(), "Mismatched argument-parameter size!"); final List>, Expr>> tail = tail(entryExprs); Expr cond = null; for (int i = 0; i < paramDecl.size(); i++) { - final Expr newTerm = EqExpr.create2(paramDecl.get(i).getRef(), - head.get1().get(i)); + final Expr newTerm = + EqExpr.create2(paramDecl.get(i).getRef(), head.get1().get(i)); cond = cond == null ? newTerm : And(cond, newTerm); } @@ -537,8 +614,10 @@ private Expr createNestedIteExpr(final List> paramDecl, private List>, Expr>> createEntryExprs( final com.microsoft.z3legacy.FuncInterp funcInterp, - final Model model, final List> vars) { - final ImmutableList.Builder>, Expr>> builder = ImmutableList.builder(); + final Model model, + final List> vars) { + final ImmutableList.Builder>, Expr>> builder = + ImmutableList.builder(); for (final com.microsoft.z3legacy.FuncInterp.Entry entry : funcInterp.getEntries()) { checkArgument(entry.getArgs().length >= 1); final List> args = new ArrayList<>(); @@ -553,8 +632,10 @@ private List>, Expr>> createEntryExprs( return builder.build(); } - private Expr transformQuantifier(final com.microsoft.z3legacy.Quantifier term, final Model model, - final List> vars) { + private Expr transformQuantifier( + final com.microsoft.z3legacy.Quantifier term, + final Model model, + final List> vars) { if (term.isUniversal()) { return transformForall(term, model, vars); @@ -572,30 +653,37 @@ private Expr transformVar(final com.microsoft.z3legacy.Expr term, final List< return decl.getRef(); } - private

Expr transformFuncApp(final Expr expr, - final com.microsoft.z3legacy.Expr[] argTerms, final Model model, final List> vars) { - final List terms = Arrays.stream(argTerms) - .collect(Collectors.toList()); + private

Expr transformFuncApp( + final Expr expr, + final com.microsoft.z3legacy.Expr[] argTerms, + final Model model, + final List> vars) { + final List terms = + Arrays.stream(argTerms).collect(Collectors.toList()); return toApp((Expr>) expr, terms, model, vars); } - private

Expr toApp(Expr> expr, - List terms, Model model, List> vars) { + private

Expr toApp( + Expr> expr, + List terms, + Model model, + List> vars) { if (terms.size() == 0) { return expr; } final com.microsoft.z3legacy.Expr term = terms.get(0); terms.remove(0); final Expr

transformed = (Expr

) transform(term, model, vars); - return toApp((Expr, R>>) App(expr, transformed), terms, model, - vars); + return toApp( + (Expr, R>>) App(expr, transformed), terms, model, vars); } //// - private Expr transformForall(final com.microsoft.z3legacy.Expr term, final Model model, - final List> vars) { - final com.microsoft.z3legacy.Quantifier quantifier = (com.microsoft.z3legacy.Quantifier) term; + private Expr transformForall( + final com.microsoft.z3legacy.Expr term, final Model model, final List> vars) { + final com.microsoft.z3legacy.Quantifier quantifier = + (com.microsoft.z3legacy.Quantifier) term; final com.microsoft.z3legacy.BoolExpr opTerm = quantifier.getBody(); final com.microsoft.z3legacy.Sort[] sorts = quantifier.getBoundVariableSorts(); final List> paramDecls = transformParams(vars, sorts); @@ -607,9 +695,10 @@ private Expr transformForall(final com.microsoft.z3legacy.Expr term, final Mo return Forall(paramDecls, op); } - private Expr transformExists(final com.microsoft.z3legacy.Expr term, final Model model, - final List> vars) { - final com.microsoft.z3legacy.Quantifier quantifier = (com.microsoft.z3legacy.Quantifier) term; + private Expr transformExists( + final com.microsoft.z3legacy.Expr term, final Model model, final List> vars) { + final com.microsoft.z3legacy.Quantifier quantifier = + (com.microsoft.z3legacy.Quantifier) term; final com.microsoft.z3legacy.BoolExpr opTerm = quantifier.getBody(); final com.microsoft.z3legacy.Sort[] sorts = quantifier.getBoundVariableSorts(); final List> paramDecls = transformParams(vars, sorts); @@ -621,8 +710,8 @@ private Expr transformExists(final com.microsoft.z3legacy.Expr term, final Mo return Exists(paramDecls, op); } - private List> transformParams(final List> vars, - final com.microsoft.z3legacy.Sort[] sorts) { + private List> transformParams( + final List> vars, final com.microsoft.z3legacy.Sort[] sorts) { final ImmutableList.Builder> builder = ImmutableList.builder(); for (final com.microsoft.z3legacy.Sort sort : sorts) { final ParamDecl param = transformParam(vars, sort); @@ -632,8 +721,8 @@ private List> transformParams(final List> vars, return paramDecls; } - private ParamDecl transformParam(final List> vars, - final com.microsoft.z3legacy.Sort sort) { + private ParamDecl transformParam( + final List> vars, final com.microsoft.z3legacy.Sort sort) { final Type type = transformSort(sort); final ParamDecl param = Param(format(PARAM_NAME_FORMAT, vars.size()), type); return param; @@ -647,7 +736,8 @@ private Type transformSort(final com.microsoft.z3legacy.Sort sort) { } else if (sort instanceof com.microsoft.z3legacy.RealSort) { return Rat(); } else if (sort instanceof com.microsoft.z3legacy.BitVecSort) { - final com.microsoft.z3legacy.BitVecSort bvSort = (com.microsoft.z3legacy.BitVecSort) sort; + final com.microsoft.z3legacy.BitVecSort bvSort = + (com.microsoft.z3legacy.BitVecSort) sort; return BvType(bvSort.getSize()); } else { throw new AssertionError("Unsupported sort: " + sort); @@ -664,141 +754,178 @@ private void popParams(final List> vars, final List> paramD } } - private Expr transformUnsupported(final com.microsoft.z3legacy.Expr term, final Model model, - final List> vars) { + private Expr transformUnsupported( + final com.microsoft.z3legacy.Expr term, final Model model, final List> vars) { throw new UnsupportedOperationException("Unsupported term: " + term); } //// - - private Tuple2>, Expr>> exprFpUnaryOperator( - final BiFunction, Expr> function) { - return Tuple2.of(2, (term, model, vars) -> { - checkArgument(term.getArgs().length == 2, "Number of arguments must be two"); - final var roundingmode = getRoundingMode(term.getArgs()[0].toString()); - final Expr op2 = transform(term.getArgs()[1], model, vars); - return function.apply(roundingmode, op2); - }); - } - - private Tuple2>, Expr>> exprFpBinaryOperator( - final TriFunction, Expr, Expr> function) { - return Tuple2.of(3, (term, model, vars) -> { - checkArgument(term.getArgs().length == 3, "Number of arguments must be three"); - final var roundingmode = getRoundingMode(term.getArgs()[0].toString()); - final Expr op1 = transform(term.getArgs()[1], model, vars); - final Expr op2 = transform(term.getArgs()[2], model, vars); - return function.apply(roundingmode, op1, op2); - }); - } - - private Tuple2>, Expr>> exprFpMultiaryOperator( - final BiFunction>, Expr> function) { - return Tuple2.of(-1, (term, model, vars) -> { - final var roundingmode = getRoundingMode(term.getArgs()[0].toString()); - final List> ops = Arrays.stream(term.getArgs()).skip(1).map(arg -> transform(arg, model, vars)) - .collect(toImmutableList()); - return function.apply(roundingmode, ops); - }); - } - - private Tuple2>, Expr>> exprFpLitUnaryOperator( - final BiFunction> function) { - return Tuple2.of(3, (term, model, vars) -> { - final BvLitExpr op1 = (BvLitExpr) transform(term.getArgs()[0], model, vars); - final IntLitExpr op2 = (IntLitExpr) transform(term.getArgs()[1], model, vars); - final IntLitExpr op3 = (IntLitExpr) transform(term.getArgs()[2], model, vars); - return function.apply(op1, FpType.of(op2.getValue().intValue(), op3.getValue().intValue() + 1)); - }); - } - - private Tuple2>, Expr>> exprNullaryOperator( - final Supplier> function) { - return Tuple2.of(0, (term, model, vars) -> { - final com.microsoft.z3legacy.Expr[] args = term.getArgs(); - checkArgument(args.length == 0, "Number of arguments must be zero"); - return function.get(); - }); - } - - private Tuple2>, Expr>> exprUnaryOperator( - final UnaryOperator> function) { - return Tuple2.of(1, (term, model, vars) -> { - final com.microsoft.z3legacy.Expr[] args = term.getArgs(); - checkArgument(args.length == 1, "Number of arguments must be one"); - final Expr op = transform(args[0], model, vars); - return function.apply(op); - }); - } - - private Tuple2>, Expr>> exprBinaryOperator( - final BinaryOperator> function) { - return Tuple2.of(2, (term, model, vars) -> { - final com.microsoft.z3legacy.Expr[] args = term.getArgs(); - checkArgument(args.length == 2, "Number of arguments must be two"); - if (args[0].getSort().getSortKind().equals(Z3_sort_kind.Z3_DATATYPE_SORT)) { - // binary operator is on enum types - // if either arg is a literal, we need special handling to get its type - // (references' decl kind is Z3_OP_UNINTERPRETED, literals' decl kind is Z3_OP_DT_CONSTRUCTOR) - int litIndex = -1; - for (int i = 0; i < 2; i++) { - if (args[i].getFuncDecl().getDeclKind().equals(Z3_decl_kind.Z3_OP_DT_CONSTRUCTOR)) - litIndex = i; - } - if (litIndex > -1) { - int refIndex = Math.abs(litIndex - 1); - final Expr refOp = transform(args[refIndex], model, vars); - final Expr litExpr = transformEnumLit(args[litIndex], (EnumType) refOp.getType()); - return function.apply(refOp, litExpr); - } - } - final Expr op1 = transform(args[0], model, vars); - final Expr op2 = transform(args[1], model, vars); - return function.apply(op1, op2); - }); - } - - private Tuple2>, Expr>> reference() { - return Tuple2.of(1, (term, model, vars) -> { - final com.microsoft.z3legacy.Expr[] args = term.getArgs(); - checkArgument(args.length == 1, "Number of arguments must be one"); - final Expr op = transform(args[0], model, vars); - return Exprs.Reference(op, transformSort(term.getSort())); - }); - } - - private Tuple2>, Expr>> dereference() { - return Tuple2.of(3, (term, model, vars) -> { - final com.microsoft.z3legacy.Expr[] args = term.getArgs(); - checkArgument(args.length == 3, "Number of arguments must be three"); - final Expr op1 = (Expr) transform(args[0], model, vars); - final Expr op2 = (Expr) transform(args[1], model, vars); - final Expr op3 = (Expr) transform(args[2], model, vars); - return Exprs.Dereference(op1, op2, transformSort(term.getSort())).withUniquenessExpr(op3); - }); - } - - private Tuple2>, Expr>> exprTernaryOperator( - final TernaryOperator> function) { - return Tuple2.of(3, (term, model, vars) -> { - final com.microsoft.z3legacy.Expr[] args = term.getArgs(); - checkArgument(args.length == 3, "Number of arguments must be three"); - final Expr op1 = transform(args[0], model, vars); - final Expr op2 = transform(args[1], model, vars); - final Expr op3 = transform(args[2], model, vars); - return function.apply(op1, op2, op3); - }); - } - - private Tuple2>, Expr>> exprMultiaryOperator( - final Function>, Expr> function) { - return Tuple2.of(-1, (term, model, vars) -> { - final com.microsoft.z3legacy.Expr[] args = term.getArgs(); - final List> ops = Stream.of(args).map(arg -> transform(arg, model, vars)) - .collect(toImmutableList()); - return function.apply(ops); - }); + private Tuple2>, Expr>> + exprFpUnaryOperator(final BiFunction, Expr> function) { + return Tuple2.of( + 2, + (term, model, vars) -> { + checkArgument(term.getArgs().length == 2, "Number of arguments must be two"); + final var roundingmode = getRoundingMode(term.getArgs()[0].toString()); + final Expr op2 = transform(term.getArgs()[1], model, vars); + return function.apply(roundingmode, op2); + }); + } + + private Tuple2>, Expr>> + exprFpBinaryOperator( + final TriFunction, Expr, Expr> function) { + return Tuple2.of( + 3, + (term, model, vars) -> { + checkArgument(term.getArgs().length == 3, "Number of arguments must be three"); + final var roundingmode = getRoundingMode(term.getArgs()[0].toString()); + final Expr op1 = transform(term.getArgs()[1], model, vars); + final Expr op2 = transform(term.getArgs()[2], model, vars); + return function.apply(roundingmode, op1, op2); + }); + } + + private Tuple2>, Expr>> + exprFpMultiaryOperator( + final BiFunction>, Expr> function) { + return Tuple2.of( + -1, + (term, model, vars) -> { + final var roundingmode = getRoundingMode(term.getArgs()[0].toString()); + final List> ops = + Arrays.stream(term.getArgs()) + .skip(1) + .map(arg -> transform(arg, model, vars)) + .collect(toImmutableList()); + return function.apply(roundingmode, ops); + }); + } + + private Tuple2>, Expr>> + exprFpLitUnaryOperator(final BiFunction> function) { + return Tuple2.of( + 3, + (term, model, vars) -> { + final BvLitExpr op1 = (BvLitExpr) transform(term.getArgs()[0], model, vars); + final IntLitExpr op2 = (IntLitExpr) transform(term.getArgs()[1], model, vars); + final IntLitExpr op3 = (IntLitExpr) transform(term.getArgs()[2], model, vars); + return function.apply( + op1, + FpType.of(op2.getValue().intValue(), op3.getValue().intValue() + 1)); + }); + } + + private Tuple2>, Expr>> + exprNullaryOperator(final Supplier> function) { + return Tuple2.of( + 0, + (term, model, vars) -> { + final com.microsoft.z3legacy.Expr[] args = term.getArgs(); + checkArgument(args.length == 0, "Number of arguments must be zero"); + return function.get(); + }); + } + + private Tuple2>, Expr>> + exprUnaryOperator(final UnaryOperator> function) { + return Tuple2.of( + 1, + (term, model, vars) -> { + final com.microsoft.z3legacy.Expr[] args = term.getArgs(); + checkArgument(args.length == 1, "Number of arguments must be one"); + final Expr op = transform(args[0], model, vars); + return function.apply(op); + }); + } + + private Tuple2>, Expr>> + exprBinaryOperator(final BinaryOperator> function) { + return Tuple2.of( + 2, + (term, model, vars) -> { + final com.microsoft.z3legacy.Expr[] args = term.getArgs(); + checkArgument(args.length == 2, "Number of arguments must be two"); + if (args[0].getSort().getSortKind().equals(Z3_sort_kind.Z3_DATATYPE_SORT)) { + // binary operator is on enum types + // if either arg is a literal, we need special handling to get its type + // (references' decl kind is Z3_OP_UNINTERPRETED, literals' decl kind is + // Z3_OP_DT_CONSTRUCTOR) + int litIndex = -1; + for (int i = 0; i < 2; i++) { + if (args[i].getFuncDecl() + .getDeclKind() + .equals(Z3_decl_kind.Z3_OP_DT_CONSTRUCTOR)) litIndex = i; + } + if (litIndex > -1) { + int refIndex = Math.abs(litIndex - 1); + final Expr refOp = transform(args[refIndex], model, vars); + final Expr litExpr = + transformEnumLit(args[litIndex], (EnumType) refOp.getType()); + return function.apply(refOp, litExpr); + } + } + final Expr op1 = transform(args[0], model, vars); + final Expr op2 = transform(args[1], model, vars); + return function.apply(op1, op2); + }); + } + + private Tuple2>, Expr>> + reference() { + return Tuple2.of( + 1, + (term, model, vars) -> { + final com.microsoft.z3legacy.Expr[] args = term.getArgs(); + checkArgument(args.length == 1, "Number of arguments must be one"); + final Expr op = transform(args[0], model, vars); + return Exprs.Reference(op, transformSort(term.getSort())); + }); + } + + private + Tuple2>, Expr>> + dereference() { + return Tuple2.of( + 3, + (term, model, vars) -> { + final com.microsoft.z3legacy.Expr[] args = term.getArgs(); + checkArgument(args.length == 3, "Number of arguments must be three"); + final Expr op1 = (Expr) transform(args[0], model, vars); + final Expr op2 = (Expr) transform(args[1], model, vars); + final Expr op3 = (Expr) transform(args[2], model, vars); + return Exprs.Dereference(op1, op2, transformSort(term.getSort())) + .withUniquenessExpr(op3); + }); + } + + private Tuple2>, Expr>> + exprTernaryOperator(final TernaryOperator> function) { + return Tuple2.of( + 3, + (term, model, vars) -> { + final com.microsoft.z3legacy.Expr[] args = term.getArgs(); + checkArgument(args.length == 3, "Number of arguments must be three"); + final Expr op1 = transform(args[0], model, vars); + final Expr op2 = transform(args[1], model, vars); + final Expr op3 = transform(args[2], model, vars); + return function.apply(op1, op2, op3); + }); + } + + private Tuple2>, Expr>> + exprMultiaryOperator(final Function>, Expr> function) { + return Tuple2.of( + -1, + (term, model, vars) -> { + final com.microsoft.z3legacy.Expr[] args = term.getArgs(); + final List> ops = + Stream.of(args) + .map(arg -> transform(arg, model, vars)) + .collect(toImmutableList()); + return function.apply(ops); + }); } private FpRoundingMode getRoundingMode(String s) { @@ -809,6 +936,4 @@ private FpRoundingMode getRoundingMode(String s) { default -> throw new Z3Exception("Unexpected value: " + s); }; } - } - diff --git a/subprojects/sts/sts-analysis/src/test/java/hu/bme/mit/theta/sts/analysis/StsMddCheckerTest.java b/subprojects/sts/sts-analysis/src/test/java/hu/bme/mit/theta/sts/analysis/StsMddCheckerTest.java index 929a01e400..347dcade0e 100644 --- a/subprojects/sts/sts-analysis/src/test/java/hu/bme/mit/theta/sts/analysis/StsMddCheckerTest.java +++ b/subprojects/sts/sts-analysis/src/test/java/hu/bme/mit/theta/sts/analysis/StsMddCheckerTest.java @@ -40,6 +40,7 @@ import java.io.FileInputStream; import java.util.Arrays; import java.util.Collection; +import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -104,6 +105,7 @@ public VarIndexing nextIndexing() { } }, sts.getProp(), + List.copyOf(sts.getVars()), solverPool, logger, IterationStrategy.GSAT); diff --git a/subprojects/xcfa/c2xcfa/src/main/java/hu/bme/mit/theta/c2xcfa/XcfaStatistics.kt b/subprojects/xcfa/c2xcfa/src/main/java/hu/bme/mit/theta/c2xcfa/XcfaStatistics.kt index 1b045b335c..0c5708d8cd 100644 --- a/subprojects/xcfa/c2xcfa/src/main/java/hu/bme/mit/theta/c2xcfa/XcfaStatistics.kt +++ b/subprojects/xcfa/c2xcfa/src/main/java/hu/bme/mit/theta/c2xcfa/XcfaStatistics.kt @@ -19,48 +19,47 @@ import hu.bme.mit.theta.xcfa.collectHavocs import hu.bme.mit.theta.xcfa.model.XCFA import hu.bme.mit.theta.xcfa.model.XcfaBuilder -data class XcfaStatistics( - val globalVars: Int, - val procedures: Collection -) +data class XcfaStatistics(val globalVars: Int, val procedures: Collection) data class XcfaProcedureStatistics( - val localVariables: Int, - val locations: Int, - val edges: Int, - val havocs: Int, - val cyclComplexity: Int, - val hasFinalLoc: Boolean, + val localVariables: Int, + val locations: Int, + val edges: Int, + val havocs: Int, + val cyclComplexity: Int, + val hasFinalLoc: Boolean, ) fun XCFA.getStatistics(): XcfaStatistics { - return XcfaStatistics( - globalVars = vars.size, - procedures = procedures.map { - XcfaProcedureStatistics( - localVariables = it.vars.size, - locations = it.locs.size, - edges = it.edges.size, - havocs = it.edges.map { it.label.collectHavocs().size }.reduce(Int::plus), - cyclComplexity = it.edges.size - it.locs.size + 2, - hasFinalLoc = it.finalLoc.isPresent - ) - } - ) + return XcfaStatistics( + globalVars = globalVars.size, + procedures = + procedures.map { + XcfaProcedureStatistics( + localVariables = it.vars.size, + locations = it.locs.size, + edges = it.edges.size, + havocs = it.edges.map { it.label.collectHavocs().size }.reduce(Int::plus), + cyclComplexity = it.edges.size - it.locs.size + 2, + hasFinalLoc = it.finalLoc.isPresent, + ) + }, + ) } fun XcfaBuilder.getStatistics(): XcfaStatistics { - return XcfaStatistics( - globalVars = this.getVars().size, - procedures = getProcedures().map { - XcfaProcedureStatistics( - localVariables = it.getVars().size, - locations = it.getLocs().size, - edges = it.getEdges().size, - havocs = it.getEdges().map { it.label.collectHavocs().size }.reduce(Int::plus), - cyclComplexity = it.getEdges().size - it.getLocs().size + 2, - hasFinalLoc = it.finalLoc.isPresent - ) - } - ) -} \ No newline at end of file + return XcfaStatistics( + globalVars = this.getVars().size, + procedures = + getProcedures().map { + XcfaProcedureStatistics( + localVariables = it.getVars().size, + locations = it.getLocs().size, + edges = it.getEdges().size, + havocs = it.getEdges().map { it.label.collectHavocs().size }.reduce(Int::plus), + cyclComplexity = it.getEdges().size - it.getLocs().size + 2, + hasFinalLoc = it.finalLoc.isPresent, + ) + }, + ) +} diff --git a/subprojects/xcfa/litmus2xcfa/src/test/java/hu/bme/mit/theta/fronted/litmus2xcfa/LitmusTest.java b/subprojects/xcfa/litmus2xcfa/src/test/java/hu/bme/mit/theta/fronted/litmus2xcfa/LitmusTest.java index f08f197803..32186db3f9 100644 --- a/subprojects/xcfa/litmus2xcfa/src/test/java/hu/bme/mit/theta/fronted/litmus2xcfa/LitmusTest.java +++ b/subprojects/xcfa/litmus2xcfa/src/test/java/hu/bme/mit/theta/fronted/litmus2xcfa/LitmusTest.java @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package hu.bme.mit.theta.fronted.litmus2xcfa; import hu.bme.mit.theta.core.type.Expr; @@ -22,17 +21,16 @@ import hu.bme.mit.theta.solver.z3legacy.Z3LegacySolverFactory; import hu.bme.mit.theta.xcfa.model.XCFA; import hu.bme.mit.theta.xcfa.model.XcfaProcedure; -import kotlin.Pair; -import org.junit.Assert; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; - import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; +import kotlin.Pair; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; @RunWith(Parameterized.class) public class LitmusTest { @@ -51,19 +49,19 @@ public class LitmusTest { @Parameterized.Parameter(4) public String mcmFilename; - @Parameterized.Parameters public static Collection data() { - return Arrays.asList(new Object[][]{ - {"/LB.litmus", 2, 2, List.of(11, 7), "/aarch64.cat"}, - }); + return Arrays.asList( + new Object[][] { + {"/LB.litmus", 2, 2, List.of(11, 7), "/aarch64.cat"}, + }); } @Test public void parse() throws IOException { final XCFA xcfa = LitmusInterpreter.getXcfa(getClass().getResourceAsStream(filepath)); - Assert.assertEquals(globalsNum, xcfa.getVars().size()); + Assert.assertEquals(globalsNum, xcfa.getGlobalVars().size()); Assert.assertEquals(threadNum, xcfa.getInitProcedures().size()); final List>>> processes = xcfa.getInitProcedures(); for (int i = 0; i < processes.size(); i++) { @@ -85,15 +83,33 @@ public void check() throws IOException { throw new RuntimeException(e); } -// final XcfaProcessMemEventProvider memEventProvider = new XcfaProcessMemEventProvider<>(processes.size()); -// final MultiprocLTS, XcfaProcessAction> multiprocLTS = new MultiprocLTS<>(processIds.stream().map(id -> Map.entry(id, new XcfaProcessLTS())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))); -// final MultiprocInitFunc, ExplPrec> multiprocInitFunc = new MultiprocInitFunc<>(processIds.stream().map(id -> Map.entry(id, new XcfaProcessInitFunc<>(processes.get(id*-1-1), ExplInitFunc.create(solver, True())))).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))); -// final MultiprocTransFunc, XcfaProcessAction, ExplPrec> multiprocTransFunc = new MultiprocTransFunc<>(processIds.stream().map(id -> Map.entry(id, new XcfaProcessTransFunc<>(ExplStmtTransFunc.create(solver, 0)))).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))); -// final XcfaProcessPartialOrd partialOrd = new XcfaProcessPartialOrd<>(ExplOrd.getInstance()); -// final MCM mcm = CatDslManager.createMCM(new File(getClass().getResource(mcmFilename).getFile())); -// final List initialWrites = xcfa.getvars().stream().filter(it -> xcfa.getInitValue(it).isPresent()).map(it -> new MemoryEvent.Write(memEventProvider.getVarId(it), it, null, Set.of(), null)).collect(Collectors.toList()); -// -// final MCMChecker, XcfaProcessAction, ExplPrec> mcmChecker = new MCMChecker<>(memEventProvider, multiprocLTS, multiprocInitFunc, multiprocTransFunc, processIds, initialWrites, partialOrd, ExplState.top(), solver, mcm, NullLogger.getInstance()); -// mcmChecker.check(ExplPrec.empty()); + // final XcfaProcessMemEventProvider memEventProvider = new + // XcfaProcessMemEventProvider<>(processes.size()); + // final MultiprocLTS, XcfaProcessAction> multiprocLTS = + // new MultiprocLTS<>(processIds.stream().map(id -> Map.entry(id, new + // XcfaProcessLTS())).collect(Collectors.toMap(Map.Entry::getKey, + // Map.Entry::getValue))); + // final MultiprocInitFunc, ExplPrec> multiprocInitFunc = + // new MultiprocInitFunc<>(processIds.stream().map(id -> Map.entry(id, new + // XcfaProcessInitFunc<>(processes.get(id*-1-1), ExplInitFunc.create(solver, + // True())))).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))); + // final MultiprocTransFunc, XcfaProcessAction, ExplPrec> + // multiprocTransFunc = new MultiprocTransFunc<>(processIds.stream().map(id -> Map.entry(id, + // new XcfaProcessTransFunc<>(ExplStmtTransFunc.create(solver, + // 0)))).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))); + // final XcfaProcessPartialOrd partialOrd = new + // XcfaProcessPartialOrd<>(ExplOrd.getInstance()); + // final MCM mcm = CatDslManager.createMCM(new + // File(getClass().getResource(mcmFilename).getFile())); + // final List initialWrites = xcfa.getvars().stream().filter(it -> + // xcfa.getInitValue(it).isPresent()).map(it -> new + // MemoryEvent.Write(memEventProvider.getVarId(it), it, null, Set.of(), + // null)).collect(Collectors.toList()); + // + // final MCMChecker, XcfaProcessAction, ExplPrec> + // mcmChecker = new MCMChecker<>(memEventProvider, multiprocLTS, multiprocInitFunc, + // multiprocTransFunc, processIds, initialWrites, partialOrd, ExplState.top(), solver, mcm, + // NullLogger.getInstance()); + // mcmChecker.check(ExplPrec.empty()); } } diff --git a/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/XcfaState.kt b/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/XcfaState.kt index c3f66a25fe..41b35eccc3 100644 --- a/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/XcfaState.kt +++ b/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/XcfaState.kt @@ -43,6 +43,18 @@ constructor( val bottom: Boolean = false, ) : ExprState { + constructor( + xcfa: XCFA, + loc: XcfaLocation, + state: S, + ) : this( + xcfa = xcfa, + processes = + mapOf(Pair(0, XcfaProcessState(locs = LinkedList(listOf(loc)), varLookup = LinkedList()))), + state, + mutexes = emptyMap(), + ) + override fun isBottom(): Boolean { return bottom || sGlobal.isBottom } diff --git a/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/XcfaToMonolithicExpr.kt b/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/XcfaToMonolithicExpr.kt index d73f5644f8..8af9d3ade3 100644 --- a/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/XcfaToMonolithicExpr.kt +++ b/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/XcfaToMonolithicExpr.kt @@ -26,21 +26,52 @@ import hu.bme.mit.theta.core.stmt.AssignStmt import hu.bme.mit.theta.core.stmt.AssumeStmt import hu.bme.mit.theta.core.stmt.NonDetStmt import hu.bme.mit.theta.core.stmt.SequenceStmt -import hu.bme.mit.theta.core.type.booltype.BoolExprs.And -import hu.bme.mit.theta.core.type.inttype.IntExprs -import hu.bme.mit.theta.core.type.inttype.IntExprs.Eq -import hu.bme.mit.theta.core.type.inttype.IntExprs.Neq +import hu.bme.mit.theta.core.type.Expr +import hu.bme.mit.theta.core.type.LitExpr +import hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Eq +import hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Neq +import hu.bme.mit.theta.core.type.booltype.BoolExprs.* +import hu.bme.mit.theta.core.type.booltype.BoolType +import hu.bme.mit.theta.core.type.bvtype.BvLitExpr +import hu.bme.mit.theta.core.type.bvtype.BvType +import hu.bme.mit.theta.core.type.fptype.FpExprs.FpAssign +import hu.bme.mit.theta.core.type.fptype.FpType +import hu.bme.mit.theta.core.type.inttype.IntExprs.Int import hu.bme.mit.theta.core.type.inttype.IntLitExpr +import hu.bme.mit.theta.core.type.inttype.IntType +import hu.bme.mit.theta.core.utils.BvUtils +import hu.bme.mit.theta.core.utils.FpUtils import hu.bme.mit.theta.core.utils.StmtUtils +import hu.bme.mit.theta.core.utils.TypeUtils.cast import hu.bme.mit.theta.core.utils.indexings.VarIndexingFactory +import hu.bme.mit.theta.frontend.ParseContext +import hu.bme.mit.theta.frontend.transformation.model.types.complex.integer.cint.CInt import hu.bme.mit.theta.xcfa.getFlatLabels -import hu.bme.mit.theta.xcfa.model.StmtLabel -import hu.bme.mit.theta.xcfa.model.XCFA -import hu.bme.mit.theta.xcfa.model.XcfaEdge -import hu.bme.mit.theta.xcfa.model.XcfaLocation +import hu.bme.mit.theta.xcfa.model.* +import java.math.BigInteger import java.util.* +import org.kframework.mpfr.BigFloat + +private val LitExpr<*>.value: Int + get() = + when (this) { + is IntLitExpr -> value.toInt() + is BvLitExpr -> BvUtils.neutralBvLitExprToBigInteger(this).toInt() + else -> error("Unknown integer type: $type") + } + +fun XCFA.toMonolithicExpr(parseContext: ParseContext, initValues: Boolean = false): MonolithicExpr { + val intType = CInt.getUnsignedInt(parseContext).smtType + + fun int(value: Int): Expr<*> = + when (intType) { + is IntType -> Int(value) + is BvType -> + BvUtils.bigIntegerToNeutralBvLitExpr(BigInteger.valueOf(value.toLong()), intType.size) + + else -> error("Unknown integer type: $intType") + } -fun XCFA.toMonolithicExpr(): MonolithicExpr { Preconditions.checkArgument(this.initProcedures.size == 1) val proc = this.initProcedures.stream().findFirst().orElse(null).first Preconditions.checkArgument( @@ -48,19 +79,26 @@ fun XCFA.toMonolithicExpr(): MonolithicExpr { ) Preconditions.checkArgument(proc.errorLoc.isPresent) - val map = mutableMapOf() + val locMap = mutableMapOf() for ((i, x) in proc.locs.withIndex()) { - map[x] = i + locMap[x] = i + } + val edgeMap = mutableMapOf() + for ((i, x) in proc.edges.withIndex()) { + edgeMap[x] = i } - val locVar = Decls.Var("__loc_", IntExprs.Int()) + val locVar = Decls.Var("__loc_", intType) + val edgeVar = Decls.Var("__edge_", intType) val tranList = proc.edges - .map { (source, target, label): XcfaEdge -> + .map { edge: XcfaEdge -> + val (source, target, label) = edge SequenceStmt.of( listOf( - AssumeStmt.of(Eq(locVar.ref, IntExprs.Int(map[source]!!))), + AssumeStmt.of(Eq(locVar.ref, int(locMap[source]!!))), label.toStmt(), - AssignStmt.of(locVar, IntExprs.Int(map[target]!!)), + AssignStmt.of(locVar, cast(int(locMap[target]!!), locVar.type)), + AssignStmt.of(edgeVar, cast(int(edgeMap[edge]!!), edgeVar.type)), ) ) } @@ -68,21 +106,56 @@ fun XCFA.toMonolithicExpr(): MonolithicExpr { val trans = NonDetStmt.of(tranList) val transUnfold = StmtUtils.toExpr(trans, VarIndexingFactory.indexing(0)) + val defaultValues = + if (initValues) + StmtUtils.getVars(trans).filter { !it.equals(locVar) and !it.equals(edgeVar) } + .map { + when (it.type) { + is IntType -> Eq(it.ref, int(0)) + is BoolType -> Eq(it.ref, Bool(false)) + is BvType -> + Eq( + it.ref, + BvUtils.bigIntegerToNeutralBvLitExpr(BigInteger.ZERO, (it.type as BvType).size), + ) + is FpType -> + FpAssign( + it.ref as Expr, + FpUtils.bigFloatToFpLitExpr( + BigFloat.zero((it.type as FpType).significand), + it.type as FpType, + ), + ) + else -> throw IllegalArgumentException("Unsupported type") + } + } + .toList() + .let { And(it) } + else True() + return MonolithicExpr( - initExpr = Eq(locVar.ref, IntExprs.Int(map[proc.initLoc]!!)), + initExpr = + And(Eq(locVar.ref, int(locMap[proc.initLoc]!!)), Eq(edgeVar.ref, int(-1)), defaultValues), transExpr = And(transUnfold.exprs), - propExpr = Neq(locVar.ref, IntExprs.Int(map[proc.errorLoc.get()]!!)), + propExpr = Neq(locVar.ref, int(locMap[proc.errorLoc.get()]!!)), transOffsetIndex = transUnfold.indexing, + vars = StmtUtils.getVars(trans).filter { !it.equals(locVar) and !it.equals(edgeVar) }.toList() + edgeVar + locVar, + valToState = { valToState(it) }, + biValToAction = { val1, val2 -> valToAction(val1, val2) }, + ctrlVars = listOf(locVar, edgeVar), ) } fun XCFA.valToAction(val1: Valuation, val2: Valuation): XcfaAction { - val val1Map = val1.toMap() val val2Map = val2.toMap() - var i = 0 - val map: MutableMap = HashMap() - for (x in this.procedures.first { it.name == "main" }.locs) { - map[x] = i++ + val proc = this.procedures.first { it.name == "main" } + val locMap = mutableMapOf() + for ((i, x) in proc.locs.withIndex()) { + locMap[x] = i + } + val edgeMap = mutableMapOf() + for ((i, x) in proc.edges.withIndex()) { + edgeMap[x] = i } return XcfaAction( pid = 0, @@ -91,39 +164,21 @@ fun XCFA.valToAction(val1: Valuation, val2: Valuation): XcfaAction { .first { it.name == "main" } .edges .first { edge -> - map[edge.source] == - (val1Map[val1Map.keys.first { it.name == "__loc_" }] as IntLitExpr).value.toInt() && - map[edge.target] == - (val2Map[val2Map.keys.first { it.name == "__loc_" }] as IntLitExpr).value.toInt() + edgeMap[edge] == (val2Map[val2Map.keys.first { it.name == "__edge_" }]?.value ?: -1) }, ) } fun XCFA.valToState(val1: Valuation): XcfaState> { val valMap = val1.toMap() - var i = 0 - val map: MutableMap = HashMap() - for (x in this.procedures.first { it.name == "main" }.locs) { - map[i++] = x + val proc = this.procedures.first { it.name == "main" } + val locMap = mutableMapOf() + for ((i, x) in proc.locs.withIndex()) { + locMap[i] = x } return XcfaState( - xcfa = this, - processes = - mapOf( - Pair( - 0, - XcfaProcessState( - locs = - LinkedList( - listOf( - map[ - (valMap[valMap.keys.first { it.name == "__loc_" }] as IntLitExpr).value.toInt()] - ) - ), - varLookup = LinkedList(), - ), - ) - ), + this, + locMap[(valMap[valMap.keys.first { it.name == "__loc_" }])?.value ?: -1]!!, PtrState( ExplState.of( ImmutableValuation.from( @@ -135,8 +190,5 @@ fun XCFA.valToState(val1: Valuation): XcfaState> { ) ) ), - mutexes = emptyMap(), - threadLookup = emptyMap(), - bottom = false, ) } diff --git a/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/oc/XcfaFurtherOptimizer.kt b/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/oc/XcfaFurtherOptimizer.kt index b62313c358..ea8748528d 100644 --- a/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/oc/XcfaFurtherOptimizer.kt +++ b/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/oc/XcfaFurtherOptimizer.kt @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package hu.bme.mit.theta.xcfa.analysis.oc import hu.bme.mit.theta.xcfa.model.XCFA @@ -23,25 +22,26 @@ import hu.bme.mit.theta.xcfa.passes.ProcedurePass import hu.bme.mit.theta.xcfa.passes.ProcedurePassManager internal fun XCFA.optimizeFurther(passes: List): XCFA { - if (passes.isEmpty()) return this - val passManager = ProcedurePassManager(passes) - val copy: XcfaProcedureBuilder.() -> XcfaProcedureBuilder = { - XcfaProcedureBuilder( - name = name, - manager = passManager, - params = getParams().toMutableList(), - vars = getVars().toMutableSet(), - locs = getLocs().toMutableSet(), - edges = getEdges().toMutableSet(), - metaData = metaData.toMutableMap() - ).also { it.copyMetaLocs(this) } - } + if (passes.isEmpty()) return this + val passManager = ProcedurePassManager(passes) + val copy: XcfaProcedureBuilder.() -> XcfaProcedureBuilder = { + XcfaProcedureBuilder( + name = name, + manager = passManager, + params = getParams().toMutableList(), + vars = getVars().toMutableSet(), + locs = getLocs().toMutableSet(), + edges = getEdges().toMutableSet(), + metaData = metaData.toMutableMap(), + ) + .also { it.copyMetaLocs(this) } + } - val builder = XcfaBuilder(name, vars.toMutableSet()) - procedureBuilders.forEach { builder.addProcedure(it.copy()) } - initProcedureBuilders.forEach { (proc, params) -> - val initProc = builder.getProcedures().find { it.name == proc.name } ?: proc.copy() - builder.addEntryPoint(initProc, params) - } - return builder.build() -} \ No newline at end of file + val builder = XcfaBuilder(name, globalVars.toMutableSet()) + procedureBuilders.forEach { builder.addProcedure(it.copy()) } + initProcedureBuilders.forEach { (proc, params) -> + val initProc = builder.getProcedures().find { it.name == proc.name } ?: proc.copy() + builder.addEntryPoint(initProc, params) + } + return builder.build() +} diff --git a/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/oc/XcfaOcChecker.kt b/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/oc/XcfaOcChecker.kt index c9cb77d55d..5c219c6efb 100644 --- a/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/oc/XcfaOcChecker.kt +++ b/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/oc/XcfaOcChecker.kt @@ -523,7 +523,7 @@ class XcfaOcChecker( private fun VarDecl.threadVar(pid: Int): VarDecl = if ( - this !== memoryDecl && xcfa.vars.none { it.wrappedVar == this && !it.threadLocal } + this !== memoryDecl && xcfa.globalVars.none { it.wrappedVar == this && !it.threadLocal } ) { // if not global var cast( localVars diff --git a/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/por/XcfaSporLts.kt b/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/por/XcfaSporLts.kt index eff8485d1f..3e0e45d2ea 100644 --- a/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/por/XcfaSporLts.kt +++ b/subprojects/xcfa/xcfa-analysis/src/main/java/hu/bme/mit/theta/xcfa/analysis/por/XcfaSporLts.kt @@ -40,351 +40,379 @@ import java.util.function.Predicate import kotlin.random.Random /** - * LTS with a POR (Partial Order Reduction) algorithm applied as a filter when returning enabled actions. - * The algorithm is similar to the static source-set based POR algorithm described in the following paper: - * Abdulla, P., Aronis, S., Jonsson, B., Sagonas, K. (2017): - * Comparing source sets and persistent sets for partial order reduction + * LTS with a POR (Partial Order Reduction) algorithm applied as a filter when returning enabled + * actions. The algorithm is similar to the static source-set based POR algorithm described in the + * following paper: Abdulla, P., Aronis, S., Jonsson, B., Sagonas, K. (2017): Comparing source sets + * and persistent sets for partial order reduction * * @param xcfa the XCFA of the verified program */ -open class XcfaSporLts(protected val xcfa: XCFA) : LTS>, XcfaAction> { +open class XcfaSporLts(protected val xcfa: XCFA) : + LTS>, XcfaAction> { - companion object { + companion object { - private val dependencySolver: Solver = Z3SolverFactory.getInstance().createSolver() - var random: Random = Random.Default - } + private val dependencySolver: Solver = Z3SolverFactory.getInstance().createSolver() + var random: Random = Random.Default + } - protected var simpleXcfaLts = getXcfaLts() + protected var simpleXcfaLts = getXcfaLts() - /* CACHE COLLECTIONS */ + /* CACHE COLLECTIONS */ - /** - * Global variables used by an edge. - */ - private val usedVars: MutableMap>> = mutableMapOf() + /** Global variables used by an edge. */ + private val usedVars: MutableMap>> = mutableMapOf() - /** - * Global variables that are used by the key edge or by edges reachable from the - * current state via a given edge. - */ - private val influencedVars: MutableMap>> = mutableMapOf() + /** + * Global variables that are used by the key edge or by edges reachable from the current state via + * a given edge. + */ + private val influencedVars: MutableMap>> = mutableMapOf() - /** - * Backward edges in the CFA (an edge of a loop). - */ - private val backwardEdges: MutableSet> = mutableSetOf() + /** Backward edges in the CFA (an edge of a loop). */ + private val backwardEdges: MutableSet> = mutableSetOf() - /** - * Variables associated to mutex identifiers. TODO: this should really be solved by storing VarDecls in FenceLabel. - */ - protected val fenceVars: MutableMap> = mutableMapOf() - private val String.fenceVar - get() = fenceVars.getOrPut("") { Decls.Var(if (this == "") "__THETA_atomic_mutex_" else this, Bool()) } + /** + * Variables associated to mutex identifiers. TODO: this should really be solved by storing + * VarDecls in FenceLabel. + */ + protected val fenceVars: MutableMap> = mutableMapOf() + private val String.fenceVar + get() = + fenceVars.getOrPut("") { + Decls.Var(if (this == "") "__THETA_atomic_mutex_" else this, Bool()) + } - init { - collectBackwardEdges() - } + init { + collectBackwardEdges() + } - /** - * Returns the enabled actions in the ARG from the given state filtered with a POR algorithm. - * - * @param state the state whose enabled actions we would like to know - * @return the enabled actions - */ - override fun getEnabledActionsFor(state: XcfaState>): Set = - getEnabledActionsFor(state, simpleXcfaLts.getEnabledActionsFor(state)) + /** + * Returns the enabled actions in the ARG from the given state filtered with a POR algorithm. + * + * @param state the state whose enabled actions we would like to know + * @return the enabled actions + */ + override fun getEnabledActionsFor( + state: XcfaState> + ): Set = getEnabledActionsFor(state, simpleXcfaLts.getEnabledActionsFor(state)) - /** - * Calculates the source set starting from every (or some of the) enabled transition; the minimal source set is returned. - */ - protected open fun getEnabledActionsFor( - state: XcfaState>, allEnabledActions: Collection - ): Set { - var minimalSourceSet = setOf() - val sourceSetFirstActions = getSourceSetFirstActions(state, allEnabledActions) - for (firstActions in sourceSetFirstActions) { - val sourceSet = calculateSourceSet(state, allEnabledActions, firstActions) - if (minimalSourceSet.isEmpty() || sourceSet.size < minimalSourceSet.size) { - minimalSourceSet = sourceSet - } - } - return minimalSourceSet + /** + * Calculates the source set starting from every (or some of the) enabled transition; the minimal + * source set is returned. + */ + protected open fun getEnabledActionsFor( + state: XcfaState>, + allEnabledActions: Collection, + ): Set { + var minimalSourceSet = setOf() + val sourceSetFirstActions = getSourceSetFirstActions(state, allEnabledActions) + for (firstActions in sourceSetFirstActions) { + val sourceSet = calculateSourceSet(state, allEnabledActions, firstActions) + if (minimalSourceSet.isEmpty() || sourceSet.size < minimalSourceSet.size) { + minimalSourceSet = sourceSet + } } + return minimalSourceSet + } - /** - * Returns the possible starting actions of a source set. - * - * @param allEnabledActions the enabled actions in the present state - * @return the possible starting actions of a source set - */ - protected fun getSourceSetFirstActions( - state: XcfaState>, - allEnabledActions: Collection - ): Collection> { - val enabledActionsByProcess = allEnabledActions.groupBy(XcfaAction::pid) - val enabledProcesses = enabledActionsByProcess.keys.toList().shuffled(random) - return enabledProcesses.map { pid -> - val firstProcesses = mutableSetOf(pid) - checkMutexBlocks(state, pid, firstProcesses, enabledActionsByProcess) - firstProcesses.flatMap { enabledActionsByProcess[it] ?: emptyList() } - } + /** + * Returns the possible starting actions of a source set. + * + * @param allEnabledActions the enabled actions in the present state + * @return the possible starting actions of a source set + */ + protected fun getSourceSetFirstActions( + state: XcfaState>, + allEnabledActions: Collection, + ): Collection> { + val enabledActionsByProcess = allEnabledActions.groupBy(XcfaAction::pid) + val enabledProcesses = enabledActionsByProcess.keys.toList().shuffled(random) + return enabledProcesses.map { pid -> + val firstProcesses = mutableSetOf(pid) + checkMutexBlocks(state, pid, firstProcesses, enabledActionsByProcess) + firstProcesses.flatMap { enabledActionsByProcess[it] ?: emptyList() } } + } - /** - * Checks whether a process is blocked by a mutex and if it is, it adds the process that blocks it to the set of - * first processes. - * - * @param state the current state - * @param pid the process whose blocking is to be checked - * @param firstProcesses the set of first processes - * @param enabledActionsByProcess the enabled actions grouped by processes - * @return the set of first processes - */ - private fun checkMutexBlocks( - state: XcfaState>, pid: Int, firstProcesses: MutableSet, - enabledActionsByProcess: Map> - ) { - val processState = checkNotNull(state.processes[pid]) - if (!processState.paramsInitialized) return - val disabledOutEdges = processState.locs.peek().outgoingEdges.filter { edge -> - enabledActionsByProcess[pid]?.none { action -> action.target == edge.target } ?: true - } - disabledOutEdges.forEach { edge -> - edge.getFlatLabels().filterIsInstance().forEach { fence -> - fence.labels.filter { it.startsWith("mutex_lock") }.forEach { lock -> - val mutex = lock.substringAfter('(').substringBefore(')') - state.mutexes[mutex]?.let { pid2 -> - if (pid2 !in firstProcesses) { - firstProcesses.add(pid2) - checkMutexBlocks(state, pid2, firstProcesses, enabledActionsByProcess) - } - } - } + /** + * Checks whether a process is blocked by a mutex and if it is, it adds the process that blocks it + * to the set of first processes. + * + * @param state the current state + * @param pid the process whose blocking is to be checked + * @param firstProcesses the set of first processes + * @param enabledActionsByProcess the enabled actions grouped by processes + * @return the set of first processes + */ + private fun checkMutexBlocks( + state: XcfaState>, + pid: Int, + firstProcesses: MutableSet, + enabledActionsByProcess: Map>, + ) { + val processState = checkNotNull(state.processes[pid]) + if (!processState.paramsInitialized) return + val disabledOutEdges = + processState.locs.peek().outgoingEdges.filter { edge -> + enabledActionsByProcess[pid]?.none { action -> action.target == edge.target } ?: true + } + disabledOutEdges.forEach { edge -> + edge.getFlatLabels().filterIsInstance().forEach { fence -> + fence.labels + .filter { it.startsWith("mutex_lock") } + .forEach { lock -> + val mutex = lock.substringAfter('(').substringBefore(')') + state.mutexes[mutex]?.let { pid2 -> + if (pid2 !in firstProcesses) { + firstProcesses.add(pid2) + checkMutexBlocks(state, pid2, firstProcesses, enabledActionsByProcess) + } } - } + } + } } + } - /** - * Calculates a source set of enabled actions starting from a particular action. - * - * @param enabledActions the enabled actions in the present state - * @param firstActions the actions that will be added to the source set as the first actions - * @return a source set of enabled actions - */ - private fun calculateSourceSet( - state: XcfaState>, - enabledActions: Collection, - firstActions: Collection - ): Set { - if (firstActions.any { it.isBackward }) { - return enabledActions.toSet() - } - val sourceSet = firstActions.toMutableSet() - val otherActions = - (enabledActions.toMutableSet() subtract sourceSet).toMutableSet() // actions not in the source set - var addedNewAction = true - while (addedNewAction) { - addedNewAction = false - val actionsToRemove = mutableSetOf() - for (action in otherActions) { - // for every action that is not in the source set it is checked whether it should be added to the source set - // (because it is dependent with an action already in the source set) - if (sourceSet.any { dependent(state, it, action) }) { - if (action.isBackward) { - return enabledActions.toSet() // see POR algorithm for the reason of handling backward edges this way - } - sourceSet.add(action) - actionsToRemove.add(action) - addedNewAction = true - } - } - actionsToRemove.forEach(otherActions::remove) + /** + * Calculates a source set of enabled actions starting from a particular action. + * + * @param enabledActions the enabled actions in the present state + * @param firstActions the actions that will be added to the source set as the first actions + * @return a source set of enabled actions + */ + private fun calculateSourceSet( + state: XcfaState>, + enabledActions: Collection, + firstActions: Collection, + ): Set { + if (firstActions.any { it.isBackward }) { + return enabledActions.toSet() + } + val sourceSet = firstActions.toMutableSet() + val otherActions = + (enabledActions.toMutableSet() subtract sourceSet) + .toMutableSet() // actions not in the source set + var addedNewAction = true + while (addedNewAction) { + addedNewAction = false + val actionsToRemove = mutableSetOf() + for (action in otherActions) { + // for every action that is not in the source set it is checked whether it should be added + // to the source set + // (because it is dependent with an action already in the source set) + if (sourceSet.any { dependent(state, it, action) }) { + if (action.isBackward) { + return enabledActions + .toSet() // see POR algorithm for the reason of handling backward edges this way + } + sourceSet.add(action) + actionsToRemove.add(action) + addedNewAction = true } - return sourceSet + } + actionsToRemove.forEach(otherActions::remove) } + return sourceSet + } - /** - * Determines whether an action is dependent with another one (based on the notions introduced for the POR - * algorithm) already in the source set. - * - * @param sourceSetAction the action in the source set - * @param action the other action (not in the source set) - * @return true, if the two actions are dependent in the context of source sets - */ - private fun dependent( - state: XcfaState>, sourceSetAction: XcfaAction, action: XcfaAction - ): Boolean { - if (sourceSetAction.pid == action.pid) return true + /** + * Determines whether an action is dependent with another one (based on the notions introduced for + * the POR algorithm) already in the source set. + * + * @param sourceSetAction the action in the source set + * @param action the other action (not in the source set) + * @return true, if the two actions are dependent in the context of source sets + */ + private fun dependent( + state: XcfaState>, + sourceSetAction: XcfaAction, + action: XcfaAction, + ): Boolean { + if (sourceSetAction.pid == action.pid) return true - val sourceSetActionVars = getCachedUsedVars(getEdge(sourceSetAction)) - val influencedVars = getInfluencedVars(getEdge(action)) - if ((influencedVars intersect sourceSetActionVars).isNotEmpty()) return true + val sourceSetActionVars = getCachedUsedVars(getEdge(sourceSetAction)) + val influencedVars = getInfluencedVars(getEdge(action)) + if ((influencedVars intersect sourceSetActionVars).isNotEmpty()) return true - return indirectlyDependent(state, sourceSetAction, sourceSetActionVars, influencedVars) - } + return indirectlyDependent(state, sourceSetAction, sourceSetActionVars, influencedVars) + } - protected fun indirectlyDependent( - state: XcfaState>, sourceSetAction: XcfaAction, - sourceSetActionVars: Set>, influencedVars: Set> - ): Boolean { - val sourceSetActionMemLocs = sourceSetActionVars.pointsTo(xcfa) - val influencedMemLocs = influencedVars.pointsTo(xcfa) - val intersection = sourceSetActionMemLocs intersect influencedMemLocs - if (intersection.isEmpty()) return false // they cannot point to the same memory location even based on static info + protected fun indirectlyDependent( + state: XcfaState>, + sourceSetAction: XcfaAction, + sourceSetActionVars: Set>, + influencedVars: Set>, + ): Boolean { + val sourceSetActionMemLocs = sourceSetActionVars.pointsTo(xcfa) + val influencedMemLocs = influencedVars.pointsTo(xcfa) + val intersection = sourceSetActionMemLocs intersect influencedMemLocs + if (intersection.isEmpty()) + return false // they cannot point to the same memory location even based on static info - val derefs = sourceSetAction.label.dereferences.map { it.array } - var expr: Expr = Or(intersection.flatMap { memLoc -> derefs.map { Eq(memLoc, it) } }) - expr = (state.sGlobal.innerState as? ExplState)?.let { s -> - ExprUtils.simplify(expr, s.`val`) - } ?: ExprUtils.simplify(expr) - if (expr == True()) return true - return WithPushPop(dependencySolver).use { - dependencySolver.add(PathUtils.unfold(state.sGlobal.toExpr(), 0)) - dependencySolver.add( - PathUtils.unfold(expr, 0) - ) // is it always given that the state will produce 0 indexed constants? - dependencySolver.check().isSat // two pointers may point to the same memory location - } + val derefs = sourceSetAction.label.dereferences.map { it.array } + var expr: Expr = Or(intersection.flatMap { memLoc -> derefs.map { Eq(memLoc, it) } }) + expr = + (state.sGlobal.innerState as? ExplState)?.let { s -> ExprUtils.simplify(expr, s.`val`) } + ?: ExprUtils.simplify(expr) + if (expr == True()) return true + return WithPushPop(dependencySolver).use { + dependencySolver.add(PathUtils.unfold(state.sGlobal.toExpr(), 0)) + dependencySolver.add( + PathUtils.unfold(expr, 0) + ) // is it always given that the state will produce 0 indexed constants? + dependencySolver.check().isSat // two pointers may point to the same memory location } + } - /** - * Returns the global variables that an edge uses (it is present in one of its labels). - * Mutex variables are also considered to avoid running into a deadlock and stop exploration. - * - * @param edge whose global variables are to be returned - * @return the set of used global variables - */ - private fun getDirectlyUsedVars(edge: XcfaEdge): Set> { - val globalVars = xcfa.vars.map(XcfaGlobalVar::wrappedVar) - return edge.getFlatLabels().flatMap { label -> - label.collectVars().filter { it in globalVars } union - ((label as? FenceLabel)?.labels - ?.filter { it.startsWith("start_cond_wait") || it.startsWith("cond_signal") } - ?.map { it.substringAfter("(").substringBefore(")").split(",")[0] } - ?.map { it.fenceVar } ?: listOf()) - }.toSet() union edge.acquiredEmbeddedFenceVars.let { mutexes -> - if (mutexes.size <= 1) setOf() else mutexes.map { it.fenceVar } - } - } + /** + * Returns the global variables that an edge uses (it is present in one of its labels). Mutex + * variables are also considered to avoid running into a deadlock and stop exploration. + * + * @param edge whose global variables are to be returned + * @return the set of used global variables + */ + private fun getDirectlyUsedVars(edge: XcfaEdge): Set> { + val globalVars = xcfa.globalVars.map(XcfaGlobalVar::wrappedVar) + return edge + .getFlatLabels() + .flatMap { label -> + label.collectVars().filter { it in globalVars } union + ((label as? FenceLabel) + ?.labels + ?.filter { it.startsWith("start_cond_wait") || it.startsWith("cond_signal") } + ?.map { it.substringAfter("(").substringBefore(")").split(",")[0] } + ?.map { it.fenceVar } ?: listOf()) + } + .toSet() union + edge.acquiredEmbeddedFenceVars.let { mutexes -> + if (mutexes.size <= 1) setOf() else mutexes.map { it.fenceVar } + } + } - /** - * Returns the global variables that an edge uses or if it is the start of an atomic block the global variables - * that are used in the atomic block. The result is cached. - * - * @param edge whose global variables are to be returned - * @return the set of directly or indirectly used global variables - */ - protected fun getCachedUsedVars(edge: XcfaEdge): Set> { - if (edge in usedVars) return usedVars[edge]!! - val flatLabels = edge.getFlatLabels() - val mutexes = flatLabels.filterIsInstance().flatMap { it.acquiredMutexes }.toMutableSet() - val vars = if (mutexes.isEmpty()) { - getDirectlyUsedVars(edge) - } else { - getVarsWithBFS(edge) { it.mutexOperations(mutexes) }.toSet() - } - usedVars[edge] = vars - return vars - } + /** + * Returns the global variables that an edge uses or if it is the start of an atomic block the + * global variables that are used in the atomic block. The result is cached. + * + * @param edge whose global variables are to be returned + * @return the set of directly or indirectly used global variables + */ + protected fun getCachedUsedVars(edge: XcfaEdge): Set> { + if (edge in usedVars) return usedVars[edge]!! + val flatLabels = edge.getFlatLabels() + val mutexes = + flatLabels.filterIsInstance().flatMap { it.acquiredMutexes }.toMutableSet() + val vars = + if (mutexes.isEmpty()) { + getDirectlyUsedVars(edge) + } else { + getVarsWithBFS(edge) { it.mutexOperations(mutexes) }.toSet() + } + usedVars[edge] = vars + return vars + } - /** - * Returns the global variables used by the given edge or by edges that are reachable - * via the given edge ("influenced vars"). - * - * @param edge whose successor edges' global variables are to be returned. - * @return the set of influenced global variables - */ - protected fun getInfluencedVars(edge: XcfaEdge): Set> { - if (edge in influencedVars) return influencedVars[edge]!! - val vars = getVarsWithBFS(edge) { true } - influencedVars[edge] = vars - return vars - } + /** + * Returns the global variables used by the given edge or by edges that are reachable via the + * given edge ("influenced vars"). + * + * @param edge whose successor edges' global variables are to be returned. + * @return the set of influenced global variables + */ + protected fun getInfluencedVars(edge: XcfaEdge): Set> { + if (edge in influencedVars) return influencedVars[edge]!! + val vars = getVarsWithBFS(edge) { true } + influencedVars[edge] = vars + return vars + } - /** - * Returns global variables encountered in a search starting from a given edge. - * - * @param startEdge the start point of the search - * @param goFurther the predicate that tells whether more edges have to be explored through this edge - * @return the set of encountered global variables - */ - private fun getVarsWithBFS(startEdge: XcfaEdge, goFurther: Predicate): Set> { - val vars = mutableSetOf>() - val exploredEdges = mutableListOf() - val edgesToExplore = mutableListOf() - edgesToExplore.add(startEdge) - while (edgesToExplore.isNotEmpty()) { - val exploring = edgesToExplore.removeFirst() - vars.addAll(getDirectlyUsedVars(exploring)) - if (goFurther.test(exploring)) { - val successiveEdges = getSuccessiveEdges(exploring) - for (newEdge in successiveEdges) { - if (newEdge !in exploredEdges) { - edgesToExplore.add(newEdge) - } - } - } - exploredEdges.add(exploring) + /** + * Returns global variables encountered in a search starting from a given edge. + * + * @param startEdge the start point of the search + * @param goFurther the predicate that tells whether more edges have to be explored through this + * edge + * @return the set of encountered global variables + */ + private fun getVarsWithBFS(startEdge: XcfaEdge, goFurther: Predicate): Set> { + val vars = mutableSetOf>() + val exploredEdges = mutableListOf() + val edgesToExplore = mutableListOf() + edgesToExplore.add(startEdge) + while (edgesToExplore.isNotEmpty()) { + val exploring = edgesToExplore.removeFirst() + vars.addAll(getDirectlyUsedVars(exploring)) + if (goFurther.test(exploring)) { + val successiveEdges = getSuccessiveEdges(exploring) + for (newEdge in successiveEdges) { + if (newEdge !in exploredEdges) { + edgesToExplore.add(newEdge) + } } - return vars + } + exploredEdges.add(exploring) } + return vars + } - /** - * Returns the xcfa edge of the given action. - * - * @param action the action whose edge is to be returned - * @return the edge of the action - */ - protected open fun getEdge(action: XcfaAction) = action.edge + /** + * Returns the xcfa edge of the given action. + * + * @param action the action whose edge is to be returned + * @return the edge of the action + */ + protected open fun getEdge(action: XcfaAction) = action.edge - /** - * Returns the outgoing edges of the target of the given edge. For start threads, the first edges of the started - * procedures are also included. - * - * @param edge the edge whose target's outgoing edges are to be returned - * @return the outgoing edges of the target of the edge - */ - private fun getSuccessiveEdges(edge: XcfaEdge): Set { - val outgoingEdges = edge.target.outgoingEdges.toMutableSet() - val startThreads = edge.getFlatLabels().filterIsInstance().toList() - if (startThreads.isNotEmpty()) { // for start thread labels, the thread procedure must be explored, too! - startThreads.forEach { startThread -> - outgoingEdges.addAll(xcfa.procedures.first { it.name == startThread.name }.initLoc.outgoingEdges) - } - } - return outgoingEdges + /** + * Returns the outgoing edges of the target of the given edge. For start threads, the first edges + * of the started procedures are also included. + * + * @param edge the edge whose target's outgoing edges are to be returned + * @return the outgoing edges of the target of the edge + */ + private fun getSuccessiveEdges(edge: XcfaEdge): Set { + val outgoingEdges = edge.target.outgoingEdges.toMutableSet() + val startThreads = edge.getFlatLabels().filterIsInstance().toList() + if ( + startThreads.isNotEmpty() + ) { // for start thread labels, the thread procedure must be explored, too! + startThreads.forEach { startThread -> + outgoingEdges.addAll( + xcfa.procedures.first { it.name == startThread.name }.initLoc.outgoingEdges + ) + } } + return outgoingEdges + } - /** - * Determines whether this action is a backward action. - * - * @return true, if the action is a backward action - */ - protected open val XcfaAction.isBackward: Boolean get() = backwardEdges.any { it.first == source && it.second == target } + /** + * Determines whether this action is a backward action. + * + * @return true, if the action is a backward action + */ + protected open val XcfaAction.isBackward: Boolean + get() = backwardEdges.any { it.first == source && it.second == target } - /** - * Collects backward edges of the given XCFA. - */ - private fun collectBackwardEdges() { - for (procedure in xcfa.procedures) { - // DFS for every procedure of the XCFA to discover backward edges - val visitedLocations = mutableSetOf() - val stack = Stack() + /** Collects backward edges of the given XCFA. */ + private fun collectBackwardEdges() { + for (procedure in xcfa.procedures) { + // DFS for every procedure of the XCFA to discover backward edges + val visitedLocations = mutableSetOf() + val stack = Stack() - stack.push(procedure.initLoc) // start from the initial location of the procedure - while (stack.isNotEmpty()) { - val visiting = stack.pop() - visitedLocations.add(visiting) - for (outgoingEdge in visiting.outgoingEdges) { - val target = outgoingEdge.target - if (target in visitedLocations) { // backward edge - backwardEdges.add(outgoingEdge.source to outgoingEdge.target) - } else { - stack.push(target) - } - } - } + stack.push(procedure.initLoc) // start from the initial location of the procedure + while (stack.isNotEmpty()) { + val visiting = stack.pop() + visitedLocations.add(visiting) + for (outgoingEdge in visiting.outgoingEdges) { + val target = outgoingEdge.target + if (target in visitedLocations) { // backward edge + backwardEdges.add(outgoingEdge.source to outgoingEdge.target) + } else { + stack.push(target) + } } + } } -} \ No newline at end of file + } +} diff --git a/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/checkers/ConfigToBoundedChecker.kt b/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/checkers/ConfigToBoundedChecker.kt index 5825b0d922..1f928b4ab3 100644 --- a/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/checkers/ConfigToBoundedChecker.kt +++ b/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/checkers/ConfigToBoundedChecker.kt @@ -18,9 +18,14 @@ package hu.bme.mit.theta.xcfa.cli.checkers import hu.bme.mit.theta.analysis.Trace import hu.bme.mit.theta.analysis.algorithm.EmptyProof import hu.bme.mit.theta.analysis.algorithm.SafetyChecker -import hu.bme.mit.theta.analysis.algorithm.bounded.BoundedChecker +import hu.bme.mit.theta.analysis.algorithm.SafetyResult +import hu.bme.mit.theta.analysis.algorithm.bounded.* +import hu.bme.mit.theta.analysis.pred.PredPrec +import hu.bme.mit.theta.analysis.pred.PredState +import hu.bme.mit.theta.analysis.ptr.PtrPrec import hu.bme.mit.theta.analysis.ptr.PtrState import hu.bme.mit.theta.common.logging.Logger +import hu.bme.mit.theta.frontend.ParseContext import hu.bme.mit.theta.graphsolver.patterns.constraints.MCM import hu.bme.mit.theta.solver.SolverFactory import hu.bme.mit.theta.xcfa.analysis.* @@ -28,35 +33,87 @@ import hu.bme.mit.theta.xcfa.cli.params.BoundedConfig import hu.bme.mit.theta.xcfa.cli.params.XcfaConfig import hu.bme.mit.theta.xcfa.cli.utils.getSolver import hu.bme.mit.theta.xcfa.model.XCFA +import java.util.* fun getBoundedChecker( xcfa: XCFA, mcm: MCM, + parseContext: ParseContext, config: XcfaConfig<*, *>, logger: Logger, ): SafetyChecker>, XcfaAction>, XcfaPrec<*>> { val boundedConfig = config.backendConfig.specConfig as BoundedConfig - return BoundedChecker( - monolithicExpr = xcfa.toMonolithicExpr(), - bmcSolver = - tryGetSolver(boundedConfig.bmcConfig.bmcSolver, boundedConfig.bmcConfig.validateBMCSolver) - ?.createSolver(), - bmcEnabled = { !boundedConfig.bmcConfig.disable }, - lfPathOnly = { !boundedConfig.bmcConfig.nonLfPath }, - itpSolver = - tryGetSolver(boundedConfig.itpConfig.itpSolver, boundedConfig.itpConfig.validateItpSolver) - ?.createItpSolver(), - imcEnabled = { !boundedConfig.itpConfig.disable }, - indSolver = - tryGetSolver(boundedConfig.indConfig.indSolver, boundedConfig.indConfig.validateIndSolver) - ?.createSolver(), - kindEnabled = { !boundedConfig.indConfig.disable }, - valToState = { xcfa.valToState(it) }, - biValToAction = { val1, val2 -> xcfa.valToAction(val1, val2) }, - logger = logger, - ) + val monolithicExpr = + xcfa.toMonolithicExpr(parseContext).let { + if (boundedConfig.reversed) it.createReversed() else it + } + + val baseChecker = { monolithicExpr: MonolithicExpr -> + BoundedChecker( + monolithicExpr = monolithicExpr, + bmcSolver = + tryGetSolver(boundedConfig.bmcConfig.bmcSolver, boundedConfig.bmcConfig.validateBMCSolver) + ?.createSolver(), + bmcEnabled = { !boundedConfig.bmcConfig.disable }, + lfPathOnly = { !boundedConfig.bmcConfig.nonLfPath }, + itpSolver = + tryGetSolver(boundedConfig.itpConfig.itpSolver, boundedConfig.itpConfig.validateItpSolver) + ?.createItpSolver(), + imcEnabled = { !boundedConfig.itpConfig.disable }, + indSolver = + tryGetSolver(boundedConfig.indConfig.indSolver, boundedConfig.indConfig.validateIndSolver) + ?.createSolver(), + kindEnabled = { !boundedConfig.indConfig.disable }, + valToState = monolithicExpr.valToState, + biValToAction = monolithicExpr.biValToAction, + logger = logger, + ) + } + + val checker = + if (boundedConfig.cegar) { + val cegarChecker = + MonolithicExprCegarChecker( + monolithicExpr, + baseChecker, + logger, + getSolver(boundedConfig.bmcConfig.bmcSolver, false), + ) + object : + SafetyChecker< + EmptyProof, + Trace>, XcfaAction>, + XcfaPrec>, + > { + override fun check( + initPrec: XcfaPrec> + ): SafetyResult>, XcfaAction>> { + val result = + cegarChecker.check(initPrec.p.innerPrec) // states are PredState, actions are XcfaAction + if (result.isUnsafe) { + val cex = result.asUnsafe().cex as Trace + val locs = + (0 until cex.length()).map { i -> cex.actions[i].source } + + cex.actions[cex.length() - 1].target + val states = locs.mapIndexed { i, it -> XcfaState(xcfa, it, PtrState(cex.states[i])) } + return SafetyResult.unsafe(Trace.of(states, cex.actions), result.proof) + } else + return result + as SafetyResult>, XcfaAction>> + } + + override fun check(): + SafetyResult>, XcfaAction>> { + return check(boundedConfig.initPrec.predPrec(xcfa)) + } + } + } else { + baseChecker(monolithicExpr) + } + + return checker as SafetyChecker>, XcfaAction>, XcfaPrec<*>> } diff --git a/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/checkers/ConfigToChecker.kt b/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/checkers/ConfigToChecker.kt index 2861be5cb9..c6bc1153fe 100644 --- a/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/checkers/ConfigToChecker.kt +++ b/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/checkers/ConfigToChecker.kt @@ -37,17 +37,18 @@ fun getChecker( parseContext: ParseContext, logger: Logger, uniqueLogger: Logger, -): SafetyChecker<*, *, XcfaPrec<*>> = +): SafetyChecker<*, *, *> = if (config.backendConfig.inProcess) { InProcessChecker(xcfa, config, parseContext, logger) } else { when (config.backendConfig.backend) { Backend.CEGAR -> getCegarChecker(xcfa, mcm, config, logger) - Backend.BOUNDED -> getBoundedChecker(xcfa, mcm, config, logger) + Backend.BOUNDED -> getBoundedChecker(xcfa, mcm, parseContext, config, logger) Backend.OC -> getOcChecker(xcfa, mcm, config, logger) Backend.LAZY -> TODO() Backend.PORTFOLIO -> getPortfolioChecker(xcfa, mcm, config, parseContext, logger, uniqueLogger) + Backend.MDD -> getMddChecker(xcfa, mcm, parseContext, config, logger) Backend.NONE -> SafetyChecker< ARG>, XcfaAction>, diff --git a/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/checkers/ConfigToMddChecker.kt b/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/checkers/ConfigToMddChecker.kt new file mode 100644 index 0000000000..8543c53465 --- /dev/null +++ b/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/checkers/ConfigToMddChecker.kt @@ -0,0 +1,80 @@ +/* + * Copyright 2024 Budapest University of Technology and Economics + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hu.bme.mit.theta.xcfa.cli.checkers + +import hu.bme.mit.theta.analysis.algorithm.SafetyChecker +import hu.bme.mit.theta.analysis.algorithm.bounded.createAbstract +import hu.bme.mit.theta.analysis.algorithm.bounded.createReversed +import hu.bme.mit.theta.analysis.algorithm.mdd.MddCex +import hu.bme.mit.theta.analysis.algorithm.mdd.MddChecker +import hu.bme.mit.theta.analysis.algorithm.mdd.MddProof +import hu.bme.mit.theta.analysis.algorithm.mdd.varordering.orderVarsFromRandomStartingPoints +import hu.bme.mit.theta.analysis.expr.ExprAction +import hu.bme.mit.theta.common.logging.Logger +import hu.bme.mit.theta.frontend.ParseContext +import hu.bme.mit.theta.graphsolver.patterns.constraints.MCM +import hu.bme.mit.theta.solver.SolverFactory +import hu.bme.mit.theta.solver.SolverPool +import hu.bme.mit.theta.xcfa.analysis.* +import hu.bme.mit.theta.xcfa.cli.params.* +import hu.bme.mit.theta.xcfa.cli.utils.getSolver +import hu.bme.mit.theta.xcfa.model.XCFA + +fun getMddChecker( + xcfa: XCFA, + mcm: MCM, + parseContext: ParseContext, + config: XcfaConfig<*, *>, + logger: Logger, +): SafetyChecker { + val mddConfig = config.backendConfig.specConfig as MddConfig + + val refinementSolverFactory: SolverFactory = getSolver(mddConfig.solver, mddConfig.validateSolver) + + val monolithicExpr = + xcfa + .toMonolithicExpr(parseContext, initValues = true) + .let { if (mddConfig.reversed) it.createReversed() else it } + .let { + if (mddConfig.cegar) it.createAbstract(mddConfig.initPrec.predPrec(xcfa).p.innerPrec) + else it + } + + val initRel = monolithicExpr.initExpr + val initIndexing = monolithicExpr.initOffsetIndex + val transRel = + object : ExprAction { + override fun toExpr() = monolithicExpr.transExpr + + override fun nextIndexing() = monolithicExpr.transOffsetIndex + } + val safetyProperty = monolithicExpr.propExpr + val stmts = xcfa.procedures.flatMap { it.edges.map { xcfaEdge -> xcfaEdge.label.toStmt() } }.toSet() + val variableOrder = orderVarsFromRandomStartingPoints(monolithicExpr.vars, stmts) + val solverPool = SolverPool(refinementSolverFactory) + val iterationStrategy = mddConfig.iterationStrategy + + return MddChecker.create( + initRel, + initIndexing, + transRel, + safetyProperty, + variableOrder, + solverPool, + logger, + iterationStrategy, + ) +} diff --git a/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/params/ParamValues.kt b/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/params/ParamValues.kt index 8f1695855c..c2d46bd504 100644 --- a/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/params/ParamValues.kt +++ b/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/params/ParamValues.kt @@ -69,6 +69,7 @@ enum class Backend { OC, LAZY, PORTFOLIO, + MDD, NONE, } @@ -391,7 +392,7 @@ enum class InitPrec( ), ALLGLOBALS( explPrec = { xcfa -> - XcfaPrec(PtrPrec(ExplPrec.of(xcfa.vars.map { it.wrappedVar }), emptySet())) + XcfaPrec(PtrPrec(ExplPrec.of(xcfa.globalVars.map { it.wrappedVar }), emptySet())) }, predPrec = { error("ALLGLOBALS is not interpreted for the predicate domain.") }, ), diff --git a/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/params/XcfaConfig.kt b/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/params/XcfaConfig.kt index ca5d3637c7..73b52dbcda 100644 --- a/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/params/XcfaConfig.kt +++ b/subprojects/xcfa/xcfa-cli/src/main/java/hu/bme/mit/theta/xcfa/cli/params/XcfaConfig.kt @@ -16,6 +16,7 @@ package hu.bme.mit.theta.xcfa.cli.params import com.beust.jcommander.Parameter +import hu.bme.mit.theta.analysis.algorithm.mdd.MddChecker.IterationStrategy import hu.bme.mit.theta.analysis.expr.refinement.PruneStrategy import hu.bme.mit.theta.common.logging.Logger import hu.bme.mit.theta.frontend.ParseContext @@ -190,6 +191,7 @@ data class BackendConfig( Backend.OC -> OcConfig() as T Backend.LAZY -> null Backend.PORTFOLIO -> PortfolioConfig() as T + Backend.MDD -> MddConfig() as T Backend.NONE -> null } } @@ -282,6 +284,12 @@ data class HornConfig( data class BoundedConfig( @Parameter(names = ["--max-bound"], description = "Maximum bound to check. Use 0 for no limit.") var maxBound: Int = 0, + @Parameter(names = ["--reversed"], description = "Create a reversed monolithic expression") + var reversed: Boolean = false, + @Parameter(names = ["--cegar"], description = "Wrap the check in a predicate-based CEGAR loop") + var cegar: Boolean = false, + @Parameter(names = ["--initprec"], description = "Wrap the check in a predicate-based CEGAR loop") + var initPrec: InitPrec = InitPrec.EMPTY, val bmcConfig: BMCConfig = BMCConfig(), val indConfig: InductionConfig = InductionConfig(), val itpConfig: InterpolationConfig = InterpolationConfig(), @@ -377,6 +385,28 @@ data class PortfolioConfig( var portfolio: String = "COMPLEX" ) : SpecBackendConfig +data class MddConfig( + @Parameter(names = ["--solver", "--mdd-solver"], description = "MDD solver name") + var solver: String = "Z3", + @Parameter( + names = ["--validate-solver", "--validate-mdd-solver"], + description = + "Activates a wrapper, which validates the assertions in the solver in each (SAT) check. Filters some solver issues.", + ) + var validateSolver: Boolean = false, + @Parameter( + names = ["--iteration-strategy"], + description = "Iteration strategy for the MDD checker", + ) + var iterationStrategy: IterationStrategy = IterationStrategy.GSAT, + @Parameter(names = ["--reversed"], description = "Create a reversed monolithic expression") + var reversed: Boolean = false, + @Parameter(names = ["--cegar"], description = "Wrap the check in a predicate-based CEGAR loop") + var cegar: Boolean = false, + @Parameter(names = ["--initprec"], description = "Wrap the check in a predicate-based CEGAR loop") + var initPrec: InitPrec = InitPrec.EMPTY, +) : SpecBackendConfig + data class OutputConfig( @Parameter(names = ["--version"], description = "Display version", help = true) var versionInfo: Boolean = false, diff --git a/subprojects/xcfa/xcfa-cli/src/test/java/hu/bme/mit/theta/xcfa/cli/XcfaCliVerifyTest.kt b/subprojects/xcfa/xcfa-cli/src/test/java/hu/bme/mit/theta/xcfa/cli/XcfaCliVerifyTest.kt index 084cbdbd98..0d767337e5 100644 --- a/subprojects/xcfa/xcfa-cli/src/test/java/hu/bme/mit/theta/xcfa/cli/XcfaCliVerifyTest.kt +++ b/subprojects/xcfa/xcfa-cli/src/test/java/hu/bme/mit/theta/xcfa/cli/XcfaCliVerifyTest.kt @@ -22,6 +22,7 @@ import hu.bme.mit.theta.frontend.chc.ChcFrontend import hu.bme.mit.theta.solver.smtlib.SmtLibSolverManager import hu.bme.mit.theta.xcfa.cli.XcfaCli.Companion.main import java.nio.file.Path +import java.util.concurrent.TimeUnit import java.util.stream.Stream import kotlin.io.path.absolutePathString import kotlin.io.path.createTempDirectory @@ -30,6 +31,7 @@ import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Assumptions import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.MethodSource @@ -113,6 +115,16 @@ class XcfaCliVerifyTest { ) } + @JvmStatic + fun finiteStateSpaceC(): Stream { + return Stream.of( + Arguments.of("/c/litmustest/singlethread/00assignment.c", null), + Arguments.of("/c/litmustest/singlethread/13typedef.c", "--domain PRED_CART"), + Arguments.of("/c/litmustest/singlethread/15addition.c", null), + Arguments.of("/c/litmustest/singlethread/20testinline.c", null), + ) + } + @JvmStatic fun cFilesShort(): Stream { return Stream.of( @@ -304,6 +316,24 @@ class XcfaCliVerifyTest { main(params) } + @ParameterizedTest + @MethodSource("finiteStateSpaceC") + @Timeout(value = 10, unit = TimeUnit.SECONDS, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) + fun testCVerifyMDD(filePath: String, extraArgs: String?) { + val params = + arrayOf( + "--backend", + "MDD", + "--input-type", + "C", + "--input", + javaClass.getResource(filePath)!!.path, + "--stacktrace", + "--debug", + ) + main(params) + } + @ParameterizedTest @MethodSource("singleThreadedCFiles") fun testCVerifyIMC(filePath: String, extraArgs: String?) { diff --git a/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/Utils.kt b/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/Utils.kt index f48bd757d8..952f891ec4 100644 --- a/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/Utils.kt +++ b/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/Utils.kt @@ -59,7 +59,7 @@ fun XcfaLabel.getFlatLabels(): List = } fun XCFA.collectVars(): Iterable> = - vars.map { it.wrappedVar } union procedures.map { it.vars }.flatten() + globalVars.map { it.wrappedVar } union procedures.map { it.vars }.flatten() fun XCFA.collectAssumes(): Iterable> = procedures @@ -255,7 +255,7 @@ private fun XcfaLabel.collectGlobalVars(globalVars: Set>): VarAccessM * second is similar for write access. */ fun XcfaEdge.collectIndirectGlobalVarAccesses(xcfa: XCFA): VarAccessMap { - val globalVars = xcfa.vars.map(XcfaGlobalVar::wrappedVar).toSet() + val globalVars = xcfa.globalVars.map(XcfaGlobalVar::wrappedVar).toSet() val flatLabels = getFlatLabels() val mutexes = flatLabels.filterIsInstance().flatMap { it.acquiredMutexes }.toMutableSet() @@ -287,7 +287,7 @@ fun XcfaEdge.getGlobalVarsWithNeededMutexes( xcfa: XCFA, currentMutexes: Set, ): List { - val globalVars = xcfa.vars.map(XcfaGlobalVar::wrappedVar).toSet() + val globalVars = xcfa.globalVars.map(XcfaGlobalVar::wrappedVar).toSet() val neededMutexes = currentMutexes.toMutableSet() val accesses = mutableListOf() getFlatLabels().forEach { label -> diff --git a/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/gson/XcfaAdapter.kt b/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/gson/XcfaAdapter.kt index 4c7c6673ef..69d2d5d6dc 100644 --- a/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/gson/XcfaAdapter.kt +++ b/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/gson/XcfaAdapter.kt @@ -36,7 +36,7 @@ class XcfaAdapter(val gsonSupplier: () -> Gson) : TypeAdapter() { writer.name("name").value(value.name) // vars writer.name("vars") - gson.toJson(gson.toJsonTree(value.vars), writer) + gson.toJson(gson.toJsonTree(value.globalVars), writer) // procedures writer.name("procedures").beginArray() diff --git a/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/model/Builders.kt b/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/model/Builders.kt index bf8d3e3730..268aa4b3fb 100644 --- a/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/model/Builders.kt +++ b/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/model/Builders.kt @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package hu.bme.mit.theta.xcfa.model import hu.bme.mit.theta.core.decl.VarDecl @@ -22,238 +21,280 @@ import hu.bme.mit.theta.core.type.Type import hu.bme.mit.theta.xcfa.passes.ProcedurePassManager import java.util.* -@DslMarker -annotation class XcfaDsl +@DslMarker annotation class XcfaDsl @XcfaDsl -class XcfaBuilder @JvmOverloads constructor( - var name: String, - private val vars: MutableSet = LinkedHashSet(), - val heapMap: MutableMap, VarDecl<*>> = LinkedHashMap(), - private val procedures: MutableSet = LinkedHashSet(), - private val initProcedures: MutableList>>> = ArrayList(), - val metaData: MutableMap = LinkedHashMap() +class XcfaBuilder +@JvmOverloads +constructor( + var name: String, + private val vars: MutableSet = LinkedHashSet(), + val heapMap: MutableMap, VarDecl<*>> = LinkedHashMap(), + private val procedures: MutableSet = LinkedHashSet(), + private val initProcedures: MutableList>>> = ArrayList(), + val metaData: MutableMap = LinkedHashMap(), ) { - fun getVars(): Set = vars - fun getProcedures(): Set = procedures - fun getInitProcedures(): List>>> = initProcedures - - fun build(): XCFA { - return XCFA( - name = name, - vars = vars, - procedureBuilders = procedures, - initProcedureBuilders = initProcedures - ) - } + fun getVars(): Set = vars - fun addVar(toAdd: XcfaGlobalVar) { - vars.add(toAdd) - } + fun getProcedures(): Set = procedures - fun addProcedure(toAdd: XcfaProcedureBuilder) { - procedures.add(toAdd) - toAdd.parent = this - } + fun getInitProcedures(): List>>> = initProcedures - fun addEntryPoint(toAdd: XcfaProcedureBuilder, params: List>) { - addProcedure(toAdd) - initProcedures.add(Pair(toAdd, params)) - } + fun build(): XCFA { + return XCFA( + name = name, + globalVars = vars, + procedureBuilders = procedures, + initProcedureBuilders = initProcedures, + ) + } + + fun addVar(toAdd: XcfaGlobalVar) { + vars.add(toAdd) + } + + fun addProcedure(toAdd: XcfaProcedureBuilder) { + procedures.add(toAdd) + toAdd.parent = this + } + + fun addEntryPoint(toAdd: XcfaProcedureBuilder, params: List>) { + addProcedure(toAdd) + initProcedures.add(Pair(toAdd, params)) + } } @XcfaDsl -class XcfaProcedureBuilder @JvmOverloads constructor( - var name: String, - val manager: ProcedurePassManager, - private val params: MutableList, ParamDirection>> = ArrayList(), - private val vars: MutableSet> = LinkedHashSet(), - private val locs: MutableSet = LinkedHashSet(), - private val edges: MutableSet = LinkedHashSet(), - val metaData: MutableMap = LinkedHashMap() +class XcfaProcedureBuilder +@JvmOverloads +constructor( + var name: String, + val manager: ProcedurePassManager, + private val params: MutableList, ParamDirection>> = ArrayList(), + private val vars: MutableSet> = LinkedHashSet(), + private val locs: MutableSet = LinkedHashSet(), + private val edges: MutableSet = LinkedHashSet(), + val metaData: MutableMap = LinkedHashMap(), ) { - lateinit var initLoc: XcfaLocation - private set - var finalLoc: Optional = Optional.empty() - private set - var errorLoc: Optional = Optional.empty() - private set - lateinit var parent: XcfaBuilder - private lateinit var built: XcfaProcedure - private lateinit var optimized: XcfaProcedureBuilder - private lateinit var partlyOptimized: XcfaProcedureBuilder - private var lastOptimized: Int = -1 - fun getParams(): List, ParamDirection>> = when { - this::optimized.isInitialized -> optimized.params - this::partlyOptimized.isInitialized -> partlyOptimized.params - else -> params - } + lateinit var initLoc: XcfaLocation + private set - fun getVars(): Set> = when { - this::optimized.isInitialized -> optimized.vars - this::partlyOptimized.isInitialized -> partlyOptimized.vars - else -> vars - } + var finalLoc: Optional = Optional.empty() + private set - fun getLocs(): Set = when { - this::optimized.isInitialized -> optimized.locs - this::partlyOptimized.isInitialized -> partlyOptimized.locs - else -> locs - } + var errorLoc: Optional = Optional.empty() + private set - fun getEdges(): Set = when { - this::optimized.isInitialized -> optimized.edges - this::partlyOptimized.isInitialized -> partlyOptimized.edges - else -> edges - } + lateinit var parent: XcfaBuilder + private lateinit var built: XcfaProcedure + private lateinit var optimized: XcfaProcedureBuilder + private lateinit var partlyOptimized: XcfaProcedureBuilder + private var lastOptimized: Int = -1 - fun optimize() { - if (!this::optimized.isInitialized) { - var that = this - for (pass in manager.passes.flatten()) { - that = pass.run(that) - } - optimized = that - } + fun getParams(): List, ParamDirection>> = + when { + this::optimized.isInitialized -> optimized.params + this::partlyOptimized.isInitialized -> partlyOptimized.params + else -> params } - fun optimize(phase: Int): Boolean { // true, if optimization is finished (no more phases to execute) - if (this::optimized.isInitialized || phase >= manager.passes.size) return true - if (phase <= lastOptimized) return lastOptimized >= manager.passes.size - 1 - check(phase == lastOptimized + 1) { "Wrong optimization phase!" } - - var that = if (this::partlyOptimized.isInitialized) partlyOptimized else this - for (pass in manager.passes[phase]) { - that = pass.run(that) - } - - partlyOptimized = that - lastOptimized = phase - if (phase >= manager.passes.size - 1) optimized = that - return phase >= manager.passes.size - 1 + fun getVars(): Set> = + when { + this::optimized.isInitialized -> optimized.vars + this::partlyOptimized.isInitialized -> partlyOptimized.vars + else -> vars } - fun build(parent: XCFA): XcfaProcedure { - if (this::built.isInitialized) return built; - if (!this::optimized.isInitialized) optimize() - built = XcfaProcedure( - name = optimized.name, - params = optimized.params, - vars = optimized.vars, - locs = optimized.locs, - edges = optimized.edges, - initLoc = optimized.initLoc, - finalLoc = optimized.finalLoc, - errorLoc = optimized.errorLoc - ) - built.parent = parent - return built + fun getLocs(): Set = + when { + this::optimized.isInitialized -> optimized.locs + this::partlyOptimized.isInitialized -> partlyOptimized.locs + else -> locs } - fun addParam(toAdd: VarDecl<*>, dir: ParamDirection) { - check(!this::optimized.isInitialized) { "Cannot add/remove new elements after optimization passes!" } - params.add(Pair(toAdd, dir)) - vars.add(toAdd) + fun getEdges(): Set = + when { + this::optimized.isInitialized -> optimized.edges + this::partlyOptimized.isInitialized -> partlyOptimized.edges + else -> edges } - fun addVar(toAdd: VarDecl<*>) { - check(!this::optimized.isInitialized) { "Cannot add/remove new elements after optimization passes!" } - vars.add(toAdd) + fun optimize() { + if (!this::optimized.isInitialized) { + var that = this + for (pass in manager.passes.flatten()) { + that = pass.run(that) + } + optimized = that } - - fun removeVar(toRemove: VarDecl<*>) { - check(!this::optimized.isInitialized) { "Cannot add/remove new elements after optimization passes!" } - vars.remove(toRemove) + } + + fun optimize( + phase: Int + ): Boolean { // true, if optimization is finished (no more phases to execute) + if (this::optimized.isInitialized || phase >= manager.passes.size) return true + if (phase <= lastOptimized) return lastOptimized >= manager.passes.size - 1 + check(phase == lastOptimized + 1) { "Wrong optimization phase!" } + + var that = if (this::partlyOptimized.isInitialized) partlyOptimized else this + for (pass in manager.passes[phase]) { + that = pass.run(that) } - @JvmOverloads - fun createErrorLoc(metaData: MetaData = EmptyMetaData) { - check(!this::optimized.isInitialized) { "Cannot add/remove new elements after optimization passes!" } - if (errorLoc.isEmpty) { - errorLoc = Optional.of(XcfaLocation(name + "_error", error = true, metadata = metaData)) - locs.add(errorLoc.get()) - } + partlyOptimized = that + lastOptimized = phase + if (phase >= manager.passes.size - 1) optimized = that + return phase >= manager.passes.size - 1 + } + + fun build(parent: XCFA): XcfaProcedure { + if (this::built.isInitialized) return built + if (!this::optimized.isInitialized) optimize() + built = + XcfaProcedure( + name = optimized.name, + params = optimized.params, + vars = optimized.vars, + locs = optimized.locs, + edges = optimized.edges, + initLoc = optimized.initLoc, + finalLoc = optimized.finalLoc, + errorLoc = optimized.errorLoc, + ) + built.parent = parent + return built + } + + fun addParam(toAdd: VarDecl<*>, dir: ParamDirection) { + check(!this::optimized.isInitialized) { + "Cannot add/remove new elements after optimization passes!" } + params.add(Pair(toAdd, dir)) + vars.add(toAdd) + } - @JvmOverloads - fun createFinalLoc(metaData: MetaData = EmptyMetaData) { - check(!this::optimized.isInitialized) { "Cannot add/remove new elements after optimization passes!" } - if (finalLoc.isEmpty) { - finalLoc = Optional.of(XcfaLocation(name + "_final", final = true, metadata = metaData)) - locs.add(finalLoc.get()) - } + fun addVar(toAdd: VarDecl<*>) { + check(!this::optimized.isInitialized) { + "Cannot add/remove new elements after optimization passes!" } + vars.add(toAdd) + } - @JvmOverloads - fun createInitLoc(metaData: MetaData = EmptyMetaData) { - check(!this::optimized.isInitialized) { "Cannot add/remove new elements after optimization passes!" } - if (!this::initLoc.isInitialized) { - initLoc = XcfaLocation(name + "_init", initial = true, metadata = metaData) - locs.add(initLoc) - } + fun removeVar(toRemove: VarDecl<*>) { + check(!this::optimized.isInitialized) { + "Cannot add/remove new elements after optimization passes!" } + vars.remove(toRemove) + } - fun copyMetaLocs(from: XcfaProcedureBuilder) { - check(!this::optimized.isInitialized) { "Cannot add/remove new elements after optimization passes!" } - initLoc = from.initLoc - finalLoc = from.finalLoc - errorLoc = from.errorLoc + @JvmOverloads + fun createErrorLoc(metaData: MetaData = EmptyMetaData) { + check(!this::optimized.isInitialized) { + "Cannot add/remove new elements after optimization passes!" + } + if (errorLoc.isEmpty) { + errorLoc = Optional.of(XcfaLocation(name + "_error", error = true, metadata = metaData)) + locs.add(errorLoc.get()) } + } - fun addEdge(toAdd: XcfaEdge) { - check(!this::optimized.isInitialized) { "Cannot add/remove new elements after optimization passes!" } - addLoc(toAdd.source) - addLoc(toAdd.target) - edges.add(toAdd) - toAdd.source.outgoingEdges.add(toAdd) - toAdd.target.incomingEdges.add(toAdd) + @JvmOverloads + fun createFinalLoc(metaData: MetaData = EmptyMetaData) { + check(!this::optimized.isInitialized) { + "Cannot add/remove new elements after optimization passes!" } + if (finalLoc.isEmpty) { + finalLoc = Optional.of(XcfaLocation(name + "_final", final = true, metadata = metaData)) + locs.add(finalLoc.get()) + } + } - fun addLoc(toAdd: XcfaLocation) { - check(!this::optimized.isInitialized) { "Cannot add/remove new elements after optimization passes!" } - if (!locs.contains(toAdd)) { - check(!toAdd.error) - check(!toAdd.initial) - check(!toAdd.final) - locs.add(toAdd) - } + @JvmOverloads + fun createInitLoc(metaData: MetaData = EmptyMetaData) { + check(!this::optimized.isInitialized) { + "Cannot add/remove new elements after optimization passes!" } + if (!this::initLoc.isInitialized) { + initLoc = XcfaLocation(name + "_init", initial = true, metadata = metaData) + locs.add(initLoc) + } + } - fun removeEdge(toRemove: XcfaEdge) { - check(!this::optimized.isInitialized) { "Cannot add/remove new elements after optimization passes!" } - toRemove.source.outgoingEdges.remove(toRemove) - toRemove.target.incomingEdges.remove(toRemove) - edges.remove(toRemove) + fun copyMetaLocs(from: XcfaProcedureBuilder) { + check(!this::optimized.isInitialized) { + "Cannot add/remove new elements after optimization passes!" + } + initLoc = from.initLoc + finalLoc = from.finalLoc + errorLoc = from.errorLoc + } + + fun addEdge(toAdd: XcfaEdge) { + check(!this::optimized.isInitialized) { + "Cannot add/remove new elements after optimization passes!" + } + addLoc(toAdd.source) + addLoc(toAdd.target) + edges.add(toAdd) + toAdd.source.outgoingEdges.add(toAdd) + toAdd.target.incomingEdges.add(toAdd) + } + + fun addLoc(toAdd: XcfaLocation) { + check(!this::optimized.isInitialized) { + "Cannot add/remove new elements after optimization passes!" + } + if (!locs.contains(toAdd)) { + check(!toAdd.error) + check(!toAdd.initial) + check(!toAdd.final) + locs.add(toAdd) } + } - fun removeLoc(toRemove: XcfaLocation) { - check(!this::optimized.isInitialized) { "Cannot add/remove new elements after optimization passes!" } - locs.remove(toRemove) + fun removeEdge(toRemove: XcfaEdge) { + check(!this::optimized.isInitialized) { + "Cannot add/remove new elements after optimization passes!" } + toRemove.source.outgoingEdges.remove(toRemove) + toRemove.target.incomingEdges.remove(toRemove) + edges.remove(toRemove) + } + + fun removeLoc(toRemove: XcfaLocation) { + check(!this::optimized.isInitialized) { + "Cannot add/remove new elements after optimization passes!" + } + locs.remove(toRemove) + } - fun removeLocs(pred: (XcfaLocation) -> Boolean) { - check(!this::optimized.isInitialized) { "Cannot add/remove new elements after optimization passes!" } - while (locs.any(pred)) { - locs.removeIf(pred) - edges.removeIf { - pred(it.source).also { removing -> - if (removing) { - it.target.incomingEdges.remove(it) - } - } - } + fun removeLocs(pred: (XcfaLocation) -> Boolean) { + check(!this::optimized.isInitialized) { + "Cannot add/remove new elements after optimization passes!" + } + while (locs.any(pred)) { + locs.removeIf(pred) + edges.removeIf { + pred(it.source).also { removing -> + if (removing) { + it.target.incomingEdges.remove(it) + } } + } } + } - fun changeVars(varLut: Map, VarDecl<*>>) { - check(!this::optimized.isInitialized) { "Cannot add/remove new elements after optimization passes!" } - val savedVars = ArrayList(vars) - vars.clear() - savedVars.forEach { vars.add(checkNotNull(varLut[it])) } - val savedParams = ArrayList(params) - params.clear() - savedParams.forEach { params.add(Pair(checkNotNull(varLut[it.first]), it.second)) } + fun changeVars(varLut: Map, VarDecl<*>>) { + check(!this::optimized.isInitialized) { + "Cannot add/remove new elements after optimization passes!" } -} \ No newline at end of file + val savedVars = ArrayList(vars) + vars.clear() + savedVars.forEach { vars.add(checkNotNull(varLut[it])) } + val savedParams = ArrayList(params) + params.clear() + savedParams.forEach { params.add(Pair(checkNotNull(varLut[it.first]), it.second)) } + } +} diff --git a/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/model/XCFA.kt b/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/model/XCFA.kt index 9e90b22d57..8873df6632 100644 --- a/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/model/XCFA.kt +++ b/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/model/XCFA.kt @@ -23,7 +23,7 @@ import java.util.* class XCFA( val name: String, - val vars: Set, // global variables + val globalVars: Set, // global variables val procedureBuilders: Set = emptySet(), val initProcedureBuilders: List>>> = emptyList(), ) { @@ -69,7 +69,7 @@ class XCFA( other as XCFA if (name != other.name) return false - if (vars != other.vars) return false + if (globalVars != other.globalVars) return false if (procedures != other.procedures) return false if (initProcedures != other.initProcedures) return false @@ -79,7 +79,7 @@ class XCFA( override fun hashCode(): Int { if (cachedHash != null) return cachedHash as Int var result = name.hashCode() - result = 31 * result + vars.hashCode() + result = 31 * result + globalVars.hashCode() result = 31 * result + procedures.hashCode() result = 31 * result + initProcedures.hashCode() cachedHash = result @@ -87,7 +87,7 @@ class XCFA( } override fun toString(): String { - return "XCFA(name='$name', vars=$vars, procedures=$procedures, initProcedures=$initProcedures)" + return "XCFA(name='$name', vars=$globalVars, procedures=$procedures, initProcedures=$initProcedures)" } } @@ -144,6 +144,8 @@ data class XcfaEdge( fun withTarget(target: XcfaLocation): XcfaEdge = XcfaEdge(source, target, label, metadata) fun withSource(source: XcfaLocation): XcfaEdge = XcfaEdge(source, target, label, metadata) + + fun withMetadata(metadata: MetaData): XcfaEdge = XcfaEdge(source, target, label, metadata) } data class XcfaGlobalVar diff --git a/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/passes/EmptyEdgeRemovalPass.kt b/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/passes/EmptyEdgeRemovalPass.kt index 43c29685ee..dd9adf3563 100644 --- a/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/passes/EmptyEdgeRemovalPass.kt +++ b/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/passes/EmptyEdgeRemovalPass.kt @@ -16,6 +16,7 @@ package hu.bme.mit.theta.xcfa.passes import hu.bme.mit.theta.core.stmt.Stmts.Assume +import hu.bme.mit.theta.core.type.booltype.BoolExprs.False import hu.bme.mit.theta.core.type.booltype.BoolExprs.True import hu.bme.mit.theta.xcfa.model.* @@ -24,31 +25,48 @@ class EmptyEdgeRemovalPass : ProcedurePass { override fun run(builder: XcfaProcedureBuilder): XcfaProcedureBuilder { while (true) { + builder.getEdges().filter { it.label.isSureStuck() }.forEach { builder.removeEdge(it) } + val edge = builder.getEdges().find { it.label.isNop() && !it.target.error && !it.target.final && !it.source.initial && - (it.source.outgoingEdges.size == 1 || it.target.incomingEdges.size == 1) && - it.metadata is EmptyMetaData + (it.source.outgoingEdges.size == 1 || it.target.incomingEdges.size == 1) } ?: return builder val collapseBefore = edge.source.outgoingEdges.size == 1 builder.removeEdge(edge) if (collapseBefore) { val incomingEdges = edge.source.incomingEdges.toList() incomingEdges.forEach { builder.removeEdge(it) } - incomingEdges.forEach { builder.addEdge(it.withTarget(edge.target)) } + incomingEdges.forEach { + builder.addEdge( + it.withTarget(edge.target).withMetadata(it.metadata.combine(edge.metadata)) + ) + } builder.removeLoc(edge.source) } else { val outgoingEdges = edge.target.outgoingEdges.toList() outgoingEdges.forEach { builder.removeEdge(it) } - outgoingEdges.forEach { builder.addEdge(it.withSource(edge.source)) } + outgoingEdges.forEach { + builder.addEdge( + it.withSource(edge.source).withMetadata(edge.metadata.combine(it.metadata)) + ) + } builder.removeLoc(edge.target) } } } + private fun XcfaLabel.isSureStuck(): Boolean = + when (this) { + is SequenceLabel -> labels.any { it.isSureStuck() } + is NondetLabel -> labels.all { it.isSureStuck() } + is StmtLabel -> stmt == Assume(False()) + else -> false + } + private fun XcfaLabel.isNop(): Boolean = when (this) { is NondetLabel -> labels.all { it.isNop() } @@ -56,5 +74,5 @@ class EmptyEdgeRemovalPass : ProcedurePass { is NopLabel -> true is StmtLabel -> stmt == Assume(True()) else -> false - }.and(metadata is EmptyMetaData) + } } diff --git a/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/passes/ProcedurePassManager.kt b/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/passes/ProcedurePassManager.kt index 59949a3a7d..f9979b3aae 100644 --- a/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/passes/ProcedurePassManager.kt +++ b/subprojects/xcfa/xcfa/src/main/java/hu/bme/mit/theta/xcfa/passes/ProcedurePassManager.kt @@ -51,6 +51,7 @@ class CPasses(checkOverflow: Boolean, parseContext: ParseContext, uniqueWarningL listOf( // trying to inline procedures InlineProceduresPass(parseContext), + EmptyEdgeRemovalPass(), RemoveDeadEnds(), EliminateSelfLoops(), ), diff --git a/subprojects/xcfa/xcfa/src/test/java/hu/bme/mit/theta/xcfa/gson/GsonTest.kt b/subprojects/xcfa/xcfa/src/test/java/hu/bme/mit/theta/xcfa/gson/GsonTest.kt index 23235c44a1..bb30bddb81 100644 --- a/subprojects/xcfa/xcfa/src/test/java/hu/bme/mit/theta/xcfa/gson/GsonTest.kt +++ b/subprojects/xcfa/xcfa/src/test/java/hu/bme/mit/theta/xcfa/gson/GsonTest.kt @@ -97,7 +97,7 @@ class GsonTest { val x_symbol = NamedSymbol("x") symbolTable.add(x_symbol) val env = Env() - env.define(x_symbol, xcfaSource.vars.find { it.wrappedVar.name == "x" }!!.wrappedVar) + env.define(x_symbol, xcfaSource.globalVars.find { it.wrappedVar.name == "x" }!!.wrappedVar) val gson = getGson(symbolTable, env, true) val output = gson.fromJson(gson.toJson(xcfaSource), XCFA::class.java) @@ -124,8 +124,11 @@ class GsonTest { symbolTable.add(x_symbol) symbolTable.add(thr1_symbol) val env = Env() - env.define(x_symbol, xcfaSource.vars.find { it.wrappedVar.name == "x" }!!.wrappedVar) - env.define(thr1_symbol, xcfaSource.vars.find { it.wrappedVar.name == "thr1" }!!.wrappedVar) + env.define(x_symbol, xcfaSource.globalVars.find { it.wrappedVar.name == "x" }!!.wrappedVar) + env.define( + thr1_symbol, + xcfaSource.globalVars.find { it.wrappedVar.name == "thr1" }!!.wrappedVar, + ) val gson = getGson(symbolTable, env, true) val output = gson.fromJson(gson.toJson(xcfaSource), XCFA::class.java) diff --git a/subprojects/xsts/xsts-analysis/src/main/java/hu/bme/mit/theta/xsts/analysis/mdd/XstsMddChecker.java b/subprojects/xsts/xsts-analysis/src/main/java/hu/bme/mit/theta/xsts/analysis/mdd/XstsMddChecker.java index b8ae5f098e..9acd255d3f 100644 --- a/subprojects/xsts/xsts-analysis/src/main/java/hu/bme/mit/theta/xsts/analysis/mdd/XstsMddChecker.java +++ b/subprojects/xsts/xsts-analysis/src/main/java/hu/bme/mit/theta/xsts/analysis/mdd/XstsMddChecker.java @@ -52,6 +52,8 @@ import hu.bme.mit.theta.core.utils.indexings.VarIndexingFactory; import hu.bme.mit.theta.solver.SolverPool; import hu.bme.mit.theta.xsts.XSTS; +import hu.bme.mit.theta.xsts.analysis.XstsVarOrderingKt; + import java.util.ArrayList; import java.util.List; @@ -111,8 +113,11 @@ public SafetyResult check(Void input) { final var initToExprResult = StmtUtils.toExpr(xsts.getInit(), VarIndexingFactory.indexing(0)); - for (var v : xsts.getVars()) { - final var domainSize = /*v.getType() instanceof BoolType ? 2 :*/ 0; + final var orderedVars = XstsVarOrderingKt.orderVars(xsts); + for (var v : xsts.getStateVars()) { + final var + domainSize = /*Math.max(v.getType().getDomainSize().getFiniteSize().intValue(), 0)*/ + 0; stateOrder.createOnTop(MddVariableDescriptor.create(v.getConstDecl(0), domainSize)); diff --git a/subprojects/xsts/xsts-analysis/src/main/kotlin/hu/bme/mit/theta/xsts/analysis/XstsVarOrdering.kt b/subprojects/xsts/xsts-analysis/src/main/kotlin/hu/bme/mit/theta/xsts/analysis/XstsVarOrdering.kt new file mode 100644 index 0000000000..a156ace009 --- /dev/null +++ b/subprojects/xsts/xsts-analysis/src/main/kotlin/hu/bme/mit/theta/xsts/analysis/XstsVarOrdering.kt @@ -0,0 +1,69 @@ +/* + * Copyright 2024 Budapest University of Technology and Economics + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package hu.bme.mit.theta.xsts.analysis + +import hu.bme.mit.theta.analysis.algorithm.mdd.varordering.orderVarsFromRandomStartingPoints +import hu.bme.mit.theta.core.decl.VarDecl +import hu.bme.mit.theta.core.stmt.IfStmt +import hu.bme.mit.theta.core.stmt.NonDetStmt +import hu.bme.mit.theta.core.stmt.SequenceStmt +import hu.bme.mit.theta.core.stmt.Stmt +import hu.bme.mit.theta.xsts.XSTS + +fun XSTS.orderVars(): List> { + val flattened = flattenStmts(tran) + val orderedVars = orderVarsFromRandomStartingPoints(this.stateVars.toList(), flattened) + return orderedVars +} + +fun cartesianProduct(vararg sets: Set<*>): Set> = + sets + .fold(listOf(listOf())) { acc, set -> + acc.flatMap { list -> set.map { element -> list + element } } + } + .toSet() + +private fun flattenStmts(stmt: Stmt): Set { + return when(stmt) { + is NonDetStmt -> { + stmt.stmts.flatMap { flattenStmts(it) }.toSet() + } + is SequenceStmt -> { + cartesianProduct(*(stmt.stmts.map { flattenStmts(it) }.toTypedArray())).map { SequenceStmt.of(it as List) }.toSet() + } + is IfStmt -> { + flattenStmts(stmt.then) + flattenStmts(stmt.elze) + } + else -> { + setOf(stmt) + } + } +} + +//private fun collectStmts(stmt: Stmt): Set { +// return when(stmt) { +// is NonDetStmt -> { +// stmt.stmts.flatMap { collectStmts(it) }.toSet() +// } +// is SequenceStmt -> { +// stmt.stmts.flatMap { collectStmts(it) }.toSet() +// } +// else -> { +// setOf(stmt) +// } +// } +//} \ No newline at end of file diff --git a/subprojects/xsts/xsts-analysis/src/test/java/hu/bme/mit/theta/xsts/analysis/XstsMddCheckerTest.java b/subprojects/xsts/xsts-analysis/src/test/java/hu/bme/mit/theta/xsts/analysis/XstsMddCheckerTest.java index 4208cdfd6a..84434e4cb1 100644 --- a/subprojects/xsts/xsts-analysis/src/test/java/hu/bme/mit/theta/xsts/analysis/XstsMddCheckerTest.java +++ b/subprojects/xsts/xsts-analysis/src/test/java/hu/bme/mit/theta/xsts/analysis/XstsMddCheckerTest.java @@ -132,6 +132,11 @@ public static Collection data() { "src/test/resources/property/count_up_down2.prop", true }, + { + "src/test/resources/model/count_up_down.xsts", + "src/test/resources/property/count_up_down2.prop", + true + }, // {"src/test/resources/model/bhmr2007.xsts", // "src/test/resources/property/bhmr2007.prop", true}, @@ -162,6 +167,12 @@ public static Collection data() { // // { "src/test/resources/model/if2.xsts", // "src/test/resources/property/if2.prop", false} + + { + "src/test/resources/model/localvars3.xsts", + "src/test/resources/property/localvars3.prop", + false + }, }); } diff --git a/subprojects/xsts/xsts-analysis/src/test/resources/model/localvars3.xsts b/subprojects/xsts/xsts-analysis/src/test/resources/model/localvars3.xsts new file mode 100644 index 0000000000..ccc26dc1a0 --- /dev/null +++ b/subprojects/xsts/xsts-analysis/src/test/resources/model/localvars3.xsts @@ -0,0 +1,20 @@ +var x: integer = 1 +var y: integer = 1 + +trans { + assume x<16 && x>0; + local var a:integer=x; + a:=a+x; + y:=a; + x:=0; +} or { + assume y<16 && y>0; + local var a:integer=y; + a:=a+y; + x:=a; + y:=0; +} + +init{} + +env{} \ No newline at end of file diff --git a/subprojects/xsts/xsts-analysis/src/test/resources/property/localvars3.prop b/subprojects/xsts/xsts-analysis/src/test/resources/property/localvars3.prop new file mode 100644 index 0000000000..2dac5b7a15 --- /dev/null +++ b/subprojects/xsts/xsts-analysis/src/test/resources/property/localvars3.prop @@ -0,0 +1,3 @@ +prop{ + x!=8 +} \ No newline at end of file diff --git a/subprojects/xsts/xsts/src/main/java/hu/bme/mit/theta/xsts/XSTS.java b/subprojects/xsts/xsts/src/main/java/hu/bme/mit/theta/xsts/XSTS.java index 97d6bad079..b5617c11da 100644 --- a/subprojects/xsts/xsts/src/main/java/hu/bme/mit/theta/xsts/XSTS.java +++ b/subprojects/xsts/xsts/src/main/java/hu/bme/mit/theta/xsts/XSTS.java @@ -15,6 +15,8 @@ */ package hu.bme.mit.theta.xsts; +import static com.google.common.base.Preconditions.checkNotNull; + import hu.bme.mit.theta.common.container.Containers; import hu.bme.mit.theta.core.decl.VarDecl; import hu.bme.mit.theta.core.stmt.NonDetStmt; @@ -22,16 +24,13 @@ import hu.bme.mit.theta.core.type.booltype.BoolType; import hu.bme.mit.theta.core.utils.ExprUtils; import hu.bme.mit.theta.core.utils.StmtUtils; - -import java.util.Collection; -import java.util.Collections; import java.util.Set; -import static com.google.common.base.Preconditions.checkNotNull; - public final class XSTS { - private final Collection> vars; + private final Set> vars; + private final Set> stateVars; + private final Set> localVars; private final Set> ctrlVars; private final NonDetStmt tran; @@ -41,51 +40,86 @@ public final class XSTS { private final Expr initFormula; private final Expr prop; + public NonDetStmt getTran() { + return tran; + } + + public NonDetStmt getEnv() { + return env; + } + public NonDetStmt getInit() { return init; } - public Collection> getVars() { - return vars; + public Expr getInitFormula() { + return initFormula; } public Expr getProp() { return prop; } - public NonDetStmt getTran() { - return tran; + public Set> getVars() { + return vars; } - public Expr getInitFormula() { - return initFormula; + public Set> getLocalVars() { + return localVars; } - public NonDetStmt getEnv() { - return env; + public Set> getStateVars() { + return stateVars; } public Set> getCtrlVars() { return ctrlVars; } - public XSTS(final Set> ctrlVars, - final NonDetStmt init, final NonDetStmt tran, final NonDetStmt env, - final Expr initFormula, final Expr prop) { + public XSTS( + final Set> ctrlVars, + final NonDetStmt init, + final NonDetStmt tran, + final NonDetStmt env, + final Expr initFormula, + final Expr prop) { this.tran = checkNotNull(tran); this.init = checkNotNull(init); this.env = checkNotNull(env); this.initFormula = checkNotNull(initFormula); this.prop = checkNotNull(prop); - this.ctrlVars = ctrlVars; - - final Set> tmpVars = Containers.createSet(); - tmpVars.addAll(StmtUtils.getVars(tran)); - tmpVars.addAll(StmtUtils.getVars(env)); - tmpVars.addAll(StmtUtils.getVars(init)); - tmpVars.addAll(ExprUtils.getVars(initFormula)); - tmpVars.addAll(ExprUtils.getVars(prop)); - this.vars = Collections.unmodifiableCollection(tmpVars); + this.ctrlVars = checkNotNull(ctrlVars); + + this.vars = Containers.createSet(); + vars.addAll(StmtUtils.getVars(tran)); + vars.addAll(StmtUtils.getVars(env)); + vars.addAll(StmtUtils.getVars(init)); + vars.addAll(ExprUtils.getVars(initFormula)); + vars.addAll(ExprUtils.getVars(prop)); + this.stateVars = this.vars; + this.localVars = Containers.createSet(); } + public XSTS( + final Set> stateVars, + final Set> localVars, + final Set> ctrlVars, + final NonDetStmt init, + final NonDetStmt tran, + final NonDetStmt env, + final Expr initFormula, + final Expr prop) { + this.tran = checkNotNull(tran); + this.init = checkNotNull(init); + this.env = checkNotNull(env); + this.initFormula = checkNotNull(initFormula); + this.prop = checkNotNull(prop); + this.ctrlVars = checkNotNull(ctrlVars); + + this.vars = Containers.createSet(); + this.vars.addAll(checkNotNull(stateVars)); + this.vars.addAll(checkNotNull(localVars)); + this.stateVars = stateVars; + this.localVars = localVars; + } } diff --git a/subprojects/xsts/xsts/src/main/java/hu/bme/mit/theta/xsts/dsl/XstsSpecification.java b/subprojects/xsts/xsts/src/main/java/hu/bme/mit/theta/xsts/dsl/XstsSpecification.java index fc3c91d6c8..ff7a78bf51 100644 --- a/subprojects/xsts/xsts/src/main/java/hu/bme/mit/theta/xsts/dsl/XstsSpecification.java +++ b/subprojects/xsts/xsts/src/main/java/hu/bme/mit/theta/xsts/dsl/XstsSpecification.java @@ -15,6 +15,12 @@ */ package hu.bme.mit.theta.xsts.dsl; +import static com.google.common.base.Preconditions.checkNotNull; +import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Eq; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.And; +import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Bool; +import static hu.bme.mit.theta.core.utils.TypeUtils.cast; + import hu.bme.mit.theta.common.container.Containers; import hu.bme.mit.theta.common.dsl.*; import hu.bme.mit.theta.core.decl.VarDecl; @@ -24,17 +30,12 @@ import hu.bme.mit.theta.core.type.booltype.BoolType; import hu.bme.mit.theta.core.type.enumtype.EnumType; import hu.bme.mit.theta.core.utils.ExprUtils; +import hu.bme.mit.theta.core.utils.StmtUtils; import hu.bme.mit.theta.xsts.XSTS; import hu.bme.mit.theta.xsts.dsl.gen.XstsDslParser.XstsContext; - import java.util.*; import java.util.regex.Pattern; - -import static com.google.common.base.Preconditions.checkNotNull; -import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Eq; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.And; -import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Bool; -import static hu.bme.mit.theta.core.utils.TypeUtils.cast; +import java.util.stream.Collectors; public class XstsSpecification implements DynamicScope { @@ -74,74 +75,127 @@ public XSTS instantiate() { typeDeclContext.literals.forEach(litCtx -> literalNames.add(litCtx.name.getText())); customTypeShortNames.addAll(literalNames); final EnumType enumType = EnumType.of(typeName, literalNames); - literalNames - .stream() + literalNames.stream() .map(litName -> EnumType.makeLongName(enumType, litName)) .map(fullLitName -> XstsCustomLiteralSymbol.of(enumType, fullLitName)) - .forEach(symbol -> { - declare(symbol); - env.define(symbol, symbol.instantiate()); - }); + .forEach( + symbol -> { + declare(symbol); + env.define(symbol, symbol.instantiate()); + }); final XstsCustomTypeSymbol typeDeclSymbol = XstsCustomTypeSymbol.of(enumType); typeTable.add(typeDeclSymbol); env.define(typeDeclSymbol, enumType); } - for (var varDeclContext : context.variableDeclarations) { - final String varName = varDeclContext.name.getText(); - if (tempVarPattern.matcher(varName).matches()) { - throw new ParseException(varDeclContext, - "Variable name '" + varName + "' is reserved!"); - } - if (customTypeShortNames.contains(varName)) - throw new ParseException(varDeclContext, - String.format("Variable name '%s' matches at least one declared enum literal", varName)); - - final XstsVariableSymbol symbol = new XstsVariableSymbol(typeTable, varDeclContext); - declare(symbol); - - final VarDecl var = symbol.instantiate(env); - if (varDeclContext.CTRL() != null) { - ctrlVars.add(var); - } - if (varDeclContext.initValue != null) { - var scope = new BasicDynamicScope(this); - if (var.getType() instanceof EnumType enumType) { - env.push(); - enumType.getValues().forEach(literal -> { - Symbol fullNameSymbol = resolve(EnumType.makeLongName(enumType, literal)).orElseThrow(); - if (fullNameSymbol instanceof XstsCustomLiteralSymbol fNameCustLitSymbol) { - var customSymbol = XstsCustomLiteralSymbol.copyWithName(fNameCustLitSymbol, literal); - scope.declare(customSymbol); - env.define(customSymbol, customSymbol.instantiate()); - } else { - throw new IllegalArgumentException(String.format("%s is not a literal of type %s", literal, enumType.getName())); - } - }); - } - initExprs.add(Eq(var.getRef(), - new XstsExpression(scope, typeTable, varDeclContext.initValue).instantiate( - env))); - if (var.getType() instanceof EnumType) - env.pop(); - } - env.define(symbol, var); - } - - final NonDetStmt tranSet = new XstsTransitionSet(this, typeTable, - context.tran.transitionSet()).instantiate(env); - final NonDetStmt initSet = new XstsTransitionSet(this, typeTable, - context.init.transitionSet()).instantiate(env); - final NonDetStmt envSet = new XstsTransitionSet(this, typeTable, - context.env.transitionSet()).instantiate(env); + final Set> stateVars = + context.variableDeclarations.stream() + .map( + varDeclContext -> { + final String varName = varDeclContext.name.getText(); + if (tempVarPattern.matcher(varName).matches()) { + throw new ParseException( + varDeclContext, + "Variable name '" + varName + "' is reserved!"); + } + if (customTypeShortNames.contains(varName)) + throw new ParseException( + varDeclContext, + String.format( + "Variable name '%s' matches at least one" + + " declared enum literal", + varName)); + + final XstsVariableSymbol symbol = + new XstsVariableSymbol(typeTable, varDeclContext); + declare(symbol); + + final VarDecl var = symbol.instantiate(env); + if (varDeclContext.CTRL() != null) { + ctrlVars.add(var); + } + if (varDeclContext.initValue != null) { + var scope = new BasicDynamicScope(this); + if (var.getType() instanceof EnumType enumType) { + env.push(); + enumType.getValues() + .forEach( + literal -> { + Symbol fullNameSymbol = + resolve( + EnumType + .makeLongName( + enumType, + literal)) + .orElseThrow(); + if (fullNameSymbol + instanceof + XstsCustomLiteralSymbol + fNameCustLitSymbol) { + var customSymbol = + XstsCustomLiteralSymbol + .copyWithName( + fNameCustLitSymbol, + literal); + scope.declare(customSymbol); + env.define( + customSymbol, + customSymbol + .instantiate()); + } else { + throw new IllegalArgumentException( + String.format( + "%s is not a" + + " literal" + + " of type" + + " %s", + literal, + enumType + .getName())); + } + }); + } + initExprs.add( + Eq( + var.getRef(), + new XstsExpression( + scope, + typeTable, + varDeclContext.initValue) + .instantiate(env))); + if (var.getType() instanceof EnumType) env.pop(); + } + env.define(symbol, var); + return var; + }) + .collect(Collectors.toUnmodifiableSet()); + + final NonDetStmt tranSet = + new XstsTransitionSet(this, typeTable, context.tran.transitionSet()) + .instantiate(env); + final NonDetStmt initSet = + new XstsTransitionSet(this, typeTable, context.init.transitionSet()) + .instantiate(env); + final NonDetStmt envSet = + new XstsTransitionSet(this, typeTable, context.env.transitionSet()) + .instantiate(env); final Expr initFormula = ExprUtils.simplify(And(initExprs)); - final Expr prop = cast( - new XstsExpression(this, typeTable, context.prop).instantiate(env), Bool()); + final Expr prop = + cast(new XstsExpression(this, typeTable, context.prop).instantiate(env), Bool()); + + final Set> localVars = Containers.createSet(); + localVars.addAll(StmtUtils.getVars(tranSet)); + localVars.addAll(StmtUtils.getVars(envSet)); + localVars.addAll(StmtUtils.getVars(initSet)); + localVars.addAll(ExprUtils.getVars(initFormula)); + localVars.addAll(ExprUtils.getVars(prop)); + localVars.removeAll(stateVars); - return new XSTS(ctrlVars, initSet, tranSet, envSet, initFormula, prop); + return new XSTS( + stateVars, localVars, ctrlVars, initSet, tranSet, envSet, initFormula, prop); } @Override