Skip to content

Commit

Permalink
Merge PointerArray fallibility and nullability
Browse files Browse the repository at this point in the history
  • Loading branch information
superaxander committed Oct 4, 2024
1 parent 5f32f5b commit dee46c0
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 37 deletions.
8 changes: 3 additions & 5 deletions src/col/vct/col/ast/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1892,11 +1892,9 @@ final case class NewArray[G](
initialize: Boolean,
)(val blame: Blame[ArraySizeError])(implicit val o: Origin)
extends Expr[G] with NewArrayImpl[G]
final case class NewPointerArray[G](
element: Type[G],
size: Expr[G],
fallible: Boolean,
)(val blame: Blame[ArraySizeError])(implicit val o: Origin)
final case class NewPointerArray[G](element: Type[G], size: Expr[G])(
val blame: Blame[ArraySizeError]
)(implicit val o: Origin)
extends Expr[G] with NewPointerArrayImpl[G]
final case class NewNonNullPointerArray[G](element: Type[G], size: Expr[G])(
val blame: Blame[ArraySizeError]
Expand Down
4 changes: 2 additions & 2 deletions src/col/vct/col/typerules/CoercingRewriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1569,8 +1569,8 @@ abstract class CoercingRewriter[Pre <: Generation]()
Neq(coerce(left, sharedType), coerce(right, sharedType))
case na @ NewArray(element, dims, moreDims, initialize) =>
NewArray(element, dims.map(int), moreDims, initialize)(na.blame)
case na @ NewPointerArray(element, size, fallible) =>
NewPointerArray(element, size, fallible)(na.blame)
case na @ NewPointerArray(element, size) =>
NewPointerArray(element, size)(na.blame)
case na @ NewNonNullPointerArray(element, size) =>
NewNonNullPointerArray(element, size)(na.blame)
case NewObject(cls) => NewObject(cls)
Expand Down
28 changes: 15 additions & 13 deletions src/rewrite/vct/rewrite/EncodeArrayValues.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
: mutable.Map[(Type[Pre], Int, Int, Boolean), Procedure[Post]] = mutable
.Map()

val pointerArrayCreationMethods: mutable.Map[(Type[Pre], Boolean), Procedure[Post]] =
val pointerArrayCreationMethods: mutable.Map[Type[Pre], Procedure[Post]] =
mutable.Map()
val nonNullPointerArrayCreationMethods
: mutable.Map[Type[Pre], Procedure[Post]] = mutable.Map()
Expand Down Expand Up @@ -496,10 +496,9 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
def makePointerCreationMethodFor(
elementType: Type[Pre],
nullable: Boolean,
fallible: Boolean,
) = {
implicit val o: Origin = arrayCreationOrigin
// fallible? then 'ar != null ==> ...'; otherwise 'ar != null ** ...'
// !nullable? then 'ar != null ==> ...'; otherwise 'ar != null ** ...'
// ar.length == size
// forall ar[i] :: Perm(ar[i], write)
// (if type ar[i] is pointer or struct):
Expand Down Expand Up @@ -528,7 +527,6 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
(PointerBlockLength(result)(FramedPtrBlockLength) === sizeArg.get) &*
(PointerBlockOffset(result)(FramedPtrBlockOffset) === zero)

if (nullable) { ensures = (result !== Null()) &* ensures }
// Pointer location needs pointer add, not pointer subscript
ensures =
ensures &* makeStruct.makePerm(
Expand Down Expand Up @@ -556,10 +554,8 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
ensures &* foldStar(permFields.map(_._1))

ensures =
if (!fallible)
(result !== Null()) &* ensures
else
Star(Implies(result !== Null(), ensures), tt)
if (nullable) { Star(Implies(result !== Null(), ensures), tt) }
else { ensures }

procedure(
blame = AbstractApplicable,
Expand All @@ -570,7 +566,13 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
args = Seq(sizeArg),
requires = UnitAccountedPredicate(requires),
ensures = UnitAccountedPredicate(ensures),
)(o.where(name = "make_pointer_array_" + elementType.toString + (if (fallible) "_fallible" else "")))
)(o.where(name =
"make_pointer_array_" + elementType.toString +
(if (nullable)
"_nullable"
else
"")
))
}))
}

Expand Down Expand Up @@ -601,10 +603,10 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
Nil,
Nil,
)(ArrayCreationFailed(newArr))
case newPointerArr @ NewPointerArray(element, size, fallible) =>
case newPointerArr @ NewPointerArray(element, size) =>
val method = pointerArrayCreationMethods.getOrElseUpdate(
(element, fallible),
makePointerCreationMethodFor(element, nullable = true, fallible=fallible),
element,
makePointerCreationMethodFor(element, nullable = true),
)
ProcedureInvocation[Post](
method.ref,
Expand All @@ -617,7 +619,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
case newPointerArr @ NewNonNullPointerArray(element, size) =>
val method = nonNullPointerArrayCreationMethods.getOrElseUpdate(
element,
makePointerCreationMethodFor(element, nullable = false, fallible=false),
makePointerCreationMethodFor(element, nullable = false),
)
ProcedureInvocation[Post](
method.ref,
Expand Down
4 changes: 3 additions & 1 deletion src/rewrite/vct/rewrite/TrivialAddrOf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] {
val newPointer = Eval(
PreAssignExpression(
newTarget,
NewPointerArray(newValue.t, const[Post](1), fallible=false)(PanicBlame("Size is > 0")),
NewNonNullPointerArray(newValue.t, const[Post](1))(PanicBlame(
"Size is > 0"
)),
)(blame)
)
(newPointer, newTarget, newValue)
Expand Down
6 changes: 3 additions & 3 deletions src/rewrite/vct/rewrite/VariableToPointer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] {
Deref[Post](dispatch(out), fieldMap.ref(f))(PanicBlame(
"Initialisation should always succeed"
)),
NewPointerArray(
NewNonNullPointerArray(
fieldMap(f).t.asPointer.get.element,
const(1),
)(PanicBlame("Size is > 0")),
Expand All @@ -125,7 +125,7 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] {
Deref[Post](dispatch(out), fieldMap.ref(f))(PanicBlame(
"Initialisation should always succeed"
)),
NewPointerArray(
NewNonNullPointerArray(
fieldMap(f).t.asPointer.get.element,
const(1),
)(PanicBlame("Size is > 0")),
Expand Down Expand Up @@ -178,7 +178,7 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] {
Deref[Post](obj.get, fieldMap.ref(f))(PanicBlame(
"Initialisation should always succeed"
)),
NewPointerArray(
NewNonNullPointerArray(
fieldMap(f).t.asPointer.get.element,
const(1),
)(PanicBlame("Size is > 0")),
Expand Down
11 changes: 8 additions & 3 deletions src/rewrite/vct/rewrite/lang/LangCPPToCol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2740,11 +2740,14 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
(sizeOption, init.init) match {
case (None, None) => throw WrongCPPType(decl)
case (Some(size), None) =>
val newArr = NewPointerArray[Post](t, rw.dispatch(size), fallible=false)(cta.blame)
val newArr =
NewNonNullPointerArray[Post](t, rw.dispatch(size))(cta.blame)
Block(Seq(LocalDecl(v), assignLocal(v.get, newArr)))
case (None, Some(CPPLiteralArray(exprs))) =>
val newArr =
NewPointerArray[Post](t, c_const[Post](exprs.size), fallible=false)(cta.blame)
NewNonNullPointerArray[Post](t, c_const[Post](exprs.size))(
cta.blame
)
Block(
Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++
assignliteralArray(v, exprs, o)
Expand All @@ -2755,7 +2758,9 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
if (realSize < exprs.size)
logger.warn(s"Excess elements in array initializer: '${decl}'")
val newArr =
NewPointerArray[Post](t, c_const[Post](realSize), fallible=false)(cta.blame)
NewNonNullPointerArray[Post](t, c_const[Post](realSize))(
cta.blame
)
Block(
Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++
assignliteralArray(v, exprs.take(realSize.intValue), o)
Expand Down
19 changes: 11 additions & 8 deletions src/rewrite/vct/rewrite/lang/LangCToCol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
(t1, rw.dispatch(r))
case _ => throw UnsupportedMalloc(c)
}
NewPointerArray(rw.dispatch(t1), size, fallible=true)(ArrayMallocFailed(inv))(c.o)
NewPointerArray(rw.dispatch(t1), size)(ArrayMallocFailed(inv))(c.o)
case CCast(CInvocation(CLocal("__vercors_malloc"), _, _, _), _) =>
throw UnsupportedMalloc(c)
case CCast(n @ Null(), t) if t.asPointer.isDefined => rw.dispatch(n)
Expand Down Expand Up @@ -645,10 +645,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
// val decl: Statement[Post] = LocalDecl(cNameSuccessor(d))
val assign: Statement[Post] = assignLocal(
Local(cNameSuccessor(d).ref),
NewPointerArray[Post](
NewNonNullPointerArray[Post](
getInnerType(cNameSuccessor(d).t),
Local(v.ref),
fallible=false,
)(PanicBlame("Shared memory sizes cannot be negative.")),
)
declarations ++= Seq(cNameSuccessor(d))
Expand All @@ -660,10 +659,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
val assign: Statement[Post] = assignLocal(
Local(cNameSuccessor(d).ref),
// Since we set the size and blame together, we can assume the blame is not None
NewPointerArray[Post](
NewNonNullPointerArray[Post](
getInnerType(cNameSuccessor(d).t),
CIntegerValue(size),
fallible=false
)(blame.get),
)
declarations ++= Seq(cNameSuccessor(d))
Expand Down Expand Up @@ -1131,11 +1129,14 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
(sizeOption, init.init) match {
case (None, None) => throw WrongCType(decl)
case (Some(size), None) =>
val newArr = NewPointerArray[Post](t, rw.dispatch(size), fallible=false)(cta.blame)
val newArr =
NewNonNullPointerArray[Post](t, rw.dispatch(size))(cta.blame)
Block(Seq(LocalDecl(v), assignLocal(v.get, newArr)))
case (None, Some(CLiteralArray(exprs))) =>
val newArr =
NewPointerArray[Post](t, c_const[Post](exprs.size), fallible=false)(cta.blame)
NewNonNullPointerArray[Post](t, c_const[Post](exprs.size))(
cta.blame
)
Block(
Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++
assignliteralArray(v, exprs, o)
Expand All @@ -1146,7 +1147,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
if (realSize < exprs.size)
logger.warn(s"Excess elements in array initializer: '${decl}'")
val newArr =
NewPointerArray[Post](t, c_const[Post](realSize), fallible=false)(cta.blame)
NewNonNullPointerArray[Post](t, c_const[Post](realSize))(
cta.blame
)
Block(
Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++
assignliteralArray(v, exprs.take(realSize.intValue), o)
Expand Down
4 changes: 2 additions & 2 deletions src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
Block(Seq(
assignLocal(
v,
NewPointerArray[Post](newT, elements)(PanicBlame(
NewNonNullPointerArray[Post](newT, elements)(PanicBlame(
"allocation should never fail"
)),
),
Expand All @@ -738,7 +738,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre])
case _ =>
assignLocal(
v,
NewPointerArray[Post](newT, elements)(PanicBlame(
NewNonNullPointerArray[Post](newT, elements)(PanicBlame(
"allocation should never fail"
)),
)
Expand Down

0 comments on commit dee46c0

Please sign in to comment.