Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Metal backend #233

Open
wants to merge 14 commits into
base: devel
Choose a base branch
from
1 change: 1 addition & 0 deletions include/oklt/core/target_backends.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ enum struct TargetBackend : unsigned char {
CUDA, ///< CUDA backend.
HIP, ///< HIP backend.
DPCPP, ///< DPCPP backend.
METAL, ///< Metal backend.

_LAUNCHER, ///< Launcher backend.
};
Expand Down
14 changes: 14 additions & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ set (OCCA_TRANSPILER_SOURCES
attributes/backend/dpcpp/common.cpp
attributes/backend/dpcpp/common.h

# Metal
attributes/backend/metal/kernel.cpp
attributes/backend/metal/translation_unit.cpp
attributes/backend/metal/outer.cpp
attributes/backend/metal/inner.cpp
attributes/backend/metal/tile.cpp
attributes/backend/metal/shared.cpp
attributes/backend/metal/restrict.cpp
attributes/backend/metal/atomic.cpp
attributes/backend/metal/barrier.cpp
attributes/backend/metal/exclusive.cpp
attributes/backend/metal/common.cpp
attributes/backend/metal/common.h

# Serial subset
attributes/utils/serial_subset/empty.cpp
attributes/utils/serial_subset/kernel.cpp
Expand Down
24 changes: 24 additions & 0 deletions lib/attributes/backend/metal/atomic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "attributes/backend/metal/common.h"

#include <spdlog/spdlog.h>

namespace {
using namespace oklt;
using namespace clang;

HandleResult handleAtomicStmtAttribute(SessionStage& s, const Stmt& stmt, const Attr& a) {
SPDLOG_DEBUG("Handle attribute [{}]", a.getNormalizedFullName());

removeAttribute(s, a);
return {};
}

__attribute__((constructor)) void registerAttrBackend() {
auto ok =
registerBackendHandler(TargetBackend::METAL, ATOMIC_ATTR_NAME, handleAtomicStmtAttribute);

if (!ok) {
SPDLOG_ERROR("[METAL] Failed to register {} attribute handler", ATOMIC_ATTR_NAME);
}
}
} // namespace
31 changes: 31 additions & 0 deletions lib/attributes/backend/metal/barrier.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "attributes/backend/metal/common.h"

#include <clang/AST/Attr.h>
#include <clang/AST/Stmt.h>

#include <spdlog/spdlog.h>

namespace {
using namespace oklt;
using namespace clang;

oklt::HandleResult handleBarrierAttribute(SessionStage& stage,
const clang::Stmt& stmt,
const clang::Attr& attr) {
SPDLOG_DEBUG("Handle [@barrier] attribute");

auto range = getAttrFullSourceRange(attr);
stage.getRewriter().ReplaceText(range, metal::SYNC_THREADS_BARRIER);

return {};
}

__attribute__((constructor)) void registerAttrBackend() {
auto ok =
registerBackendHandler(TargetBackend::METAL, BARRIER_ATTR_NAME, handleBarrierAttribute);

if (!ok) {
SPDLOG_ERROR("[METAL] Failed to register {} attribute handler", BARRIER_ATTR_NAME);
}
}
} // namespace
62 changes: 62 additions & 0 deletions lib/attributes/backend/metal/common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include "attributes/backend/metal/common.h"
#include "core/sema/okl_sema_ctx.h"
#include "core/utils/range_to_string.h"
#include "util/string_utils.hpp"

#include <clang/Rewrite/Core/Rewriter.h>

namespace oklt::metal {
using namespace clang;

std::string axisToStr(const Axis& axis) {
static std::map<Axis, std::string> mapping{{Axis::X, "x"}, {Axis::Y, "y"}, {Axis::Z, "z"}};
return mapping[axis];
}

std::string getIdxVariable(const AttributedLoop& loop) {
auto strAxis = axisToStr(loop.axis);
switch (loop.type) {
case (LoopType::Inner):
return util::fmt("_occa_thread_position.{}", strAxis).value();
case (LoopType::Outer):
return util::fmt(" _occa_group_position.{}", strAxis).value();
default: // Incorrect case
return "";
}
}

std::string getTiledVariableName(const OklLoopInfo& forLoop) {
return "_occa_tiled_" + forLoop.var.name;
}

std::string buildInnerOuterLoopIdxLine(const OklLoopInfo& forLoop,
const AttributedLoop& loop,
int& openedScopeCounter,
oklt::Rewriter& rewriter) {
static_cast<void>(openedScopeCounter);
auto idx = getIdxVariable(loop);
auto op = forLoop.IsInc() ? "+" : "-";

std::string res;
if (forLoop.isUnary()) {
res = std::move(util::fmt("{} {} = ({}) {} {};\n",
forLoop.var.typeName,
forLoop.var.name,
getLatestSourceText(forLoop.range.start, rewriter),
op,
idx)
.value());
} else {
res = std::move(util::fmt("{} {} = ({}) {} (({}) * {});\n",
forLoop.var.typeName,
forLoop.var.name,
getLatestSourceText(forLoop.range.start, rewriter),
op,
getLatestSourceText(forLoop.inc.val, rewriter),
idx)
.value());
}
return res;
}

} // namespace oklt::metal
36 changes: 36 additions & 0 deletions lib/attributes/backend/metal/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "attributes/attribute_names.h"
#include "attributes/utils/code_gen.h"
#include "attributes/utils/default_handlers.h"
#include "attributes/utils/kernel_utils.h"
#include "attributes/utils/utils.h"
#include "core/handler_manager/backend_handler.h"
#include "core/rewriter/rewriter_proxy.h"
#include "core/sema/okl_sema_ctx.h"
#include "core/transpiler_session/session_stage.h"
#include "core/utils/attributes.h"
#include "core/utils/range_to_string.h"

#include <string>

namespace clang {
class Rewriter;
}

namespace oklt {
struct OklLoopInfo;
}

namespace oklt::metal {
std::string axisToStr(const Axis& axis);
std::string getIdxVariable(const AttributedLoop& loop);
std::string getTiledVariableName(const OklLoopInfo& forLoop);

// Produces something like: int i = start +- (inc * _occa_group_position.x);
// or: int i = start +- (inc * _occa_thread_position.x);
std::string buildInnerOuterLoopIdxLine(const OklLoopInfo& forLoop,
const AttributedLoop& loop,
int& openedScopeCounter,
oklt::Rewriter& rewriter);

const std::string SYNC_THREADS_BARRIER = "threadgroup_barrier(mem_flags::mem_threadgroup)";
} // namespace oklt::metal
53 changes: 53 additions & 0 deletions lib/attributes/backend/metal/exclusive.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "attributes/backend/metal/common.h"

#include <spdlog/spdlog.h>

namespace {
using namespace oklt;
using namespace clang;

HandleResult handleExclusiveDeclAttribute(SessionStage& s, const Decl& decl, const Attr& a) {
SPDLOG_DEBUG("Handle [@exclusive] attribute (Decl)");

removeAttribute(s, a);
return {};
}

HandleResult handleExclusiveVarAttribute(SessionStage& s, const VarDecl& decl, const Attr& a) {
SPDLOG_DEBUG("Handle [@exclusive] attribute");

removeAttribute(s, a);

auto& sema = s.tryEmplaceUserCtx<OklSemaCtx>();
auto loopInfo = sema.getLoopInfo();
if (loopInfo && loopInfo->isRegular()) {
loopInfo = loopInfo->getAttributedParent();
}
if (loopInfo && loopInfo->has(LoopType::Inner)) {
return tl::make_unexpected(
Error{{}, "Cannot define [@exclusive] variables inside an [@inner] loop"});
}

auto child = loopInfo ? loopInfo->getFirstAttributedChild() : nullptr;
bool isInnerChild = child && child->has(LoopType::Inner);
if (!loopInfo || !loopInfo->has(LoopType::Outer) || !isInnerChild) {
return tl::make_unexpected(
Error{{}, "Must define [@exclusive] variables between [@outer] and [@inner] loops"});
}

return defaultHandleExclusiveDeclAttribute(s, decl, a);
}

__attribute__((constructor)) void registerAttrBackend() {
auto ok = registerBackendHandler(
TargetBackend::METAL, EXCLUSIVE_ATTR_NAME, handleExclusiveDeclAttribute);
ok &= registerBackendHandler(
TargetBackend::METAL, EXCLUSIVE_ATTR_NAME, handleExclusiveVarAttribute);
ok &= registerBackendHandler(
TargetBackend::METAL, EXCLUSIVE_ATTR_NAME, defaultHandleExclusiveStmtAttribute);

if (!ok) {
SPDLOG_ERROR("[METAL] Failed to register {} attribute handler", EXCLUSIVE_ATTR_NAME);
}
}
} // namespace
49 changes: 49 additions & 0 deletions lib/attributes/backend/metal/inner.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include "attributes/backend/metal/common.h"
#include "attributes/frontend/params/loop.h"

#include <spdlog/spdlog.h>

namespace {
using namespace oklt;
using namespace clang;

HandleResult handleInnerAttribute(SessionStage& s,
const clang::ForStmt& forStmt,
const clang::Attr& a,
const AttributedLoop* params) {
SPDLOG_DEBUG("Handle [@inner] attribute");
handleChildAttr(s, forStmt, NO_BARRIER_ATTR_NAME);

auto& sema = s.tryEmplaceUserCtx<OklSemaCtx>();
auto loopInfo = sema.getLoopInfo(forStmt);
if (!loopInfo) {
return tl::make_unexpected(
Error{std::error_code(), "@inner: failed to fetch loop meta data from sema"});
}

// Auto Axis in loopInfo are replaced with specific.
// TODO: maybe somehow update params earlier?
auto updatedParams = *params;
updatedParams.axis = loopInfo->axis.front();

std::string afterRBraceCode = "";
if (loopInfo->shouldSync()) {
afterRBraceCode += metal::SYNC_THREADS_BARRIER + ";\n";
}

int openedScopeCounter = 0;
auto prefixCode = metal::buildInnerOuterLoopIdxLine(
*loopInfo, updatedParams, openedScopeCounter, s.getRewriter());
auto suffixCode = buildCloseScopes(openedScopeCounter);

return replaceAttributedLoop(s, forStmt, a, suffixCode, afterRBraceCode, prefixCode, true);
}

__attribute__((constructor)) void registerBackendHandler() {
auto ok = registerBackendHandler(TargetBackend::METAL, INNER_ATTR_NAME, handleInnerAttribute);

if (!ok) {
SPDLOG_ERROR("[METAL] Failed to register {} attribute handler", INNER_ATTR_NAME);
}
}
} // namespace
Loading