Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ESI] BSP: variable sized reads and gearboxing read responses #8095

Merged
merged 2 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions frontends/PyCDE/integration_test/esitester.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,19 @@ def construct(ports):
clk=ports.clk,
rst=ports.rst,
rst_value=0,
ce=read_loc_ce)
ce=read_loc_ce,
name="read_loc")
read_loc.assign(cmd.data.as_uint())

mem_data_ce = Wire(Bits(1))
mem_data = Reg(Bits(64),
mem_data = Reg(Bits(96),
clk=ports.clk,
rst=ports.rst,
rst_value=0,
ce=mem_data_ce)
ce=mem_data_ce,
name="mem_data")

response_data = mem_data
response_data = mem_data.as_bits(64)
response_chan, response_ready = Channel(Bits(64)).wrap(
response_data, cmd_valid)
resp_ready_wire.assign(response_ready)
Expand All @@ -97,7 +99,10 @@ def construct(ports):
mmio_rw_cmd_chan = mmio_rw.unpack(data=response_chan)['cmd']
cmd_chan_wire.assign(mmio_rw_cmd_chan)

tag = Counter(8)(clk=ports.clk, rst=ports.rst, increment=mmio_xact)
tag = Counter(8)(clk=ports.clk,
rst=ports.rst,
clear=Bits(1)(0),
increment=mmio_xact)

hostmem_read_req, hostmem_read_req_ready = Channel(
esi.HostMem.ReadReqType).wrap({
Expand All @@ -107,7 +112,7 @@ def construct(ports):

hostmem_read_resp = esi.HostMem.read(appid=AppID("ReadMem_hostread"),
req=hostmem_read_req,
data_type=Bits(64))
data_type=mem_data.type)
hostmem_read_resp_data, hostmem_read_resp_valid = hostmem_read_resp.unwrap(
1)
mem_data.assign(hostmem_read_resp_data.data)
Expand Down Expand Up @@ -144,10 +149,14 @@ def construct(ports):
mmio_rw_cmd_chan = mmio_rw.unpack(data=response_chan)['cmd']
cmd_chan_wire.assign(mmio_rw_cmd_chan)

tag = Counter(8)(clk=ports.clk, rst=ports.rst, increment=mmio_xact)
tag = Counter(8)(clk=ports.clk,
rst=ports.rst,
clear=Bits(1)(0),
increment=mmio_xact)

cycle_counter = Counter(64)(clk=ports.clk,
rst=ports.rst,
clear=Bits(1)(0),
increment=Bits(1)(1))

hostmem_write_req, _ = esi.HostMem.wrap_write_req(
Expand Down
261 changes: 199 additions & 62 deletions frontends/PyCDE/src/pycde/bsp/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations
from math import ceil

from ..common import Clock, Input, Output, Reset
from ..constructs import AssignableSignal, NamedWire, Wire
from ..common import Clock, Input, InputChannel, Output, OutputChannel, Reset
from ..constructs import (AssignableSignal, ControlReg, Counter, NamedWire, Reg,
Wire)
from .. import esi
from ..module import Module, generator, modparams
from ..signals import BitsSignal, BundleSignal, ChannelSignal
from ..support import clog2
from ..types import (Array, Bits, Bundle, BundledChannel, Channel,
ChannelDirection, StructType, UInt)
ChannelDirection, StructType, Type, UInt)

from typing import Dict, List, Tuple
import typing
Expand Down Expand Up @@ -255,6 +257,187 @@ def build_addr_read(
return sel_bits, client_addr_chan


@modparams
def TaggedGearbox(input_bitwidth: int,
output_bitwidth: int) -> type["TaggedGearboxImpl"]:
"""Build a gearbox to convert the upstream data to the client data
type. Assumes a struct {tag, data} and only gearboxes the data. Tag is stored
separately and the struct is re-assembled later on."""

class TaggedGearboxImpl(Module):
clk = Clock()
rst = Reset()
in_ = InputChannel(
StructType([
("tag", esi.HostMem.TagType),
("data", Bits(input_bitwidth)),
]))
out = OutputChannel(
StructType([
("tag", esi.HostMem.TagType),
("data", Bits(output_bitwidth)),
]))

@generator
def build(ports):
ready_for_upstream = Wire(Bits(1), name="ready_for_upstream")
upstream_tag_and_data, upstream_valid = ports.in_.unwrap(
ready_for_upstream)
upstream_data = upstream_tag_and_data.data
upstream_xact = ready_for_upstream & upstream_valid

# Determine if gearboxing is necessary and whether it needs to be
# gearboxed up or just sliced down.
if output_bitwidth == input_bitwidth:
client_data_bits = upstream_data
client_valid = upstream_valid
elif output_bitwidth < input_bitwidth:
client_data_bits = upstream_data[:input_bitwidth]
client_valid = upstream_valid
else:
# Create registers equal to the number of upstream transactions needed
# to fill the client data. Set the output to the concatenation of said
# registers.
chunks = ceil(output_bitwidth / input_bitwidth)
reg_ces = [Wire(Bits(1)) for _ in range(chunks)]
regs = [
upstream_data.reg(ports.clk,
ports.rst,
ce=reg_ces[idx],
name=f"chunk_reg_{idx}") for idx in range(chunks)
]
client_data_bits = BitsSignal.concat(reversed(regs))[:output_bitwidth]

# Use counter to determine to which register to write and determine if
# the registers are all full.
clear_counter = Wire(Bits(1))
counter_width = clog2(chunks)
counter = Counter(counter_width)(clk=ports.clk,
rst=ports.rst,
clear=clear_counter,
increment=upstream_xact)
set_client_valid = counter.out == chunks - 1
client_xact = Wire(Bits(1))
client_valid = ControlReg(ports.clk, ports.rst, [set_client_valid],
[client_xact])
client_xact.assign(client_valid & ready_for_upstream)
clear_counter.assign(client_xact)
for idx, reg_ce in enumerate(reg_ces):
reg_ce.assign(upstream_xact &
(counter.out == UInt(counter_width)(idx)))

# Construct the output channel. Shared logic across all three cases.
tag_reg = upstream_tag_and_data.tag.reg(ports.clk,
ports.rst,
ce=upstream_xact,
name="tag_reg")
client_channel, client_ready = TaggedGearboxImpl.out.type.wrap(
{
"tag": tag_reg,
"data": client_data_bits,
}, client_valid)
ready_for_upstream.assign(client_ready)
ports.out = client_channel

return TaggedGearboxImpl


def HostmemReadProcessor(read_width: int, hostmem_module,
reqs: List[esi._OutputBundleSetter]):
"""Construct a host memory read request module to orchestrate the the read
connections. Responsible for both gearboxing the data, multiplexing the
requests, reassembling out-of-order responses and routing the responses to the
correct clients.

Generate this module dynamically to allow for multiple read clients of
multiple types to be directly accomodated."""

class HostmemReadProcessorImpl(Module):
clk = Clock()
rst = Reset()

# Add an output port for each read client.
reqPortMap: Dict[esi._OutputBundleSetter, str] = {}
for req in reqs:
name = "client_" + req.client_name_str
locals()[name] = Output(req.type)
reqPortMap[req] = name

# And then the port which goes to the host.
upstream = Output(hostmem_module.read.type)

@generator
def build(ports):
"""Build the read side of the HostMem service."""

# If there's no read clients, just return a no-op read bundle.
if len(reqs) == 0:
upstream_req_channel, _ = Channel(hostmem_module.UpstreamReadReq).wrap(
{
"tag": 0,
"length": 0,
"address": 0
}, 0)
upstream_read_bundle, _ = hostmem_module.read.type.pack(
req=upstream_req_channel)
ports.upstream = upstream_read_bundle
return

# TODO: mux together multiple read clients.
assert len(reqs) == 1, "Only one read client supported for now."

# Pack the upstream bundle and leave the request as a wire.
upstream_req_channel = Wire(Channel(hostmem_module.UpstreamReadReq))
upstream_read_bundle, froms = hostmem_module.read.type.pack(
req=upstream_req_channel)
ports.upstream = upstream_read_bundle
upstream_resp_channel = froms["resp"]

for client in reqs:
# Find the response channel in the request bundle.
resp_type = [
c.channel for c in client.type.channels if c.name == 'resp'
][0]
# TODO: route the response to the correct client.
# TODO: tag re-writing to deal with tag aliasing.
# Pretend to demux the upstream response channel.
demuxed_upstream_channel = upstream_resp_channel

# TODO: Should responses come back out-of-order (interleaved tags),
# re-order them here so the gearbox doesn't get confused. (Longer term.)
# For now, only support one outstanding transaction at a time. This has
# the additional benefit of letting the upstream tag be the client
# identifier. TODO: Implement the gating logic here.

# Gearbox the data to the client's data type.
client_type = resp_type.inner_type
gearbox = TaggedGearbox(read_width, client_type.data.bitwidth)(
clk=ports.clk, rst=ports.rst, in_=demuxed_upstream_channel)
client_resp_channel = gearbox.out.transform(lambda m: client_type({
"tag": m.tag,
"data": m.data.bitcast(client_type.data)
}))

# Assign the client response to the correct port.
client_bundle, froms = client.type.pack(resp=client_resp_channel)
client_req = froms["req"]
# Set the port for the client request.
setattr(ports, HostmemReadProcessorImpl.reqPortMap[client],
client_bundle)

# Assign the multiplexed read request to the upstream request.
# TODO: mux together multiple read clients.
upstream_req_channel.assign(
client_req.transform(lambda r: hostmem_module.UpstreamReadReq({
"address": r.address,
"length": (client_type.data.bitwidth + 7) // 8,
"tag": r.tag
})))
HostmemReadProcessorImpl.reqPortMap.clear()

return HostmemReadProcessorImpl


@modparams
def ChannelHostMem(read_width: int,
write_width: int) -> typing.Type['ChannelHostMemImpl']:
Expand All @@ -268,7 +451,7 @@ class ChannelHostMemImpl(esi.ServiceImplementation):

UpstreamReadReq = StructType([
("address", UInt(64)),
("length", UInt(32)),
("length", UInt(32)), # In bytes.
("tag", UInt(8)),
])
read = Output(
Expand All @@ -277,7 +460,7 @@ class ChannelHostMemImpl(esi.ServiceImplementation):
BundledChannel(
"resp", ChannelDirection.FROM,
StructType([
("tag", UInt(8)),
("tag", esi.HostMem.TagType),
("data", Bits(read_width)),
])),
]))
Expand All @@ -294,8 +477,18 @@ class ChannelHostMemImpl(esi.ServiceImplementation):

@generator
def generate(ports, bundles: esi._ServiceGeneratorBundles):
# Split the read side out into a separate module. Must assign the output
# ports to the clients since we can't service a request in a different
# module.
read_reqs = [req for req in bundles.to_client_reqs if req.port == 'read']
ports.read = ChannelHostMemImpl.build_tagged_read_mux(ports, read_reqs)
read_proc_module = HostmemReadProcessor(read_width, ChannelHostMemImpl,
read_reqs)
read_proc = read_proc_module(clk=ports.clk, rst=ports.rst)
ports.read = read_proc.upstream
for req in read_reqs:
req.assign(getattr(read_proc, read_proc_module.reqPortMap[req]))

# The write side.
write_reqs = [
req for req in bundles.to_client_reqs if req.port == 'write'
]
Expand Down Expand Up @@ -353,60 +546,4 @@ def build_tagged_write_mux(
write_acks[0].assign(ack_tag)
return upstream_write_bundle

@staticmethod
def build_tagged_read_mux(
ports, reqs: List[esi._OutputBundleSetter]) -> BundleSignal:
"""Build the read side of the HostMem service."""

if len(reqs) == 0:
req, req_ready = Channel(ChannelHostMemImpl.UpstreamReadReq).wrap(
{
"tag": 0,
"length": 0,
"address": 0
}, 0)
read_bundle, _ = ChannelHostMemImpl.read.type.pack(req=req)
return read_bundle

# TODO: mux together multiple read clients.
assert len(reqs) == 1, "Only one read client supported for now."

req = Wire(Channel(ChannelHostMemImpl.UpstreamReadReq))
read_bundle, froms = ChannelHostMemImpl.read.type.pack(req=req)
resp_chan_ready = Wire(Bits(1))
resp_data, resp_valid = froms["resp"].unwrap(resp_chan_ready)
for client in reqs:
resp_type = [
c.channel for c in client.type.channels if c.name == 'resp'
][0]
client_read_type = resp_type.inner_type.data
# TODO: support gearboxing up to the correct width.
assert client_read_type.width == read_width, \
"Gearboxing not yet supported."
client_tag = resp_data.tag
client_resp_valid = resp_valid
client_resp, client_resp_ready = Channel(resp_type).wrap(
{
# TODO: tag re-writing to deal with tag aliasing.
"tag": client_tag,
"data": resp_data.data.bitcast(client_read_type),
},
client_resp_valid)
# TODO: mux this properly.
resp_chan_ready.assign(client_resp_ready)

client_bundle, froms = client.type.pack(resp=client_resp)
client_req = froms["req"]
client.assign(client_bundle)

# Assign the multiplexed read request to the upstream request.
req.assign(
client_req.transform(lambda r: ChannelHostMemImpl.UpstreamReadReq({
"address": r.address,
"length": 1,
"tag": r.tag
})))

return read_bundle

return ChannelHostMemImpl
19 changes: 9 additions & 10 deletions frontends/PyCDE/src/pycde/bsp/cosim.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,15 @@ def build(ports):
appid=esi.AppID("__cosim_hostmem"),
clk=ports.clk,
rst=ports.rst)
resp_wire = Wire(
Channel(
StructType([
("tag", UInt(8)),
("data", Bits(ESI_Cosim_UserTopWrapper.HostMemWidth)),
])))
req = hostmem.read.unpack(resp=resp_wire)['req']
data = esi.CallService.call(esi.AppID("__cosim_hostmem_read"), req,
resp_wire.type)
resp_wire.assign(data)

resp_channel = esi.ChannelService.from_host(
esi.AppID("__cosim_hostmem_read_resp"),
StructType([
("tag", UInt(8)),
("data", Bits(ESI_Cosim_UserTopWrapper.HostMemWidth)),
]))
req = hostmem.read.unpack(resp=resp_channel)['req']
esi.ChannelService.to_host(esi.AppID("__cosim_hostmem_read_req"), req)

ack_wire = Wire(Channel(UInt(8)))
write_req = hostmem.write.unpack(ackTag=ack_wire)['req']
Expand Down
Loading
Loading