diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala index 6908a3a1ed..24ebc7813c 100644 --- a/src/main/scala/firrtl/annotations/JsonProtocol.scala +++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala @@ -315,12 +315,13 @@ object JsonProtocol extends LazyLogging { // this used on the first invocation to check all annotations do so def findTypeHints(classInst: Seq[JValue], requireClassField: Boolean = false): Seq[String] = classInst .flatMap({ - case JObject(("class", JString(name)) :: fields) => name +: findTypeHints(fields.map(_._2)) - case obj: JObject if requireClassField => - throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj") - case JObject(fields) => findTypeHints(fields.map(_._2)) - case JArray(arr) => findTypeHints(arr) - case _ => Seq() + case JObject(fields) => + val hint = fields.collectFirst { case ("class", JString(name)) => name } + if (requireClassField && hint.isEmpty) + throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $fields") + hint ++: findTypeHints(fields.map(_._2)) + case JArray(arr) => findTypeHints(arr) + case _ => Seq() }) .distinct diff --git a/src/main/scala/firrtl/package.scala b/src/main/scala/firrtl/package.scala index 844d84ec41..67d5e52c47 100644 --- a/src/main/scala/firrtl/package.scala +++ b/src/main/scala/firrtl/package.scala @@ -10,10 +10,10 @@ package object firrtl { implicit def annoSeqToSeq(as: AnnotationSeq): Seq[Annotation] = as.toSeq /* Options as annotations compatibility items */ - @deprecated("Use firrtl.stage.TargetDirAnnotation", "FIRRTL 1.2") + @deprecated("Use firrtl.options.TargetDirAnnotation", "FIRRTL 1.2") type TargetDirAnnotation = firrtl.options.TargetDirAnnotation - @deprecated("Use firrtl.stage.TargetDirAnnotation", "FIRRTL 1.2") + @deprecated("Use firrtl.options.TargetDirAnnotation", "FIRRTL 1.2") val TargetDirAnnotation = firrtl.options.TargetDirAnnotation type WRef = ir.Reference diff --git a/src/main/scala/firrtl/passes/memlib/MemConf.scala b/src/main/scala/firrtl/passes/memlib/MemConf.scala index 637b57e0f6..8063c62787 100644 --- a/src/main/scala/firrtl/passes/memlib/MemConf.scala +++ b/src/main/scala/firrtl/passes/memlib/MemConf.scala @@ -13,7 +13,21 @@ case object MaskedReadWritePort extends MemPort("mrw") object MemPort { - val all = Set(ReadPort, WritePort, MaskedWritePort, ReadWritePort, MaskedReadWritePort) + // This is the order that ports will render in MemConf.portsStr + val ordered: Seq[MemPort] = Seq( + MaskedReadWritePort, + MaskedWritePort, + ReadWritePort, + WritePort, + ReadPort + ) + + val all: Set[MemPort] = ordered.toSet + // uses orderedPorts when sorting MemPorts + implicit def ordering: Ordering[MemPort] = { + val orderedPorts = ordered.zipWithIndex.toMap + Ordering.by(e => orderedPorts(e)) + } def apply(s: String): Option[MemPort] = MemPort.all.find(_.name == s) @@ -38,7 +52,8 @@ case class MemConf( ports: Map[MemPort, Int], maskGranularity: Option[Int]) { - private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") }.mkString(",") + private def portsStr = + ports.toSeq.sortBy(_._1).map { case (port, num) => Seq.fill(num)(port.name).mkString(",") }.mkString(",") private def maskGranStr = maskGranularity.map((p) => s"mask_gran $p").getOrElse("") // Assert that all of the entries in the port map are greater than zero to make it easier to compare two of these case classes diff --git a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala index 515c8af981..7f4266fd06 100644 --- a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala +++ b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala @@ -283,8 +283,12 @@ case class FirrtlCircuitAnnotation(circuit: Circuit) extends NoTargetAnnotation /* Caching the hashCode for a large circuit is necessary due to repeated queries, e.g., in * [[Compiler.propagateAnnotations]]. Not caching the hashCode will cause severe performance degredations for large * [[Annotations]]. + * @note Uses the hashCode of the name of the circuit. Creating a HashMap with different Circuits + * that nevertheless have the same name is extremely uncommon so collisions are not a concern. + * Include productPrefix so that this doesn't collide with other types that use a similar + * strategy and hash the same String. */ - override lazy val hashCode: Int = circuit.hashCode + override lazy val hashCode: Int = (this.productPrefix + circuit.main).hashCode } diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala index f1434ad297..7f2207af82 100644 --- a/src/main/scala/firrtl/transforms/RemoveReset.scala +++ b/src/main/scala/firrtl/transforms/RemoveReset.scala @@ -52,6 +52,7 @@ object RemoveReset extends Transform with DependencyAPIMigration { val resets = mutable.HashMap.empty[String, Reset] val asyncResets = mutable.HashMap.empty[String, Reset] val invalids = computeInvalids(m) + lazy val namespace = Namespace(m) def onStmt(stmt: Statement): Statement = { stmt match { case reg @ DefRegister(_, name, _, _, reset, init) if isPreset(name) => @@ -93,6 +94,17 @@ object RemoveReset extends Transform with DependencyAPIMigration { // addUpdate(info, Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), Seq.empty) val infox = MultiInfo(reset.info, reset.info, info) Connect(infox, ref, expr) + /* Synchronously reset register that has reset value but only an invalid connection */ + case IsInvalid(iinfo, ref @ WRef(rname, tpe, RegKind, _)) if resets.contains(rname) => + // We need to mux with the invalid value to be consistent with async reset registers + val dummyWire = DefWire(iinfo, namespace.newName(rname), tpe) + val wireRef = Reference(dummyWire).copy(flow = SourceFlow) + val invalid = IsInvalid(iinfo, wireRef) + // Now mux between the invalid wire and the reset value + val Reset(cond, init, info) = resets(rname) + val muxType = Utils.mux_type_and_widths(init, wireRef) + val connect = Connect(info, ref, Mux(cond, init, wireRef, muxType)) + Block(Seq(dummyWire, invalid, connect)) case other => other.map(onStmt) } } diff --git a/src/test/scala/firrtl/JsonProtocolSpec.scala b/src/test/scala/firrtl/JsonProtocolSpec.scala index 0da4320443..3e07542bb3 100644 --- a/src/test/scala/firrtl/JsonProtocolSpec.scala +++ b/src/test/scala/firrtl/JsonProtocolSpec.scala @@ -31,6 +31,8 @@ object JsonProtocolTestClasses { with HasSerializationHints { def typeHints = Seq(param.getClass) } + + case class SimpleAnnotation(alpha: String) extends NoTargetAnnotation } import JsonProtocolTestClasses._ @@ -68,4 +70,14 @@ class JsonProtocolSpec extends AnyFlatSpec { val deserAnno = serializeAndDeserialize(anno) assert(anno == deserAnno) } + + "JSON object order" should "not affect deserialization" in { + val anno = SimpleAnnotation("hello") + val serializedAnno = """[{ + "alpha": "hello", + "class": "firrtlTests.JsonProtocolTestClasses$SimpleAnnotation" + }]""" + val deserAnno = JsonProtocol.deserialize(serializedAnno).head + assert(anno == deserAnno) + } } diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index d9dc2e5779..e8d72043de 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -693,4 +693,16 @@ circuit Top : |""".stripMargin compileAndEmit(CircuitState(parse(input), ChirrtlForm)) } + + "MemPorts" should "serialize in a deterministic order regardless" in { + def compare(seq1: Seq[MemPort]) { + val m1 = MemConf("memconf", 8, 16, seq1.map(s => s -> 1).toMap, None) + val m2 = MemConf("memconf", 8, 16, seq1.reverse.map(s => s -> 1).toMap, None) + m1.toString should be(m2.toString) + } + + compare(Seq(ReadPort, WritePort)) + compare(Seq(MaskedWritePort, ReadWritePort)) + compare(Seq(MaskedReadWritePort, WritePort, ReadWritePort)) + } } diff --git a/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala b/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala index 1adeeed8d3..666320b71b 100644 --- a/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala +++ b/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala @@ -8,7 +8,7 @@ import firrtl.testutils.FirrtlFlatSpec import firrtl.testutils.FirrtlCheckers._ import firrtl.{CircuitState, WRef} -import firrtl.ir.{Connect, DefRegister, Mux} +import firrtl.ir.{Connect, DefRegister, IsInvalid, Mux, UIntLiteral} import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlSourceAnnotation, FirrtlStage} class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { @@ -47,6 +47,32 @@ class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { outputState shouldNot containTree { case Connect(_, WRef("foo", _, _, _), Mux(_, _, _, _)) => true } } + it should "generate a reset mux for a sync reset register with an invalid connection" in { + Given("an 8-bit register 'foo' initialized to UInt(3) with an invalid connection") + val input = + """|circuit Example : + | module Example : + | input clock : Clock + | input rst : UInt<1> + | input in : UInt<8> + | output out : UInt<8> + | + | reg foo : UInt<8>, clock with : (reset => (rst, UInt(3))) + | foo is invalid + | out <= foo""".stripMargin + + val outputState = toLowFirrtl(input) + + Then("'foo' should not have a reset") + outputState should containTree { + case DefRegister(_, "foo", _, _, UIntLiteral(value, _), WRef("foo", _, _, _)) if value == 0 => true + } + And("'foo' is connected to a mux with its old reset value") + outputState should containTree { + case Connect(_, WRef("foo", _, _, _), Mux(_, UIntLiteral(value, _), _, _)) if value == 3 => true + } + } + it should "generate a reset mux for only the portion of an invalid aggregate that is reset" in { Given("aggregate register 'foo' with 2-bit field 'a' and 1-bit field 'b'") And("aggregate, invalid wire 'bar' with the same fields")