Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Set up CNN accelerator codegen skeleton + counting utilities #10

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,5 @@ conda/pkg
# nix files
.envrc
*.nix

tests/python/byo3la/*.params
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/CODEGENC.cmake)
include(cmake/modules/contrib/DNNL.cmake)
include(cmake/modules/contrib/ILAVTA.cmake)
include(cmake/modules/contrib/ILACNN.cmake)
include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/Posit.cmake)
include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
Expand Down
9 changes: 9 additions & 0 deletions cmake/modules/contrib/ILACNN.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
if(USE_ILACNN_CODEGEN STREQUAL "ON")
add_definitions(-DUSE_ILACNN_RUNTIME=1)
file(GLOB ILACNN_RELAY_CONTRIB_SRC src/relay/backend/contrib/ilacnn/*.cc)
list(APPEND COMPILER_SRCS ${ILACNN_RELAY_CONTRIB_SRC})
list(APPEND COMPILER_SRCS ${JSON_RELAY_CONTRIB_SRC})

file(GLOB ILACNN_CONTRIB_SRC src/runtime/contrib/ilacnn/ilacnn_runtime.cc)
list(APPEND RUNTIME_SRCS ${ILACNN_CONTRIB_SRC})
endif()
63 changes: 63 additions & 0 deletions python/tvm/relay/op/contrib/ilacnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Python bindings and helpers for ILACNN codegen,
note that the accelerator does not do padding for Conv2D's,
so you should use remove_padding on the main function before pattern matching
(this converts conv2d's with padding to conv2d(pad(data)))
"""
import tvm
from tvm import relay
from tvm.relay.expr_functor import ExprMutator
import tvm.ir
from ...dataflow_pattern import wildcard, is_op
from .register import register_pattern_table

def remove_padding(func):
"""
The CNN accelerator cannot handle padding in conv2d,
so this will rewrite all conv2d's with padding into
conv2d on a separately padded tensor (i.e., handle padding in the host)
"""
class PaddingRemover(ExprMutator):
def visit_call(self, call):
if call.attrs is None:
return super().visit_call(call)
attrs = call.attrs
if not isinstance(attrs, relay.op.op_attrs.Conv2DAttrs):
return super().visit_call(call)
padding = attrs.padding
# nothing to do if no padding
if all(map(lambda d: d == 0, padding)):
return super().visit_call(call)

# otherwise rewrite as a padded call
data = self.visit(call.args[0])
weight = self.visit(call.args[1])

# relay.nn.pad expects padding in the format of (x_left, x_right), (y_top, y_bottom)
data_layout = attrs.data_layout
# we are only padding the H and W dimensions
pad_dims = [(0, 0), (0, 0), (padding[0], padding[2]), (padding[1], padding[3])]
if data_layout == "NHWC":
pad_dims = [(0, 0), (padding[0], padding[2]), (padding[1], padding[3]), (0, 0)]

padded_data = relay.nn.pad(data, pad_dims)
return relay.nn.conv2d(padded_data, weight,
strides=attrs.strides,
padding=0,
dilation=attrs.dilation,
groups=attrs.groups,
channels=attrs.channels,
kernel_size=attrs.kernel_size,
data_layout=attrs.data_layout,
kernel_layout=attrs.kernel_layout,
out_layout=attrs.out_layout,
out_dtype=attrs.out_dtype)

remover = PaddingRemover()
return remover.visit(func)


@register_pattern_table("ilacnn")
def pattern_table():
conv2d_pattern = ("ilacnn.conv2d", is_op('nn.conv2d')(wildcard(), wildcard()))
return [conv2d_pattern]
2 changes: 2 additions & 0 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
# these are just for testing
from .exact_matcher import deduplicate_vars, check_compiler_call

from .op_summary import count_all_ops, count_all_overloads, count_all_ops_in_overloads

def run_opt_pass(expr, opt_pass, import_prelude=False):
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
Expand Down
72 changes: 72 additions & 0 deletions python/tvm/relay/testing/op_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Utility functions for counting the number of operators
and BYOC overloads in modules.
"""
import tvm
from tvm import relay
from tvm.relay.expr_functor import ExprVisitor

def is_overload(func):
if func.attrs is None:
return False
return "Compiler" in func.attrs


def get_count_expr(counter_class, expr):
counter = counter_class()
counter.visit(expr)
return counter.count


def get_count_mod(counter_class, mod):
total_count = 0
for gv in mod.get_global_vars():
total_count += get_count_expr(counter_class, mod[gv])
return total_count


class Counter(ExprVisitor):
def __init__(self):
super().__init__()
self.count = 0

def eligible(self, expr):
raise NotImplementedError()

def increment(self, expr):
return 1

def visit(self, expr):
if self.eligible(expr):
self.count += self.increment(expr)
super().visit(expr)


class OpCounter(Counter):
def eligible(self, expr):
return isinstance(expr, tvm.ir.op.Op)


class OverloadCounter(Counter):
def eligible(self, expr):
return isinstance(expr, relay.Function) and is_overload(expr)


class OpInOverloadCounter(Counter):
def eligible(self, expr):
return isinstance(expr, relay.Function) and is_overload(expr)

def increment(self, expr):
return get_count_expr(OpCounter, expr)


def count_all_ops(mod):
return get_count_mod(OpCounter, mod)


def count_all_overloads(mod):
return get_count_mod(OverloadCounter, mod)


def count_all_ops_in_overloads(mod):
return get_count_mod(OpInOverloadCounter, mod)
98 changes: 98 additions & 0 deletions src/relay/backend/contrib/ilacnn/ilacnn_codegen.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>

#include <fstream>
#include <iostream>
#include <numeric>
#include <sstream>

#include "../../utils.h"

#include "../../../../runtime/contrib/json/json_node.h"
#include "../codegen_json/codegen_json.h"

namespace tvm {
namespace relay {
namespace contrib {

using namespace backend;

class IlaCNNJSONSerializer : public backend::contrib::JSONSerializer {
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;

public:
IlaCNNJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {}

std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* cn) override {
std::string name;

if (const auto* op_node = cn->op.as<OpNode>()) {
name = op_node->name;
} else if (const auto* fn = cn->op.as<FunctionNode>()) {
auto comp = fn->GetAttr<String>(attr::kComposite);
CHECK(comp.defined())
<< "JSON runtime only supports composite functions.";
name = comp.value();

if (name != "ilacnn.conv2d") {
LOG(FATAL) << "Unrecognized pattern: " << name;
}
} else {
LOG(FATAL) << "IlaCNN runtime does not support calls to "
<< cn->op->GetTypeKey();
}
LOG(INFO) << "[Pattern Matching] Find annotated: " << name;

std::vector<JSONGraphNodeEntry> inputs;
for (const auto& arg : cn->args) {
auto res = VisitExpr(arg);
inputs.insert(inputs.end(), res.begin(), res.end());
}
auto node = std::make_shared<JSONGraphNode>(name, /* name_ */
"kernel", /* op_type_ */
inputs, 1 /* num_outputs_ */);

// Note: conv2d has a lot of attrs that are relevant for codegen,
// especially the stride size.
// However, the pattern matcher will produce patterns in the form of
// fn(Compiler="ilacnn") {
// fn(Composuite="ilacnn.conv2d") { nn.conv2d(...) }
// }
// so we need to reach inside the inner function to get the conv2d attrs (weird, yeah);
// see codegen_json.h:SetCallNodeAttribute

tvm::relay::backend::contrib::OpAttrExtractor extractor(node);
auto inner_func = Downcast<Function>(cn->op);
auto inner_call = Downcast<Call>(inner_func->body);
const Object* inner_call_attr = inner_call->attrs.get();
extractor.Extract(const_cast<Object*>(inner_call_attr));
return AddNode(node, GetRef<Expr>(cn));
}
}; // class IlaCNNJSONSerializer

runtime::Module IlaCNNCompiler(const ObjectRef& ref) {
CHECK(ref->IsInstance<FunctionNode>());
auto func = Downcast<Function>(ref);
auto func_name = GetExtSymbol(func);

IlaCNNJSONSerializer serializer(func_name, func);
serializer.serialize();
std::string graph_json = serializer.GetJSON();
auto params = serializer.GetParams();

const auto* pf = runtime::Registry::Get("runtime.IlaCNNRuntimeCreate");
CHECK(pf != nullptr) << "Cannot find IlaCNN runtime module to create";
auto mod = (*pf)(func_name, graph_json, params);
return mod;
}

TVM_REGISTER_GLOBAL("relay.ext.ilacnn").set_body_typed(IlaCNNCompiler);

} // namespace contrib
} // namespace relay
} // namespace tvm
Loading