Skip to content

Commit

Permalink
support (variadic) regions
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx committed Sep 26, 2023
1 parent 9c319a7 commit c2d4f98
Showing 1 changed file with 69 additions and 16 deletions.
85 changes: 69 additions & 16 deletions deps/tblgen/jl-generators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,61 @@ namespace
std::vector<mlir::tblgen::NamedTypeConstraint> operands;
};

class RegionsGenerator
{
RegionsGenerator(std::vector<std::string> binders,
std::vector<mlir::tblgen::NamedRegion> regions)
: binders(std::move(binders)),
regions(std::move(regions)) {}

public:
static std::optional<RegionsGenerator> buildFor(mlir::tblgen::Operator &op)
{
if (op.getNumRegions() == 0)
return RegionsGenerator({}, {});

std::vector<std::string> binders;
std::vector<mlir::tblgen::NamedRegion> regions;

for (const auto &named_region : op.getRegions())
{
binders.push_back(sanitizeName(named_region.name) + "_");
regions.push_back(named_region);
}
if (binders.empty())
return RegionsGenerator({}, {});
return RegionsGenerator(std::move(binders),
std::move(regions));
}

void print(llvm::raw_ostream &os) const
{
std::vector<std::string> required_region_creator;

for (size_t i = 0; i < regions.size(); ++i)
{
const mlir::tblgen::NamedRegion &nregion = regions[i];
const auto postfix = nregion.isVariadic() ? "..." : "";
required_region_creator.push_back(llvm::formatv("{0}{1}", binders[i], postfix));
}
const char *kRegionPattern = R"(regions = [{0:$[, ]}]
)";
os << llvm::formatv(kRegionPattern,
make_range(required_region_creator));
}

std::vector<std::string> types() const
{
return map_vector(regions, [](const mlir::tblgen::NamedRegion &p)
{ return std::string(p.isVariadic() ? "Vector{Region}" : "Region"); });
}

std::vector<std::string> binders;

private:
std::vector<mlir::tblgen::NamedRegion> regions;
};

class SuccessorsGenerator
{
SuccessorsGenerator(std::vector<std::string> binders,
Expand All @@ -515,7 +570,6 @@ namespace
return SuccessorsGenerator({}, {});

std::vector<std::string> binders;
std::vector<std::string> default_values;
std::vector<mlir::tblgen::NamedSuccessor> successors;
for (const auto &successor : op.getSuccessors())
{
Expand Down Expand Up @@ -680,24 +734,18 @@ namespace
return std::optional<std::string>();
};

// Skip currently unsupported cases
if (op.getNumVariadicRegions() != 0)
return fail("variadic regions");
// if (op.getNumSuccessors() != 0) return fail("successors");

const char *kPatternExplicitType = R"(create_operation(
"{0}", {1},
results = results,
operands = operands,
owned_regions = [{2:$[, ]}],
owned_regions = regions,
successors = successors,
attributes = attributes,
result_inference=false
))";
return llvm::formatv(kPatternExplicitType,
op.getOperationName(), // 0
location_expr, // 1
make_range(region_exprs)) // 2
op.getOperationName(), // 0
location_expr) // 1
.str();
}

Expand All @@ -723,7 +771,7 @@ namespace
* @brief Emit Julia function definition, managing attributes, operands, and successors, for creating an operation.
*/
void emitPattern(const llvm::Record *def, const ResultsGenerator &results,
const OperandsGenerator &operands, const SuccessorsGenerator &successors,
const OperandsGenerator &operands, const RegionsGenerator &regions, const SuccessorsGenerator &successors,
const AttributesGenerator &attr_pattern, llvm::raw_ostream &os)
{
mlir::tblgen::Operator op(def);
Expand All @@ -733,10 +781,6 @@ namespace
};

// Skip currently unsupported cases
// if (op.getNumVariableLengthResults() != 0)
// return fail("variadic results");
if (op.getNumRegions() != 0)
return fail("regions");
// if (op.getNumSuccessors() != 0) return fail("successors");
if (!def->getName().endswith("Op"))
return fail("unsupported name format");
Expand All @@ -759,6 +803,11 @@ namespace
pattern_arg_types.insert(pattern_arg_types.end(), operand_types.begin(),
operand_types.end());

// Prepare regions
auto region_types = regions.types();
pattern_arg_types.insert(pattern_arg_types.end(), region_types.begin(),
region_types.end());

// Prepare successors
auto successor_types = successors.types();
pattern_arg_types.insert(pattern_arg_types.end(), successor_types.begin(),
Expand All @@ -780,13 +829,15 @@ namespace
binders.push_back("location");
binders.insert(binders.end(), results.binders.begin(), results.binders.end());
binders.insert(binders.end(), operands.binders.begin(), operands.binders.end());
binders.insert(binders.end(), regions.binders.begin(), regions.binders.end());
binders.insert(binders.end(), successors.binders.begin(), successors.binders.end());
binders.insert(binders.end(), attr_pattern.binders.begin(), attr_pattern.binders.end());

// fill default values with empty strings except for attributes
default_values.push_back("");
default_values.insert(default_values.end(), results.default_values.begin(), results.default_values.end());
default_values.insert(default_values.end(), operands.default_values.begin(), operands.default_values.end());
default_values.insert(default_values.end(), regions.binders.size(), "");
default_values.insert(default_values.end(), successors.binders.size(), "");
default_values.insert(default_values.end(), attr_pattern.default_values.begin(), attr_pattern.default_values.end());

Expand All @@ -802,6 +853,7 @@ namespace

results.print(stream);
operands.print(stream);
regions.print(stream);
successors.print(stream);

attr_pattern.print(stream, attr_pattern_state);
Expand Down Expand Up @@ -888,10 +940,11 @@ end
}
std::optional<ResultsGenerator> results = ResultsGenerator::buildFor(op);
std::optional<OperandsGenerator> operands = OperandsGenerator::buildFor(op);
std::optional<RegionsGenerator> regions = RegionsGenerator::buildFor(op);
std::optional<SuccessorsGenerator> successors = SuccessorsGenerator::buildFor(op);
std::optional<AttributesGenerator> attr_pattern = AttributesGenerator::buildFor(op);

emitPattern(def, *results, *operands, *successors, *attr_pattern, os);
emitPattern(def, *results, *operands, *regions, *successors, *attr_pattern, os);
os << "\n";
}

Expand Down

0 comments on commit c2d4f98

Please sign in to comment.