diff --git a/frontends/PyCDE/integration_test/esitester.py b/frontends/PyCDE/integration_test/esitester.py index b42ad9d21d10..103865d76e5d 100644 --- a/frontends/PyCDE/integration_test/esitester.py +++ b/frontends/PyCDE/integration_test/esitester.py @@ -22,7 +22,7 @@ # RUN: esi-cosim.py --source %t -- esitester cosim env dmatest import pycde -from pycde import AppID, Clock, Module, Reset, generator +from pycde import AppID, Clock, Module, Reset, generator, modparams from pycde.bsp import cosim from pycde.constructs import Counter, Reg, Wire from pycde.esi import CallService @@ -58,60 +58,71 @@ def construct(ports): CallService.call(AppID("PrintfExample"), arg_chan, Bits(0)) -class ReadMem(Module): - """Module which reads host memory at a certain address as given by writes to - MMIO register 0x8. Stores the read value and responds to all MMIO reads with - the stored value.""" - - clk = Clock() - rst = Reset() - - @generator - def construct(ports): - cmd_chan_wire = Wire(Channel(esi.MMIOReadWriteCmdType)) - resp_ready_wire = Wire(Bits(1)) - cmd, cmd_valid = cmd_chan_wire.unwrap(resp_ready_wire) - mmio_xact = cmd_valid & resp_ready_wire - - read_loc_ce = mmio_xact & cmd.write & (cmd.offset == 0x8) - read_loc = Reg(UInt(64), - clk=ports.clk, - rst=ports.rst, - rst_value=0, - ce=read_loc_ce) - read_loc.assign(cmd.data.as_uint()) - - mem_data_ce = Wire(Bits(1)) - mem_data = Reg(Bits(64), - clk=ports.clk, - rst=ports.rst, - rst_value=0, - ce=mem_data_ce) - - response_data = mem_data - response_chan, response_ready = Channel(Bits(64)).wrap( - response_data, cmd_valid) - resp_ready_wire.assign(response_ready) - - mmio_rw = esi.MMIO.read_write(appid=AppID("ReadMem")) - mmio_rw_cmd_chan = mmio_rw.unpack(data=response_chan)['cmd'] - cmd_chan_wire.assign(mmio_rw_cmd_chan) - - tag = Counter(8)(clk=ports.clk, rst=ports.rst, increment=mmio_xact) - - hostmem_read_req, hostmem_read_req_ready = Channel( - esi.HostMem.ReadReqType).wrap({ - "tag": tag.out, - "address": read_loc - }, read_loc_ce.reg(ports.clk, ports.rst)) - - hostmem_read_resp = esi.HostMem.read(appid=AppID("ReadMem_hostread"), - req=hostmem_read_req, - data_type=Bits(64)) - hostmem_read_resp_data, hostmem_read_resp_valid = hostmem_read_resp.unwrap( - 1) - mem_data.assign(hostmem_read_resp_data.data) - mem_data_ce.assign(hostmem_read_resp_valid) +@modparams +def ReadMem(width: int): + + class ReadMem(Module): + """Module which reads host memory at a certain address as given by writes to + MMIO register 0x8. Stores the read value and responds to all MMIO reads with + the stored value.""" + + clk = Clock() + rst = Reset() + + @generator + def construct(ports): + cmd_chan_wire = Wire(Channel(esi.MMIOReadWriteCmdType)) + resp_ready_wire = Wire(Bits(1)) + cmd, cmd_valid = cmd_chan_wire.unwrap(resp_ready_wire) + mmio_xact = cmd_valid & resp_ready_wire + + read_loc_ce = mmio_xact & cmd.write & (cmd.offset == 0x8) + read_loc = Reg(UInt(64), + clk=ports.clk, + rst=ports.rst, + rst_value=0, + ce=read_loc_ce, + name="read_loc") + read_loc.assign(cmd.data.as_uint()) + + mem_data_ce = Wire(Bits(1)) + mem_data = Reg(Bits(width), + clk=ports.clk, + rst=ports.rst, + rst_value=0, + ce=mem_data_ce, + name="mem_data") + + response_data = mem_data.as_bits(64) + response_chan, response_ready = Channel(Bits(64)).wrap( + response_data, cmd_valid) + resp_ready_wire.assign(response_ready) + + mmio_rw = esi.MMIO.read_write(appid=AppID("ReadMem")) + mmio_rw_cmd_chan = mmio_rw.unpack(data=response_chan)['cmd'] + cmd_chan_wire.assign(mmio_rw_cmd_chan) + + tag = Counter(8)(clk=ports.clk, + rst=ports.rst, + clear=Bits(1)(0), + increment=mmio_xact) + + # Ignoring the ready signal isn't safe, but for cosim it's probably fine. + hostmem_read_req, hostmem_read_req_ready = Channel( + esi.HostMem.ReadReqType).wrap({ + "tag": tag.out, + "address": read_loc + }, read_loc_ce.reg(ports.clk, ports.rst)) + + hostmem_read_resp = esi.HostMem.read(appid=AppID("ReadMem_hostread"), + req=hostmem_read_req, + data_type=mem_data.type) + hostmem_read_resp_data, hostmem_read_resp_valid = hostmem_read_resp.unwrap( + 1) + mem_data.assign(hostmem_read_resp_data.data) + mem_data_ce.assign(hostmem_read_resp_valid) + + return ReadMem class WriteMem(Module): @@ -144,10 +155,14 @@ def construct(ports): mmio_rw_cmd_chan = mmio_rw.unpack(data=response_chan)['cmd'] cmd_chan_wire.assign(mmio_rw_cmd_chan) - tag = Counter(8)(clk=ports.clk, rst=ports.rst, increment=mmio_xact) + tag = Counter(8)(clk=ports.clk, + rst=ports.rst, + clear=Bits(1)(0), + increment=mmio_xact) cycle_counter = Counter(64)(clk=ports.clk, rst=ports.rst, + clear=Bits(1)(0), increment=Bits(1)(1)) hostmem_write_req, _ = esi.HostMem.wrap_write_req( @@ -167,7 +182,10 @@ class EsiTesterTop(Module): @generator def construct(ports): PrintfExample(clk=ports.clk, rst=ports.rst) - ReadMem(clk=ports.clk, rst=ports.rst) + # Once I get read muxing working, enable all three. + # ReadMem(32)(clk=ports.clk, rst=ports.rst) + # ReadMem(64)(clk=ports.clk, rst=ports.rst) + ReadMem(96)(clk=ports.clk, rst=ports.rst) WriteMem(clk=ports.clk, rst=ports.rst) diff --git a/frontends/PyCDE/src/pycde/bsp/common.py b/frontends/PyCDE/src/pycde/bsp/common.py index b61aa8360ded..2d856f12a07b 100644 --- a/frontends/PyCDE/src/pycde/bsp/common.py +++ b/frontends/PyCDE/src/pycde/bsp/common.py @@ -3,15 +3,17 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from __future__ import annotations +from math import ceil -from ..common import Clock, Input, Output, Reset -from ..constructs import AssignableSignal, NamedWire, Wire +from ..common import Clock, Input, InputChannel, Output, OutputChannel, Reset +from ..constructs import (AssignableSignal, ControlReg, Counter, NamedWire, Reg, + Wire) from .. import esi from ..module import Module, generator, modparams from ..signals import BitsSignal, BundleSignal, ChannelSignal from ..support import clog2 from ..types import (Array, Bits, Bundle, BundledChannel, Channel, - ChannelDirection, StructType, UInt) + ChannelDirection, StructType, Type, UInt) from typing import Dict, List, Tuple import typing @@ -255,6 +257,187 @@ def build_addr_read( return sel_bits, client_addr_chan +@modparams +def TaggedGearbox(input_bitwidth: int, + output_bitwidth: int) -> type["TaggedGearboxImpl"]: + """Build a gearbox to convert the upstream data to the client data + type. Assumes a struct {tag, data} and only gearboxes the data. Tag is stored + separately and the struct is re-assembled later on.""" + + class TaggedGearboxImpl(Module): + clk = Clock() + rst = Reset() + in_ = InputChannel( + StructType([ + ("tag", esi.HostMem.TagType), + ("data", Bits(input_bitwidth)), + ])) + out = OutputChannel( + StructType([ + ("tag", esi.HostMem.TagType), + ("data", Bits(output_bitwidth)), + ])) + + @generator + def build(ports): + ready_for_upstream = Wire(Bits(1), name="ready_for_upstream") + upstream_tag_and_data, upstream_valid = ports.in_.unwrap( + ready_for_upstream) + upstream_data = upstream_tag_and_data.data + upstream_xact = ready_for_upstream & upstream_valid + + # Determine if gearboxing is necessary and whether it needs to be + # gearboxed up or just sliced down. + if output_bitwidth == input_bitwidth: + client_data_bits = upstream_data + client_valid = upstream_valid + elif output_bitwidth < input_bitwidth: + client_data_bits = upstream_data[:output_bitwidth] + client_valid = upstream_valid + else: + # Create registers equal to the number of upstream transactions needed + # to fill the client data. Set the output to the concatenation of said + # registers. + chunks = ceil(output_bitwidth / input_bitwidth) + reg_ces = [Wire(Bits(1)) for _ in range(chunks)] + regs = [ + upstream_data.reg(ports.clk, + ports.rst, + ce=reg_ces[idx], + name=f"chunk_reg_{idx}") for idx in range(chunks) + ] + client_data_bits = BitsSignal.concat(reversed(regs))[:output_bitwidth] + + # Use counter to determine to which register to write and determine if + # the registers are all full. + clear_counter = Wire(Bits(1)) + counter_width = clog2(chunks) + counter = Counter(counter_width)(clk=ports.clk, + rst=ports.rst, + clear=clear_counter, + increment=upstream_xact) + set_client_valid = counter.out == chunks - 1 + client_xact = Wire(Bits(1)) + client_valid = ControlReg(ports.clk, ports.rst, [set_client_valid], + [client_xact]) + client_xact.assign(client_valid & ready_for_upstream) + clear_counter.assign(client_xact) + for idx, reg_ce in enumerate(reg_ces): + reg_ce.assign(upstream_xact & + (counter.out == UInt(counter_width)(idx))) + + # Construct the output channel. Shared logic across all three cases. + tag_reg = upstream_tag_and_data.tag.reg(ports.clk, + ports.rst, + ce=upstream_xact, + name="tag_reg") + client_channel, client_ready = TaggedGearboxImpl.out.type.wrap( + { + "tag": tag_reg, + "data": client_data_bits, + }, client_valid) + ready_for_upstream.assign(client_ready) + ports.out = client_channel + + return TaggedGearboxImpl + + +def HostmemReadProcessor(read_width: int, hostmem_module, + reqs: List[esi._OutputBundleSetter]): + """Construct a host memory read request module to orchestrate the the read + connections. Responsible for both gearboxing the data, multiplexing the + requests, reassembling out-of-order responses and routing the responses to the + correct clients. + + Generate this module dynamically to allow for multiple read clients of + multiple types to be directly accomodated.""" + + class HostmemReadProcessorImpl(Module): + clk = Clock() + rst = Reset() + + # Add an output port for each read client. + reqPortMap: Dict[esi._OutputBundleSetter, str] = {} + for req in reqs: + name = "client_" + req.client_name_str + locals()[name] = Output(req.type) + reqPortMap[req] = name + + # And then the port which goes to the host. + upstream = Output(hostmem_module.read.type) + + @generator + def build(ports): + """Build the read side of the HostMem service.""" + + # If there's no read clients, just return a no-op read bundle. + if len(reqs) == 0: + upstream_req_channel, _ = Channel(hostmem_module.UpstreamReadReq).wrap( + { + "tag": 0, + "length": 0, + "address": 0 + }, 0) + upstream_read_bundle, _ = hostmem_module.read.type.pack( + req=upstream_req_channel) + ports.upstream = upstream_read_bundle + return + + # TODO: mux together multiple read clients. + assert len(reqs) == 1, "Only one read client supported for now." + + # Pack the upstream bundle and leave the request as a wire. + upstream_req_channel = Wire(Channel(hostmem_module.UpstreamReadReq)) + upstream_read_bundle, froms = hostmem_module.read.type.pack( + req=upstream_req_channel) + ports.upstream = upstream_read_bundle + upstream_resp_channel = froms["resp"] + + for client in reqs: + # Find the response channel in the request bundle. + resp_type = [ + c.channel for c in client.type.channels if c.name == 'resp' + ][0] + # TODO: route the response to the correct client. + # TODO: tag re-writing to deal with tag aliasing. + # Pretend to demux the upstream response channel. + demuxed_upstream_channel = upstream_resp_channel + + # TODO: Should responses come back out-of-order (interleaved tags), + # re-order them here so the gearbox doesn't get confused. (Longer term.) + # For now, only support one outstanding transaction at a time. This has + # the additional benefit of letting the upstream tag be the client + # identifier. TODO: Implement the gating logic here. + + # Gearbox the data to the client's data type. + client_type = resp_type.inner_type + gearbox = TaggedGearbox(read_width, client_type.data.bitwidth)( + clk=ports.clk, rst=ports.rst, in_=demuxed_upstream_channel) + client_resp_channel = gearbox.out.transform(lambda m: client_type({ + "tag": m.tag, + "data": m.data.bitcast(client_type.data) + })) + + # Assign the client response to the correct port. + client_bundle, froms = client.type.pack(resp=client_resp_channel) + client_req = froms["req"] + # Set the port for the client request. + setattr(ports, HostmemReadProcessorImpl.reqPortMap[client], + client_bundle) + + # Assign the multiplexed read request to the upstream request. + # TODO: mux together multiple read clients. + upstream_req_channel.assign( + client_req.transform(lambda r: hostmem_module.UpstreamReadReq({ + "address": r.address, + "length": (client_type.data.bitwidth + 7) // 8, + "tag": r.tag + }))) + HostmemReadProcessorImpl.reqPortMap.clear() + + return HostmemReadProcessorImpl + + @modparams def ChannelHostMem(read_width: int, write_width: int) -> typing.Type['ChannelHostMemImpl']: @@ -268,7 +451,7 @@ class ChannelHostMemImpl(esi.ServiceImplementation): UpstreamReadReq = StructType([ ("address", UInt(64)), - ("length", UInt(32)), + ("length", UInt(32)), # In bytes. ("tag", UInt(8)), ]) read = Output( @@ -277,7 +460,7 @@ class ChannelHostMemImpl(esi.ServiceImplementation): BundledChannel( "resp", ChannelDirection.FROM, StructType([ - ("tag", UInt(8)), + ("tag", esi.HostMem.TagType), ("data", Bits(read_width)), ])), ])) @@ -294,8 +477,18 @@ class ChannelHostMemImpl(esi.ServiceImplementation): @generator def generate(ports, bundles: esi._ServiceGeneratorBundles): + # Split the read side out into a separate module. Must assign the output + # ports to the clients since we can't service a request in a different + # module. read_reqs = [req for req in bundles.to_client_reqs if req.port == 'read'] - ports.read = ChannelHostMemImpl.build_tagged_read_mux(ports, read_reqs) + read_proc_module = HostmemReadProcessor(read_width, ChannelHostMemImpl, + read_reqs) + read_proc = read_proc_module(clk=ports.clk, rst=ports.rst) + ports.read = read_proc.upstream + for req in read_reqs: + req.assign(getattr(read_proc, read_proc_module.reqPortMap[req])) + + # The write side. write_reqs = [ req for req in bundles.to_client_reqs if req.port == 'write' ] @@ -353,60 +546,4 @@ def build_tagged_write_mux( write_acks[0].assign(ack_tag) return upstream_write_bundle - @staticmethod - def build_tagged_read_mux( - ports, reqs: List[esi._OutputBundleSetter]) -> BundleSignal: - """Build the read side of the HostMem service.""" - - if len(reqs) == 0: - req, req_ready = Channel(ChannelHostMemImpl.UpstreamReadReq).wrap( - { - "tag": 0, - "length": 0, - "address": 0 - }, 0) - read_bundle, _ = ChannelHostMemImpl.read.type.pack(req=req) - return read_bundle - - # TODO: mux together multiple read clients. - assert len(reqs) == 1, "Only one read client supported for now." - - req = Wire(Channel(ChannelHostMemImpl.UpstreamReadReq)) - read_bundle, froms = ChannelHostMemImpl.read.type.pack(req=req) - resp_chan_ready = Wire(Bits(1)) - resp_data, resp_valid = froms["resp"].unwrap(resp_chan_ready) - for client in reqs: - resp_type = [ - c.channel for c in client.type.channels if c.name == 'resp' - ][0] - client_read_type = resp_type.inner_type.data - # TODO: support gearboxing up to the correct width. - assert client_read_type.width == read_width, \ - "Gearboxing not yet supported." - client_tag = resp_data.tag - client_resp_valid = resp_valid - client_resp, client_resp_ready = Channel(resp_type).wrap( - { - # TODO: tag re-writing to deal with tag aliasing. - "tag": client_tag, - "data": resp_data.data.bitcast(client_read_type), - }, - client_resp_valid) - # TODO: mux this properly. - resp_chan_ready.assign(client_resp_ready) - - client_bundle, froms = client.type.pack(resp=client_resp) - client_req = froms["req"] - client.assign(client_bundle) - - # Assign the multiplexed read request to the upstream request. - req.assign( - client_req.transform(lambda r: ChannelHostMemImpl.UpstreamReadReq({ - "address": r.address, - "length": 1, - "tag": r.tag - }))) - - return read_bundle - return ChannelHostMemImpl diff --git a/frontends/PyCDE/src/pycde/bsp/cosim.py b/frontends/PyCDE/src/pycde/bsp/cosim.py index 54e5586c3ba5..85de665d0e77 100644 --- a/frontends/PyCDE/src/pycde/bsp/cosim.py +++ b/frontends/PyCDE/src/pycde/bsp/cosim.py @@ -61,16 +61,15 @@ def build(ports): appid=esi.AppID("__cosim_hostmem"), clk=ports.clk, rst=ports.rst) - resp_wire = Wire( - Channel( - StructType([ - ("tag", UInt(8)), - ("data", Bits(ESI_Cosim_UserTopWrapper.HostMemWidth)), - ]))) - req = hostmem.read.unpack(resp=resp_wire)['req'] - data = esi.CallService.call(esi.AppID("__cosim_hostmem_read"), req, - resp_wire.type) - resp_wire.assign(data) + + resp_channel = esi.ChannelService.from_host( + esi.AppID("__cosim_hostmem_read_resp"), + StructType([ + ("tag", UInt(8)), + ("data", Bits(ESI_Cosim_UserTopWrapper.HostMemWidth)), + ])) + req = hostmem.read.unpack(resp=resp_channel)['req'] + esi.ChannelService.to_host(esi.AppID("__cosim_hostmem_read_req"), req) ack_wire = Wire(Channel(UInt(8))) write_req = hostmem.write.unpack(ackTag=ack_wire)['req'] diff --git a/frontends/PyCDE/src/pycde/constructs.py b/frontends/PyCDE/src/pycde/constructs.py index e1fe0081499c..f380c667bb0d 100644 --- a/frontends/PyCDE/src/pycde/constructs.py +++ b/frontends/PyCDE/src/pycde/constructs.py @@ -263,6 +263,7 @@ def Counter(width: int): class Counter(Module): clk = Clock() rst = Reset() + clear = Input(Bits(1)) increment = Input(Bits(1)) out = Output(UInt(width)) @@ -273,7 +274,8 @@ def construct(ports): rst=ports.rst, rst_value=0, ce=ports.increment) - count.assign((count + 1).as_uint(width)) + next = (count + 1).as_uint(width) + count.assign(Mux(ports.clear, next, UInt(width)(0))) ports.out = count return Counter diff --git a/frontends/PyCDE/src/pycde/esi.py b/frontends/PyCDE/src/pycde/esi.py index a5ee4f2dd2f7..7e2aa162905d 100644 --- a/frontends/PyCDE/src/pycde/esi.py +++ b/frontends/PyCDE/src/pycde/esi.py @@ -174,6 +174,10 @@ def add_record(self, def client_name(self) -> List[AppID]: return [AppID(x) for x in self.req.relativeAppIDPath] + @property + def client_name_str(self) -> str: + return "_".join([str(appid) for appid in self.client_name]) + def assign(self, new_value: ChannelSignal): """Assign the generated channel to this request.""" if self._bundle_to_replace is None: @@ -184,6 +188,12 @@ def assign(self, new_value: ChannelSignal): f"Channel type mismatch. Expected {self.type}, got {new_value.type}.") msft.replaceAllUsesWith(self._bundle_to_replace, new_value.value) self._bundle_to_replace = None + self.req = None + + def cleanup(self): + """Null out all the references to all the ops to allow them to be GC'd.""" + self.req = None + self.rec = None class _ServiceGeneratorBundles: @@ -219,6 +229,13 @@ def check_unconnected_outputs(self): name_str = str(req.client_name) raise ValueError(f"{name_str} has not been connected.") + def cleanup(self): + """Null out all the references to all the ops to allow them to be GC'd.""" + for req in self._output_reqs: + req.cleanup() + self._req = None + self._rec = None + class ServiceImplementationModuleBuilder(ModuleLikeBuilderBase): """Define how to build ESI service implementations. Unlike Modules, there is @@ -243,7 +260,8 @@ def instantiate(self, impl, inputs: Dict[str, Signal], appid: AppID): impl_opts=opts, loc=self.loc) - def generate_svc_impl(self, serviceReq: raw_esi.ServiceImplementReqOp, + def generate_svc_impl(self, sys: System, + serviceReq: raw_esi.ServiceImplementReqOp, record_op: raw_esi.ServiceImplRecordOp) -> bool: """"Generate the service inline and replace the `ServiceInstanceOp` which is being implemented.""" @@ -251,7 +269,7 @@ def generate_svc_impl(self, serviceReq: raw_esi.ServiceImplementReqOp, assert len(self.generators) == 1 generator: Generator = list(self.generators.values())[0] ports = self.generator_port_proxy(serviceReq.operation.operands, self) - with self.GeneratorCtxt(self, ports, serviceReq, generator.loc): + with sys, self.GeneratorCtxt(self, ports, serviceReq, generator.loc): # Run the generator. bundles = _ServiceGeneratorBundles(self, serviceReq, record_op) @@ -268,7 +286,17 @@ def generate_svc_impl(self, serviceReq: raw_esi.ServiceImplementReqOp, for idx, port_value in enumerate(ports._output_values): msft.replaceAllUsesWith(serviceReq.operation.results[idx], port_value.value) - serviceReq.operation.erase() + + # Erase the service request op so as to avoid bundles with no consumers. + serviceReq.operation.erase() + + # The service implementation generator could have instantiated new modules, + # so we need to generate them. Don't run the appID indexer since during a + # pass, the IR can be invalid and the indexers assumes it is valid. + sys.generate(skip_appid_index=True) + # Now that the bundles should be assigned, we can cleanup the bundles and + # delete the service request op. + bundles.cleanup() return rc @@ -339,14 +367,9 @@ def _implement_service(self, req: ir.Operation, decl: ir.Operation, if impl_name not in self._registry: return False (impl, sys) = self._registry[impl_name] - with sys: - ret = impl._builder.generate_svc_impl(serviceReq=req.opview, - record_op=rec.opview) - # The service implementation generator could have instantiated new modules, - # so we need to generate them. Don't run the appID indexer since during a - # pass, the IR can be invalid and the indexers assumes it is valid. - sys.generate(skip_appid_index=True) - return ret + return impl._builder.generate_svc_impl(sys, + serviceReq=req.opview, + record_op=rec.opview) _service_generator_registry = _ServiceGeneratorRegistry() @@ -514,9 +537,11 @@ def _op(sym_name: ir.StringAttr): class _HostMem(ServiceDecl): """ESI standard service to request read or write access to host memory.""" + TagType = UInt(8) + ReadReqType = StructType([ ("address", UInt(64)), - ("tag", UInt(8)), + ("tag", TagType), ]) def __init__(self): @@ -531,14 +556,14 @@ def write_req_bundle_type(self, data_type: Type) -> Bundle: ]) return Bundle([ BundledChannel("req", ChannelDirection.FROM, write_req_type), - BundledChannel("ackTag", ChannelDirection.TO, UInt(8)) + BundledChannel("ackTag", ChannelDirection.TO, _HostMem.TagType), ]) def write_req_channel_type(self, data_type: Type) -> StructType: """Return a write request struct type for 'data_type'.""" return StructType([ ("address", UInt(64)), - ("tag", UInt(8)), + ("tag", _HostMem.TagType), ("data", data_type), ]) diff --git a/frontends/PyCDE/src/pycde/signals.py b/frontends/PyCDE/src/pycde/signals.py index c6de3f9175f9..d92aa9f28c0f 100644 --- a/frontends/PyCDE/src/pycde/signals.py +++ b/frontends/PyCDE/src/pycde/signals.py @@ -13,7 +13,7 @@ from contextvars import ContextVar from functools import singledispatchmethod -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import re import numpy as np @@ -336,7 +336,7 @@ def __get_item__value(self, idx: BitVectorSignal) -> BitVectorSignal: return self.slice(idx, 1) @staticmethod - def concat(items: List[BitVectorSignal]): + def concat(items: Iterable[BitVectorSignal]): """Concatenate a list of bitvectors into one larger bitvector.""" from .dialects import comb return comb.ConcatOp(*items) diff --git a/frontends/PyCDE/test/test_esi.py b/frontends/PyCDE/test/test_esi.py index 2644670b4691..460220f330c1 100644 --- a/frontends/PyCDE/test/test_esi.py +++ b/frontends/PyCDE/test/test_esi.py @@ -1,7 +1,6 @@ # RUN: rm -rf %t # RUN: %PYTHON% %s %t 2>&1 | FileCheck %s -import unittest from pycde import (Clock, Input, InputChannel, Output, OutputChannel, Module, Reset, generator, types) from pycde import esi diff --git a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp index 71a24bfe7907..c99f95e419ba 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp @@ -451,9 +451,9 @@ class CosimHostMem : public HostMem { // Setup the read side callback. ChannelDesc readArg, readResp; - if (!rpcClient->getChannelDesc("__cosim_hostmem_read.arg", readArg) || - !rpcClient->getChannelDesc("__cosim_hostmem_read.result", readResp)) - throw std::runtime_error("Could not find HostMem channels"); + if (!rpcClient->getChannelDesc("__cosim_hostmem_read_req.data", readArg) || + !rpcClient->getChannelDesc("__cosim_hostmem_read_resp.data", readResp)) + throw std::runtime_error("Could not find HostMem read channels"); const esi::Type *readRespType = getType(ctxt, new StructType(readResp.type(), @@ -465,23 +465,22 @@ class CosimHostMem : public HostMem { {"length", new UIntType("ui32", 32)}, {"tag", new UIntType("ui8", 8)}})); - // Get ports, create the function, then connect to it. + // Get ports. Unfortunately, we can't model this as a callback since there + // will sometimes be multiple responses per request. readRespPort = std::make_unique( rpcClient->stub.get(), readResp, readRespType, - "__cosim_hostmem_read.result"); + "__cosim_hostmem_read_resp.data"); readReqPort = std::make_unique( rpcClient->stub.get(), readArg, readReqType, - "__cosim_hostmem_read.arg"); - read.reset(CallService::Callback::get(acc, AppID("__cosim_hostmem_read"), - *readRespPort, *readReqPort)); - read->connect([this](const MessageData &req) { return serviceRead(req); }, - true); + "__cosim_hostmem_read_req.data"); + readReqPort->connect( + [this](const MessageData &req) { return serviceRead(req); }); // Setup the write side callback. ChannelDesc writeArg, writeResp; if (!rpcClient->getChannelDesc("__cosim_hostmem_write.arg", writeArg) || !rpcClient->getChannelDesc("__cosim_hostmem_write.result", writeResp)) - throw std::runtime_error("Could not find HostMem channels"); + throw std::runtime_error("Could not find HostMem write channels"); const esi::Type *writeRespType = getType(ctxt, new UIntType(writeResp.type(), 8)); @@ -506,7 +505,7 @@ class CosimHostMem : public HostMem { // Service the read request as a callback. Simply reads the data from the // location specified. TODO: check that the memory has been mapped. - MessageData serviceRead(const MessageData &reqBytes) { + bool serviceRead(const MessageData &reqBytes) { const HostMemReadReq *req = reqBytes.as(); acc.getLogger().debug( [&](std::string &subsystem, std::string &msg, @@ -516,16 +515,20 @@ class CosimHostMem : public HostMem { " len=" + std::to_string(req->length) + " tag=" + std::to_string(req->tag); }); + // Send one response per 8 bytes. uint64_t *dataPtr = reinterpret_cast(req->address); - HostMemReadResp resp{.data = *dataPtr, .tag = req->tag}; - acc.getLogger().debug( - [&](std::string &subsystem, std::string &msg, - std::unique_ptr> &details) { - subsystem = "HostMem"; - msg = "Read result: data=0x" + toHex(resp.data) + - " tag=" + std::to_string(resp.tag); - }); - return MessageData::from(resp); + for (uint32_t i = 0, e = (req->length + 7) / 8; i < e; ++i) { + HostMemReadResp resp{.data = dataPtr[i], .tag = req->tag}; + acc.getLogger().debug( + [&](std::string &subsystem, std::string &msg, + std::unique_ptr> &details) { + subsystem = "HostMem"; + msg = "Read result: data=0x" + toHex(resp.data) + + " tag=" + std::to_string(resp.tag); + }); + readRespPort->write(MessageData::from(resp)); + } + return true; } // Service a write request as a callback. Simply write the data to the diff --git a/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp b/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp index fedf99fe86d3..aaa4c7e665df 100644 --- a/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp +++ b/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp @@ -103,26 +103,32 @@ void dmaTest(AcceleratorConnection *conn, Accelerator *acc) { // Enable the host memory service. auto hostmem = conn->getService(); hostmem->start(); - auto scratchRegion = hostmem->allocate(8, /*memOpts=*/{}); + auto scratchRegion = hostmem->allocate(16, /*memOpts=*/{}); uint64_t *dataPtr = static_cast(scratchRegion->getPtr()); // Initiate a test read. auto *readMem = acc->getPorts().at(AppID("ReadMem")).getAs(); - *dataPtr = 0x12345678; - readMem->write(8, (uint64_t)dataPtr); - // Wait for the accelerator to read the correct value. Timeout and fail after - // 10ms. - uint64_t val = 0; - for (int i = 0; i < 100; ++i) { - val = readMem->read(0); - if (val == *dataPtr) - break; - std::this_thread::sleep_for(std::chrono::microseconds(100)); + for (size_t i = 0; i < 8; ++i) { + dataPtr[0] = 0x12345678 << i; + dataPtr[1] = 0xDEADBEEF << i; + readMem->write(8, (uint64_t)dataPtr); + + // Wait for the accelerator to read the correct value. Timeout and fail + // after 10ms. + uint64_t val = 0; + for (int i = 0; i < 100; ++i) { + val = readMem->read(0); + if (val == *dataPtr) + break; + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + if (val != *dataPtr) + throw std::runtime_error("DMA read test failed. Expected " + + std::to_string(*dataPtr) + ", got " + + std::to_string(val)); } - if (val != *dataPtr) - throw std::runtime_error("DMA read test failed"); // Initiate a test write. auto *writeMem =