Skip to content

Commit

Permalink
Implemented basic kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
Ellen Wittingen committed Sep 15, 2023
1 parent 277ed00 commit 65537a5
Show file tree
Hide file tree
Showing 18 changed files with 295 additions and 202 deletions.
13 changes: 3 additions & 10 deletions examples/concepts/cpp/Namespaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
int x;

namespace spaceA {
int x;

//@ context Perm(x, write);
//@ ensures x == 90;
Expand All @@ -28,21 +27,15 @@ namespace spaceA {
}
}

//@ context Perm(spaceA::x, write);
//@ context Perm(x, write);
int main() {
x = 99;
spaceA::x = 5;
//@ assert spaceA::x == 5;
int varA = spaceA::incr();
//@ assert varA == 6;
//@ assert spaceA::x == 90;
//@ assert varA == 100;
//@ assert x == 90;
int varB = spaceA::spaceB::incr();
//@ assert varB == 92;
spaceA::spaceB::doNothing();
int varX = spaceA::x;
//@ assert varX == 90;

//@ assert x == 99;
//@ assert x == 90;
return 0;
}
16 changes: 12 additions & 4 deletions examples/concepts/sycl/MethodResolving.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@

void test2() {
sycl::queue myQueue;
// sycl::handler myHandler;

sycl::event myEvent = myQueue.submit([&](sycl::handler& cgh) {
cgh.parallel_for(sycl::range<3>(3,3,3), // global range
/*@ */
// sycl::range<3> myRange;

sycl::event myEvent = myQueue.submit(
/*@ requires true; */
[&](sycl::handler& cgh) {
cgh.parallel_for(sycl::range<3>(6,4,2), // global range
/*@ requires it.get_id(0) > -1; */
[=] (sycl::item<3> it) {
//[kernel code]
int a = it.get_id(1) + 3;
int a = it.get_id(0) + it.get_id(1) + it.get_id(2) + 3;
int b = a + 45;
int c = it.get_range(0);
int d = it.get_linear_id();
});
});
int b = 20;
Expand Down
16 changes: 16 additions & 0 deletions res/universal/res/cpp/sycl/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ namespace sycl {

namespace item {
/*@ pure @*/ int get_id(int dimension);
/*@ pure @*/ int get_linear_id();
/*@ pure @*/ int get_range(int dimension);
}

namespace nd_item {
/*@ pure @*/ int get_global_id(int dimension);
/*@ pure @*/ int get_global_linear_id();
/*@ pure @*/ int get_global_range(int dimension);

/*@ pure @*/ int get_local_id(int dimension);
/*@ pure @*/ int get_local_linear_id();
/*@ pure @*/ int get_local_range(int dimension);

/*@ pure @*/ int get_group_id(int dimension);
/*@ pure @*/ int get_group_linear_id();
/*@ pure @*/ int get_group_range(int dimension);
}

}
9 changes: 4 additions & 5 deletions src/col/vct/col/ast/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,6 @@ final class CPPLocalDeclaration[G](val decl: CPPDeclaration[G])(implicit val o:
final class CPPFunctionDefinition[G](val contract: ApplicableContract[G], val specs: Seq[CPPDeclarationSpecifier[G]], val declarator: CPPDeclarator[G], val body: Statement[G])(val blame: Blame[CallableFailure])(implicit val o: Origin) extends GlobalDeclaration[G] with CPPFunctionDefinitionImpl[G] {
var ref: Option[RefCPPGlobalDeclaration[G]] = None
}
final class CPPNamespaceDefinition[G](val name: String, val declarations: Seq[GlobalDeclaration[G]])(implicit val o: Origin) extends GlobalDeclaration[G] with CPPNamespaceDefinitionImpl[G]

sealed trait CPPStatement[G] extends Statement[G] with CPPStatementImpl[G]
final case class CPPDeclarationStatement[G](decl: CPPLocalDeclaration[G])(implicit val o: Origin) extends CPPStatement[G] with CPPDeclarationStatementImpl[G]
Expand Down Expand Up @@ -1041,10 +1040,10 @@ sealed trait SYCLClassObject[G] extends CPPExpr[G]
final case class SYCLEvent[G](kernel: GlobalDeclaration[G])(implicit val o: Origin) extends SYCLClassObject[G] with SYCLEventImpl[G]
final case class SYCLQueue[G](kernels: Seq[GlobalDeclaration[G]])(implicit val o: Origin) extends SYCLClassObject[G] with SYCLQueueImpl[G]
final case class SYCLHandler[G](kernel: GlobalDeclaration[G])(implicit val o: Origin) extends SYCLClassObject[G] with SYCLHandlerImpl[G]
final case class SYCLItem[G](dimCount: Int, dimensions: Seq[Int])(implicit val o: Origin) extends SYCLClassObject[G] with SYCLItemImpl[G]
final case class SYCLNDItem[G](dimCount: Int, dimensions: Seq[Expr[G]])(implicit val o: Origin) extends SYCLClassObject[G] with SYCLNDItemImpl[G]
final case class SYCLRange[G](dimCount: Int, dimensions: Seq[Int])(implicit val o: Origin) extends SYCLClassObject[G] with SYCLRangeImpl[G]
final case class SYCLNDRange[G](dimCount: Int, dimensions: Seq[Expr[G]])(implicit val o: Origin) extends SYCLClassObject[G] with SYCLNDRangeImpl[G]
final case class SYCLItem[G](dimensions: Seq[Int])(implicit val o: Origin) extends SYCLClassObject[G] with SYCLItemImpl[G]
final case class SYCLNDItem[G](dimensions: Seq[Expr[G]])(implicit val o: Origin) extends SYCLClassObject[G] with SYCLNDItemImpl[G]
final case class SYCLRange[G](dimensions: Seq[Int])(implicit val o: Origin) extends SYCLClassObject[G] with SYCLRangeImpl[G]
final case class SYCLNDRange[G](dimensions: Seq[Expr[G]])(implicit val o: Origin) extends SYCLClassObject[G] with SYCLNDRangeImpl[G]

sealed trait SYCLKernelDefinition[G] extends GlobalDeclaration[G]
final case class SYCLBasicKernelDefinition[G](dimensions: Expr[G], body: GlobalDeclaration[G])(implicit val o: Origin) extends SYCLKernelDefinition[G] with SYCLBasicKernelDefinitionImpl[G]
Expand Down
13 changes: 0 additions & 13 deletions src/col/vct/col/ast/lang/CPPNamespaceDefinitionImpl.scala

This file was deleted.

4 changes: 2 additions & 2 deletions src/col/vct/col/ast/lang/SYCLItemImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import vct.col.ast.{SYCLItem, SYCLTItem, Type}
import vct.col.print.{Ctx, Doc, Group, Text}

trait SYCLItemImpl[G] { this: SYCLItem[G] =>
override def t: Type[G] = SYCLTItem(dimCount)
override def t: Type[G] = SYCLTItem(dimensions.size)

override def layout(implicit ctx: Ctx): Doc =
Group(Text("sycl::item") <> "<" <> Text(dimCount.toString) <> ">" <>
Group(Text("sycl::item") <> "<" <> Text(dimensions.size.toString) <> ">" <>
"(" <> Text(dimensions.mkString(",")) <> ")")
}
4 changes: 2 additions & 2 deletions src/col/vct/col/ast/lang/SYCLNDItemImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import vct.col.ast.{SYCLNDItem, SYCLTNDItem, Type}
import vct.col.print.{Ctx, Doc, Group, Text}

trait SYCLNDItemImpl[G] { this: SYCLNDItem[G] =>
override def t: Type[G] = SYCLTNDItem(dimCount)
override def t: Type[G] = SYCLTNDItem(dimensions.size)

override def layout(implicit ctx: Ctx): Doc =
Group(Text("sycl::nd_item") <> "<" <> Text(dimCount.toString) <> ">" <>
Group(Text("sycl::nd_item") <> "<" <> Text(dimensions.size.toString) <> ">" <>
"(" <> Text(dimensions.map(x => x.show).mkString(",")) <> ")")
}
4 changes: 2 additions & 2 deletions src/col/vct/col/ast/lang/SYCLNDRangeImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import vct.col.ast.{SYCLNDRange, SYCLTNDRange, Type}
import vct.col.print.{Ctx, Doc, Group, Text}

trait SYCLNDRangeImpl[G] { this: SYCLNDRange[G] =>
override def t: Type[G] = SYCLTNDRange(dimCount)
override def t: Type[G] = SYCLTNDRange(dimensions.size)

override def layout(implicit ctx: Ctx): Doc =
Group(Text("sycl::nd_range") <> "<" <> Text(dimCount.toString) <> ">" <>
Group(Text("sycl::nd_range") <> "<" <> Text(dimensions.size.toString) <> ">" <>
"(" <> Text(dimensions.map(x => x.show).mkString(",")) <> ")")
}
4 changes: 2 additions & 2 deletions src/col/vct/col/ast/lang/SYCLRangeImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import vct.col.ast.{SYCLRange, SYCLTRange, Type}
import vct.col.print.{Ctx, Doc, Group, Text}

trait SYCLRangeImpl[G] { this: SYCLRange[G] =>
override def t: Type[G] = SYCLTRange(dimCount)
override def t: Type[G] = SYCLTRange(dimensions.size)

override def layout(implicit ctx: Ctx): Doc =
Group(Text("sycl::range") <> "<" <> Text(dimCount.toString) <> ">" <>
Group(Text("sycl::range") <> "<" <> Text(dimensions.size.toString) <> ">" <>
"(" <> Text(dimensions.mkString(",")) <> ")")
}
1 change: 0 additions & 1 deletion src/col/vct/col/feature/FeatureRainbow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ class FeatureRainbow[G] {
case node: CPPLocalDeclaration[G] => CPPSpecific
case node: CPPLong[G] => CPPSpecific
case node: CPPName[G] => CPPSpecific
case node: CPPNamespaceDefinition[G] => CPPSpecific
case node: CPPParam[G] => CPPSpecific
case node: CPPPrimitiveType[G] => CPPSpecific
case node: CPPPure[G] => CPPSpecific
Expand Down
1 change: 0 additions & 1 deletion src/col/vct/col/resolve/Resolve.scala
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ case object ResolveReferences extends LazyLogging {
ctx
.copy(currentResult = Some(RefCPPLambdaDefinition(func)))
.declare(CPP.paramsFromDeclarator(func.declarator) ++ scanLabels(func.body) ++ func.contract.givenArgs ++ func.contract.yieldsArgs)
case ns: CPPNamespaceDefinition[G] => ctx.declare(ns.declarations)
case func: CPPGlobalDeclaration[G] =>
if (func.decl.contract.nonEmpty && func.decl.inits.size > 1) {
throw MultipleForwardDeclarationContractError(func)
Expand Down
3 changes: 0 additions & 3 deletions src/col/vct/col/resolve/ctx/Referrable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ sealed trait Referrable[G] {
case RefCPPTranslationUnit(_) => ""
case RefCPPParam(decl) => CPP.nameFromDeclarator(decl.declarator)
case RefCPPFunctionDefinition(decl) => CPP.nameFromDeclarator(decl.declarator)
case RefCPPNamespaceDefinition(decl) => decl.name
case RefCPPGlobalDeclaration(decls, initIdx) => CPP.nameFromDeclarator(decls.decl.inits(initIdx).decl)
case RefCPPLocalDeclaration(decls, initIdx) => CPP.nameFromDeclarator(decls.decl.inits(initIdx).decl)
case RefJavaNamespace(_) => ""
Expand Down Expand Up @@ -93,7 +92,6 @@ case object Referrable {
case decl: CPPTranslationUnit[G] => RefCPPTranslationUnit(decl)
case decl: CPPParam[G] => RefCPPParam(decl)
case decl: CPPFunctionDefinition[G] => RefCPPFunctionDefinition(decl)
case decl: CPPNamespaceDefinition[G] => RefCPPNamespaceDefinition(decl)
case decl: CPPGlobalDeclaration[G] => return decl.decl.inits.indices.map(RefCPPGlobalDeclaration(decl, _))
case decl: JavaNamespace[G] => RefJavaNamespace(decl)
case decl: JavaClass[G] => RefJavaClass(decl)
Expand Down Expand Up @@ -212,7 +210,6 @@ case class RefCPPParam[G](decl: CPPParam[G]) extends Referrable[G] with CPPNameT
case class RefCPPFunctionDefinition[G](decl: CPPFunctionDefinition[G]) extends Referrable[G] with CPPNameTarget[G] with CPPInvocationTarget[G] with ResultTarget[G]
case class RefCPPLambdaDefinition[G](decl: CPPLambdaDefinition[G]) extends Referrable[G] with CPPInvocationTarget[G] with ResultTarget[G] with CPPTypeNameTarget[G] with CPPDerefTarget[G]
case class RefCPPLambda[G](decl: CPPLambdaRef[G]) extends Referrable[G] with CPPTypeNameTarget[G] with CPPDerefTarget[G]
case class RefCPPNamespaceDefinition[G](decl: CPPNamespaceDefinition[G]) extends Referrable[G]
case class RefCPPGlobalDeclaration[G](decls: CPPGlobalDeclaration[G], initIdx: Int) extends Referrable[G] with CPPNameTarget[G] with CPPInvocationTarget[G] with ResultTarget[G]
case class RefCPPLocalDeclaration[G](decls: CPPLocalDeclaration[G], initIdx: Int) extends Referrable[G] with CPPNameTarget[G]
case class RefJavaNamespace[G](decl: JavaNamespace[G]) extends Referrable[G]
Expand Down
55 changes: 9 additions & 46 deletions src/col/vct/col/resolve/lang/CPP.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ case object CPP {
}
}

def replacePotentialSYCLClassInstance[G](name: String, ctx: ReferenceResolutionContext[G]): String = {
def replacePotentialClassmemberName[G](name: String, ctx: ReferenceResolutionContext[G]): String = {
if (name.contains('.') && name.count(x => x == '.') == 1) {
// Class method, replace with SYCL equivalent
val classVarName = name.split('.').head
Expand All @@ -145,56 +145,19 @@ case object CPP {
}

def findCPPName[G](name: String, genericArg: Option[Int], ctx: ReferenceResolutionContext[G]): Seq[CPPNameTarget[G]] = {
val targetName: String = replacePotentialSYCLClassInstance(name, ctx)
val targetName: String = replacePotentialClassmemberName(name, ctx)

var nameSeq = targetName.split("::")
if (nameSeq.length == 1) {
ctx.stack.flatten.collect {
case target: CPPNameTarget[G] if target.name == targetName => target
}
} else {
val ctxTarget: Option[RefCPPNamespaceDefinition[G]] = ctx.stack.flatten.collectFirst {
case namespace: RefCPPNamespaceDefinition[G] if namespace.name == nameSeq.head => namespace
}

ctxTarget match {
case Some(ref) =>
nameSeq = nameSeq.drop(1);
var foundNamespace: Option[CPPNamespaceDefinition[G]] = Some(ref.decl)
var returnVal: Seq[CPPNameTarget[G]] = Seq()
while (nameSeq.nonEmpty) {
if (foundNamespace.isEmpty) {
return Seq()
}
var targets = ctx.stack.flatten.collect {
case target: CPPNameTarget[G] if target.name == targetName => target
}

if (nameSeq.length > 1) {
// Look for nested namespaces
foundNamespace = foundNamespace.get.declarations.collectFirst {
case namespace: CPPNamespaceDefinition[G] if namespace.name == nameSeq.head => namespace
}
} else {
// Look for final nameTarget
returnVal = findDeclInNamespace(nameSeq.head, foundNamespace.get)
if (returnVal.isEmpty) {
returnVal = foundNamespace.get.declarations.collectFirst {
case namespace: CPPNamespaceDefinition[G] if namespace.name == nameSeq.head =>
findDeclInNamespace("constructor", namespace)
}.getOrElse(Seq())
}
}
nameSeq = nameSeq.drop(1)
}
returnVal
case None => Seq()
}
if (targets.isEmpty && !name.endsWith("::constructor")) {
// Not a known method, so search for constructor
targets = findCPPName(name + "::constructor", genericArg, ctx)
}
targets
}

def findDeclInNamespace[G](name: String, namespace: CPPNamespaceDefinition[G]): Seq[CPPNameTarget[G]] =
namespace.declarations.collect {
case funcDef: CPPFunctionDefinition[G] if getDeclaratorInfo(funcDef.declarator).name == name => RefCPPFunctionDefinition(funcDef)
case globalDecl: CPPGlobalDeclaration[G] if getDeclaratorInfo(globalDecl.decl.inits.head.decl).name == name => RefCPPGlobalDeclaration(globalDecl, 0)
}

def findForwardDeclaration[G](declarator: CPPDeclarator[G], ctx: ReferenceResolutionContext[G]): Option[RefCPPGlobalDeclaration[G]] =
ctx.stack.flatten.collectFirst {
Expand Down
2 changes: 0 additions & 2 deletions src/col/vct/col/typerules/CoercingRewriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1665,8 +1665,6 @@ abstract class CoercingRewriter[Pre <: Generation]() extends AbstractRewriter[Pr
declaration
case definition: CPPFunctionDefinition[Pre] =>
definition
case namespace: CPPNamespaceDefinition[Pre] =>
namespace
case declaration: CPPGlobalDeclaration[Pre] =>
declaration
case namespace: JavaNamespace[Pre] =>
Expand Down
2 changes: 1 addition & 1 deletion src/parsers/antlr4/LangCPPParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ nestedNameSpecifier:
| nestedNameSpecifier Template? simpleTemplateId Doublecolon;

lambdaExpression:
lambdaIntroducer lambdaDeclarator? compoundStatement;
valEmbedContract? lambdaIntroducer lambdaDeclarator? compoundStatement;

lambdaIntroducer: LeftBracket lambdaCapture? RightBracket;

Expand Down
Loading

0 comments on commit 65537a5

Please sign in to comment.