diff --git a/examples/concepts/c/malloc_free.c b/examples/concepts/c/malloc_free.c index b803d639dc..b5d6943d35 100644 --- a/examples/concepts/c/malloc_free.c +++ b/examples/concepts/c/malloc_free.c @@ -13,6 +13,7 @@ struct e{ int main(){ int* xs = (int*) malloc(sizeof(int)*3); + if (xs == NULL) return 1; xs[0] = 3; xs[1] = 2; @@ -20,6 +21,7 @@ int main(){ free(xs); int** xxs = (int * *) malloc(sizeof(int *)*3); + if (xxs == NULL) return 1; int temp[3] = {1,2,3}; xxs[0] = temp; @@ -27,17 +29,23 @@ int main(){ free(xxs); struct d* ys = (struct d*) malloc(3*sizeof(struct d)); + if (ys == NULL) return 1; + ys[0].x = 3; ys[1].x = 2; ys[2].x = 1; free(ys); struct e* a = (struct e*) malloc(1*sizeof(struct e)); + if (a == NULL) return 1; + a->s.x = 1; struct d* b = &(a->s); free(a); float * z = (float *) malloc(sizeof(float)); + if (z == NULL) return 1; + z[0] = 3.0; *z = 2.0; free(z); diff --git a/src/col/vct/col/ast/Node.scala b/src/col/vct/col/ast/Node.scala index 59d8d36153..44359b9d0f 100644 --- a/src/col/vct/col/ast/Node.scala +++ b/src/col/vct/col/ast/Node.scala @@ -1820,9 +1820,11 @@ final case class NewArray[G]( initialize: Boolean, )(val blame: Blame[ArraySizeError])(implicit val o: Origin) extends Expr[G] with NewArrayImpl[G] -final case class NewPointerArray[G](element: Type[G], size: Expr[G])( - val blame: Blame[ArraySizeError] -)(implicit val o: Origin) +final case class NewPointerArray[G]( + element: Type[G], + size: Expr[G], + fallible: Boolean, +)(val blame: Blame[ArraySizeError])(implicit val o: Origin) extends Expr[G] with NewPointerArrayImpl[G] final case class FreePointer[G](pointer: Expr[G])( val blame: Blame[PointerFreeError] diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index 8de7a33d8a..154ffa0981 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -1552,8 +1552,8 @@ abstract class CoercingRewriter[Pre <: Generation]() Neq(coerce(left, sharedType), coerce(right, sharedType)) case na @ NewArray(element, dims, moreDims, initialize) => NewArray(element, dims.map(int), moreDims, initialize)(na.blame) - case na @ NewPointerArray(element, size) => - NewPointerArray(element, size)(na.blame) + case na @ NewPointerArray(element, size, fallible) => + NewPointerArray(element, size, fallible)(na.blame) case NewObject(cls) => NewObject(cls) case NoPerm() => NoPerm() case Not(arg) => Not(bool(arg)) diff --git a/src/rewrite/vct/rewrite/EncodeArrayValues.scala b/src/rewrite/vct/rewrite/EncodeArrayValues.scala index db6fe989f1..48ed5e9cc0 100644 --- a/src/rewrite/vct/rewrite/EncodeArrayValues.scala +++ b/src/rewrite/vct/rewrite/EncodeArrayValues.scala @@ -104,7 +104,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { : mutable.Map[(Type[Pre], Int, Int, Boolean), Procedure[Post]] = mutable .Map() - val pointerArrayCreationMethods: mutable.Map[Type[Pre], Procedure[Post]] = + val pointerArrayCreationMethods: mutable.Map[(Type[Pre], Boolean), Procedure[Post]] = mutable.Map() val freeMethods: mutable.Map[Type[ @@ -503,10 +503,10 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { case _ => false } - def makePointerCreationMethodFor(elementType: Type[Pre]) = { + def makePointerCreationMethodFor(elementType: Type[Pre], fallible: Boolean) = { implicit val o: Origin = arrayCreationOrigin - // ar != null - // ar.length == dim0 + // fallible? then 'ar != null ==> ...'; otherwise 'ar != null ** ...' + // ar.length == size // forall ar[i] :: Perm(ar[i], write) // (if type ar[i] is pointer or struct): // forall i,j :: i!=j ==> ar[i] != ar[j] @@ -528,7 +528,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { Seq(access(i), access(j)), ) - var ensures = (result !== Null()) &* + var ensures = (PointerBlockLength(result)(FramedPtrBlockLength) === sizeArg.get) &* (PointerBlockOffset(result)(FramedPtrBlockOffset) === zero) // Pointer location needs pointer add, not pointer subscript @@ -557,6 +557,12 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { else ensures &* foldStar(permFields.map(_._1)) + ensures = + if (!fallible) + (result !== Null()) &* ensures + else + Star(Implies(result !== Null(), ensures), tt) + procedure( blame = AbstractApplicable, contractBlame = TrueSatisfiable, @@ -564,7 +570,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { args = Seq(sizeArg), requires = UnitAccountedPredicate(requires), ensures = UnitAccountedPredicate(ensures), - )(o.where(name = "make_pointer_array_" + elementType.toString)) + )(o.where(name = "make_pointer_array_" + elementType.toString + (if (fallible) "_fallible" else ""))) })) } @@ -595,9 +601,9 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { Nil, Nil, )(ArrayCreationFailed(newArr)) - case newPointerArr @ NewPointerArray(element, size) => + case newPointerArr @ NewPointerArray(element, size, fallible) => val method = pointerArrayCreationMethods - .getOrElseUpdate(element, makePointerCreationMethodFor(element)) + .getOrElseUpdate((element, fallible), makePointerCreationMethodFor(element, fallible)) ProcedureInvocation[Post]( method.ref, Seq(dispatch(size)), diff --git a/src/rewrite/vct/rewrite/TrivialAddrOf.scala b/src/rewrite/vct/rewrite/TrivialAddrOf.scala index edc400f193..db71dfb1d8 100644 --- a/src/rewrite/vct/rewrite/TrivialAddrOf.scala +++ b/src/rewrite/vct/rewrite/TrivialAddrOf.scala @@ -98,7 +98,7 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] { val newPointer = Eval( PreAssignExpression( newTarget, - NewPointerArray(newValue.t, const[Post](1))(PanicBlame("Size is > 0")), + NewPointerArray(newValue.t, const[Post](1), fallible=false)(PanicBlame("Size is > 0")), )(blame) ) (newPointer, newTarget, newValue) diff --git a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala index 737d301e7c..cf93cc2f94 100644 --- a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala @@ -2737,11 +2737,11 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) (sizeOption, init.init) match { case (None, None) => throw WrongCPPType(decl) case (Some(size), None) => - val newArr = NewPointerArray[Post](t, rw.dispatch(size))(cta.blame) + val newArr = NewPointerArray[Post](t, rw.dispatch(size), fallible=false)(cta.blame) Block(Seq(LocalDecl(v), assignLocal(v.get, newArr))) case (None, Some(CPPLiteralArray(exprs))) => val newArr = - NewPointerArray[Post](t, c_const[Post](exprs.size))(cta.blame) + NewPointerArray[Post](t, c_const[Post](exprs.size), fallible=false)(cta.blame) Block( Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++ assignliteralArray(v, exprs, o) @@ -2752,7 +2752,7 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) if (realSize < exprs.size) logger.warn(s"Excess elements in array initializer: '${decl}'") val newArr = - NewPointerArray[Post](t, c_const[Post](realSize))(cta.blame) + NewPointerArray[Post](t, c_const[Post](realSize), fallible=false)(cta.blame) Block( Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++ assignliteralArray(v, exprs.take(realSize.intValue), o) diff --git a/src/rewrite/vct/rewrite/lang/LangCToCol.scala b/src/rewrite/vct/rewrite/lang/LangCToCol.scala index b187e5f7f4..adcc96b903 100644 --- a/src/rewrite/vct/rewrite/lang/LangCToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCToCol.scala @@ -424,7 +424,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) (t1, rw.dispatch(r)) case _ => throw UnsupportedMalloc(c) } - NewPointerArray(rw.dispatch(t1), size)(ArrayMallocFailed(inv))(c.o) + NewPointerArray(rw.dispatch(t1), size, fallible=true)(ArrayMallocFailed(inv))(c.o) case CCast(CInvocation(CLocal("__vercors_malloc"), _, _, _), _) => throw UnsupportedMalloc(c) case CCast(n @ Null(), t) if t.asPointer.isDefined => rw.dispatch(n) @@ -650,6 +650,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) NewPointerArray[Post]( getInnerType(cNameSuccessor(d).t), Local(v.ref), + fallible=false, )(PanicBlame("Shared memory sizes cannot be negative.")), ) declarations ++= Seq(cNameSuccessor(d)) @@ -664,6 +665,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) NewPointerArray[Post]( getInnerType(cNameSuccessor(d).t), CIntegerValue(size), + fallible=false )(blame.get), ) declarations ++= Seq(cNameSuccessor(d)) @@ -1127,11 +1129,11 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) (sizeOption, init.init) match { case (None, None) => throw WrongCType(decl) case (Some(size), None) => - val newArr = NewPointerArray[Post](t, rw.dispatch(size))(cta.blame) + val newArr = NewPointerArray[Post](t, rw.dispatch(size), fallible=false)(cta.blame) Block(Seq(LocalDecl(v), assignLocal(v.get, newArr))) case (None, Some(CLiteralArray(exprs))) => val newArr = - NewPointerArray[Post](t, c_const[Post](exprs.size))(cta.blame) + NewPointerArray[Post](t, c_const[Post](exprs.size), fallible=false)(cta.blame) Block( Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++ assignliteralArray(v, exprs, o) @@ -1142,7 +1144,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) if (realSize < exprs.size) logger.warn(s"Excess elements in array initializer: '${decl}'") val newArr = - NewPointerArray[Post](t, c_const[Post](realSize))(cta.blame) + NewPointerArray[Post](t, c_const[Post](realSize), fallible=false)(cta.blame) Block( Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++ assignliteralArray(v, exprs.take(realSize.intValue), o) diff --git a/test/main/vct/test/integration/examples/CSpec.scala b/test/main/vct/test/integration/examples/CSpec.scala index b550566831..c97c8be7d5 100644 --- a/test/main/vct/test/integration/examples/CSpec.scala +++ b/test/main/vct/test/integration/examples/CSpec.scala @@ -33,6 +33,8 @@ class CSpec extends VercorsSpec { int main(){ struct e* a = (struct e*) malloc(1*sizeof(struct e)); + if (a == NULL) return 1; + a->s.x = 1; struct d* b = &(a->s); free(a); @@ -40,6 +42,16 @@ class CSpec extends VercorsSpec { } """ + vercors should fail withCode "ptrNull" using silicon in "use malloc result without null check" c + """ + #include + int main(){ + int* xs = (int*) malloc(1*sizeof(int)); + *xs = 12; + free(xs); + } + """ + vercors should verify using silicon in "free null pointer" c """ #include @@ -402,6 +414,7 @@ class CSpec extends VercorsSpec { struct nested *np = NULL; np = (struct nested*) NULL; np = (struct nested*) malloc(sizeof(struct nested)); + if (np == NULL) return; np->inner = NULL; np->inner = (struct nested*) NULL; }