From 69cf8f1a084e477284d8ae48a4a7abc499e92991 Mon Sep 17 00:00:00 2001 From: Adam Izraelevitz Date: Fri, 22 Dec 2017 19:34:39 -0500 Subject: [PATCH] API change: out-of-bounds vec accesses now invalid, not first element (#685) [skip formal checks] Generate nicer name for remove accesses --- .run_formal_checks.sh | 2 +- src/main/scala/firrtl/LoweringCompilers.scala | 1 + src/main/scala/firrtl/Namespace.scala | 3 + src/main/scala/firrtl/Utils.scala | 28 +++++- .../scala/firrtl/passes/RemoveAccesses.scala | 11 ++- src/test/scala/firrtlTests/UnitTests.scala | 92 +++++++++++++++---- 6 files changed, 115 insertions(+), 22 deletions(-) diff --git a/.run_formal_checks.sh b/.run_formal_checks.sh index ace9e4b057..942b3e1a66 100755 --- a/.run_formal_checks.sh +++ b/.run_formal_checks.sh @@ -12,7 +12,7 @@ DUT=$1 if [ $TRAVIS_PULL_REQUEST = "false" ]; then echo "Not a pull request, no formal check" exit 0 -elif [[ $TRAVIS_COMMIT_MESSAGE == *"[skip formal checks]"* ]]; then +elif git log --format=%B --no-merges $TRAVIS_BRANCH..HEAD | grep '\[skip formal checks\]'; then echo "Commit message says to skip formal checks" exit 0 else diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index f032868a32..57aa1533dd 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -59,6 +59,7 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform { passes.ReplaceAccesses, passes.ExpandConnects, passes.RemoveAccesses, + passes.Uniquify, passes.ExpandWhens, passes.CheckInitialization, passes.ResolveKinds, diff --git a/src/main/scala/firrtl/Namespace.scala b/src/main/scala/firrtl/Namespace.scala index cf8472a6b0..9e1cae67e1 100644 --- a/src/main/scala/firrtl/Namespace.scala +++ b/src/main/scala/firrtl/Namespace.scala @@ -40,6 +40,9 @@ class Namespace private { } } +/* TODO(azidar): Make Namespace return unique names that will not conflict with expanded + * names after LowerTypes expands names (like the Uniquify pass). + */ object Namespace { // Initializes a namespace from a Module def apply(m: DefModule): Namespace = { diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index c965964455..76f58f3005 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -130,6 +130,7 @@ class FIRRTLException(val str: String) extends Exception(str) object Utils extends LazyLogging { def throwInternalError = error("Internal Error! Please file an issue at https://github.com/ucb-bar/firrtl/issues") + private[firrtl] def time[R](block: => R): (Double, R) = { val t0 = System.nanoTime() val result = block @@ -139,7 +140,7 @@ object Utils extends LazyLogging { } /** Removes all [[firrtl.ir.EmptyStmt]] statements and condenses - * [[firrtl.ir.Block]] statements. + * [[firrtl.ir.Block]] statements. */ def squashEmpty(s: Statement): Statement = s map squashEmpty match { case Block(stmts) => @@ -152,6 +153,31 @@ object Utils extends LazyLogging { case sx => sx } + /** Provide a nice name to create a temporary **/ + def niceName(e: Expression): String = niceName(1)(e) + def niceName(depth: Int)(e: Expression): String = { + e match { + case WRef(name, _, _, _) if name(0) == '_' => name + case WRef(name, _, _, _) => "_" + name + case WSubAccess(expr, index, _, _) if depth <= 0 => niceName(depth)(expr) + case WSubAccess(expr, index, _, _) => niceName(depth)(expr) + niceName(depth - 1)(index) + case WSubField(expr, field, _, _) => niceName(depth)(expr) + "_" + field + case WSubIndex(expr, index, _, _) => niceName(depth)(expr) + "_" + index + case Reference(name, _) if name(0) == '_' => name + case Reference(name, _) => "_" + name + case SubAccess(expr, index, _) if depth <= 0 => niceName(depth)(expr) + case SubAccess(expr, index, _) => niceName(depth)(expr) + niceName(depth - 1)(index) + case SubField(expr, field, _) => niceName(depth)(expr) + "_" + field + case SubIndex(expr, index, _) => niceName(depth)(expr) + "_" + index + case DoPrim(op, args, consts, _) if depth <= 0 => "_" + op + case DoPrim(op, args, consts, _) => "_" + op + (args.map(niceName(depth - 1)) ++ consts.map("_" + _)).mkString("") + case Mux(cond, tval, fval, _) if depth <= 0 => "_mux" + case Mux(cond, tval, fval, _) => "_mux" + Seq(cond, tval, fval).map(niceName(depth - 1)).mkString("") + case UIntLiteral(value, _) => "_" + value + case SIntLiteral(value, _) => "_" + value + } + } + /** Indent the results of [[ir.FirrtlNode.serialize]] */ def indent(str: String) = str replaceAllLiterally ("\n", "\n ") diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index 253e6a15d4..37b92a9e76 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -15,10 +15,12 @@ import scala.collection.mutable */ object RemoveAccesses extends Pass { private def AND(e1: Expression, e2: Expression) = - DoPrim(And, Seq(e1, e2), Nil, BoolType) + if(e1 == one) e2 + else if(e2 == one) e1 + else DoPrim(And, Seq(e1, e2), Nil, BoolType) private def EQV(e1: Expression, e2: Expression): Expression = - DoPrim(Eq, Seq(e1, e2), Nil, e1.tpe) + DoPrim(Eq, Seq(e1, e2), Nil, BoolType) /** Container for a base expression and its corresponding guard */ @@ -82,7 +84,7 @@ object RemoveAccesses extends Pass { val namespace = Namespace(m) def onStmt(s: Statement): Statement = { def create_temp(e: Expression): (Statement, Expression) = { - val n = namespace.newTemp + val n = namespace.newName(niceName(e)) (DefWire(get_info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e))) } @@ -101,7 +103,8 @@ object RemoveAccesses extends Pass { stmts += wire rs.zipWithIndex foreach { case (x, i) if i < temps.size => - stmts += Connect(get_info(s),getTemp(i),x.base) + stmts += IsInvalid(get_info(s),getTemp(i)) + stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt) case (x, i) => stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt) } diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index c51798196c..447998298e 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -6,24 +6,27 @@ import java.io._ import org.scalatest._ import org.scalatest.prop._ import firrtl._ -import firrtl.ir.Circuit +import firrtl.ir._ import firrtl.passes._ import firrtl.transforms._ -import firrtl.Parser.IgnoreInfo +import FirrtlCheckers._ class UnitTests extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], transforms: Seq[Transform]) = { - val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, t: Transform) => t.runTransform(c) - }.circuit - - val lines = c.serialize.split("\n") map normalized + val lines = execute(input, transforms).circuit.serialize.split("\n") map normalized expected foreach { e => lines should contain(e) } } + def execute(input: String, transforms: Seq[Transform]): CircuitState = { + val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, t: Transform) => t.runTransform(c) + }.circuit + CircuitState(c, UnknownForm, None, None) + } + "Pull muxes" should "not be exponential in runtime" in { val passes = Seq( ToWorkingIR, @@ -192,6 +195,7 @@ class UnitTests extends FirrtlFlatSpec { val check = Seq("c <= mux(pred, a, pad(b, 32))") executeTest(input, check, passes) } + "Indexes into sub-accesses" should "be dealt with" in { val passes = Seq( ToWorkingIR, @@ -218,16 +222,16 @@ class UnitTests extends FirrtlFlatSpec { //TODO(azidar): I realize this is brittle, but unfortunately there // isn't a better way to test this pass val check = Seq( - """wire _GEN_0 : { a : UInt<8>}""", - """_GEN_0.a <= table[0].a""", + """wire _table_1 : { a : UInt<8>}""", + """_table_1.a is invalid""", """when UInt<1>("h1") :""", - """_GEN_0.a <= table[1].a""", - """wire _GEN_1 : UInt<8>""", - """when eq(UInt<1>("h0"), _GEN_0.a) :""", - """otherTable[0].a <= _GEN_1""", - """when eq(UInt<1>("h1"), _GEN_0.a) :""", - """otherTable[1].a <= _GEN_1""", - """_GEN_1 <= UInt<1>("h0")""" + """_table_1.a <= table[1].a""", + """wire _otherTable_table_1_a_a : UInt<8>""", + """when eq(UInt<1>("h0"), _table_1.a) :""", + """otherTable[0].a <= _otherTable_table_1_a_a""", + """when eq(UInt<1>("h1"), _table_1.a) :""", + """otherTable[1].a <= _otherTable_table_1_a_a""", + """_otherTable_table_1_a_a <= UInt<1>("h0")""" ) executeTest(input, check, passes) } @@ -376,4 +380,60 @@ class UnitTests extends FirrtlFlatSpec { lines should contain(e) } } + + + "Out of bound accesses" should "be invalid" in { + val passes = Seq( + ToWorkingIR, + ResolveKinds, + InferTypes, + ResolveGenders, + InferWidths, + PullMuxes, + ExpandConnects, + RemoveAccesses, + ResolveGenders, + new ConstantPropagation + ) + val input = + """circuit Top : + | module Top : + | input index: UInt<2> + | output out: UInt<16> + | wire array: UInt<16>[3] + | out <= array[index]""".stripMargin + + val result = execute(input, passes) + + def u(value: Int) = UIntLiteral(BigInt(value), IntWidth(scala.math.max(BigInt(value).bitLength, 1))) + + val ut16 = UIntType(IntWidth(BigInt(16))) + val ut2 = UIntType(IntWidth(BigInt(2))) + val ut1 = UIntType(IntWidth(BigInt(1))) + + val mgen = WRef("_array_index", ut16, WireKind, MALE) + val fgen = WRef("_array_index", ut16, WireKind, FEMALE) + val index = WRef("index", ut2, PortKind, MALE) + val out = WRef("out", ut16, PortKind, FEMALE) + + def eq(e1: Expression, e2: Expression): Expression = DoPrim(PrimOps.Eq, Seq(e1, e2), Nil, ut1) + def array(v: Int): Expression = WSubIndex(WRef("array", VectorType(ut16, 3), WireKind, MALE), v, ut16, MALE) + + result should containTree { case DefWire(_, "_array_index", `ut16`) => true } + result should containTree { case IsInvalid(_, `fgen`) => true } + + val eq0 = eq(u(0), index) + val array0 = array(0) + result should containTree { case Conditionally(_, `eq0`, Connect(_, `fgen`, `array0`), EmptyStmt) => true } + + val eq1 = eq(u(1), index) + val array1 = array(1) + result should containTree { case Conditionally(_, `eq1`, Connect(_, `fgen`, `array1`), EmptyStmt) => true } + + val eq2 = eq(u(2), index) + val array2 = array(2) + result should containTree { case Conditionally(_, `eq2`, Connect(_, `fgen`, `array2`), EmptyStmt) => true } + + result should containTree { case Connect(_, `out`, mgen) => true } + } }