From 61d5cd6755b43523dc2be08f4c4e22995fed1a34 Mon Sep 17 00:00:00 2001 From: John Demme Date: Fri, 31 Jan 2025 23:58:47 +0000 Subject: [PATCH] [ESI Runtime] Pluggable channel engines Introduces the concept of 'engines' -- things which are responsible for transmitting/recieving messages over a channel. Since this is a new concept (and was previously done only in cosim), add a new section to the manifest. Re-use a bunch of stuff which supports services. Hacky, but it works. The entire manifest thing could use a re-think and second iteration. It is, however, make-it-work time. --- include/circt/Dialect/ESI/ESIInterfaces.td | 2 +- include/circt/Dialect/ESI/ESIManifest.td | 3 +- .../Dialect/ESI/runtime/loopback.mlir | 2 - lib/Dialect/ESI/ESIOps.cpp | 4 +- lib/Dialect/ESI/ESIServices.cpp | 6 +- lib/Dialect/ESI/runtime/CMakeLists.txt | 1 + .../ESI/runtime/cpp/include/esi/Accelerator.h | 27 ++- .../ESI/runtime/cpp/include/esi/Engines.h | 117 ++++++++++++ .../ESI/runtime/cpp/include/esi/Ports.h | 14 +- .../ESI/runtime/cpp/include/esi/Services.h | 80 +++++---- .../ESI/runtime/cpp/include/esi/Types.h | 9 +- .../runtime/cpp/include/esi/backends/Cosim.h | 17 +- .../runtime/cpp/include/esi/backends/Trace.h | 11 +- .../runtime/cpp/include/esi/backends/Xrt.h | 7 - .../ESI/runtime/cpp/lib/Accelerator.cpp | 29 +++ lib/Dialect/ESI/runtime/cpp/lib/Engines.cpp | 73 ++++++++ lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp | 83 +++++---- lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp | 4 +- lib/Dialect/ESI/runtime/cpp/lib/Services.cpp | 149 +++++++--------- .../ESI/runtime/cpp/lib/backends/Cosim.cpp | 167 +++++++++++------- .../ESI/runtime/cpp/lib/backends/Trace.cpp | 90 +++++----- .../ESI/runtime/cpp/lib/backends/Xrt.cpp | 22 +-- .../ESI/runtime/cpp/tools/esitester.cpp | 5 + test/Dialect/ESI/manifest.mlir | 4 +- test/Dialect/ESI/services.mlir | 4 +- 25 files changed, 600 insertions(+), 330 deletions(-) create mode 100644 lib/Dialect/ESI/runtime/cpp/include/esi/Engines.h create mode 100644 lib/Dialect/ESI/runtime/cpp/lib/Engines.cpp diff --git a/include/circt/Dialect/ESI/ESIInterfaces.td b/include/circt/Dialect/ESI/ESIInterfaces.td index 764a718ce6db..f2f5fe9cde2b 100644 --- a/include/circt/Dialect/ESI/ESIInterfaces.td +++ b/include/circt/Dialect/ESI/ESIInterfaces.td @@ -64,7 +64,7 @@ def IsManifestData : OpInterface<"IsManifestData"> { }]; let methods = [ - StaticInterfaceMethod< + InterfaceMethod< "Get the class name for this op.", "StringRef", "getManifestClass", (ins) >, diff --git a/include/circt/Dialect/ESI/ESIManifest.td b/include/circt/Dialect/ESI/ESIManifest.td index d6a18749032e..a80e68baa0cf 100644 --- a/include/circt/Dialect/ESI/ESIManifest.td +++ b/include/circt/Dialect/ESI/ESIManifest.td @@ -99,6 +99,7 @@ def ServiceImplRecordOp : ESI_Op<"manifest.service_impl", [ }]; let arguments = (ins AppIDAttr:$appID, + DefaultValuedAttr:$isEngine, OptionalAttr:$service, OptionalAttr:$stdService, StrAttr:$serviceImplName, @@ -106,7 +107,7 @@ def ServiceImplRecordOp : ESI_Op<"manifest.service_impl", [ let regions = (region SizedRegion<1>:$reqDetails); let assemblyFormat = [{ qualified($appID) (`svc` $service^)? (`std` $stdService^)? - `by` $serviceImplName `with` $implDetails + `by` $serviceImplName (`engine` $isEngine^)? `with` $implDetails attr-dict-with-keyword custom($reqDetails) }]; diff --git a/integration_test/Dialect/ESI/runtime/loopback.mlir b/integration_test/Dialect/ESI/runtime/loopback.mlir index 59ee41311dd9..e955176ea233 100644 --- a/integration_test/Dialect/ESI/runtime/loopback.mlir +++ b/integration_test/Dialect/ESI/runtime/loopback.mlir @@ -143,8 +143,6 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // QUERY-HIER: * Instance:top // QUERY-HIER: * Ports: // QUERY-HIER: internal_write: -// QUERY-HIER: ack: !esi.channel -// QUERY-HIER: req: !esi.channel> // QUERY-HIER: func1: function i16(i16) // QUERY-HIER: structFunc: function !hw.struct(!hw.struct) // QUERY-HIER: arrayFunc: function !hw.array<2xsi8>(!hw.array<1xsi8>) diff --git a/lib/Dialect/ESI/ESIOps.cpp b/lib/Dialect/ESI/ESIOps.cpp index 6e54eb33b3dd..ea01c650583f 100644 --- a/lib/Dialect/ESI/ESIOps.cpp +++ b/lib/Dialect/ESI/ESIOps.cpp @@ -768,7 +768,9 @@ void ESIPureModuleOp::setHWModuleType(hw::ModuleType type) { // Manifest ops. //===----------------------------------------------------------------------===// -StringRef ServiceImplRecordOp::getManifestClass() { return "service"; } +StringRef ServiceImplRecordOp::getManifestClass() { + return getIsEngine() ? "engine" : "service"; +} void ServiceImplRecordOp::getDetails(SmallVectorImpl &results) { auto *ctxt = getContext(); diff --git a/lib/Dialect/ESI/ESIServices.cpp b/lib/Dialect/ESI/ESIServices.cpp index 97d2945c8622..982e9fce1750 100644 --- a/lib/Dialect/ESI/ESIServices.cpp +++ b/lib/Dialect/ESI/ESIServices.cpp @@ -58,6 +58,7 @@ instantiateCosimEndpointOps(ServiceImplementReqOp implReq, } Block &connImplBlock = implRecord.getReqDetails().front(); + implRecord.setIsEngine(true); OpBuilder implRecords = OpBuilder::atBlockEnd(&connImplBlock); // Assemble the name to use for an endpoint. @@ -300,8 +301,9 @@ ServiceGeneratorDispatcher::generate(ServiceImplementReqOp req, // the generator for possible modification. OpBuilder b(req); auto implRecord = b.create( - req.getLoc(), req.getAppID(), req.getServiceSymbolAttr(), - req.getStdServiceAttr(), req.getImplTypeAttr(), b.getDictionaryAttr({})); + req.getLoc(), req.getAppID(), /*isEngine=*/false, + req.getServiceSymbolAttr(), req.getStdServiceAttr(), + req.getImplTypeAttr(), b.getDictionaryAttr({})); implRecord.getReqDetails().emplaceBlock(); return genF->second(req, decl, implRecord); diff --git a/lib/Dialect/ESI/runtime/CMakeLists.txt b/lib/Dialect/ESI/runtime/CMakeLists.txt index d815cb7d13f0..0e345f38863b 100644 --- a/lib/Dialect/ESI/runtime/CMakeLists.txt +++ b/lib/Dialect/ESI/runtime/CMakeLists.txt @@ -100,6 +100,7 @@ set(ESICppRuntimeSources ${CMAKE_CURRENT_SOURCE_DIR}/cpp/lib/Context.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/lib/Common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/lib/Design.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/lib/Engines.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/lib/Logging.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/lib/Manifest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/lib/Services.cpp diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Accelerator.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Accelerator.h index 04c4b1b4d0d6..c4dd4117b21f 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Accelerator.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Accelerator.h @@ -23,6 +23,7 @@ #include "esi/Context.h" #include "esi/Design.h" +#include "esi/Engines.h" #include "esi/Manifest.h" #include "esi/Ports.h" #include "esi/Services.h" @@ -92,11 +93,6 @@ class AcceleratorConnection { // each level of the tree. using ServiceTable = std::map; - /// Request the host side channel ports for a particular instance (identified - /// by the AppID path). For convenience, provide the bundle type. - virtual std::map - requestChannelsFor(AppIDPath, const BundleType *, const ServiceTable &) = 0; - /// Return a pointer to the accelerator 'service' thread (or threads). If the /// thread(s) are not running, they will be started when this method is /// called. `std::thread` is used. If users don't want the runtime to spin up @@ -126,7 +122,23 @@ class AcceleratorConnection { /// accelerator to this connection. Returns a raw pointer to the object. Accelerator *takeOwnership(std::unique_ptr accel); + /// Create a new engine for channel communication with the accelerator. The + /// default is to call the global `createEngine` to get an engine which has + /// registered itself. Individual accelerator connection backends can override + /// this to customize behavior. + virtual void createEngine(const std::string &engineTypeName, AppIDPath idPath, + const ServiceImplDetails &details, + const HWClientDetails &clients); + virtual const BundleEngineMap &getEngineMapFor(AppIDPath id) { + return clientEngines[id]; + } + protected: + /// If `createEngine` is overridden, this method should be called to register + /// the engine and all of the channels it services. + void registerEngine(AppIDPath idPath, std::unique_ptr engine, + const HWClientDetails &clients); + /// Called by `getServiceImpl` exclusively. It wraps the pointer returned by /// this in a unique_ptr and caches it. Separate this from the /// wrapping/caching since wrapping/caching is an implementation detail. @@ -135,6 +147,11 @@ class AcceleratorConnection { const ServiceImplDetails &details, const HWClientDetails &clients) = 0; + /// Collection of owned engines. + std::map> ownedEngines; + /// Mapping of clients to their servicing engines. + std::map clientEngines; + private: /// ESI accelerator context. Context &ctxt; diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Engines.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Engines.h new file mode 100644 index 000000000000..612acabb5d29 --- /dev/null +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Engines.h @@ -0,0 +1,117 @@ +//===- Engines.h - Implement port communication -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// DO NOT EDIT! +// This file is distributed as part of an ESI package. The source for this file +// should always be modified within CIRCT. +// +//===----------------------------------------------------------------------===// +// +// Engines (as in DMA engine) implement the actual communication between the +// host and the accelerator. They are low level of the ESI runtime API and are +// not intended to be used directly by users. +// +// They are called "engines" rather than "DMA engines" since communication need +// not be implemented via DMA. +// +//===----------------------------------------------------------------------===// + +// NOLINTNEXTLINE(llvm-header-guard) +#ifndef ESI_ENGINGES_H +#define ESI_ENGINGES_H + +#include "esi/Common.h" +#include "esi/Ports.h" +#include "esi/Services.h" +#include "esi/Utils.h" + +#include +#include + +namespace esi { + +/// Engines implement the actual channel communication between the host and the +/// accelerator. Engines can support multiple channels. They are low level of +/// the ESI runtime API and are not intended to be used directly by users. +class Engine { +public: + virtual ~Engine() = default; + /// Start the engine, if applicable. + virtual void connect(){}; + /// Stop the engine, if applicable. + virtual void disconnect(){}; + /// Get a port for a channel, from the cache if it exists or create it. An + /// engine may override this method if different behavior is desired. + virtual ChannelPort &requestPort(AppIDPath idPath, + const std::string &channelName, + BundleType::Direction dir, const Type *type); + +protected: + /// Each engine needs to know how to create a ports. This method is called if + /// a port doesn't exist in the engine cache. + virtual std::unique_ptr + createPort(AppIDPath idPath, const std::string &channelName, + BundleType::Direction dir, const Type *type) = 0; + +private: + std::map, std::unique_ptr> + ownedPorts; +}; + +/// Since engines can support multiple channels BUT not necessarily all of the +/// channels in a bundle, a mapping from bundle channels to engines is needed. +class BundleEngineMap { + friend class AcceleratorConnection; + +public: + /// Request ports for all the channels in a bundle. If the engine doesn't + /// exist for a particular channel, skip said channel. + PortMap requestPorts(const AppIDPath &idPath, + const BundleType *bundleType) const; + +private: + /// Set a particlar engine for a particular channel. Should only be called by + /// AcceleratorConnection while registering engines. + void setEngine(const std::string &channelName, Engine *engine); + std::map bundleEngineMap; +}; + +namespace registry { + +/// Create an engine by name. This is the primary way to create engines for +/// "normal" backends. +std::unique_ptr createEngine(AcceleratorConnection &conn, + const std::string &dmaEngineName, + AppIDPath idPath, + const ServiceImplDetails &details, + const HWClientDetails &clients); + +namespace internal { + +/// Engines can register themselves for pluggable functionality. +using EngineCreate = std::function( + AcceleratorConnection &conn, AppIDPath idPath, + const ServiceImplDetails &details, const HWClientDetails &clients)>; +void registerEngine(const std::string &name, EngineCreate create); + +/// Helper struct to register engines. +template +struct RegisterEngine { + RegisterEngine(const char *name) { registerEngine(name, &TEngine::create); } +}; + +#define REGISTER_ENGINE(Name, TEngine) \ + static ::esi::registry::internal::RegisterEngine \ + __register_engine____LINE__(Name) + +} // namespace internal +} // namespace registry + +} // namespace esi + +#endif // ESI_PORTS_H diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h index c6fdf466428c..02cf39e07cd5 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h @@ -25,6 +25,9 @@ namespace esi { +class ChannelPort; +using PortMap = std::map; + /// Unidirectional channels are the basic communication primitive between the /// host and accelerator. A 'ChannelPort' is the host side of a channel. It can /// be either read or write but not both. At this level, channels are untyped -- @@ -190,7 +193,7 @@ class BundlePort { } /// Construct a port. - BundlePort(AppID id, std::map channels); + BundlePort(AppID id, const BundleType *type, PortMap channels); virtual ~BundlePort() = default; /// Get the ID of the port. @@ -202,9 +205,7 @@ class BundlePort { /// ordinary users should not use. You have been warned. WriteChannelPort &getRawWrite(const std::string &name) const; ReadChannelPort &getRawRead(const std::string &name) const; - const std::map &getChannels() const { - return channels; - } + const PortMap &getChannels() const { return channels; } /// Cast this Bundle port to a subclass which is actually useful. Returns /// nullptr if the cast fails. @@ -224,9 +225,10 @@ class BundlePort { return result; } -private: +protected: AppID id; - std::map channels; + const BundleType *type; + PortMap channels; }; } // namespace esi diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Services.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Services.h index cb871b618a96..15ec362a3fbe 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Services.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Services.h @@ -28,6 +28,7 @@ namespace esi { class AcceleratorConnection; +class Engine; namespace services { /// Add a custom interface to a service client at a particular point in the @@ -45,6 +46,7 @@ class ServicePort : public BundlePort { class Service { public: using Type = const std::type_info &; + Service(AcceleratorConnection &conn) : conn(conn) {} virtual ~Service() = default; virtual std::string getServiceSymbol() const = 0; @@ -56,19 +58,21 @@ class Service { /// calling the `getService` method on `AcceleratorConnection` to get the /// global service, implying that the child service does not need to use the /// service it is replacing. - virtual Service *getChildService(AcceleratorConnection *conn, - Service::Type service, AppIDPath id = {}, + virtual Service *getChildService(Service::Type service, AppIDPath id = {}, std::string implName = {}, ServiceImplDetails details = {}, HWClientDetails clients = {}); /// Get specialized port for this service to attach to the given appid path. /// Null returns mean nothing to attach. - virtual ServicePort *getPort(AppIDPath id, const BundleType *type, - const std::map &, - AcceleratorConnection &) const { + virtual BundlePort *getPort(AppIDPath id, const BundleType *type) const { return nullptr; } + + AcceleratorConnection &getConnection() const { return conn; } + +protected: + AcceleratorConnection &conn; }; /// A service for which there are no standard services registered. Requires @@ -76,13 +80,16 @@ class Service { /// the ones in StdServices.h. class CustomService : public Service { public: - CustomService(AppIDPath idPath, const ServiceImplDetails &details, + CustomService(AppIDPath idPath, AcceleratorConnection &, + const ServiceImplDetails &details, const HWClientDetails &clients); virtual ~CustomService() = default; virtual std::string getServiceSymbol() const override { return serviceSymbol; } + virtual BundlePort *getPort(AppIDPath id, + const BundleType *type) const override; protected: std::string serviceSymbol; @@ -92,6 +99,7 @@ class CustomService : public Service { /// Information about the Accelerator system. class SysInfo : public Service { public: + using Service::Service; virtual ~SysInfo() = default; virtual std::string getServiceSymbol() const override; @@ -116,9 +124,7 @@ class MMIO : public Service { uint32_t size; }; - MMIO(Context &ctxt, AppIDPath idPath, std::string implName, - const ServiceImplDetails &details, const HWClientDetails &clients); - MMIO() = default; + MMIO(AcceleratorConnection &, const HWClientDetails &clients); virtual ~MMIO() = default; /// Read a 64-bit value from the global MMIO space. @@ -133,8 +139,7 @@ class MMIO : public Service { /// If the service is a MMIO service, return a region of the MMIO space which /// peers into ours. - virtual Service *getChildService(AcceleratorConnection *conn, - Service::Type service, AppIDPath id = {}, + virtual Service *getChildService(Service::Type service, AppIDPath id = {}, std::string implName = {}, ServiceImplDetails details = {}, HWClientDetails clients = {}) override; @@ -142,9 +147,8 @@ class MMIO : public Service { virtual std::string getServiceSymbol() const override; /// Get a MMIO region port for a particular region descriptor. - virtual ServicePort *getPort(AppIDPath id, const BundleType *type, - const std::map &, - AcceleratorConnection &) const override; + virtual BundlePort *getPort(AppIDPath id, + const BundleType *type) const override; private: /// MMIO base address table. @@ -195,6 +199,7 @@ class HostMem : public Service { public: static constexpr std::string_view StdName = "esi.service.std.hostmem"; + using Service::Service; virtual ~HostMem() = default; virtual std::string getServiceSymbol() const override; @@ -246,22 +251,20 @@ class HostMem : public Service { /// Service for calling functions. class FuncService : public Service { public: - FuncService(AcceleratorConnection *acc, AppIDPath id, - const std::string &implName, ServiceImplDetails details, + FuncService(AppIDPath id, AcceleratorConnection &, ServiceImplDetails details, HWClientDetails clients); virtual std::string getServiceSymbol() const override; - virtual ServicePort *getPort(AppIDPath id, const BundleType *type, - const std::map &, - AcceleratorConnection &) const override; + virtual BundlePort *getPort(AppIDPath id, + const BundleType *type) const override; /// A function call which gets attached to a service port. class Function : public ServicePort { friend class FuncService; - Function(AppID id, const std::map &channels); + using ServicePort::ServicePort; public: - static Function *get(AppID id, WriteChannelPort &arg, + static Function *get(AppID id, BundleType *type, WriteChannelPort &arg, ReadChannelPort &result); void connect(); @@ -269,16 +272,18 @@ class FuncService : public Service { virtual std::optional toString() const override { const esi::Type *argType = - dynamic_cast(arg.getType())->getInner(); + dynamic_cast(type->findChannel("arg").first) + ->getInner(); const esi::Type *resultType = - dynamic_cast(result.getType())->getInner(); + dynamic_cast(type->findChannel("result").first) + ->getInner(); return "function " + resultType->getID() + "(" + argType->getID() + ")"; } private: std::mutex callMutex; - WriteChannelPort &arg; - ReadChannelPort &result; + WriteChannelPort *arg; + ReadChannelPort *result; }; private: @@ -288,22 +293,21 @@ class FuncService : public Service { /// Service for servicing function calls from the accelerator. class CallService : public Service { public: - CallService(AcceleratorConnection *acc, AppIDPath id, std::string implName, - ServiceImplDetails details, HWClientDetails clients); + CallService(AcceleratorConnection &acc, AppIDPath id, + ServiceImplDetails details); virtual std::string getServiceSymbol() const override; - virtual ServicePort *getPort(AppIDPath id, const BundleType *type, - const std::map &, - AcceleratorConnection &) const override; + virtual BundlePort *getPort(AppIDPath id, + const BundleType *type) const override; /// A function call which gets attached to a service port. class Callback : public ServicePort { friend class CallService; - Callback(AcceleratorConnection &acc, AppID id, - const std::map &channels); + Callback(AcceleratorConnection &acc, AppID id, const BundleType *, + PortMap channels); public: - static Callback *get(AcceleratorConnection &acc, AppID id, + static Callback *get(AcceleratorConnection &acc, AppID id, BundleType *type, WriteChannelPort &result, ReadChannelPort &arg); /// Connect a callback to code which will be executed when the accelerator @@ -315,15 +319,17 @@ class CallService : public Service { virtual std::optional toString() const override { const esi::Type *argType = - dynamic_cast(arg.getType())->getInner(); + dynamic_cast(type->findChannel("arg").first) + ->getInner(); const esi::Type *resultType = - dynamic_cast(result.getType())->getInner(); + dynamic_cast(type->findChannel("result").first) + ->getInner(); return "callback " + resultType->getID() + "(" + argType->getID() + ")"; } private: - ReadChannelPort &arg; - WriteChannelPort &result; + ReadChannelPort *arg; + WriteChannelPort *result; AcceleratorConnection &acc; }; diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Types.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Types.h index ab3c39d0a74c..23ef991c205b 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Types.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Types.h @@ -16,8 +16,8 @@ #ifndef ESI_TYPES_H #define ESI_TYPES_H -#include #include +#include #include #include @@ -54,6 +54,13 @@ class BundleType : public Type { const ChannelVector &getChannels() const { return channels; } std::ptrdiff_t getBitWidth() const override { return -1; }; + std::pair findChannel(std::string name) const { + for (auto [channelName, dir, type] : channels) + if (channelName == name) + return std::make_pair(type, dir); + throw std::runtime_error("Channel '" + name + "' not found in bundle"); + } + protected: ChannelVector channels; }; diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Cosim.h b/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Cosim.h index a554e2237032..dd6ed8d09020 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Cosim.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Cosim.h @@ -32,9 +32,12 @@ class ChannelDesc; namespace backends { namespace cosim { +class CosimEngine; /// Connect to an ESI simulation. class CosimAccelerator : public esi::AcceleratorConnection { + friend class CosimEngine; + public: CosimAccelerator(Context &, std::string hostname, uint16_t port); ~CosimAccelerator(); @@ -50,18 +53,15 @@ class CosimAccelerator : public esi::AcceleratorConnection { // Set the way this connection will retrieve the manifest. void setManifestMethod(ManifestMethod method); - /// Request the host side channel ports for a particular instance (identified - /// by the AppID path). For convenience, provide the bundle type and direction - /// of the bundle port. - virtual std::map - requestChannelsFor(AppIDPath, const BundleType *, - const ServiceTable &) override; - // C++ doesn't have a mechanism to forward declare a nested class and we don't // want to include the generated header here. So we have to wrap it in a // forward-declared struct we write ourselves. struct StubContainer; + void createEngine(const std::string &engineTypeName, AppIDPath idPath, + const ServiceImplDetails &details, + const HWClientDetails &clients) override; + protected: virtual Service *createService(Service::Type service, AppIDPath path, std::string implName, @@ -74,9 +74,6 @@ class CosimAccelerator : public esi::AcceleratorConnection { // We own all channels connected to rpcClient since their lifetime is tied to // rpcClient. std::set> channels; - // Map from client path to channel assignments for that client. - std::map> - clientChannelAssignments; ManifestMethod manifestMethod = Cosim; }; diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Trace.h b/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Trace.h index c73835a9a5fa..42f7990d00a2 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Trace.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Trace.h @@ -64,14 +64,13 @@ class TraceAccelerator : public esi::AcceleratorConnection { /// Internal implementation. struct Impl; - - /// Request the host side channel ports for a particular instance (identified - /// by the AppID path). For convenience, provide the bundle type. - std::map - requestChannelsFor(AppIDPath, const BundleType *, - const ServiceTable &) override; + Impl &getImpl(); protected: + void createEngine(const std::string &engineTypeName, AppIDPath idPath, + const ServiceImplDetails &details, + const HWClientDetails &clients) override; + virtual Service *createService(Service::Type service, AppIDPath idPath, std::string implName, const ServiceImplDetails &details, diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Xrt.h b/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Xrt.h index 4186f04380f0..3c0efef2a564 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Xrt.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/backends/Xrt.h @@ -37,13 +37,6 @@ class XrtAccelerator : public esi::AcceleratorConnection { static std::unique_ptr connect(Context &, std::string connectionString); - /// Request the host side channel ports for a particular instance (identified - /// by the AppID path). For convenience, provide the bundle type and direction - /// of the bundle port. - std::map - requestChannelsFor(AppIDPath, const BundleType *, - const ServiceTable &) override; - protected: virtual Service *createService(Service::Type service, AppIDPath path, std::string implName, diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Accelerator.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Accelerator.cpp index 9df78be3f4e3..798fb489c385 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Accelerator.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Accelerator.cpp @@ -42,6 +42,35 @@ AcceleratorServiceThread *AcceleratorConnection::getServiceThread() { serviceThread = std::make_unique(); return serviceThread.get(); } +void AcceleratorConnection::createEngine(const std::string &engineTypeName, + AppIDPath idPath, + const ServiceImplDetails &details, + const HWClientDetails &clients) { + std::unique_ptr engine = ::esi::registry::createEngine( + *this, engineTypeName, idPath, details, clients); + registerEngine(idPath, std::move(engine), clients); +} + +void AcceleratorConnection::registerEngine(AppIDPath idPath, + std::unique_ptr engine, + const HWClientDetails &clients) { + assert(engine); + auto [engineIter, _] = ownedEngines.emplace(idPath, std::move(engine)); + + // Engine is now owned by the accelerator connection, so the std::unique_ptr + // is no longer valid. Resolve a new one from the map iter. + Engine *enginePtr = engineIter->second.get(); + // Compute our parents idPath path. + AppIDPath prefix = std::move(idPath); + if (prefix.size() > 0) + prefix.pop_back(); + + for (const auto &client : clients) { + AppIDPath fullClientPath = prefix + client.relPath; + for (const auto &channel : client.channelAssignments) + clientEngines[fullClientPath].setEngine(channel.first, enginePtr); + } +} services::Service *AcceleratorConnection::getService(Service::Type svcType, AppIDPath id, diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Engines.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Engines.cpp new file mode 100644 index 000000000000..f47bfb9a54a1 --- /dev/null +++ b/lib/Dialect/ESI/runtime/cpp/lib/Engines.cpp @@ -0,0 +1,73 @@ +//===- Engines.cpp --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// DO NOT EDIT! +// This file is distributed as part of an ESI package. The source for this file +// should always be modified within CIRCT (lib/dialect/ESI/runtime/cpp/). +// +//===----------------------------------------------------------------------===// + +#include "esi/Engines.h" + +using namespace esi; + +ChannelPort &Engine::requestPort(AppIDPath idPath, + const std::string &channelName, + BundleType::Direction dir, const Type *type) { + auto portIter = ownedPorts.find(std::make_pair(idPath, channelName)); + if (portIter != ownedPorts.end()) + return *portIter->second; + std::unique_ptr port = + createPort(idPath, channelName, dir, type); + ChannelPort &ret = *port; + ownedPorts.emplace(std::make_pair(idPath, channelName), std::move(port)); + return ret; +} + +PortMap BundleEngineMap::requestPorts(const AppIDPath &idPath, + const BundleType *bundleType) const { + PortMap ports; + for (auto [channelName, dir, type] : bundleType->getChannels()) { + auto engineIter = bundleEngineMap.find(channelName); + if (engineIter == bundleEngineMap.end()) + continue; + + ports.emplace(channelName, engineIter->second->requestPort( + idPath, channelName, dir, type)); + } + return ports; +} + +void BundleEngineMap::setEngine(const std::string &channelName, + Engine *engine) { + auto [it, inserted] = bundleEngineMap.try_emplace(channelName, engine); + if (!inserted) + throw std::runtime_error("Channel already exists in engine map"); +} + +namespace { +std::map engineRegistry; +} + +std::unique_ptr +registry::createEngine(AcceleratorConnection &conn, + const std::string &dmaEngineName, AppIDPath idPath, + const ServiceImplDetails &details, + const HWClientDetails &clients) { + auto it = engineRegistry.find(dmaEngineName); + if (it == engineRegistry.end()) + throw std::runtime_error("Unknown engine: " + dmaEngineName); + return it->second(conn, idPath, details, clients); +} + +void registry::internal::registerEngine(const std::string &name, + EngineCreate create) { + auto tried = engineRegistry.try_emplace(name, create); + if (!tried.second) + throw std::runtime_error("Engine already exists in registry"); +} diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp index 15e4919e70d7..a908aff3cdda 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp @@ -44,11 +44,16 @@ class Manifest::Impl { void scanServiceDecls(AcceleratorConnection &, const nlohmann::json &, ServiceTable &) const; + void createEngine(AcceleratorConnection &, AppIDPath appID, + const nlohmann::json &) const; + /// Get a Service for the service specified in 'json'. Update the - /// activeServices table. + /// activeServices table. TODO: re-using this for the engines section is a + /// terrible hack. Figure out a better way. services::Service *getService(AppIDPath idPath, AcceleratorConnection &, const nlohmann::json &, - ServiceTable &activeServices) const; + ServiceTable &activeServices, + bool isEngine = false) const; /// Get all the services in the description of an instance. Update the active /// services table. @@ -273,12 +278,20 @@ std::unique_ptr Manifest::Impl::buildAccelerator(AcceleratorConnection &acc) const { ServiceTable activeSvcs; + auto designJson = manifestJson.at("design"); + + // Create all of the engines at the top level of the design. + // TODO: support engines at lower levels. + auto enginesIter = designJson.find("engines"); + if (enginesIter != designJson.end()) + for (auto &engineDesc : enginesIter.value()) + createEngine(acc, {}, engineDesc); + // Get the initial active services table. Update it as we descend down. auto svcDecls = manifestJson.at("serviceDeclarations"); scanServiceDecls(acc, svcDecls, activeSvcs); // Get the services instantiated at the top level. - auto designJson = manifestJson.at("design"); std::vector services = getServices({}, acc, designJson, activeSvcs); @@ -301,24 +314,34 @@ Manifest::Impl::getModInfo(const nlohmann::json &json) const { return std::nullopt; } +/// TODO: Hack. This method is a giant hack to reuse the getService method for +/// engines. It works, but it ain't pretty and it ain't right. +void Manifest::Impl::createEngine(AcceleratorConnection &acc, AppIDPath idPath, + const nlohmann::json &eng) const { + ServiceTable dummy; + getService(idPath, acc, eng, dummy, /*isEngine=*/true); +} + void Manifest::Impl::scanServiceDecls(AcceleratorConnection &acc, const nlohmann::json &svcDecls, ServiceTable &activeServices) const { for (auto &svcDecl : svcDecls) { - if (auto f = svcDecl.find("serviceName"); f != svcDecl.end()) { - // Get the implementation details. - ServiceImplDetails svcDetails; - for (auto &detail : svcDecl.items()) - svcDetails[detail.key()] = getAny(detail.value()); - - // Create the service. - services::Service::Type svcId = - services::ServiceRegistry::lookupServiceType(f.value()); - auto svc = acc.getService(svcId, /*id=*/{}, /*implName=*/"", - /*details=*/svcDetails, /*clients=*/{}); - if (svc) - activeServices[svcDecl.at("symbol")] = svc; - } + // Get the implementation details. + ServiceImplDetails svcDetails; + for (auto &detail : svcDecl.items()) + svcDetails[detail.key()] = getAny(detail.value()); + + // Create the service. + auto serviceNameIter = svcDecl.find("serviceName"); + std::string serviceName; + if (serviceNameIter != svcDecl.end()) + serviceName = serviceNameIter.value(); + services::Service::Type svcId = + services::ServiceRegistry::lookupServiceType(serviceName); + auto svc = acc.getService(svcId, /*id=*/{}, /*implName=*/"", + /*details=*/svcDetails, /*clients=*/{}); + if (svc) + activeServices[svcDecl.at("symbol")] = svc; } } @@ -352,10 +375,11 @@ Manifest::Impl::getChildInstance(AppIDPath idPath, AcceleratorConnection &acc, services, ports); } -services::Service * -Manifest::Impl::getService(AppIDPath idPath, AcceleratorConnection &acc, - const nlohmann::json &svcJson, - ServiceTable &activeServices) const { +services::Service *Manifest::Impl::getService(AppIDPath idPath, + AcceleratorConnection &acc, + const nlohmann::json &svcJson, + ServiceTable &activeServices, + bool isEngine) const { AppID id = parseIDChecked(svcJson.at("appID")); idPath.push_back(id); @@ -407,12 +431,16 @@ Manifest::Impl::getService(AppIDPath idPath, AcceleratorConnection &acc, services::Service::Type svcType = services::ServiceRegistry::lookupServiceType( activeServiceIter->second->getServiceSymbol()); - svc = activeServiceIter->second->getChildService( - &acc, svcType, idPath, implName, svcDetails, clientDetails); + svc = activeServiceIter->second->getChildService(svcType, idPath, implName, + svcDetails, clientDetails); } else { services::Service::Type svcType = services::ServiceRegistry::lookupServiceType(service); - svc = acc.getService(svcType, idPath, implName, svcDetails, clientDetails); + if (isEngine) + acc.createEngine(implName, idPath, svcDetails, clientDetails); + else + svc = + acc.getService(svcType, idPath, implName, svcDetails, clientDetails); } if (svc) @@ -471,15 +499,10 @@ Manifest::Impl::getBundlePorts(AcceleratorConnection &acc, AppIDPath idPath, "' is not a bundle type"); idPath.push_back(parseIDChecked(content.at("appID"))); - std::map portChannels = - acc.requestChannelsFor(idPath, bundleType, activeServices); - services::ServicePort *svcPort = - svc->getPort(idPath, bundleType, portChannels, acc); + BundlePort *svcPort = svc->getPort(idPath, bundleType); if (svcPort) ret.emplace_back(svcPort); - else - ret.emplace_back(new BundlePort(idPath.back(), portChannels)); // Since we share idPath between iterations, pop the last element before the // next iteration. idPath.pop_back(); diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp index cd5fd6596811..6fe43df250ec 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp @@ -19,8 +19,8 @@ using namespace esi; -BundlePort::BundlePort(AppID id, std::map channels) - : id(id), channels(channels) {} +BundlePort::BundlePort(AppID id, const BundleType *type, PortMap channels) + : id(id), type(type), channels(channels) {} WriteChannelPort &BundlePort::getRawWrite(const std::string &name) const { auto f = channels.find(name); diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Services.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Services.cpp index e7bafc327100..c44a5e0325b0 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Services.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Services.cpp @@ -15,6 +15,7 @@ #include "esi/Services.h" #include "esi/Accelerator.h" +#include "esi/Engines.h" #include "zlib.h" @@ -24,12 +25,11 @@ using namespace esi; using namespace esi::services; -Service *Service::getChildService(AcceleratorConnection *conn, - Service::Type service, AppIDPath id, +Service *Service::getChildService(Service::Type service, AppIDPath id, std::string implName, ServiceImplDetails details, HWClientDetails clients) { - return conn->getService(service, id, implName, details, clients); + return conn.getService(service, id, implName, details, clients); } std::string SysInfo::getServiceSymbol() const { return "__builtin_SysInfo"; } @@ -53,8 +53,8 @@ std::string SysInfo::getJsonManifest() const { // MMIO class implementations. //===----------------------------------------------------------------------===// -MMIO::MMIO(Context &ctxt, AppIDPath idPath, std::string implName, - const ServiceImplDetails &details, const HWClientDetails &clients) { +MMIO::MMIO(AcceleratorConnection &conn, const HWClientDetails &clients) + : Service(conn) { for (const HWClientDetail &client : clients) { auto offsetIter = client.implOptions.find("offset"); if (offsetIter == client.implOptions.end()) @@ -79,9 +79,7 @@ MMIO::MMIO(Context &ctxt, AppIDPath idPath, std::string implName, std::string MMIO::getServiceSymbol() const { return std::string(MMIO::StdName); } -ServicePort *MMIO::getPort(AppIDPath id, const BundleType *type, - const std::map &, - AcceleratorConnection &conn) const { +BundlePort *MMIO::getPort(AppIDPath id, const BundleType *type) const { auto regionIter = regions.find(id); if (regionIter == regions.end()) return nullptr; @@ -92,10 +90,8 @@ ServicePort *MMIO::getPort(AppIDPath id, const BundleType *type, namespace { class MMIOPassThrough : public MMIO { public: - MMIOPassThrough(Context &ctxt, AppIDPath idPath, std::string implName, - const ServiceImplDetails &details, - const HWClientDetails &clients, MMIO *parent) - : MMIO(ctxt, idPath, implName, details, clients), parent(parent) {} + MMIOPassThrough(const HWClientDetails &clients, MMIO *parent) + : MMIO(parent->getConnection(), clients), parent(parent) {} uint64_t read(uint32_t addr) const override { return parent->read(addr); } void write(uint32_t addr, uint64_t data) override { parent->write(addr, data); @@ -106,15 +102,12 @@ class MMIOPassThrough : public MMIO { }; } // namespace -Service *MMIO::getChildService(AcceleratorConnection *conn, - Service::Type service, AppIDPath id, +Service *MMIO::getChildService(Service::Type service, AppIDPath id, std::string implName, ServiceImplDetails details, HWClientDetails clients) { if (service != typeid(MMIO)) - return Service::getChildService(conn, service, id, implName, details, - clients); - return new MMIOPassThrough(conn->getCtxt(), id, implName, details, clients, - this); + return Service::getChildService(service, id, implName, details, clients); + return new MMIOPassThrough(clients, this); } //===----------------------------------------------------------------------===// @@ -122,7 +115,7 @@ Service *MMIO::getChildService(AcceleratorConnection *conn, //===----------------------------------------------------------------------===// MMIO::MMIORegion::MMIORegion(AppID id, MMIO *parent, RegionDescriptor desc) - : ServicePort(id, {}), parent(parent), desc(desc) {} + : ServicePort(id, nullptr, {}), parent(parent), desc(desc) {} uint64_t MMIO::MMIORegion::read(uint32_t addr) const { if (addr >= desc.size) throw std::runtime_error("MMIO read out of bounds: " + toHex(addr)); @@ -134,7 +127,8 @@ void MMIO::MMIORegion::write(uint32_t addr, uint64_t data) { parent->write(desc.base + addr, data); } -MMIOSysInfo::MMIOSysInfo(const MMIO *mmio) : mmio(mmio) {} +MMIOSysInfo::MMIOSysInfo(const MMIO *mmio) + : SysInfo(mmio->getConnection()), mmio(mmio) {} uint32_t MMIOSysInfo::getEsiVersion() const { uint64_t reg; @@ -165,10 +159,10 @@ std::vector MMIOSysInfo::getCompressedManifest() const { std::string HostMem::getServiceSymbol() const { return "__builtin_HostMem"; } -CustomService::CustomService(AppIDPath idPath, +CustomService::CustomService(AppIDPath idPath, AcceleratorConnection &conn, const ServiceImplDetails &details, const HWClientDetails &clients) - : id(idPath) { + : Service(conn), id(idPath) { if (auto f = details.find("service"); f != details.end()) { serviceSymbol = std::any_cast(f->second); // Strip off initial '@'. @@ -176,9 +170,15 @@ CustomService::CustomService(AppIDPath idPath, } } -FuncService::FuncService(AcceleratorConnection *acc, AppIDPath idPath, - const std::string &implName, - ServiceImplDetails details, HWClientDetails clients) { +BundlePort *CustomService::getPort(AppIDPath id, const BundleType *type) const { + return new BundlePort(id.back(), type, + conn.getEngineMapFor(id).requestPorts(id, type)); +} + +FuncService::FuncService(AppIDPath idPath, AcceleratorConnection &conn, + ServiceImplDetails details, HWClientDetails clients) + : Service(conn) { + if (auto f = details.find("service"); f != details.end()) // Strip off initial '@'. symbol = std::any_cast(f->second).substr(1); @@ -186,42 +186,38 @@ FuncService::FuncService(AcceleratorConnection *acc, AppIDPath idPath, std::string FuncService::getServiceSymbol() const { return symbol; } -ServicePort * -FuncService::getPort(AppIDPath id, const BundleType *type, - const std::map &channels, - AcceleratorConnection &acc) const { - return new Function(id.back(), channels); +BundlePort *FuncService::getPort(AppIDPath id, const BundleType *type) const { + return new Function(id.back(), type, + conn.getEngineMapFor(id).requestPorts(id, type)); } -FuncService::Function::Function( - AppID id, const std::map &channels) - : ServicePort(id, channels), - arg(dynamic_cast(channels.at("arg"))), - result(dynamic_cast(channels.at("result"))) { - assert(channels.size() == 2 && "FuncService must have exactly two channels"); -} - -FuncService::Function *FuncService::Function::get(AppID id, +FuncService::Function *FuncService::Function::get(AppID id, BundleType *type, WriteChannelPort &arg, ReadChannelPort &result) { - return new Function(id, {{"arg", arg}, {"result", result}}); + return new Function( + id, type, {{std::string("arg"), arg}, {std::string("result"), result}}); + return nullptr; } void FuncService::Function::connect() { - arg.connect(); - result.connect(); + if (channels.size() != 2) + throw std::runtime_error("FuncService must have exactly two channels"); + arg = &getRawWrite("arg"); + arg->connect(); + result = &getRawRead("result"); + result->connect(); } std::future FuncService::Function::call(const MessageData &argData) { std::scoped_lock lock(callMutex); - arg.write(argData); - return result.readAsync(); + arg->write(argData); + return result->readAsync(); } -CallService::CallService(AcceleratorConnection *acc, AppIDPath idPath, - std::string implName, ServiceImplDetails details, - HWClientDetails clients) { +CallService::CallService(AcceleratorConnection &acc, AppIDPath idPath, + ServiceImplDetails details) + : Service(acc) { if (auto f = details.find("service"); f != details.end()) // Strip off initial '@'. symbol = std::any_cast(f->second).substr(1); @@ -229,63 +225,46 @@ CallService::CallService(AcceleratorConnection *acc, AppIDPath idPath, std::string CallService::getServiceSymbol() const { return symbol; } -ServicePort * -CallService::getPort(AppIDPath id, const BundleType *type, - const std::map &channels, - AcceleratorConnection &acc) const { - return new Callback(acc, id.back(), channels); +BundlePort *CallService::getPort(AppIDPath id, const BundleType *type) const { + return new Callback(conn, id.back(), type, + conn.getEngineMapFor(id).requestPorts(id, type)); } -ReadChannelPort &getRead(const std::map &channels, - const std::string &name) { - auto f = channels.find(name); - if (f == channels.end()) - throw std::runtime_error("CallService must have an '" + name + "' channel"); - return dynamic_cast(f->second); -} - -WriteChannelPort &getWrite(const std::map &channels, - const std::string &name) { - auto f = channels.find(name); - if (f == channels.end()) - throw std::runtime_error("CallService must have an '" + name + "' channel"); - return dynamic_cast(f->second); -} - -CallService::Callback::Callback( - AcceleratorConnection &acc, AppID id, - const std::map &channels) - : ServicePort(id, channels), arg(getRead(channels, "arg")), - result(getWrite(channels, "result")), acc(acc) { +CallService::Callback::Callback(AcceleratorConnection &acc, AppID id, + const BundleType *type, PortMap channels) + : ServicePort(id, type, channels), acc(acc) { if (channels.size() != 2) throw std::runtime_error("CallService must have exactly two channels"); } CallService::Callback *CallService::Callback::get(AcceleratorConnection &acc, - AppID id, + AppID id, BundleType *type, WriteChannelPort &result, ReadChannelPort &arg) { - return new Callback(acc, id, {{"arg", arg}, {"result", result}}); + return new Callback(acc, id, type, {{"arg", arg}, {"result", result}}); } void CallService::Callback::connect( std::function callback, bool quick) { - result.connect(); + if (channels.size() != 2) + throw std::runtime_error("CallService must have exactly two channels"); + result = &getRawWrite("result"); + result->connect(); + arg = &getRawRead("arg"); if (quick) { // If it's quick, we can just call the callback directly. - arg.connect([this, callback](MessageData argMsg) -> bool { + arg->connect([this, callback](MessageData argMsg) -> bool { MessageData resultMsg = callback(std::move(argMsg)); - this->result.write(std::move(resultMsg)); + this->result->write(std::move(resultMsg)); return true; }); } else { // If it's not quick, we need to use the service thread. - arg.connect(); + arg->connect(); acc.getServiceThread()->addListener( - {&arg}, - [this, callback](ReadChannelPort *, MessageData argMsg) -> void { + {arg}, [this, callback](ReadChannelPort *, MessageData argMsg) -> void { MessageData resultMsg = callback(std::move(argMsg)); - this->result.write(std::move(resultMsg)); + this->result->write(std::move(resultMsg)); }); } } @@ -297,9 +276,11 @@ Service *ServiceRegistry::createService(AcceleratorConnection *acc, HWClientDetails clients) { // TODO: Add a proper registration mechanism. if (svcType == typeid(FuncService)) - return new FuncService(acc, id, implName, details, clients); + return new FuncService(id, *acc, details, clients); if (svcType == typeid(CallService)) - return new CallService(acc, id, implName, details, clients); + return new CallService(*acc, id, details); + if (svcType == typeid(CustomService)) + return new CustomService(id, *acc, details, clients); return nullptr; } diff --git a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp index 88e83674ba33..6e44f4851a7f 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp @@ -14,6 +14,7 @@ //===----------------------------------------------------------------------===// #include "esi/backends/Cosim.h" +#include "esi/Engines.h" #include "esi/Services.h" #include "esi/Utils.h" @@ -135,7 +136,8 @@ CosimAccelerator::~CosimAccelerator() { namespace { class CosimSysInfo : public SysInfo { public: - CosimSysInfo(ChannelServer::Stub *rpcClient) : rpcClient(rpcClient) {} + CosimSysInfo(CosimAccelerator &conn, ChannelServer::Stub *rpcClient) + : SysInfo(conn), rpcClient(rpcClient) {} uint32_t getEsiVersion() const override { ::esi::cosim::Manifest response = getManifest(); @@ -291,45 +293,6 @@ class ReadCosimChannelPort } // namespace -std::map CosimAccelerator::requestChannelsFor( - AppIDPath idPath, const BundleType *bundleType, const ServiceTable &) { - std::map channelResults; - - // Find the client details for the port at 'fullPath'. - auto f = clientChannelAssignments.find(idPath); - if (f == clientChannelAssignments.end()) - return channelResults; - const std::map &channelAssignments = f->second; - - // Each channel in a bundle has a separate cosim endpoint. Find them all. - for (auto [name, dir, type] : bundleType->getChannels()) { - auto f = channelAssignments.find(name); - if (f == channelAssignments.end()) - throw std::runtime_error("Could not find channel assignment for '" + - idPath.toStr() + "." + name + "'"); - std::string channelName = f->second; - - // Get the endpoint, which may or may not exist. Construct the port. - // Everything is validated when the client calls 'connect()' on the port. - ChannelDesc chDesc; - if (!rpcClient->getChannelDesc(channelName, chDesc)) - throw std::runtime_error("Could not find channel '" + channelName + - "' in cosimulation"); - - ChannelPort *port; - if (BundlePort::isWrite(dir)) { - port = new WriteCosimChannelPort(rpcClient->stub.get(), chDesc, type, - channelName); - } else { - port = new ReadCosimChannelPort(rpcClient->stub.get(), chDesc, type, - channelName); - } - channels.emplace(port); - channelResults.emplace(name, *port); - } - return channelResults; -} - /// Get the channel description for a channel name. Iterate through the list /// each time. Since this will only be called a small number of times on a small /// list, it's not worth doing anything fancy. @@ -351,7 +314,9 @@ bool StubContainer::getChannelDesc(const std::string &channelName, namespace { class CosimMMIO : public MMIO { public: - CosimMMIO(Context &ctxt, StubContainer *rpcClient) { + CosimMMIO(CosimAccelerator &conn, Context &ctxt, StubContainer *rpcClient, + const HWClientDetails &clients) + : MMIO(conn, clients) { // We have to locate the channels ourselves since this service might be used // to retrieve the manifest. ChannelDesc cmdArg, cmdResp; @@ -372,8 +337,11 @@ class CosimMMIO : public MMIO { cmdRespPort = std::make_unique( rpcClient->stub.get(), cmdResp, i64Type, "__cosim_mmio_read_write.result"); - cmdMMIO.reset(FuncService::Function::get(AppID("__cosim_mmio"), *cmdArgPort, - *cmdRespPort)); + auto *bundleType = new BundleType( + "cosimMMIO", {{"arg", BundleType::Direction::To, cmdType}, + {"result", BundleType::Direction::From, i64Type}}); + cmdMMIO.reset(FuncService::Function::get(AppID("__cosim_mmio"), bundleType, + *cmdArgPort, *cmdRespPort)); cmdMMIO->connect(); } @@ -441,7 +409,7 @@ class CosimHostMem : public HostMem { public: CosimHostMem(AcceleratorConnection &acc, Context &ctxt, StubContainer *rpcClient) - : acc(acc), ctxt(ctxt), rpcClient(rpcClient) {} + : HostMem(acc), acc(acc), ctxt(ctxt), rpcClient(rpcClient) {} void start() override { // We have to locate the channels ourselves since this service might be used @@ -498,8 +466,13 @@ class CosimHostMem : public HostMem { writeReqPort = std::make_unique( rpcClient->stub.get(), writeArg, writeReqType, "__cosim_hostmem_write.arg"); + auto *bundleType = new BundleType( + "cosimHostMem", + {{"arg", BundleType::Direction::To, writeReqType}, + {"result", BundleType::Direction::From, writeRespType}}); write.reset(CallService::Callback::get(acc, AppID("__cosim_hostmem_write"), - *writeRespPort, *writeReqPort)); + bundleType, *writeRespPort, + *writeReqPort)); write->connect([this](const MessageData &req) { return serviceWrite(req); }, true); } @@ -595,42 +568,104 @@ class CosimHostMem : public HostMem { std::unique_ptr writeReqPort; std::unique_ptr write; }; - } // namespace +namespace esi::backends::cosim { +/// Implement the magic cosim channel communication. +class CosimEngine : public Engine { +public: + CosimEngine(CosimAccelerator &conn, AppIDPath idPath, + const ServiceImplDetails &details, const HWClientDetails &clients) + : conn(conn) { + // Compute our parents idPath path. + AppIDPath prefix = std::move(idPath); + if (prefix.size() > 0) + prefix.pop_back(); + + for (auto client : clients) { + AppIDPath fullClientPath = prefix + client.relPath; + std::map channelAssignments; + for (auto assignment : client.channelAssignments) + if (assignment.second.type == "cosim") + channelAssignments[assignment.first] = std::any_cast( + assignment.second.implOptions.at("name")); + clientChannelAssignments[fullClientPath] = std::move(channelAssignments); + } + } + + std::unique_ptr createPort(AppIDPath idPath, + const std::string &channelName, + BundleType::Direction dir, + const Type *type) override; + +private: + CosimAccelerator &conn; + std::map> + clientChannelAssignments; +}; +} // namespace esi::backends::cosim + +std::unique_ptr +CosimEngine::createPort(AppIDPath idPath, const std::string &channelName, + BundleType::Direction dir, const Type *type) { + + // Find the client details for the port at 'fullPath'. + auto f = clientChannelAssignments.find(idPath); + if (f == clientChannelAssignments.end()) + throw std::runtime_error("Could not find port for '" + idPath.toStr() + + "." + channelName + "'"); + const std::map &channelAssignments = f->second; + auto cosimChannelNameIter = channelAssignments.find(channelName); + if (cosimChannelNameIter == channelAssignments.end()) + throw std::runtime_error("Could not find channel '" + idPath.toStr() + "." + + channelName + "' in cosimulation"); + + // Get the endpoint, which may or may not exist. Construct the port. + // Everything is validated when the client calls 'connect()' on the port. + ChannelDesc chDesc; + if (!conn.rpcClient->getChannelDesc(cosimChannelNameIter->second, chDesc)) + throw std::runtime_error("Could not find channel '" + idPath.toStr() + "." + + channelName + "' in cosimulation"); + + std::unique_ptr port; + std::string fullChannelName = idPath.toStr() + "." + channelName; + if (BundlePort::isWrite(dir)) + port = std::make_unique( + conn.rpcClient->stub.get(), chDesc, type, fullChannelName); + else + port = std::make_unique( + conn.rpcClient->stub.get(), chDesc, type, fullChannelName); + return port; +} + +void CosimAccelerator::createEngine(const std::string &engineTypeName, + AppIDPath idPath, + const ServiceImplDetails &details, + const HWClientDetails &clients) { + + std::unique_ptr engine = nullptr; + if (engineTypeName == "cosim") + engine = std::make_unique(*this, idPath, details, clients); + else + engine = ::esi::registry::createEngine(*this, engineTypeName, idPath, + details, clients); + registerEngine(idPath, std::move(engine), clients); +} Service *CosimAccelerator::createService(Service::Type svcType, AppIDPath idPath, std::string implName, const ServiceImplDetails &details, const HWClientDetails &clients) { - // Compute our parents idPath path. - AppIDPath prefix = std::move(idPath); - if (prefix.size() > 0) - prefix.pop_back(); - - // Get the channel assignments for each client. - for (auto client : clients) { - AppIDPath fullClientPath = prefix + client.relPath; - std::map channelAssignments; - for (auto assignment : client.channelAssignments) - if (assignment.second.type == "cosim") - channelAssignments[assignment.first] = std::any_cast( - assignment.second.implOptions.at("name")); - clientChannelAssignments[fullClientPath] = std::move(channelAssignments); - } - if (svcType == typeid(services::MMIO)) { - return new CosimMMIO(getCtxt(), rpcClient); + return new CosimMMIO(*this, getCtxt(), rpcClient, clients); } else if (svcType == typeid(services::HostMem)) { return new CosimHostMem(*this, getCtxt(), rpcClient); } else if (svcType == typeid(SysInfo)) { switch (manifestMethod) { case ManifestMethod::Cosim: - return new CosimSysInfo(rpcClient->stub.get()); + return new CosimSysInfo(*this, rpcClient->stub.get()); case ManifestMethod::MMIO: return new MMIOSysInfo(getService()); } - } else if (svcType == typeid(CustomService) && implName == "cosim") { - return new CustomService(idPath, details, clients); } return nullptr; } diff --git a/lib/Dialect/ESI/runtime/cpp/lib/backends/Trace.cpp b/lib/Dialect/ESI/runtime/cpp/lib/backends/Trace.cpp index 4ba30c83d2c9..2b9e197bb542 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/backends/Trace.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/backends/Trace.cpp @@ -33,12 +33,17 @@ constexpr uint32_t ESIVersion = 0; namespace { class TraceChannelPort; -} +class TraceEngine; +} // namespace + +TraceAccelerator::Impl &TraceAccelerator::getImpl() { return *impl; } struct esi::backends::trace::TraceAccelerator::Impl { + friend class TraceAccelerator; Impl(Mode mode, std::filesystem::path manifestJson, std::filesystem::path traceFile) : manifestJson(manifestJson), traceFile(traceFile) { + engine = std::make_unique(*this); if (!std::filesystem::exists(manifestJson)) throw std::runtime_error("manifest file '" + manifestJson.string() + "' does not exist"); @@ -63,16 +68,10 @@ struct esi::backends::trace::TraceAccelerator::Impl { } } - Service *createService(Service::Type svcType, AppIDPath idPath, - const ServiceImplDetails &details, + Service *createService(TraceAccelerator &conn, Service::Type svcType, + AppIDPath idPath, const ServiceImplDetails &details, const HWClientDetails &clients); - /// Request the host side channel ports for a particular instance (identified - /// by the AppID path). For convenience, provide the bundle type and direction - /// of the bundle port. - std::map requestChannelsFor(AppIDPath, - const BundleType *); - void adoptChannelPort(ChannelPort *port) { channels.emplace_back(port); } void write(const AppIDPath &id, const std::string &portName, const void *data, @@ -89,6 +88,7 @@ struct esi::backends::trace::TraceAccelerator::Impl { std::filesystem::path manifestJson; std::filesystem::path traceFile; std::vector> channels; + std::unique_ptr engine; }; void TraceAccelerator::Impl::write(const AppIDPath &id, @@ -146,17 +146,11 @@ TraceAccelerator::TraceAccelerator(Context &ctxt, Mode mode, } TraceAccelerator::~TraceAccelerator() { disconnect(); } -Service *TraceAccelerator::createService(Service::Type svcType, - AppIDPath idPath, std::string implName, - const ServiceImplDetails &details, - const HWClientDetails &clients) { - return impl->createService(svcType, idPath, details, clients); -} namespace { class TraceSysInfo : public SysInfo { public: - TraceSysInfo(std::filesystem::path manifestJson) - : manifestJson(manifestJson) {} + TraceSysInfo(AcceleratorConnection &conn, std::filesystem::path manifestJson) + : SysInfo(conn), manifestJson(manifestJson) {} uint32_t getEsiVersion() const override { return ESIVersion; } @@ -232,39 +226,39 @@ class ReadTraceChannelPort : public ReadChannelPort { } // namespace namespace { -class TraceCustomService : public CustomService { +class TraceEngine : public Engine { public: - TraceCustomService(TraceAccelerator::Impl &impl, AppIDPath idPath, - const ServiceImplDetails &details, - const HWClientDetails &clients) - : CustomService(idPath, details, clients) {} -}; -} // namespace + TraceEngine(TraceAccelerator::Impl &impl) : impl(impl) {} -std::map -TraceAccelerator::Impl::requestChannelsFor(AppIDPath idPath, - const BundleType *bundleType) { - std::map channels; - for (auto [name, dir, type] : bundleType->getChannels()) { - ChannelPort *port; + std::unique_ptr createPort(AppIDPath idPath, + const std::string &channelName, + BundleType::Direction dir, + const Type *type) override { + std::unique_ptr port; if (BundlePort::isWrite(dir)) - port = new WriteTraceChannelPort(*this, type, idPath, name); + port = std::make_unique(impl, type, idPath, + channelName); else - port = new ReadTraceChannelPort(*this, type); - channels.emplace(name, *port); - adoptChannelPort(port); + port = std::make_unique(impl, type); + return port; } - return channels; -} -std::map TraceAccelerator::requestChannelsFor( - AppIDPath idPath, const BundleType *bundleType, const ServiceTable &) { - return impl->requestChannelsFor(idPath, bundleType); +private: + TraceAccelerator::Impl &impl; +}; +} // namespace + +void TraceAccelerator::createEngine(const std::string &dmaEngineName, + AppIDPath idPath, + const ServiceImplDetails &details, + const HWClientDetails &clients) { + registerEngine(idPath, std::make_unique(getImpl()), clients); } class TraceMMIO : public MMIO { public: - TraceMMIO(TraceAccelerator::Impl &impl) : impl(impl) {} + TraceMMIO(TraceAccelerator &conn, const HWClientDetails &clients) + : MMIO(conn, clients), impl(conn.getImpl()) {} virtual uint64_t read(uint32_t addr) const override { uint64_t data = rand(); @@ -286,7 +280,7 @@ class TraceMMIO : public MMIO { class TraceHostMem : public HostMem { public: - TraceHostMem(TraceAccelerator::Impl &impl) : impl(impl) {} + TraceHostMem(TraceAccelerator &conn) : HostMem(conn), impl(conn.getImpl()) {} struct TraceHostMemRegion : public HostMemRegion { TraceHostMemRegion(std::size_t size, TraceAccelerator::Impl &impl) @@ -338,18 +332,16 @@ class TraceHostMem : public HostMem { TraceAccelerator::Impl &impl; }; -Service * -TraceAccelerator::Impl::createService(Service::Type svcType, AppIDPath idPath, - const ServiceImplDetails &details, - const HWClientDetails &clients) { +Service *TraceAccelerator::createService(Service::Type svcType, + AppIDPath idPath, std::string implName, + const ServiceImplDetails &details, + const HWClientDetails &clients) { if (svcType == typeid(SysInfo)) - return new TraceSysInfo(manifestJson); + return new TraceSysInfo(*this, getImpl().manifestJson); if (svcType == typeid(MMIO)) - return new TraceMMIO(*this); + return new TraceMMIO(*this, clients); if (svcType == typeid(HostMem)) return new TraceHostMem(*this); - if (svcType == typeid(CustomService)) - return new TraceCustomService(*this, idPath, details, clients); return nullptr; } diff --git a/lib/Dialect/ESI/runtime/cpp/lib/backends/Xrt.cpp b/lib/Dialect/ESI/runtime/cpp/lib/backends/Xrt.cpp index 6406982f43d4..ed6879bf5491 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/backends/Xrt.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/backends/Xrt.cpp @@ -81,12 +81,6 @@ struct esi::backends::xrt::XrtAccelerator::Impl { ip = ::xrt::ip(device, uuid, kernel_name); } - std::map requestChannelsFor(AppIDPath, - const BundleType *) { - // TODO: Create plugin DMA engines. Instantiate elsewhere. - return {}; - } - ::xrt::device device; ::xrt::ip ip; int32_t memoryGroup; @@ -103,7 +97,8 @@ XrtAccelerator::~XrtAccelerator() { disconnect(); } namespace { class XrtMMIO : public MMIO { public: - XrtMMIO(::xrt::ip &ip) : ip(ip) {} + XrtMMIO(XrtAccelerator &conn, ::xrt::ip &ip, const HWClientDetails &clients) + : MMIO(conn, clients), ip(ip) {} uint64_t read(uint32_t addr) const override { auto lo = static_cast(ip.read_register(addr)); @@ -124,8 +119,8 @@ namespace { /// Host memory service specialized to XRT. class XrtHostMem : public HostMem { public: - XrtHostMem(::xrt::device &device, int32_t memoryGroup) - : device(device), memoryGroup(memoryGroup){}; + XrtHostMem(XrtAccelerator &conn, ::xrt::device &device, int32_t memoryGroup) + : HostMem(conn), device(device), memoryGroup(memoryGroup){}; struct XrtHostMemRegion : public HostMemRegion { XrtHostMemRegion(::xrt::device &device, std::size_t size, @@ -161,19 +156,14 @@ class XrtHostMem : public HostMem { }; } // namespace -std::map XrtAccelerator::requestChannelsFor( - AppIDPath idPath, const BundleType *bundleType, const ServiceTable &) { - return impl->requestChannelsFor(idPath, bundleType); -} - Service *XrtAccelerator::createService(Service::Type svcType, AppIDPath id, std::string implName, const ServiceImplDetails &details, const HWClientDetails &clients) { if (svcType == typeid(MMIO)) - return new XrtMMIO(impl->ip); + return new XrtMMIO(*this, impl->ip, clients); else if (svcType == typeid(HostMem)) - return new XrtHostMem(impl->device, impl->memoryGroup); + return new XrtHostMem(*this, impl->device, impl->memoryGroup); else if (svcType == typeid(SysInfo)) return new MMIOSysInfo(getService()); return nullptr; diff --git a/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp b/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp index 85fb4201d723..473a575706a7 100644 --- a/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp +++ b/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp @@ -101,6 +101,11 @@ void registerCallbacks(AcceleratorConnection *conn, Accelerator *accel) { return MessageData(); }, true); + else + std::cerr << "PrintfExample port is not a CallService::Callback" + << std::endl; + } else { + std::cerr << "No PrintfExample port found" << std::endl; } } diff --git a/test/Dialect/ESI/manifest.mlir b/test/Dialect/ESI/manifest.mlir index 5c6580d98225..e2f4783c3a43 100644 --- a/test/Dialect/ESI/manifest.mlir +++ b/test/Dialect/ESI/manifest.mlir @@ -55,7 +55,7 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // HIER-LABEL: esi.manifest.compressed <"{{.+}}"> // HIER-LABEL: esi.manifest.hier_root @top { -// HIER-NEXT: esi.manifest.service_impl #esi.appid<"cosim"> svc @HostComms by "cosim" with {} { +// HIER-NEXT: esi.manifest.service_impl #esi.appid<"cosim"> svc @HostComms by "cosim" engine with {} { // HIER-NEXT: esi.manifest.impl_conn [#esi.appid<"loopback_inst"[0]>, #esi.appid<"loopback_tohw">] req <@HostComms::@Recv>(!esi.bundle<[!esi.channel to "recv"]>) channels {recv = {name = "loopback_inst[0].loopback_tohw.recv", type = "cosim"}} // HIER-NEXT: esi.manifest.impl_conn [#esi.appid<"loopback_inst"[0]>, #esi.appid<"loopback_fromhw">] req <@HostComms::@Send>(!esi.bundle<[!esi.channel from "send"]>) channels {send = {name = "loopback_inst[0].loopback_fromhw.send", type = "cosim"}} // HIER-NEXT: esi.manifest.impl_conn [#esi.appid<"loopback_inst"[0]>, #esi.appid<"loopback_fromhw_i0">] req <@HostComms::@SendI0>(!esi.bundle<[!esi.channel from "send"]>) channels {send = {name = "loopback_inst[0].loopback_fromhw_i0.send", type = "cosim"}} @@ -111,7 +111,7 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: } // CHECK-NEXT: ], -// CHECK-LABEL: "services": [ +// CHECK-LABEL: "engines": [ // CHECK-NEXT: { // CHECK-NEXT: "appID": { // CHECK-NEXT: "name": "cosim" diff --git a/test/Dialect/ESI/services.mlir b/test/Dialect/ESI/services.mlir index 31c6713bba1f..631890955fbc 100644 --- a/test/Dialect/ESI/services.mlir +++ b/test/Dialect/ESI/services.mlir @@ -106,7 +106,7 @@ hw.module @InOutLoopback (in %clk: !seq.clock) { // CONN-LABEL: esi.pure_module @LoopbackCosimPure { // CONN-NEXT: [[clk:%.+]] = esi.pure_module.input "clk" : !seq.clock // CONN-NEXT: [[rst:%.+]] = esi.pure_module.input "rst" : i1 -// CONN-NEXT: esi.manifest.service_impl #esi.appid<"cosim"> svc @HostComms by "cosim" with {} { +// CONN-NEXT: esi.manifest.service_impl #esi.appid<"cosim"> svc @HostComms by "cosim" engine with {} { // CONN-NEXT: esi.manifest.impl_conn [#esi.appid<"loopback_inout">] req <@HostComms::@ReqResp>(!esi.bundle<[!esi.channel to "req", !esi.channel from "resp"]>) channels {req = {name = "loopback_inout.req", type = "cosim"}, resp = {name = "loopback_inout.resp", type = "cosim"}} // CONN-NEXT: } // CONN-NEXT: [[r2:%.+]] = esi.cosim.from_host [[clk]], [[rst]], "loopback_inout.req" : !esi.channel @@ -194,7 +194,7 @@ hw.module @CallableFunc1() { // CONN-LABEL: hw.module @CallableAccel1(in %clk : !seq.clock, in %rst : i1) { // CONN-NEXT: hw.instance "func1" @CallableFunc1(func1: %bundle: !esi.bundle<[!esi.channel to "arg", !esi.channel from "result"]>) -> () -// CONN-NEXT: esi.manifest.service_impl #esi.appid<"funcComms"> svc @funcs std "esi.service.std.func" by "cosim" with {} { +// CONN-NEXT: esi.manifest.service_impl #esi.appid<"funcComms"> svc @funcs std "esi.service.std.func" by "cosim" engine with {} { // CONN-NEXT: esi.manifest.impl_conn [#esi.appid<"func1">] req <@funcs::@call>(!esi.bundle<[!esi.channel to "arg", !esi.channel from "result"]>) channels {arg = {name = "func1.arg", type = "cosim"}, result = {name = "func1.result", type = "cosim"}} // CONN-NEXT: } // CONN-NEXT: %0 = esi.cosim.from_host %clk, %rst, "func1.arg" : !esi.channel