diff --git a/benchmark/scripts/find_heap_bound.py b/benchmark/scripts/find_heap_bound.py index e89cfa5d2f..764fdf0075 100755 --- a/benchmark/scripts/find_heap_bound.py +++ b/benchmark/scripts/find_heap_bound.py @@ -6,10 +6,13 @@ from typing import NamedTuple from subprocess import TimeoutExpired import logging +from functools import total_ordering from monitor_job import monitor_job, JobFailedError BaseHeapSize = NamedTuple('JavaHeapSize', [('value', int), ('suffix', str)]) + +@total_ordering class HeapSize(BaseHeapSize): K_FACTOR = 1024 M_FACTOR = 1024*1024 @@ -51,6 +54,16 @@ def __add__(self, rhs): def __sub__(self, rhs): return HeapSize.from_bytes(self.toBytes() - rhs.toBytes()) + + def __eq__(self, rhs): + return self.toBytes() == rhs.toBytes() + + # Defining __eq__ for total_ordering forces us to explicitly inherit __hash__ + __hash__ = BaseHeapSize.__hash__ + + def __ge__(self, rhs): + return self.toBytes() >= rhs.toBytes() + @classmethod def from_str(cls, s: str): regex = '(\d+)([kKmMgG])?' @@ -97,6 +110,8 @@ def parseargs(): parser.add_argument("--timeout-factor", type=float, default=4.0, help="Multiple of wallclock time of first successful run " "that counts as a timeout, runs over this time count as a fail") + parser.add_argument("--context", type=int, default=0, + help="Number of extra steps above the minimum bound to run") return parser.parse_args() @@ -137,16 +152,23 @@ def main(): seen = set() timeout = None # Set by first successful run cur = HeapSize.from_str(args.start_size) - while cur not in seen: + last_success = cur + + # Do binary search + while cur not in seen and (step is None or step >= min_step): seen.add(cur) try: cmd = mk_cmd(args.java, cur, args.args) - logger.info("Running {}".format(" ".join(cmd))) + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Running {}".format(" ".join(cmd))) + else: + logger.info("Running {}".format(cur)) stats = monitor_job(cmd, timeout=timeout) logger.debug(stats) if timeout is None: timeout = stats.wall_clock_time * args.timeout_factor logger.debug("Timeout set to {} s".format(timeout)) + last_success = cur results.append((cur, stats)) if step is None: step = (cur / 2).round_to(min_step) @@ -166,6 +188,33 @@ def main(): cur = (cur + step).round_to(min_step) logger.debug("Next = {}, step = {}".format(cur, step)) + # Run extra steps for some context above the minimum size + extra_steps = [] + if args.context > 0: + for i in range(1, args.context): + diff = min_step * i + heap_size = last_success + diff + if heap_size not in seen: + extra_steps.append(heap_size) + log_steps = ", ".join([str(e) for e in extra_steps]) # Pretty print + logger.info("Because context is {}, running extra heap sizes: {}".format(args.context, log_steps)) + + for cur in extra_steps: + logger.debug("Next = {}".format(cur)) + seen.add(cur) + try: + cmd = mk_cmd(args.java, cur, args.args) + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Running {}".format(" ".join(cmd))) + else: + logger.info("Running {}".format(cur)) + stats = monitor_job(cmd, timeout=timeout) + logger.debug(stats) + results.append((cur, stats)) + except (JobFailedError, TimeoutExpired) as e: + logger.debug(job_failed_msg(e)) + results.append((cur, None)) + sorted_results = sorted(results, key=lambda tup: tup[0].toBytes(), reverse=True) table = [["Xmx", "Max RSS (MiB)", "Wall Clock (s)", "User Time (s)", "System Time (s)"]] diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index a9f4442165..02d3574021 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -121,9 +121,13 @@ object CheckWidths extends Pass { // This is a leaf check of the "local" width-correctness of one expression node, so no recursion. expr match { case e @ UIntLiteral(v, w: IntWidth) if math.max(1, v.bitLength) > w.width => - errors.append(new WidthTooSmall(info, target.serialize, v)) + if (w.width > 0 || (w.width == 0 && v != 0)) { // UInt<0>(0) is allowed + errors.append(new WidthTooSmall(info, target.serialize, v)) + } case e @ SIntLiteral(v, w: IntWidth) if v.bitLength + 1 > w.width => - errors.append(new WidthTooSmall(info, target.serialize, v)) + if (w.width > 0 || (w.width == 0 && v != 0)) { // SInt<0>(0) is allowed + errors.append(new WidthTooSmall(info, target.serialize, v)) + } case e @ DoPrim(op, Seq(a, b), _, tpe) => (op, a.tpe, b.tpe) match { case (Squeeze, IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _)) diff --git a/src/test/scala/firrtl/testutils/LeanTransformSpec.scala b/src/test/scala/firrtl/testutils/LeanTransformSpec.scala index d3510326a4..2d1cad8de7 100644 --- a/src/test/scala/firrtl/testutils/LeanTransformSpec.scala +++ b/src/test/scala/firrtl/testutils/LeanTransformSpec.scala @@ -35,6 +35,10 @@ class LeanTransformSpec(protected val transforms: Seq[TransformDependency]) actual should be(expected) finalState } + protected def removeSkip(c: ir.Circuit): ir.Circuit = { + def onStmt(s: ir.Statement): ir.Statement = s.mapStmt(onStmt) + c.mapModule(m => m.mapStmt(onStmt)) + } } private object LeanTransformSpec { diff --git a/src/test/scala/firrtlTests/ZeroWidthTests.scala b/src/test/scala/firrtlTests/ZeroWidthTests.scala index 4e22ff51a3..4844fa6f59 100644 --- a/src/test/scala/firrtlTests/ZeroWidthTests.scala +++ b/src/test/scala/firrtlTests/ZeroWidthTests.scala @@ -3,20 +3,11 @@ package firrtlTests import firrtl._ +import firrtl.options.Dependency import firrtl.passes._ import firrtl.testutils._ -class ZeroWidthTests extends FirrtlFlatSpec { - def transforms = Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, ZeroWidth) - private def exec(input: String) = { - val circuit = parse(input) - transforms - .foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) => - p.runTransform(c) - } - .circuit - .serialize - } +class ZeroWidthTests extends LeanTransformSpec(Seq(Dependency(ZeroWidth))) { // ============================= "Zero width port" should " be deleted" in { val input = @@ -30,7 +21,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | output x : UInt<1> | x <= UInt<1>(0)""".stripMargin - (parse(exec(input))) should be(parse(check)) + compile(input).circuit.serialize should be(parse(check).serialize) } "Add of <0> and <2> " should " put in zero" in { val input = @@ -44,21 +35,19 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | output x : UInt<3> | x <= add(UInt<1>(0), UInt<2>(2))""".stripMargin - (parse(exec(input)).serialize) should be(parse(check).serialize) + compile(input).circuit.serialize should be(parse(check).serialize) } - "Mux on <0>" should "put in zero" in { + "Mux on <0>" should "not be allowed" in { + // Note that this used to be allowed, but the support seems to have bit-rotted + // and modern firrtl enforces 1-bit UInt for muxes. val input = """circuit Top : | module Top : | input y : UInt<0> | output x : UInt | x <= mux(y, UInt<2>(2), UInt<2>(1))""".stripMargin - val check = - """circuit Top : - | module Top : - | output x : UInt<2> - | x <= mux(UInt<1>(0), UInt<2>(2), UInt<2>(1))""".stripMargin - (parse(exec(input)).serialize) should be(parse(check).serialize) + val e = intercept[PassException] { compile(input) } + assert(e.getMessage.contains("A mux condition must be of type 1-bit UInt")) } "Bundle with field of <0>" should "get deleted" in { val input = @@ -66,13 +55,16 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | input y : { a: UInt<0> } | output x : { a: UInt<0>, b: UInt<1>} + | x.b <= UInt(1) | x.a <= y.a""".stripMargin val check = """circuit Top : | module Top : | output x : { b: UInt<1> } - | skip""".stripMargin - (parse(exec(input)).serialize) should be(parse(check).serialize) + | skip + | x.b <= UInt(1) + | """.stripMargin + compile(input).circuit.serialize should be(parse(check).serialize) } "Vector with type of <0>" should "get deleted" in { val input = @@ -85,7 +77,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { """circuit Top : | module Top : | skip""".stripMargin - (parse(exec(input)).serialize) should be(parse(check).serialize) + removeSkip(compile(input).circuit).serialize should be(parse(check).serialize) } "Node with <0>" should "be removed" in { val input = @@ -97,7 +89,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { """circuit Top : | module Top : | skip""".stripMargin - (parse(exec(input)).serialize) should be(parse(check).serialize) + compile(input).circuit.serialize should be(parse(check).serialize) } "IsInvalid on <0>" should "be deleted" in { val input = @@ -109,7 +101,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { """circuit Top : | module Top : | skip""".stripMargin - (parse(exec(input)).serialize) should be(parse(check).serialize) + compile(input).circuit.serialize should be(parse(check).serialize) } "Expression in node with type <0>" should "be replaced by UInt<1>(0)" in { val input = @@ -123,7 +115,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | input x: UInt<1> | node z = add(x, UInt<1>(0))""".stripMargin - (parse(exec(input)).serialize) should be(parse(check).serialize) + compile(input).circuit.serialize should be(parse(check).serialize) } "Expression in cat with type <0>" should "be removed" in { val input = @@ -137,7 +129,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | input x: UInt<1> | node z = x""".stripMargin - (parse(exec(input)).serialize) should be(parse(check).serialize) + compile(input).circuit.serialize should be(parse(check).serialize) } "Nested cats with type <0>" should "be removed" in { val input = @@ -151,7 +143,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { """circuit Top : | module Top : | skip""".stripMargin - (parse(exec(input)).serialize) should be(parse(check).serialize) + compile(input).circuit.serialize should be(parse(check).serialize) } "Nested cats where one has type <0>" should "be unaffected" in { val input = @@ -167,9 +159,11 @@ class ZeroWidthTests extends FirrtlFlatSpec { | input x: UInt<1> | input z: UInt<1> | node a = cat(x, z)""".stripMargin - (parse(exec(input)).serialize) should be(parse(check).serialize) + compile(input).circuit.serialize should be(parse(check).serialize) } "Stop with type <0>" should "be replaced with UInt(0)" in { + // Note that this used to be allowed, but the support seems to have bit-rotted + // and modern firrtl enforces 1-bit UInt for stop enables. val input = """circuit Top : | module Top : @@ -178,14 +172,8 @@ class ZeroWidthTests extends FirrtlFlatSpec { | input y: UInt<0> | input z: UInt<1> | stop(clk, y, 1)""".stripMargin - val check = - """circuit Top : - | module Top : - | input clk: Clock - | input x: UInt<1> - | input z: UInt<1> - | stop(clk, UInt(0), 1)""".stripMargin - (parse(exec(input)).serialize) should be(parse(check).serialize) + val e = intercept[PassException] { compile(input) } + assert(e.getMessage.contains("Enable must be a 1-bit UIntType typed signal")) } "Print with type <0>" should "be replaced with UInt(0)" in { val input = @@ -203,7 +191,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | input x: UInt<1> | input z: UInt<1> | printf(clk, UInt(1), "%d %d %d\n", x, UInt(0), z)""".stripMargin - (parse(exec(input)).serialize) should be(parse(check).serialize) + compile(input).circuit.serialize should be(parse(check).serialize) } "Andr of zero-width expression" should "return true" in { @@ -218,7 +206,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | output x : UInt<1> | x <= UInt<1>(1)""".stripMargin - (parse(exec(input))) should be(parse(check)) + compile(input).circuit.serialize should be(parse(check).serialize) } "Cat of SInt with zero-width" should "keep type correctly" in { @@ -235,7 +223,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | input y : SInt<1> | output z : UInt<1> | z <= asUInt(y)""".stripMargin - (parse(exec(input))) should be(parse(check)) + compile(input).circuit.serialize should be(parse(check).serialize) } "dshl with zero-width" should "canonicalize to the un-shifted expression" in { @@ -252,7 +240,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | input y : SInt<1> | output z : SInt<1> | z <= y""".stripMargin - (parse(exec(input))) should be(parse(check)) + compile(input).circuit.serialize should be(parse(check).serialize) } "Memories with zero-width data-type" should "be fully removed" in { @@ -315,7 +303,24 @@ class ZeroWidthTests extends FirrtlFlatSpec { | input rwMask: UInt<1> | |${Seq.tabulate(17)(_ => " skip").mkString("\n")}""".stripMargin - parse(exec(input)) should be(parse(check)) + compile(input).circuit.serialize should be(parse(check).serialize) + } + + "zero width literals" should "be permissible" in { + val input = + """circuit Foo: + | module Foo: + | output x : UInt<1> + | output y : SInt<3> + | + | x <= UInt<0>(0) + | y <= SInt<0>(0) + |""".stripMargin + + val result = compile(input).circuit + val lines = result.serialize.split('\n').map(_.trim) + assert(lines.contains("x <= UInt<1>(\"h0\")")) + assert(lines.contains("y <= SInt<1>(\"h0\")")) } }