Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

Commit

Permalink
Fix dynamic SubAccess of zero-length vectors (#1450)
Browse files Browse the repository at this point in the history
* Fix dynamic SubAccess of zero-length vectors

* Fixes #230
* Add new ZeroLengthVecs pass that occurs before RemoveAccesses
* Include this in stage.Forms.MidForm
* Add to High->Mid order in compiler test based on @seldridge feedback

* Use validif to produce out-of-bounds value in ZeroLengthVecs

* Update scaladoc

* Fix test imports
  • Loading branch information
albert-magyar authored Apr 7, 2020
1 parent a9034ba commit 1a03e63
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/main/scala/firrtl/passes/RemoveAccesses.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ object RemoveAccesses extends Pass {

override val prerequisites =
Seq( Dependency(PullMuxes),
Dependency(ZeroLengthVecs),
Dependency(ReplaceAccesses),
Dependency(ExpandConnects) ) ++ firrtl.stage.Forms.Deduped

Expand Down
69 changes: 69 additions & 0 deletions src/main/scala/firrtl/passes/ZeroLengthVecs.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// See LICENSE for license details.

package firrtl.passes

import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.PrimOps._
import firrtl.options.{Dependency, PreservesAll}

/** Handles dynamic accesses to zero-length vectors.
*
* @note Removes assignments that use a zero-length vector as a sink
* @note Removes signals resulting from accesses to a zero-length vector from attach groups
* @note Removes attaches that become degenerate after zero-length-accessor removal
* @note Replaces "source" references to elements of zero-length vectors with always-invalid validif
*/
object ZeroLengthVecs extends Pass with PreservesAll[Transform] {
override val prerequisites =
Seq( Dependency(PullMuxes),
Dependency(ResolveKinds),
Dependency(InferTypes),
Dependency(ExpandConnects) )

// Pass in an expression, not just a type, since it's not possible to generate an expression of
// interval type with the type alone unless you declare a component
private def replaceWithDontCare(toReplace: Expression): Expression = {
val default = toReplace.tpe match {
case UIntType(w) => UIntLiteral(0, w)
case SIntType(w) => SIntLiteral(0, w)
case FixedType(w, p) => FixedLiteral(0, w, p)
case it: IntervalType =>
val zeroType = IntervalType(Closed(0), Closed(0), IntWidth(0))
val zeroLit = DoPrim(AsInterval, Seq(SIntLiteral(0)), Seq(0, 0, 0), zeroType)
DoPrim(Clip, Seq(zeroLit, toReplace), Nil, it)
}
ValidIf(UIntLiteral(0), default, toReplace.tpe)
}

private def zeroLenDerivedRefLike(expr: Expression): Boolean = (expr, expr.tpe) match {
case (_, VectorType(_, 0)) => true
case (WSubIndex(e, _, _, _), _) => zeroLenDerivedRefLike(e)
case (WSubAccess(e, _, _, _), _) => zeroLenDerivedRefLike(e)
case (WSubField(e, _, _, _), _) => zeroLenDerivedRefLike(e)
case _ => false
}

// The connects have all been lowered, so all aggregate-typed expressions are "grounded" by WSubField/WSubAccess/WSubIndex
// Map before matching because we want don't-cares to propagate UP expression trees
private def dropZeroLenSubAccesses(expr: Expression): Expression = expr match {
case _: WSubIndex | _: WSubAccess | _: WSubField =>
if (zeroLenDerivedRefLike(expr)) replaceWithDontCare(expr) else expr
case e => e map dropZeroLenSubAccesses
}

// Attach semantics: drop all zero-length-derived members of attach group, drop stmt if trivial
private def onStmt(stmt: Statement): Statement = stmt match {
case Connect(_, sink, _) if zeroLenDerivedRefLike(sink) => EmptyStmt
case IsInvalid(_, sink) if zeroLenDerivedRefLike(sink) => EmptyStmt
case Attach(info, sinks) =>
val filtered = Attach(info, sinks.filterNot(zeroLenDerivedRefLike))
if (filtered.exprs.length < 2) EmptyStmt else filtered
case s => s.map(onStmt).map(dropZeroLenSubAccesses)
}

override def run(c: Circuit): Circuit = {
c.copy(modules = c.modules.map(m => m.map(onStmt)))
}
}
1 change: 1 addition & 0 deletions src/main/scala/firrtl/stage/Forms.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ object Forms {
Dependency(passes.ReplaceAccesses),
Dependency(passes.ExpandConnects),
Dependency(passes.RemoveAccesses),
Dependency(passes.ZeroLengthVecs),
Dependency[passes.ExpandWhensAndCheck],
Dependency[passes.RemoveIntervals],
Dependency(passes.ConvertFixedToSInt),
Expand Down
13 changes: 7 additions & 6 deletions src/test/scala/firrtlTests/LoweringCompilersSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class LoweringCompilersSpec extends FlatSpec with Matchers {
passes.PullMuxes,
passes.ReplaceAccesses,
passes.ExpandConnects,
passes.ZeroLengthVecs,
passes.RemoveAccesses,
passes.Uniquify,
passes.ExpandWhens,
Expand Down Expand Up @@ -156,17 +157,17 @@ class LoweringCompilersSpec extends FlatSpec with Matchers {
it should "replicate the old order" in {
val tm = new TransformManager(Forms.MidForm, Forms.Deduped)
val patches = Seq(
Add(5, Seq(Dependency(firrtl.passes.ResolveKinds),
Add(6, Seq(Dependency(firrtl.passes.ResolveKinds),
Dependency(firrtl.passes.InferTypes))),
Del(6),
Del(7),
Add(6, Seq(Dependency[firrtl.passes.ExpandWhensAndCheck])),
Del(10),
Del(8),
Add(7, Seq(Dependency[firrtl.passes.ExpandWhensAndCheck])),
Del(11),
Del(12),
Add(11, Seq(Dependency(firrtl.passes.ResolveFlows),
Del(13),
Add(12, Seq(Dependency(firrtl.passes.ResolveFlows),
Dependency[firrtl.passes.InferWidths])),
Del(13)
Del(14)
)
compare(legacyTransforms(new HighFirrtlToMiddleFirrtl), tm, patches)
}
Expand Down
68 changes: 68 additions & 0 deletions src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// See LICENSE for license details.

package firrtlTests

import firrtl._
import firrtl.passes._
import firrtl.testutils.FirrtlFlatSpec

class ZeroLengthVecsSpec extends FirrtlFlatSpec {
val transforms = Seq(
ToWorkingIR,
ResolveKinds,
InferTypes,
ResolveFlows,
new InferWidths,
ZeroLengthVecs,
CheckTypes)
protected def exec(input: String) = {
transforms.foldLeft(CircuitState(parse(input), UnknownForm)) {
(c: CircuitState, t: Transform) => t.runTransform(c)
}.circuit.serialize
}

"ZeroLengthVecs" should "drop subaccesses to zero-length vectors" in {
val input =
"""circuit bar :
| module bar :
| input i : { a : UInt<8>, b : UInt<4> }[0]
| input sel : UInt<1>
| output foo : UInt<1>[0]
| output o : UInt<8>
| foo[UInt<1>(0)] <= UInt<1>(0)
| o <= i[sel].a
|""".stripMargin
val check =
"""circuit bar :
| module bar :
| input i : { a : UInt<8>, b : UInt<4> }[0]
| input sel : UInt<1>
| output foo : UInt<1>[0]
| output o : UInt<8>
| skip
| o <= validif(UInt<1>(0), UInt<8>(0))
|""".stripMargin
(parse(exec(input))) should be (parse(check))
}

"ZeroLengthVecs" should "handle intervals correctly" in {
val input =
"""circuit bar :
| module bar :
| input i : Interval[3,4].0[0]
| input sel : UInt<1>
| output o : Interval[3,4].0
| o <= i[sel]
|""".stripMargin
val check =
"""circuit bar :
| module bar :
| input i : Interval[3,4].0[0]
| input sel : UInt<1>
| output o : Interval[3,4].0
| o <= validif(UInt<1>(0), clip(asInterval(SInt<1>(0), 0, 0, 0), i[sel]))
|""".stripMargin
(parse(exec(input))) should be (parse(check))
}

}

0 comments on commit 1a03e63

Please sign in to comment.