Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

Commit

Permalink
API change: out-of-bounds vec accesses now invalid, not first element (
Browse files Browse the repository at this point in the history
…#685)

[skip formal checks]
Generate nicer name for remove accesses
  • Loading branch information
azidar authored and jackkoenig committed Dec 23, 2017
1 parent 19abcb0 commit 69cf8f1
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .run_formal_checks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ DUT=$1
if [ $TRAVIS_PULL_REQUEST = "false" ]; then
echo "Not a pull request, no formal check"
exit 0
elif [[ $TRAVIS_COMMIT_MESSAGE == *"[skip formal checks]"* ]]; then
elif git log --format=%B --no-merges $TRAVIS_BRANCH..HEAD | grep '\[skip formal checks\]'; then
echo "Commit message says to skip formal checks"
exit 0
else
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/firrtl/LoweringCompilers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform {
passes.ReplaceAccesses,
passes.ExpandConnects,
passes.RemoveAccesses,
passes.Uniquify,
passes.ExpandWhens,
passes.CheckInitialization,
passes.ResolveKinds,
Expand Down
3 changes: 3 additions & 0 deletions src/main/scala/firrtl/Namespace.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class Namespace private {
}
}

/* TODO(azidar): Make Namespace return unique names that will not conflict with expanded
* names after LowerTypes expands names (like the Uniquify pass).
*/
object Namespace {
// Initializes a namespace from a Module
def apply(m: DefModule): Namespace = {
Expand Down
28 changes: 27 additions & 1 deletion src/main/scala/firrtl/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class FIRRTLException(val str: String) extends Exception(str)
object Utils extends LazyLogging {
def throwInternalError =
error("Internal Error! Please file an issue at https://github.com/ucb-bar/firrtl/issues")

private[firrtl] def time[R](block: => R): (Double, R) = {
val t0 = System.nanoTime()
val result = block
Expand All @@ -139,7 +140,7 @@ object Utils extends LazyLogging {
}

/** Removes all [[firrtl.ir.EmptyStmt]] statements and condenses
* [[firrtl.ir.Block]] statements.
* [[firrtl.ir.Block]] statements.
*/
def squashEmpty(s: Statement): Statement = s map squashEmpty match {
case Block(stmts) =>
Expand All @@ -152,6 +153,31 @@ object Utils extends LazyLogging {
case sx => sx
}

/** Provide a nice name to create a temporary **/
def niceName(e: Expression): String = niceName(1)(e)
def niceName(depth: Int)(e: Expression): String = {
e match {
case WRef(name, _, _, _) if name(0) == '_' => name
case WRef(name, _, _, _) => "_" + name
case WSubAccess(expr, index, _, _) if depth <= 0 => niceName(depth)(expr)
case WSubAccess(expr, index, _, _) => niceName(depth)(expr) + niceName(depth - 1)(index)
case WSubField(expr, field, _, _) => niceName(depth)(expr) + "_" + field
case WSubIndex(expr, index, _, _) => niceName(depth)(expr) + "_" + index
case Reference(name, _) if name(0) == '_' => name
case Reference(name, _) => "_" + name
case SubAccess(expr, index, _) if depth <= 0 => niceName(depth)(expr)
case SubAccess(expr, index, _) => niceName(depth)(expr) + niceName(depth - 1)(index)
case SubField(expr, field, _) => niceName(depth)(expr) + "_" + field
case SubIndex(expr, index, _) => niceName(depth)(expr) + "_" + index
case DoPrim(op, args, consts, _) if depth <= 0 => "_" + op
case DoPrim(op, args, consts, _) => "_" + op + (args.map(niceName(depth - 1)) ++ consts.map("_" + _)).mkString("")
case Mux(cond, tval, fval, _) if depth <= 0 => "_mux"
case Mux(cond, tval, fval, _) => "_mux" + Seq(cond, tval, fval).map(niceName(depth - 1)).mkString("")
case UIntLiteral(value, _) => "_" + value
case SIntLiteral(value, _) => "_" + value
}
}

/** Indent the results of [[ir.FirrtlNode.serialize]] */
def indent(str: String) = str replaceAllLiterally ("\n", "\n ")

Expand Down
11 changes: 7 additions & 4 deletions src/main/scala/firrtl/passes/RemoveAccesses.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ import scala.collection.mutable
*/
object RemoveAccesses extends Pass {
private def AND(e1: Expression, e2: Expression) =
DoPrim(And, Seq(e1, e2), Nil, BoolType)
if(e1 == one) e2
else if(e2 == one) e1
else DoPrim(And, Seq(e1, e2), Nil, BoolType)

private def EQV(e1: Expression, e2: Expression): Expression =
DoPrim(Eq, Seq(e1, e2), Nil, e1.tpe)
DoPrim(Eq, Seq(e1, e2), Nil, BoolType)

/** Container for a base expression and its corresponding guard
*/
Expand Down Expand Up @@ -82,7 +84,7 @@ object RemoveAccesses extends Pass {
val namespace = Namespace(m)
def onStmt(s: Statement): Statement = {
def create_temp(e: Expression): (Statement, Expression) = {
val n = namespace.newTemp
val n = namespace.newName(niceName(e))
(DefWire(get_info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e)))
}

Expand All @@ -101,7 +103,8 @@ object RemoveAccesses extends Pass {
stmts += wire
rs.zipWithIndex foreach {
case (x, i) if i < temps.size =>
stmts += Connect(get_info(s),getTemp(i),x.base)
stmts += IsInvalid(get_info(s),getTemp(i))
stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt)
case (x, i) =>
stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt)
}
Expand Down
92 changes: 76 additions & 16 deletions src/test/scala/firrtlTests/UnitTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,27 @@ import java.io._
import org.scalatest._
import org.scalatest.prop._
import firrtl._
import firrtl.ir.Circuit
import firrtl.ir._
import firrtl.passes._
import firrtl.transforms._
import firrtl.Parser.IgnoreInfo
import FirrtlCheckers._

class UnitTests extends FirrtlFlatSpec {
private def executeTest(input: String, expected: Seq[String], transforms: Seq[Transform]) = {
val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm)) {
(c: CircuitState, t: Transform) => t.runTransform(c)
}.circuit

val lines = c.serialize.split("\n") map normalized
val lines = execute(input, transforms).circuit.serialize.split("\n") map normalized

expected foreach { e =>
lines should contain(e)
}
}

def execute(input: String, transforms: Seq[Transform]): CircuitState = {
val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm)) {
(c: CircuitState, t: Transform) => t.runTransform(c)
}.circuit
CircuitState(c, UnknownForm, None, None)
}

"Pull muxes" should "not be exponential in runtime" in {
val passes = Seq(
ToWorkingIR,
Expand Down Expand Up @@ -192,6 +195,7 @@ class UnitTests extends FirrtlFlatSpec {
val check = Seq("c <= mux(pred, a, pad(b, 32))")
executeTest(input, check, passes)
}

"Indexes into sub-accesses" should "be dealt with" in {
val passes = Seq(
ToWorkingIR,
Expand All @@ -218,16 +222,16 @@ class UnitTests extends FirrtlFlatSpec {
//TODO(azidar): I realize this is brittle, but unfortunately there
// isn't a better way to test this pass
val check = Seq(
"""wire _GEN_0 : { a : UInt<8>}""",
"""_GEN_0.a <= table[0].a""",
"""wire _table_1 : { a : UInt<8>}""",
"""_table_1.a is invalid""",
"""when UInt<1>("h1") :""",
"""_GEN_0.a <= table[1].a""",
"""wire _GEN_1 : UInt<8>""",
"""when eq(UInt<1>("h0"), _GEN_0.a) :""",
"""otherTable[0].a <= _GEN_1""",
"""when eq(UInt<1>("h1"), _GEN_0.a) :""",
"""otherTable[1].a <= _GEN_1""",
"""_GEN_1 <= UInt<1>("h0")"""
"""_table_1.a <= table[1].a""",
"""wire _otherTable_table_1_a_a : UInt<8>""",
"""when eq(UInt<1>("h0"), _table_1.a) :""",
"""otherTable[0].a <= _otherTable_table_1_a_a""",
"""when eq(UInt<1>("h1"), _table_1.a) :""",
"""otherTable[1].a <= _otherTable_table_1_a_a""",
"""_otherTable_table_1_a_a <= UInt<1>("h0")"""
)
executeTest(input, check, passes)
}
Expand Down Expand Up @@ -376,4 +380,60 @@ class UnitTests extends FirrtlFlatSpec {
lines should contain(e)
}
}


"Out of bound accesses" should "be invalid" in {
val passes = Seq(
ToWorkingIR,
ResolveKinds,
InferTypes,
ResolveGenders,
InferWidths,
PullMuxes,
ExpandConnects,
RemoveAccesses,
ResolveGenders,
new ConstantPropagation
)
val input =
"""circuit Top :
| module Top :
| input index: UInt<2>
| output out: UInt<16>
| wire array: UInt<16>[3]
| out <= array[index]""".stripMargin

val result = execute(input, passes)

def u(value: Int) = UIntLiteral(BigInt(value), IntWidth(scala.math.max(BigInt(value).bitLength, 1)))

val ut16 = UIntType(IntWidth(BigInt(16)))
val ut2 = UIntType(IntWidth(BigInt(2)))
val ut1 = UIntType(IntWidth(BigInt(1)))

val mgen = WRef("_array_index", ut16, WireKind, MALE)
val fgen = WRef("_array_index", ut16, WireKind, FEMALE)
val index = WRef("index", ut2, PortKind, MALE)
val out = WRef("out", ut16, PortKind, FEMALE)

def eq(e1: Expression, e2: Expression): Expression = DoPrim(PrimOps.Eq, Seq(e1, e2), Nil, ut1)
def array(v: Int): Expression = WSubIndex(WRef("array", VectorType(ut16, 3), WireKind, MALE), v, ut16, MALE)

result should containTree { case DefWire(_, "_array_index", `ut16`) => true }
result should containTree { case IsInvalid(_, `fgen`) => true }

val eq0 = eq(u(0), index)
val array0 = array(0)
result should containTree { case Conditionally(_, `eq0`, Connect(_, `fgen`, `array0`), EmptyStmt) => true }

val eq1 = eq(u(1), index)
val array1 = array(1)
result should containTree { case Conditionally(_, `eq1`, Connect(_, `fgen`, `array1`), EmptyStmt) => true }

val eq2 = eq(u(2), index)
val array2 = array(2)
result should containTree { case Conditionally(_, `eq2`, Connect(_, `fgen`, `array2`), EmptyStmt) => true }

result should containTree { case Connect(_, `out`, mgen) => true }
}
}

0 comments on commit 69cf8f1

Please sign in to comment.