Skip to content

Commit

Permalink
method invokation refs fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
petravandenbos-utwente committed Jul 4, 2023
1 parent 6c93b76 commit 8cb1d51
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 81 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package vct.rewrite.veymont

import vct.col.ast.{Assign, Deref, InstanceField, JavaClass}
import vct.col.rewrite.{Generation, Rewritten}

class ParallelCommExprBuildingBlocks[Pre <: Generation](val channelField: InstanceField[Rewritten[Pre]], val channelClass: JavaClass[Rewritten[Pre]],
val thisChanField: Deref[Rewritten[Pre]], val assign: Assign[Pre] ) {

}
197 changes: 118 additions & 79 deletions src/rewrite/vct/rewrite/veymont/ParalleliseVeyMontThreads.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package vct.rewrite.veymont
import hre.util.ScopedStack
import vct.col.ast
import vct.col.ast.RewriteHelpers.{RewriteAssign, RewriteDeref, RewriteJavaClass, RewriteMethodInvocation}
import vct.col.ast.{Assert, Assign, Block, BooleanValue, Branch, Class, ClassDeclaration, Declaration, Deref, DerefVeyMontThread, Eval, Expr, InstanceField, InstanceMethod, JavaClass, JavaInvocation, JavaMethod, JavaNamedType, JavaTClass, Local, Loop, MethodInvocation, Node, Program, RunMethod, Scope, Skip, Statement, TClass, ThisObject, ThisSeqProg, Type, VeyMontAssignExpression, VeyMontCommExpression, VeyMontCondition, VeyMontSeqProg, VeyMontThread}
import vct.col.ast.{Assert, Assign, Block, BooleanValue, Branch, Class, ClassDeclaration, Declaration, Deref, DerefVeyMontThread, Eval, Expr, InstanceField, InstanceMethod, JavaClass, JavaInvocation, JavaMethod, JavaNamedType, JavaTClass, Local, Loop, MethodInvocation, Node, Program, RunMethod, Scope, Statement, TClass, ThisObject, ThisSeqProg, Type, VeyMontAssignExpression, VeyMontCommExpression, VeyMontCondition, VeyMontSeqProg, VeyMontThread}
import vct.col.origin.{Origin, PreferredNameOrigin}
import vct.col.ref.Ref
import vct.col.resolve.ctx.RefJavaMethod
Expand Down Expand Up @@ -56,31 +56,42 @@ case class ParalleliseVeyMontThreads[Pre <: Generation](channelClass: JavaClass[
override def dispatch(decl : Declaration[Pre]) : Unit = {
decl match {
case seqProg: VeyMontSeqProg[Pre] =>
val channelInfo = collectChannelsFromRun(seqProg) ++ collectChannelsFromMethods(seqProg)
val indexedChannelInfo : Seq[ChannelInfo[Pre]] = channelInfo.groupBy(_.channelName).values.flatMap(chanInfoSeq =>
if (chanInfoSeq.size <= 1) chanInfoSeq
else chanInfoSeq.zipWithIndex.map{ case (chanInfo,index) => new ChannelInfo(chanInfo.comExpr,chanInfo.channelType,chanInfo.channelName + index) }).toSeq
val channelClasses = generateChannelClasses(indexedChannelInfo)
val channelFields = getChannelFields(indexedChannelInfo, channelClasses)
seqProg.threads.foreach(thread => {
val threadField = new InstanceField[Post](dispatch(thread.threadType), Set.empty)(thread.o)
threadBuildingBlocks.having(new ThreadBuildingBlocks(seqProg.runMethod, seqProg.methods, channelFields, channelClasses, thread, threadField)) {
dispatch(thread)
}
})
val (channelClasses,channelFields) = extractChannelInfo(seqProg)
dispatchThreads(seqProg, channelClasses, channelFields)
case thread: VeyMontThread[Pre] => {
if(threadBuildingBlocks.nonEmpty) {
val threadRes: ThreadBuildingBlocks[Pre] = threadBuildingBlocks.top
val threadMethods: Seq[ClassDeclaration[Post]] = createThreadMethod(thread, threadRes)
createThreadClass(thread, threadRes, threadMethods)
} else rewriteDefault(thread)
dispatchThread(thread)
}
//case m: InstanceMethod[Pre] => getThreadMethod(m)
//case r: RunMethod[Pre] => getThreadRunMethod(r)
case other => rewriteDefault(other)
}
}

private def dispatchThread(thread: VeyMontThread[Pre]): Unit = {
if (threadBuildingBlocks.nonEmpty) {
val threadRes: ThreadBuildingBlocks[Pre] = threadBuildingBlocks.top
val threadMethods: Seq[ClassDeclaration[Post]] = createThreadMethod(thread, threadRes)
createThreadClass(thread, threadRes, threadMethods)
} else rewriteDefault(thread)
}

private def dispatchThreads(seqProg: VeyMontSeqProg[Pre], channelClasses: Map[Type[Pre], JavaClass[Post]], channelFields: Map[(VeyMontCommExpression[Pre], Origin), InstanceField[Post]]): Unit = {
seqProg.threads.foreach(thread => {
val threadField = new InstanceField[Post](dispatch(thread.threadType), Set.empty)(thread.o)
threadBuildingBlocks.having(new ThreadBuildingBlocks(seqProg.runMethod, seqProg.methods, channelFields, channelClasses, thread, threadField)) {
dispatch(thread)
}
})
}

private def extractChannelInfo(seqProg: VeyMontSeqProg[Pre]): (Map[Type[Pre], JavaClass[Post]], Map[(VeyMontCommExpression[Pre], Origin), InstanceField[Post]]) = {
val channelInfo = collectChannelsFromRun(seqProg) ++ collectChannelsFromMethods(seqProg)
val indexedChannelInfo: Seq[ChannelInfo[Pre]] = channelInfo.groupBy(_.channelName).values.flatMap(chanInfoSeq =>
if (chanInfoSeq.size <= 1) chanInfoSeq
else chanInfoSeq.zipWithIndex.map { case (chanInfo, index) => new ChannelInfo(chanInfo.comExpr, chanInfo.channelType, chanInfo.channelName + index) }).toSeq
val channelClasses = generateChannelClasses(indexedChannelInfo)
val channelFields = getChannelFields(indexedChannelInfo, channelClasses)
(channelClasses, channelFields)
}

private def createThreadMethod(thread: VeyMontThread[Pre], threadRes: ThreadBuildingBlocks[Pre]) = {
threadRes.methods.map { preMethod =>
val postMethod = getThreadMethodFromDecl(thread)(preMethod)
Expand All @@ -95,11 +106,8 @@ case class ParalleliseVeyMontThreads[Pre <: Generation](channelClass: JavaClass[
}.values.toSeq
val threadRun = getThreadRunFromDecl(thread, threadRes.runMethod)
classDeclarations.scope {
//classDeclarations.collect {
val threadClass = new Class[Post](
(threadRes.threadField +: channelFieldsForThread) ++ (threadRun +: threadMethods),
Seq(),
BooleanValue(true)(thread.o))(ThreadClassOrigin(thread))
(threadRes.threadField +: channelFieldsForThread) ++ (threadRun +: threadMethods), Seq(), BooleanValue(true)(thread.o))(ThreadClassOrigin(thread))
globalDeclarations.declare(threadClass)
threadClassSucc.update(thread, threadClass)
}
Expand All @@ -126,7 +134,7 @@ case class ParalleliseVeyMontThreads[Pre <: Generation](channelClass: JavaClass[
channelTypes.map(channelType =>
channelType -> {
val chanClassPre = channelClass.asInstanceOf[JavaClass[Pre]]
val rw = new ChannelClassGenerator(channelType)
val rw = ChannelClassGenerator(channelType)
new RewriteJavaClass[Pre, Post](chanClassPre)(rw).rewrite(decls = classDeclarations.collect {
chanClassPre.decls.foreach(d => rw.dispatch(d))
}._1)
Expand All @@ -138,7 +146,6 @@ case class ParalleliseVeyMontThreads[Pre <: Generation](channelClass: JavaClass[
override val allScopes = outer.allScopes

override def dispatch(t: Type[Pre]): Type[Post] = t match {
// case jt: JavaType[Pre] => jt match {
case jnt: JavaNamedType[Pre] =>
if (jnt.names.head._1 == "MessageType") {
dispatch(channelType.asInstanceOf[Type[Pre]])
Expand Down Expand Up @@ -192,80 +199,112 @@ case class ParalleliseVeyMontThreads[Pre <: Generation](channelClass: JavaClass[
if(threadBuildingBlocks.nonEmpty) {
val thread = threadBuildingBlocks.top.thread
node match {
case c: VeyMontCondition[Pre] => c.condition.find{ case (threadRef,_) =>
threadRef.decl == thread
} match {
case Some((_,threadExpr)) => dispatch(threadExpr)
case _ => throw ParalliseVeyMontThreadsError(node, "Condition of if statement or while loop must contain an expression for every thread")
}
case m: MethodInvocation[Pre] => m.obj match {
case threadRef: DerefVeyMontThread[Pre] => m.rewrite(obj = dispatch(threadRef))
case _ => threadMethodSucc.get((thread, m.ref.decl)) match {
case Some(postMethod) => m.rewrite(obj = dispatch(m.obj), ref = postMethod.ref, m.args.map(dispatch))
case None => throw ParalliseVeyMontThreadsError(m, "No successor for this method found")
}
}
case d: Deref[Pre] => d.obj match {
case t: DerefVeyMontThread[Pre] =>
val threadField = threadBuildingBlocks.top.threadField
if(t.ref.decl == thread) {
d.rewrite(
obj = Deref(ThisObject(threadClassSucc.ref[Post, Class[Post]](thread))(thread.o), threadField.ref[InstanceField[Post]])(null)(d.o)
)
}
else rewriteDefault(node)

}
case c: VeyMontCondition[Pre] => paralleliseThreadCondition(node, thread, c)
case m: MethodInvocation[Pre] => updateThreadRefMethodInvoc(thread, m)
case d: Deref[Pre] => updateThreadRefInDeref(node, thread, d)
case t: DerefVeyMontThread[Pre] => updateThreadRefVeyMontDeref(node, thread, t)
case _ => rewriteDefault(node)
}
} else rewriteDefault(node)
}

private def updateThreadRefVeyMontDeref(node: Expr[Pre], thread: VeyMontThread[Pre], t: DerefVeyMontThread[Pre]) = {
if (t.ref.decl == thread) {
getThisVeyMontDeref(thread, t.o, threadBuildingBlocks.top.threadField)
} else rewriteDefault(node)
}

override def dispatch(st : Statement[Pre]) : Statement[Post] = {
private def updateThreadRefInDeref(node: Expr[Pre], thread: VeyMontThread[Pre], d: Deref[Pre]) = {
d.obj match {
case t: DerefVeyMontThread[Pre] =>
if (t.ref.decl == thread) {
d.rewrite(
obj = getThisVeyMontDeref(thread, d.o, threadBuildingBlocks.top.threadField)
)
}
else rewriteDefault(node)
}
}

private def updateThreadRefMethodInvoc(thread: VeyMontThread[Pre], m: MethodInvocation[Pre]) = {
m.obj match {
case threadRef: DerefVeyMontThread[Pre] => m.rewrite(obj = dispatch(threadRef))
case _ => threadMethodSucc.get((thread, m.ref.decl)) match {
case Some(postMethod) => m.rewrite(obj = dispatch(m.obj), ref = postMethod.ref, m.args.map(dispatch))
case None => throw ParalliseVeyMontThreadsError(m, "No successor for this method found")
}
}
}

private def paralleliseThreadCondition(node: Expr[Pre], thread: VeyMontThread[Pre], c: VeyMontCondition[Pre]) = {
c.condition.find { case (threadRef, _) =>
threadRef.decl == thread
} match {
case Some((_, threadExpr)) => dispatch(threadExpr)
case _ => throw ParalliseVeyMontThreadsError(node, "Condition of if statement or while loop must contain an expression for every thread")
}
}

private def getThisVeyMontDeref(thread: VeyMontThread[Pre], o: Origin, threadField: InstanceField[Rewritten[Pre]]) = {
Deref(ThisObject(threadClassSucc.ref[Post, Class[Post]](thread))(thread.o), threadField.ref[InstanceField[Post]])(null)(o)
}

override def dispatch(st : Statement[Pre]) : Statement[Post] = {
if (threadBuildingBlocks.nonEmpty) {
val thread = threadBuildingBlocks.top.thread
st match {
case v@VeyMontCommExpression(recv,sender,chanType,assign) =>
val channelField = threadBuildingBlocks.top.channelFields((v,v.o))
val channelClass = threadBuildingBlocks.top.channelClasses(chanType)
val thisChanField = Deref(ThisObject(threadClassSucc.ref[Post, Class[Post]](thread))(thread.o), channelField.ref[InstanceField[Post]])(null)(assign.o)
val assignment = assign.asInstanceOf[Assign[Pre]]
if (recv.decl == thread) {
val readMethod = findChannelClassMethod(v, channelClass, "readValue")
val assignValue = JavaInvocation(Some(thisChanField),Seq.empty, "readValue",Seq.empty, Seq.empty,Seq.empty)(null)(v.o)
assignValue.ref = Some(RefJavaMethod(readMethod))
Assign(dispatch(assignment.target),assignValue)(null)(v.o)
} else if(sender.decl == thread) {
val writeMethod = findChannelClassMethod(v, channelClass, "writeValue")
val writeInvoc = JavaInvocation(Some(thisChanField),Seq.empty,"writeValue",Seq(dispatch(assignment.value)),Seq.empty,Seq.empty)(null)(v.o)
writeInvoc.ref = Some(RefJavaMethod(writeMethod))
Eval(writeInvoc)(v.o)
}
else Skip()(assign.o)
case v: VeyMontCommExpression[Pre] =>
paralleliseVeyMontCommExpr(thread, v, createParComBlocks(thread, v))
case v@VeyMontAssignExpression(threadRef, assign) =>
if (threadRef.decl == thread)
dispatch(assign)
else Skip()(assign.o)
case a: Assign[Pre] => rewriteDefault(st)
else Block(Seq.empty)(assign.o)
case a: Assign[Pre] => Assign(dispatch(a.target),dispatch(a.value))(a.blame)(a.o)
case Branch(_) => rewriteDefault(st)
case Loop(_, _, _, _, _) => rewriteDefault(st)
case Scope(_, _) => rewriteDefault(st)
case Block(_) => rewriteDefault(st)
case Eval(expr) => expr match {
case m: MethodInvocation[Pre] => m.obj match {
case _: ThisSeqProg[Pre] => Eval(dispatch(expr))(st.o)
case d: DerefVeyMontThread[Pre] => if(d.ref.decl == thread) Eval(dispatch(expr))(st.o) else Skip()(st.o)
case _ => throw ParalliseVeyMontThreadsError(st, "Statement not allowed in seq_program")
}
case _ => rewriteDefault(st)
}
case _: Assert[Pre] => Skip()(st.o)
case Eval(expr) => paralleliseMethodInvocation(st, thread, expr)
case _: Assert[Pre] => Block(Seq.empty)(st.o)
case _ => throw ParalliseVeyMontThreadsError(st, "Statement not allowed in seq_program")
}
} else rewriteDefault(st)
}

private def createParComBlocks(thread: VeyMontThread[Pre], v: VeyMontCommExpression[Pre]): ParallelCommExprBuildingBlocks[Pre] = {
val channelField = threadBuildingBlocks.top.channelFields((v, v.o))
val channelClass = threadBuildingBlocks.top.channelClasses(v.chanType)
val thisChanField = Deref(ThisObject(threadClassSucc.ref[Post, Class[Post]](thread))(thread.o), channelField.ref[InstanceField[Post]])(null)(v.assign.o)
val assignment = v.assign.asInstanceOf[Assign[Pre]]
new ParallelCommExprBuildingBlocks(channelField, channelClass, thisChanField, assignment)
}

private def paralleliseMethodInvocation(st: Statement[Pre], thread: VeyMontThread[Pre], expr: Expr[Pre]): Statement[Post] = {
expr match {
case m: MethodInvocation[Pre] => m.obj match {
case _: ThisSeqProg[Pre] => Eval(m.rewrite(obj = ThisObject(threadClassSucc.ref[Post, Class[Post]](thread))(thread.o), ref = threadMethodSucc.ref((thread, m.ref.decl))))(st.o)
case d: DerefVeyMontThread[Pre] => if (d.ref.decl == thread) Eval(dispatch(expr))(st.o) else Block(Seq.empty)(st.o)
case _ => throw ParalliseVeyMontThreadsError(st, "Statement not allowed in seq_program")
}
case _ => throw ParalliseVeyMontThreadsError(st, "Statement not allowed in seq_program")
}
}

private def paralleliseVeyMontCommExpr(thread: VeyMontThread[Pre], v: VeyMontCommExpression[Pre], blocks: ParallelCommExprBuildingBlocks[Pre]): Statement[Post] = {
if (v.receiver.decl == thread) {
val readMethod = findChannelClassMethod(v, blocks.channelClass, "readValue")
val assignValue = JavaInvocation(Some(blocks.thisChanField), Seq.empty, "readValue", Seq.empty, Seq.empty, Seq.empty)(null)(v.o)
assignValue.ref = Some(RefJavaMethod(readMethod))
Assign(dispatch(blocks.assign.target), assignValue)(null)(v.o)
} else if (v.sender.decl == thread) {
val writeMethod = findChannelClassMethod(v, blocks.channelClass, "writeValue")
val writeInvoc = JavaInvocation(Some(blocks.thisChanField), Seq.empty, "writeValue", Seq(dispatch(blocks.assign.value)), Seq.empty, Seq.empty)(null)(v.o)
writeInvoc.ref = Some(RefJavaMethod(writeMethod))
Eval(writeInvoc)(v.o)
}
else Block(Seq.empty)(blocks.assign.o)
}

private def findChannelClassMethod(v: VeyMontCommExpression[Pre], channelClass: JavaClass[Post], methodName: String): JavaMethod[Post] = {
val method = channelClass.decls.find {
case jm: JavaMethod[Post] => jm.name == methodName
Expand Down
3 changes: 1 addition & 2 deletions src/rewrite/vct/rewrite/veymont/ThreadBuildingBlocks.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package vct.rewrite.veymont

import vct.col.ast.{ClassDeclaration, InstanceField, JavaClass, Statement, Type, VeyMontCommExpression, VeyMontThread}
import vct.col.ast.{ClassDeclaration, InstanceField, JavaClass, Type, VeyMontCommExpression, VeyMontThread}
import vct.col.origin.Origin
import vct.col.ref.Ref
import vct.col.rewrite.{Generation, Rewritten}

class ThreadBuildingBlocks[Pre <: Generation](
Expand Down

0 comments on commit 8cb1d51

Please sign in to comment.