From 67749a7546389e71ebe1a2429d48a14a6818d684 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Wed, 17 Jan 2024 18:41:49 +0100 Subject: [PATCH] also sanitize results, attributes, owned_regions, and successors. (#43) --- deps/tblgen/jl-generators.cc | 47 ++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/deps/tblgen/jl-generators.cc b/deps/tblgen/jl-generators.cc index 931f6668..6362ff88 100644 --- a/deps/tblgen/jl-generators.cc +++ b/deps/tblgen/jl-generators.cc @@ -133,6 +133,24 @@ namespace return dialect_name; } + std::string sanitizeName(std::string name) { + // check if name starts with digit: + if (std::isdigit(name[0])) + { + name = "_" + name; + } + // check if name colides with Julia keywords, generated module name, or "location": + // https://docs.julialang.org/en/v1/base/base/#Keywords + std::vector reservedKeywords = {"location", "baremodule", "begin", "break", "catch", "const", "continue", "do", "else", "elseif", "end", "export", "false", "finally", "for", "function", "global", "if", "import", "let", "local", "macro", "module", "public", "quote", "return", "struct", "true", "try", "using", "while"}; + if (std::find(reservedKeywords.begin(), reservedKeywords.end(), name) != reservedKeywords.end()) + { + name = name + "_"; + } + // replace all .'s with _'s + std::replace(name.begin(), name.end(), '.', '_'); + return name; + } + } // namespace bool emitOpTableDefs(const llvm::RecordKeeper &recordKeeper, @@ -193,15 +211,7 @@ end auto opname = op.getOperationName(); auto functionname = opname.substr(op.getDialectName().str().length() + 1); // get rid of "dialect." prefix. - // check if functionname colides with Julia keywords or generated module name: - // https://docs.julialang.org/en/v1/base/base/#Keywords - std::vector reservedKeywords = {modulename, "baremodule", "begin", "break", "catch", "const", "continue", "do", "else", "elseif", "end", "export", "false", "finally", "for", "function", "global", "if", "import", "let", "local", "macro", "module", "public", "quote", "return", "struct", "true", "try", "using", "while"}; - if (std::find(reservedKeywords.begin(), reservedKeywords.end(), functionname) != reservedKeywords.end()) - { - functionname = functionname + "_"; - } - // replace all .'s with _'s - std::replace(functionname.begin(), functionname.end(), '.', '_'); + functionname = sanitizeName(functionname); std::string description = ""; if (op.hasDescription()) @@ -220,8 +230,8 @@ end { operandname = "operand_" + std::to_string(i); } + operandname = sanitizeName(operandname); - // auto type = named_operand.constraint.getPredicate().getCondition(); std::string type = "Value"; bool optional = named_operand.isOptional(); @@ -292,6 +302,7 @@ end { resultname = "result_" + std::to_string(i); } + resultname = sanitizeName(resultname); std::string type = "MLIRType"; bool optional = named_result.isOptional() || inferrable; @@ -333,25 +344,23 @@ end std::string defaultvalue = ""; std::string attributename = named_attr.name.str(); - if (attributename.empty()) - { - attributename = "attribute_" + std::to_string(i); - } + assert(!attributename.empty() && "expected NamedAttribute to have a name"); + std::string sanitizedname = sanitizeName(attributename); bool optional = named_attr.attr.isOptional() || named_attr.attr.hasDefaultValue(); if (optional) { - optionals += llvm::formatv(R"(({0} != nothing) && push!(attributes, namedattribute("{0}", {0})) + optionals += llvm::formatv(R"(({0} != nothing) && push!(attributes, namedattribute("{0}", {1})) )", - attributename); + attributename, sanitizedname); defaultvalue = "=nothing"; } else { - attributecontainer += "namedattribute(\"" + attributename + "\", " + attributename + "), "; + attributecontainer += "namedattribute(\"" + attributename + "\", " + sanitizedname + "), "; } - attributearguments += attributename + defaultvalue + ", "; + attributearguments += sanitizedname + defaultvalue + ", "; } std::string regionarguments = ""; @@ -365,6 +374,7 @@ end { regionname = "region_" + std::to_string(i); } + regionname = sanitizeName(regionname); std::string type = "Region"; bool variadic = named_region.isVariadic(); @@ -389,6 +399,7 @@ end { successorname = "successor_" + std::to_string(i); } + successorname = sanitizeName(successorname); std::string type = "Block"; bool variadic = named_successor.isVariadic();