Skip to content

Commit

Permalink
make values getter class/wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Nov 28, 2024
1 parent 7efb729 commit 1fb48f2
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 96 deletions.
3 changes: 2 additions & 1 deletion projects/eudsl/eudsl-tblgen/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
click==8.1.7
pytest==8.3.3
pytest==8.3.3
nanobind==2.2.0
28 changes: 24 additions & 4 deletions projects/eudsl/eudsl-tblgen/src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.29)
project(eudsl_tblgen CXX C)
set (CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 17)

find_package(LLVM REQUIRED CONFIG)
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
Expand All @@ -15,11 +15,31 @@ execute_process(
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_DIR)
find_package(nanobind CONFIG REQUIRED)

set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/eudsl_tblgen)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)

nanobind_add_module(eudsl_tblgen_ext NB_STATIC eudsl_tblgen_ext.cpp TGParser.cpp TGLexer.cpp)
nanobind_add_module(eudsl_tblgen_ext NB_STATIC STABLE_ABI eudsl_tblgen_ext.cpp TGParser.cpp TGLexer.cpp)
target_link_libraries(eudsl_tblgen_ext PRIVATE LLVMTableGenCommon LLVMTableGen)
nanobind_add_stub(
eudsl_tblgen_ext_stub
MODULE eudsl_tblgen_ext
OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/eudsl_tblgen/eudsl_tblgen_ext.pyi
PYTHON_PATH $<TARGET_FILE_DIR:eudsl_tblgen_ext>
DEPENDS eudsl_tblgen_ext
)
nanobind_add_stub(
eudsl_tblgen_stub
MODULE eudsl_tblgen
OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/eudsl_tblgen/__init__.pyi
PYTHON_PATH ${CMAKE_CURRENT_SOURCE_DIR}
DEPENDS eudsl_tblgen_ext
)

install(TARGETS eudsl_tblgen_ext LIBRARY DESTINATION eudsl_tblgen)
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/eudsl_tblgen DESTINATION ${CMAKE_INSTALL_PREFIX})
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/eudsl_tblgen
DESTINATION ${CMAKE_INSTALL_PREFIX}
PATTERN "*.so" EXCLUDE
PATTERN "*.a" EXCLUDE
PATTERN ".gitignore" EXCLUDE
)
3 changes: 3 additions & 0 deletions projects/eudsl/eudsl-tblgen/src/eudsl_tblgen/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.so
*.a
*.pyi
129 changes: 89 additions & 40 deletions projects/eudsl/eudsl-tblgen/src/eudsl_tblgen_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ constexpr auto coerceReturn(Return (*pf)(Args...)) noexcept {
}

template <typename NewReturn, typename Return, typename Class, typename... Args>
constexpr auto coerceReturn(Return (Class:: *pmf)(Args...),
constexpr auto coerceReturn(Return (Class::*pmf)(Args...),
std::false_type = {}) noexcept {
return [&pmf](Class *cls, Args &&...args) -> NewReturn {
return (cls->*pmf)(std::forward<Args>(args)...);
Expand All @@ -40,7 +40,7 @@ constexpr auto coerceReturn(Return (Class:: *pmf)(Args...),
* passing the `this` pointer as the first arg
*/
template <typename NewReturn, typename Return, typename Class, typename... Args>
constexpr auto coerceReturn(Return (Class:: *pmf)(Args...) const,
constexpr auto coerceReturn(Return (Class::*pmf)(Args...) const,
std::true_type) noexcept {
// copy the *pmf, not capture by ref
return [pmf](const Class &cls, Args &&...args) -> NewReturn {
Expand All @@ -50,7 +50,7 @@ constexpr auto coerceReturn(Return (Class:: *pmf)(Args...) const,

template <>
struct nb::detail::type_caster<StringRef> {
NB_TYPE_CASTER(StringRef, const_name("str_ref"))
NB_TYPE_CASTER(StringRef, const_name("str"))

bool from_python(handle src, uint8_t, cleanup_list *) noexcept {
Py_ssize_t size;
Expand All @@ -74,7 +74,6 @@ struct HackInit : public Init {
};

NB_MODULE(eudsl_tblgen_ext, m) {

auto recty = nb::class_<RecTy>(m, "RecTy");

nb::enum_<RecTy::RecTyKind>(m, "RecTyKind")
Expand All @@ -90,13 +89,13 @@ NB_MODULE(eudsl_tblgen_ext, m) {
.def_prop_ro("record_keeper", &RecTy::getRecordKeeper)
.def_prop_ro("as_string", &RecTy::getAsString)
.def("__str__", &RecTy::getAsString)
.def("type_is_a", &RecTy::typeIsA)
.def("type_is_convertible_to", &RecTy::typeIsConvertibleTo);
.def("type_is_a", &RecTy::typeIsA, "rhs"_a)
.def("type_is_convertible_to", &RecTy::typeIsConvertibleTo, "rhs"_a);

nb::class_<RecordRecTy, RecTy>(m, "RecordRecTy")
.def_prop_ro("classes", coerceReturn<std::vector<const Record *>>(
&RecordRecTy::getClasses, nb::const_))
.def("is_sub_class_of", &RecordRecTy::isSubClassOf);
.def("is_sub_class_of", &RecordRecTy::isSubClassOf, "class_"_a);

nb::enum_<HackInit::InitKind>(m, "InitKind")
.value("IK_FirstTypedInit", HackInit::InitKind::IK_FirstTypedInit)
Expand Down Expand Up @@ -131,9 +130,10 @@ NB_MODULE(eudsl_tblgen_ext, m) {
.def("__str__", &Init::getAsUnquotedString)
.def("is_complete", &Init::isComplete)
.def("is_concrete", &Init::isConcrete)
.def("get_field_type", &Init::getFieldType,
.def("get_field_type", &Init::getFieldType, "field_name"_a,
nb::rv_policy::reference_internal)
.def("get_bit", &Init::getBit, nb::rv_policy::reference_internal);
.def("get_bit", &Init::getBit, "bit"_a,
nb::rv_policy::reference_internal);

nb::class_<TypedInit, Init>(m, "TypedInit")
.def_prop_ro("record_keeper", &TypedInit::getRecordKeeper)
Expand All @@ -149,7 +149,8 @@ NB_MODULE(eudsl_tblgen_ext, m) {
.def_prop_ro("name", &ArgumentInit::getName);

nb::class_<BitInit, TypedInit>(m, "BitInit")
.def_prop_ro("value", &BitInit::getValue);
.def_prop_ro("value", &BitInit::getValue)
.def("__bool__", &BitInit::getValue);

nb::class_<BitsInit, TypedInit>(m, "BitsInit")
.def_prop_ro("num_bits", &BitsInit::getNumBits)
Expand Down Expand Up @@ -189,13 +190,15 @@ NB_MODULE(eudsl_tblgen_ext, m) {
},
nb::rv_policy::reference_internal)
.def_prop_ro("element_type", &ListInit::getElementType)
.def("get_element_as_record", &ListInit::getElementAsRecord)
.def("get_element_as_record", &ListInit::getElementAsRecord, "i"_a,
nb::rv_policy::reference_internal)
.def_prop_ro("values", coerceReturn<std::vector<const Init *>>(
&ListInit::getValues, nb::const_));

nb::class_<OpInit, TypedInit>(m, "OpInit")
.def_prop_ro("num_operands", &OpInit::getNumOperands)
.def("operand", &OpInit::getOperand, nb::rv_policy::reference_internal);
.def("operand", &OpInit::getOperand, "i"_a,
nb::rv_policy::reference_internal);

auto unaryOpInit = nb::class_<UnOpInit, OpInit>(m, "UnOpInit");
nb::enum_<UnOpInit::UnaryOp>(m, "UnaryOp")
Expand Down Expand Up @@ -282,7 +285,8 @@ NB_MODULE(eudsl_tblgen_ext, m) {
.def_prop_ro("def_", &DefInit::getDef);

nb::class_<VarDefInit, TypedInit>(m, "VarDefInit")
.def("get_arg", &VarDefInit::getArg, nb::rv_policy::reference_internal)
.def("get_arg", &VarDefInit::getArg, "i"_a,
nb::rv_policy::reference_internal)
.def_prop_ro("args", coerceReturn<std::vector<const ArgumentInit *>>(
&VarDefInit::args, nb::const_))
.def("__len__", [](const VarDefInit &v) { return v.args_size(); })
Expand Down Expand Up @@ -311,11 +315,12 @@ NB_MODULE(eudsl_tblgen_ext, m) {
.def_prop_ro("name_init", &DagInit::getName)
.def_prop_ro("name_str", &DagInit::getNameStr)
.def_prop_ro("num_args", &DagInit::getNumArgs)
.def("get_arg", &DagInit::getArg, nb::rv_policy::reference_internal)
.def("get_arg_no", &DagInit::getArgNo)
.def("get_arg_name_init", &DagInit::getArgName,
.def("get_arg", &DagInit::getArg, "num"_a,
nb::rv_policy::reference_internal)
.def("get_arg_no", &DagInit::getArgNo, "name"_a)
.def("get_arg_name_init", &DagInit::getArgName, "num"_a,
nb::rv_policy::reference_internal)
.def("get_arg_name_str", &DagInit::getArgNameStr)
.def("get_arg_name_str", &DagInit::getArgNameStr, "num"_a)
.def("get_arg_name_inits",
coerceReturn<std::vector<const StringInit *>>(&DagInit::getArgNames,
nb::const_),
Expand Down Expand Up @@ -350,8 +355,41 @@ NB_MODULE(eudsl_tblgen_ext, m) {
.def_prop_ro("is_nonconcrete_ok", &RecordVal::isNonconcreteOK)
.def_prop_ro("is_template_arg", &RecordVal::isTemplateArg)
.def_prop_ro("value", &RecordVal::getValue)
.def("__str__",
[](const RecordVal &self) {
return self.getValue()->getAsUnquotedString();
})
.def_prop_ro("is_used", &RecordVal::isUsed);

nb::class_<ArrayRef<RecordVal>>(m, "ArrayRefofRecordVal");
struct RecordValues {};
auto valuesCl =
nb::class_<RecordValues>(m, "RecordValues", nb::dynamic_attr())
.def("__init__",
[](nb::object &self, ArrayRef<RecordVal> values) {
for (const RecordVal &recordVal : values) {
nb::setattr(self, recordVal.getName().str().c_str(),
nb::borrow(nb::cast(recordVal)));
}
})
.def("__repr__", [](const nb::object &self) {
nb::str s{"RecordValues("};
auto dic = nb::cast<nb::dict>(nb::getattr(self, "__dict__"));
int i = 0;
for (auto [key, value] : dic) {
s += key + nb::str("=") +
nb::str(nb::cast<RecordVal>(value)
.getValue()
->getAsUnquotedString()
.c_str());
if (i < dic.size() - 1)
s += nb::str(", ");
++i;
}
s += nb::str(")");
return s;
});

nb::class_<Record>(m, "Record")
.def_prop_ro("direct_super_classes",
[](const Record &self) -> std::vector<const Record *> {
Expand All @@ -365,43 +403,54 @@ NB_MODULE(eudsl_tblgen_ext, m) {
.def_prop_ro("records", &Record::getRecords)
.def_prop_ro("type", &Record::getType)
.def("get_value", nb::overload_cast<StringRef>(&Record::getValue),
nb::rv_policy::reference_internal)
.def("get_value_as_bit", &Record::getValueAsBit)
.def("get_value_as_def", &Record::getValueAsDef)
.def("get_value_as_int", &Record::getValueAsInt)
"name"_a, nb::rv_policy::reference_internal)
.def("get_value_as_bit", &Record::getValueAsBit, "field_name"_a)
.def("get_value_as_def", &Record::getValueAsDef, "field_name"_a)
.def("get_value_as_int", &Record::getValueAsInt, "field_name"_a)
.def("get_value_as_list_of_defs", &Record::getValueAsListOfDefs,
nb::rv_policy::reference_internal)
.def("get_value_as_list_of_ints", &Record::getValueAsListOfInts)
.def("get_value_as_list_of_strings", &Record::getValueAsListOfStrings)
"field_name"_a, nb::rv_policy::reference_internal)
.def("get_value_as_list_of_ints", &Record::getValueAsListOfInts,
"field_name"_a)
.def("get_value_as_list_of_strings", &Record::getValueAsListOfStrings,
"field_name"_a)
.def("get_value_as_optional_def", &Record::getValueAsOptionalDef,
nb::rv_policy::reference_internal)
.def("get_value_as_optional_string", &Record::getValueAsOptionalString)
.def("get_value_as_string", &Record::getValueAsString)
.def("get_value_as_bit_or_unset", &Record::getValueAsBitOrUnset)
"field_name"_a, nb::rv_policy::reference_internal)
.def("get_value_as_optional_string", &Record::getValueAsOptionalString,
nb::sig("def get_value_as_optional_string(self, field_name: str, /) "
"-> Optional[str]"))
.def("get_value_as_string", &Record::getValueAsString, "field_name"_a)
.def("get_value_as_bit_or_unset", &Record::getValueAsBitOrUnset,
"field_name"_a, "unset"_a)
.def("get_value_as_bits_init", &Record::getValueAsBitsInit,
nb::rv_policy::reference_internal)
.def("get_value_as_dag", &Record::getValueAsDag,
"field_name"_a, nb::rv_policy::reference_internal)
.def("get_value_as_dag", &Record::getValueAsDag, "field_name"_a,
nb::rv_policy::reference_internal)
.def("get_value_as_list_init", &Record::getValueAsListInit,
"field_name"_a, nb::rv_policy::reference_internal)
.def("get_value_init", &Record::getValueInit, "field_name"_a,
nb::rv_policy::reference_internal)
.def("get_value_init", &Record::getValueInit,
nb::rv_policy::reference_internal)
.def_prop_ro("values", coerceReturn<std::vector<RecordVal>>(
&Record::getValues, nb::const_))
.def("has_direct_super_class", &Record::hasDirectSuperClass)
.def_prop_ro("values",
[&valuesCl](Record &self) {
ArrayRef<RecordVal> values = self.getValues();
return valuesCl(values);
})
.def("has_direct_super_class", &Record::hasDirectSuperClass,
"super_class"_a)
.def_prop_ro("is_anonymous", &Record::isAnonymous)
.def_prop_ro("is_class", &Record::isClass)
.def_prop_ro("is_multi_class", &Record::isMultiClass)
.def("is_sub_class_of",
nb::overload_cast<const Record *>(&Record::isSubClassOf, nb::const_))
nb::overload_cast<const Record *>(&Record::isSubClassOf, nb::const_),
"r"_a)
.def("is_sub_class_of",
nb::overload_cast<StringRef>(&Record::isSubClassOf, nb::const_))
.def("is_value_unset", &Record::isValueUnset)
nb::overload_cast<StringRef>(&Record::isSubClassOf, nb::const_),
"name"_a)
.def("is_value_unset", &Record::isValueUnset, "field_name"_a)
.def_prop_ro("def_init", &Record::getDefInit)
.def_prop_ro("name_init", &Record::getNameInit)
.def_prop_ro("template_args", coerceReturn<std::vector<const Init *>>(
&Record::getTemplateArgs, nb::const_))
.def("is_template_arg", &Record::isTemplateArg);
.def("is_template_arg", &Record::isTemplateArg, "name"_a);

using RecordMap = std::map<std::string, std::unique_ptr<Record>, std::less<>>;
using GlobalMap = std::map<std::string, const Init *, std::less<>>;
Expand Down Expand Up @@ -475,5 +524,5 @@ NB_MODULE(eudsl_tblgen_ext, m) {
.def("get_all_derived_definitions",
coerceReturn<std::vector<const Record *>, ArrayRef<const Record *>>(
&RecordKeeper::getAllDerivedDefinitions, nb::const_),
nb::rv_policy::reference_internal);
"class_name"_a, nb::rv_policy::reference_internal);
}
Loading

0 comments on commit 1fb48f2

Please sign in to comment.