Skip to content

Commit

Permalink
A reference to the successor of this declaration was made, but it has…
Browse files Browse the repository at this point in the history
… no successor.
  • Loading branch information
petravandenbos-utwente committed Jul 17, 2023
1 parent f36c2ba commit 5866439
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 17 deletions.
3 changes: 2 additions & 1 deletion src/col/vct/col/ast/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ final case class TAnyClass[G]()(implicit val o: Origin = DiagnosticOrigin) exten
final case class TAxiomatic[G](adt: Ref[G, AxiomaticDataType[G]], args: Seq[Type[G]])(implicit val o: Origin = DiagnosticOrigin) extends DeclaredType[G] with TAxiomaticImpl[G]
final case class TEnum[G](enum: Ref[G, Enum[G]])(implicit val o: Origin = DiagnosticOrigin) extends DeclaredType[G]
final case class TProverType[G](ref: Ref[G, ProverType[G]])(implicit val o: Origin = DiagnosticOrigin) extends DeclaredType[G] with TProverTypeImpl[G]
final case class TVeyMontChannel[G](channelType: String)(implicit val o: Origin = DiagnosticOrigin) extends DeclaredType[G] with TVeyMontChannelImpl[G]

sealed trait ParRegion[G] extends NodeFamily[G] with ParRegionImpl[G]
final case class ParParallel[G](regions: Seq[ParRegion[G]])(val blame: Blame[ParPreconditionFailed])(implicit val o: Origin) extends ParRegion[G] with ParParallelImpl[G]
Expand Down Expand Up @@ -994,7 +995,7 @@ sealed trait JavaClassDeclaration[G] extends ClassDeclaration[G] with JavaClassD
final class JavaSharedInitialization[G](val isStatic: Boolean, val initialization: Statement[G])(implicit val o: Origin) extends JavaClassDeclaration[G] with JavaSharedInitializationImpl[G]
final class JavaFields[G](val modifiers: Seq[JavaModifier[G]], val t: Type[G], val decls: Seq[JavaVariableDeclaration[G]])(implicit val o: Origin) extends JavaClassDeclaration[G] with JavaFieldsImpl[G]
final class JavaConstructor[G](val modifiers: Seq[JavaModifier[G]], val name: String, val parameters: Seq[JavaParam[G]], val typeParameters: Seq[Variable[G]], val signals: Seq[Type[G]], val body: Statement[G], val contract: ApplicableContract[G])(val blame: Blame[ConstructorFailure])(implicit val o: Origin) extends JavaClassDeclaration[G] with JavaConstructorImpl[G]
final class JavaParam[G](val modifiers: Seq[JavaModifier[G]], val name: String, val t: Type[G])(implicit val o: Origin) extends Declaration[G]
final class JavaParam[G](val modifiers: Seq[JavaModifier[G]], val name: String, val t: Type[G])(implicit val o: Origin) extends Declaration[G] with JavaParamImpl[G]

final class JavaMethod[G](val modifiers: Seq[JavaModifier[G]], val returnType: Type[G], val dims: Int, val name: String, val parameters: Seq[JavaParam[G]], val typeParameters: Seq[Variable[G]], val signals: Seq[Type[G]], val body: Option[Statement[G]], val contract: ApplicableContract[G])(val blame: Blame[CallableFailure])(implicit val o: Origin) extends JavaClassDeclaration[G] with JavaMethodImpl[G]
final class JavaAnnotationMethod[G](val returnType: Type[G], val name: String, val default: Option[Expr[G]])(implicit val o: Origin) extends JavaClassDeclaration[G] with JavaAnnotationMethodImpl[G]
Expand Down
2 changes: 1 addition & 1 deletion src/col/vct/col/ast/lang/JavaFieldsImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ trait JavaFieldsImpl[G] { this: JavaFields[G] =>
override def isStatic = modifiers.contains(JavaStatic[G]())

override def layout(implicit ctx: Ctx): Doc =
Doc.lspread(modifiers) <> t <+> Doc.spread(decls) <> ";"
Doc.rspread(modifiers) <> t <+> Doc.spread(decls) <> ";"
}
2 changes: 1 addition & 1 deletion src/col/vct/col/ast/lang/JavaLocalDeclarationImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ import vct.col.print.{Ctx, Doc, Group}

trait JavaLocalDeclarationImpl[G] { this: JavaLocalDeclaration[G] =>
override def layout(implicit ctx: Ctx): Doc =
Group(Doc.lspread(modifiers) <> t <+> Doc.args(decls))
Group(Doc.rspread(modifiers) <> t <+> Doc.args(decls))
}
2 changes: 1 addition & 1 deletion src/col/vct/col/ast/lang/JavaMethodImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ trait JavaMethodImpl[G] extends Declarator[G] { this: JavaMethod[G] =>
Doc.stack(Seq(
contract,
Group(Group(Group(
Doc.lspread(modifiers) <>
Doc.rspread(modifiers) <>
(if(typeParameters.isEmpty) Empty else Text("<") <> Doc.args(typeParameters) <> ">" <+> Empty) <>
returnType <+> name <> "[]".repeat(dims)) <> "(" <> Doc.args(parameters) <> ")") <>
(if(signals.isEmpty) Empty else Empty <>> Group(Text("throws") <+> Doc.args(signals)))
Expand Down
9 changes: 9 additions & 0 deletions src/col/vct/col/ast/lang/JavaParamImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package vct.col.ast.lang

import vct.col.ast.JavaParam
import vct.col.print.{Ctx, Doc, Text}

trait JavaParamImpl[G] { this: JavaParam[G] =>
override def layout(implicit ctx: Ctx): Doc = Text(t + " " + name)

}
2 changes: 2 additions & 0 deletions src/col/vct/col/ast/type/TUnionImpl.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package vct.col.ast.`type`

import vct.col.ast.TUnion
import vct.col.print.{Ctx, Doc}

trait TUnionImpl[G] { this: TUnion[G] =>
override def layout(implicit ctx: Ctx): Doc = Doc.spread(types)

}
9 changes: 9 additions & 0 deletions src/col/vct/col/ast/type/TVeyMontChannelImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package vct.col.ast.`type`

import vct.col.ast.TVeyMontChannel
import vct.col.print.{Ctx, Doc, Text}

trait TVeyMontChannelImpl[G] { this: TVeyMontChannel[G] =>
override def layout(implicit ctx: Ctx): Doc =
Text(this.channelType)
}
127 changes: 114 additions & 13 deletions src/rewrite/vct/rewrite/veymont/ParalleliseVeyMontThreads.scala
Original file line number Diff line number Diff line change
@@ -1,28 +1,41 @@
package vct.rewrite.veymont

import hre.util.ScopedStack
import vct.col.ast.RewriteHelpers.{RewriteDeref, RewriteJavaClass, RewriteJavaConstructor, RewriteMethodInvocation}
import vct.col.ast.{Assert, Assign, Block, BooleanValue, Branch, Class, ClassDeclaration, Declaration, Deref, DerefVeyMontThread, Eval, Expr, InstanceField, InstanceMethod, JavaClass, JavaConstructor, JavaInvocation, JavaMethod, JavaNamedType, JavaTClass, Local, Loop, MethodInvocation, Node, Program, RunMethod, Scope, Statement, TClass, ThisObject, ThisSeqProg, Type, VeyMontAssignExpression, VeyMontCommExpression, VeyMontCondition, VeyMontSeqProg, VeyMontThread}
import vct.col.ast.RewriteHelpers.{RewriteClass, RewriteDeref, RewriteJavaClass, RewriteJavaConstructor, RewriteMethodInvocation}
import vct.col.ast.{ApplicableContract, Assert, Assign, Block, BooleanValue, Branch, Class, ClassDeclaration, Declaration, Deref, DerefVeyMontThread, Eval, Expr, InstanceField, InstanceMethod, JavaClass, JavaConstructor, JavaInvocation, JavaLocal, JavaMethod, JavaNamedType, JavaParam, JavaPublic, JavaTClass, Local, Loop, MethodInvocation, NewObject, Node, Procedure, Program, RunMethod, Scope, Statement, TClass, TVeyMontChannel, ThisObject, ThisSeqProg, Type, UnitAccountedPredicate, Variable, VeyMontAssignExpression, VeyMontCommExpression, VeyMontCondition, VeyMontSeqProg, VeyMontThread}
import vct.col.origin.Origin
import vct.col.resolve.ctx.RefJavaMethod
import vct.col.rewrite.{Generation, Rewriter, RewriterBuilderArg, Rewritten}
import vct.col.util.SuccessionMap
import vct.result.VerificationError.UserError
import vct.rewrite.veymont.ParalleliseVeyMontThreads.{ChannelFieldOrigin, ParalliseVeyMontThreadsError, ThreadClassOrigin}
import vct.rewrite.veymont.ParalleliseVeyMontThreads.{ChannelFieldOrigin, ParalliseVeyMontThreadsError, ThreadClassOrigin, getChannelClassName, getThreadClassName, getVarName}

import java.lang

object ParalleliseVeyMontThreads extends RewriterBuilderArg[JavaClass[_]] {
override def key: String = "ParalleliseVeyMontThreads"

override def desc: String = "Generate classes for VeyMont threads in parallel program"

private val channelClassName = "Channel"
private val threadClassName = "Thread"

def getChannelClassName(channelType: Type[_]): String =
channelType.toString.capitalize + channelClassName

def getThreadClassName(thread: VeyMontThread[_]) : String =
thread.o.preferredName.capitalize + threadClassName

def getVarName(v: Variable[_]) = v.o.preferredName

case class ParalliseVeyMontThreadsError(node : Node[_], msg: String) extends UserError {
override def code: String = "ParalleliseVeyMontThreadsError"

override def text: String = node.o.messageInContext(msg)
}

case class ThreadClassOrigin(thread: VeyMontThread[_]) extends Origin {
override def preferredName: String = thread.o.preferredName.toUpperCase() + "Thread"
override def preferredName: String = getThreadClassName(thread)

override def context: String = thread.o.context

Expand All @@ -47,11 +60,13 @@ case class ParalleliseVeyMontThreads[Pre <: Generation](channelClass: JavaClass[
private val threadBuildingBlocks: ScopedStack[ThreadBuildingBlocks[Pre]] = ScopedStack()
private val threadClassSucc: SuccessionMap[VeyMontThread[Pre],Class[Post]] = SuccessionMap()
private val threadMethodSucc: SuccessionMap[(VeyMontThread[Pre],ClassDeclaration[Pre]),InstanceMethod[Post]] = SuccessionMap()
private val channelClassSucc: SuccessionMap[Type[Pre],JavaClass[Post]] = SuccessionMap()
private val channelClassName = "Channel"
private val givenClassSucc: SuccessionMap[Type[Pre],Class[Post]] = SuccessionMap()
private val givenClassConstrSucc: SuccessionMap[Type[Pre],Procedure[Pre]] = SuccessionMap()

override def dispatch(decl : Declaration[Pre]) : Unit = {
decl match {
case p: Procedure[Pre] => givenClassConstrSucc.update(p.returnType,p)
case c : Class[Pre] => globalDeclarations.declare(dispatchGivenClass(c))
case seqProg: VeyMontSeqProg[Pre] => dispatchThreads(seqProg)
case thread: VeyMontThread[Pre] => dispatchThread(thread)
case other => rewriteDefault(other)
Expand All @@ -70,17 +85,75 @@ case class ParalleliseVeyMontThreads[Pre <: Generation](channelClass: JavaClass[
val (channelClasses,indexedChannelInfo) = extractChannelInfo(seqProg)
channelClasses.foreach{ case (t,c) =>
globalDeclarations.declare(c)
channelClassSucc.update(t,c)
}
seqProg.threads.foreach(thread => {
val threadField = new InstanceField[Post](dispatch(thread.threadType), Set.empty)(thread.o)
val threadField = new InstanceField[Post](TClass(givenClassSucc.ref(thread.threadType)), Set.empty)(thread.o)
val channelFields = getChannelFields(thread, indexedChannelInfo, channelClasses)
threadBuildingBlocks.having(new ThreadBuildingBlocks(seqProg.runMethod, seqProg.methods, channelFields, channelClasses, thread, threadField)) {
dispatch(thread)
}
})
}

private def dispatchGivenClass(c: Class[Pre]): Class[Post] = {
val rw = GivenClassRewriter()
val gc = new RewriteClass[Pre, Post](c)(rw).rewrite(
declarations = classDeclarations.collect {
(givenClassConstrSucc.get(TClass(c.ref)).get +: c.declarations).foreach(d => rw.dispatch(d))
}._1)
givenClassSucc.update(TClass(c.ref),gc)
gc
}

case class GivenClassRewriter() extends Rewriter[Pre] {
override val allScopes = outer.allScopes

val rewritingConstr: ScopedStack[(Seq[Variable[Pre]],TClass[Pre])] = ScopedStack()

override def dispatch(decl: Declaration[Pre]): Unit = decl match {
case p: Procedure[Pre] => p.returnType match {
case tc: TClass[Pre] => rewritingConstr.having(p.args,tc){ classDeclarations.declare(createClassConstructor(p)) };
case _ => ??? //("This procedure is expected to have a class as return type");
}
case other => rewriteDefault(other)
}

def createClassConstructor(p: Procedure[Pre]): JavaConstructor[Post] =
new JavaConstructor[Post](Seq(JavaPublic[Post]()(p.o)),
rewritingConstr.top._2.cls.decl.o.preferredName,
p.args.map(createJavaParam),
variables.dispatch(p.typeArgs),
Seq.empty, dispatch(p.body.get),
dispatch(p.contract))(null)(p.o)

def createJavaParam(v: Variable[Pre]): JavaParam[Post] =
new JavaParam[Post](Seq.empty, getVarName(v), dispatch(v.t))(v.o)

override def dispatch(e: Expr[Pre]): Expr[Post] = e match {
case l: Local[Pre] =>
if(rewritingConstr.nonEmpty && rewritingConstr.top._1.contains(l.ref.decl))
JavaLocal[Post](getVarName(l.ref.decl))(null)(e.o)
else rewriteDefault(l)
case no: NewObject[Pre] =>
val newClassType = TClass(no.cls)
if(rewritingConstr.nonEmpty && rewritingConstr.top._2 == newClassType)
NewObject(givenClassSucc.ref[Post,Class[Post]](newClassType))(no.o)
else rewriteDefault(no)
case t: ThisObject[Pre] =>
val thisClassType = TClass(t.cls)
if(rewritingConstr.nonEmpty && rewritingConstr.top._2 == thisClassType)
ThisObject(givenClassSucc.ref[Post,Class[Post]](thisClassType))(t.o)
else rewriteDefault(t)
case other => rewriteDefault(other)
}

override def dispatch(t: Type[Pre]): Type[Post] = {
if(rewritingConstr.nonEmpty && rewritingConstr.top._2 == t)
TClass(givenClassSucc.ref(t))
else rewriteDefault(t)
}
}

private def extractChannelInfo(seqProg: VeyMontSeqProg[Pre]): (Map[Type[Pre], JavaClass[Post]], Seq[ChannelInfo[Pre]]) = {
val channelInfo = collectChannelsFromRun(seqProg) ++ collectChannelsFromMethods(seqProg)
val indexedChannelInfo: Seq[ChannelInfo[Pre]] = channelInfo.groupBy(_.channelName).values.flatMap(chanInfoSeq =>
Expand All @@ -99,15 +172,46 @@ case class ParalleliseVeyMontThreads[Pre <: Generation](channelClass: JavaClass[
}

private def createThreadClass(thread: VeyMontThread[Pre], threadRes: ThreadBuildingBlocks[Pre], threadMethods: Seq[ClassDeclaration[Post]]): Unit = {
val threadConstr = createThreadClassConstructor(thread,threadRes.threadField)
val threadRun = getThreadRunFromDecl(thread, threadRes.runMethod)
classDeclarations.scope {
val threadClass = new Class[Post](
(threadRes.threadField +: threadRes.channelFields.values.toSeq) ++ (threadRun +: threadMethods), Seq(), BooleanValue(true)(thread.o))(ThreadClassOrigin(thread))
(threadRes.threadField +: threadRes.channelFields.values.toSeq) ++ (threadConstr +: threadRun +: threadMethods), Seq(), BooleanValue(true)(thread.o))(ThreadClassOrigin(thread))
globalDeclarations.declare(threadClass)
threadClassSucc.update(thread, threadClass)
}
}

private def createThreadClassConstructor(thread: VeyMontThread[Pre], threadField: InstanceField[Post]): JavaConstructor[Post] = {
val threadConstrArgBlocks = thread.args.map{
case l: Local[Pre] => (l.ref.decl.o.preferredName,dispatch(l.t))
case other => throw ParalliseVeyMontThreadsError(other,"This node is expected to be an argument of seq_prog, and have type Local")
}
val threadConstrArgs: Seq[JavaParam[Post]] =
threadConstrArgBlocks.map{ case (a,t) => new JavaParam[Post](Seq.empty, a, t)(ThreadClassOrigin(thread)) }
val passedArgs = threadConstrArgs.map(a => JavaLocal[Post](a.name)(null)(ThreadClassOrigin(thread)))
val ThreadTypeName = thread.threadType match {
case tc: TClass[Pre] => tc.cls.decl.o.preferredName
case _ => throw ParalliseVeyMontThreadsError(thread.threadType,"This type is expected to be a class")
}
val threadConstrBody = {
Assign(getThisVeyMontDeref(thread,ThreadClassOrigin(thread),threadField),
JavaInvocation[Post](None, Seq.empty, "new " + ThreadTypeName, passedArgs, Seq.empty, Seq.empty)(null)(ThreadClassOrigin(thread)))(null)(ThreadClassOrigin(thread))
}
val threadConstrContract = new ApplicableContract[Post](
UnitAccountedPredicate[Post](BooleanValue(true)(ThreadClassOrigin(thread)))(ThreadClassOrigin(thread)),
UnitAccountedPredicate[Post](BooleanValue(true)(ThreadClassOrigin(thread)))(ThreadClassOrigin(thread)),
BooleanValue(true)(ThreadClassOrigin(thread)),
Seq.empty, Seq.empty, Seq.empty, None)(null)(ThreadClassOrigin(thread))
new JavaConstructor[Post](
Seq(JavaPublic[Post]()(ThreadClassOrigin(thread))),
getThreadClassName(thread),
threadConstrArgs,
Seq.empty, Seq.empty,
threadConstrBody,
threadConstrContract)(ThreadClassOrigin(thread))(ThreadClassOrigin(thread))
}

private def getThreadMethodFromDecl(thread: VeyMontThread[Pre])(decl: ClassDeclaration[Pre]): InstanceMethod[Post] = decl match {
case m: InstanceMethod[Pre] => getThreadMethod(m)
case _ => throw ParalliseVeyMontThreadsError(thread, "Methods of seq_program need to be of type InstanceMethod")
Expand All @@ -123,14 +227,11 @@ case class ParalleliseVeyMontThreads[Pre <: Generation](channelClass: JavaClass[
.filter( chanInfo => chanInfo.comExpr.receiver.decl == thread || chanInfo.comExpr.sender.decl == thread)
.map { chanInfo =>
val chanFieldOrigin = ChannelFieldOrigin(chanInfo.channelName,chanInfo.comExpr.assign)
val chanField = new InstanceField[Post](JavaTClass(channelClassSucc.ref(chanInfo.channelType),Seq.empty), Set.empty)(chanFieldOrigin)
val chanField = new InstanceField[Post](TVeyMontChannel(getChannelClassName(chanInfo.channelType)), Set.empty)(chanFieldOrigin)
((chanInfo.comExpr, chanInfo.comExpr.o), chanField)
}.toMap
}

private def getChannelClassName(channelType: Type[_]): String =
channelType.toString.capitalize + channelClassName

private def generateChannelClasses(channelInfo: Seq[ChannelInfo[Pre]]) : Map[Type[Pre],JavaClass[Post]] = {
val channelTypes = channelInfo.map(_.channelType).toSet
channelTypes.map(channelType =>
Expand Down

0 comments on commit 5866439

Please sign in to comment.