diff --git a/deps/ReactantExtra/tblgen/jl-generators.cc b/deps/ReactantExtra/tblgen/jl-generators.cc index 4853a3dd1..01db867cb 100644 --- a/deps/ReactantExtra/tblgen/jl-generators.cc +++ b/deps/ReactantExtra/tblgen/jl-generators.cc @@ -13,32 +13,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include +#include +#include #include #include +#include #include "mlir/TableGen/Argument.h" -#include "mlir/TableGen/Class.h" #include "mlir/TableGen/CodeGenHelpers.h" -#include "mlir/TableGen/Format.h" #include "mlir/TableGen/Interfaces.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/Region.h" -#include "mlir/TableGen/SideEffects.h" #include "mlir/TableGen/Trait.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/StringSet.h" -#include "llvm/ADT/iterator_range.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/FormatAdapters.h" -#include "llvm/Support/FormatCommon.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/Path.h" #include "llvm/Support/Signals.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TableGen/Error.h" @@ -87,21 +78,45 @@ static bool canInferType(const Operator &op) { hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); } -std::string formatDescription(mlir::tblgen::Operator op) { - std::string description; - description = op.getDescription().str(); +std::string assemblyFormatToJulia( + std::string s, + const std::function &applyJuliaFormat) { + auto p = -1; + auto output = std::string(); + auto length = s.length() - 1; + for (auto [i, c] : llvm::enumerate(s)) { + if (c == '`') + continue; + if (c == '$') + p = i; + + if (p != -1 && (c == ' ' || length == i)) { + auto name = s.substr(p + 1, i - p - 1); + auto new_name = applyJuliaFormat(name); + output.append(new_name); + p = -1; + continue; + } + + if (p == -1 && c != ' ') + output.push_back(c); + } + return output; +} + +std::string formatDescription(std::string name, std::string description) { size_t pos = 0; while (description[pos] == '\n') ++pos; - size_t leading_spaces = 0; + size_t leadingSpaces = 0; while (description[pos++] == ' ') - ++leading_spaces; - if (leading_spaces) { - std::string leading_spaces_str; - for (size_t i = 0; i < leading_spaces; ++i) - leading_spaces_str += "[ ]"; - description = std::regex_replace( - description, std::regex("\n" + leading_spaces_str), "\n"); + ++leadingSpaces; + if (leadingSpaces) { + std::string leadingSpacesStr; + for (size_t i = 0; i < leadingSpaces; ++i) + leadingSpacesStr += "[ ]"; + description = std::regex_replace(description, + std::regex("\n" + leadingSpacesStr), "\n"); } description = std::regex_replace(description, std::regex(R"(\\)"), R"(\\)"); description = std::regex_replace(description, std::regex("(['\"$])"), "\\$1"); @@ -112,27 +127,31 @@ std::string formatDescription(mlir::tblgen::Operator op) { while (std::isspace(description.back())) { description.pop_back(); } - return description; + + return "\"\"\"\n`" + name + "`\n" + description + "\n\"\"\""; } -std::string getDialectName(llvm::ArrayRef op_defs) { - mlir::tblgen::Operator any_op(op_defs.front()); - assert( - std::all_of(op_defs.begin(), op_defs.end(), [&any_op](llvm::Record *op) { - return mlir::tblgen::Operator(op).getDialectName() == - any_op.getDialectName(); - })); - std::string dialect_name; +std::string getDialectName(llvm::ArrayRef opDefs) { + mlir::tblgen::Operator anyOp(opDefs.front()); + assert(std::all_of(opDefs.begin(), opDefs.end(), + [&anyOp](const llvm::Record *op) { + return mlir::tblgen::Operator(op).getDialectName() == + anyOp.getDialectName(); + })); + std::string dialectName; if (DialectName.empty()) { - dialect_name = any_op.getDialectName().str(); + dialectName = anyOp.getDialectName().str(); } else { - dialect_name = DialectName; + dialectName = DialectName; } - return dialect_name; + return dialectName; } std::string sanitizeName(std::string name, std::optional modulename = std::nullopt) { + if (name.empty()) { + return "empty"; + } // check if name starts with digit: if (std::isdigit(name[0])) { name = "_" + name; @@ -163,39 +182,292 @@ std::string sanitizeName(std::string name, extern bool disableModuleWrap; +template +std::optional get(llvm::StringMap m, std::string k) { + auto entry = m.find(k); + return (entry != m.end()) ? std::optional(entry->getValue()) + : std::nullopt; +} + +std::string removeNamespace(std::string s) { + auto pos = s.rfind("::"); + if (pos >= s.length()) + return s; + return s.substr(pos + 2); +} + +auto attribs = std::string(); + +llvm::StringMap attributeCache; + +std::string emitEnum(llvm::Record def, std::string dialect) { + EnumAttr e(def.isSubClassOf("EnumAttrInfo") ? def + : *def.getValueAsDef("enum")); + auto tableGenName = def.getName().str(); + if (auto cached = get(attributeCache, tableGenName)) + return *cached; + + auto base = e.getBaseAttrClass(); + auto enumJuliaType_ = e.getEnumClassName().str(); + auto enumJuliaType = sanitizeName(enumJuliaType_); + auto juliaEnum = "@enumx " + enumJuliaType + ' '; + auto juliaStorage = enumJuliaType + "Storage"; + enumJuliaType += ".T"; + auto mlirAttributeDef = "IR.Attribute(e::" + enumJuliaType + ") = "; + auto isSpecialized = e.genSpecializedAttr(); + if (!isSpecialized) { // parse the attribute using the name + auto juliaNameArray = "const " + juliaStorage + " = ["; + auto mnemonic = def.getValueAsString("mnemonic"); + for (auto c : e.getAllCases()) { + juliaEnum += sanitizeName(c.getSymbol().str()) + ' '; + juliaNameArray += '"' + c.getStr().str() + "\", "; + } + + juliaEnum += + "\n" + juliaNameArray.substr(0, juliaNameArray.size() - 2) + "]"; + auto assemblyFormat = assemblyFormatToJulia( + def.getValueAsString("assemblyFormat").str(), + [&](std::string _) { return "$(" + juliaStorage + "[Int(e)+1])"; }); + + mlirAttributeDef += llvm::formatv(R"(parse(Attribute,"#{0}<{1} {2}>"))", + dialect, mnemonic, assemblyFormat); + } else { + for (auto c : e.getAllCases()) { + juliaEnum += sanitizeName(c.getSymbol().str()) + '=' + + std::to_string(c.getValue()) + ' '; + } + mlirAttributeDef += "Int(e)"; + } + attributeCache.insert({tableGenName, enumJuliaType}); + if (auto description = def.getValueAsOptionalString("summary")) { + attribs += + '\n' + formatDescription(enumJuliaType_, description->str()) + '\n'; + } + attribs += juliaEnum + "\n\n" + mlirAttributeDef + "\n\n"; + return enumJuliaType; +} + +const llvm::StringMap cppToJuliaTypeMap = { + {"int32_t", "Int32"}, + {"int64_t", "Int64"}, + {"uint32_t", + "Int32"}, // TODO: both are handled strangly => Int is working... + {"uint64_t", "Int64"}, + {"bool", "Bool"}, + {"Type", "IR.Type"}, + {"FunctionType", "IR.Type"}, + {"Attribute", "IR.AbstractAttribute"}, + {"StringRef", "String"}, + {"ArrayAttr", "Vector{<:IR.AbstractAttribute}"}, + {"FlatSymbolRefAttr", "IR.FlatSymbolRefAttribute"}, + {"DenseIntElementsAttr", "IR.AbstractDenseElementsAttribute{Int64}"}, + {"ElementsAttr", "IR.AbstractDenseElementsAttribute"}, +}; + +std::optional +cppToJuliaType(std::string t, std::optional attr = std::nullopt) { + return llvm::StringSwitch()>>(t) + .StartsWith("ArrayRef", + [&]() -> std::optional { + auto outType = t.substr(9, t.length() - 10); + outType = removeNamespace(outType); + auto in = cppToJuliaType(outType); + if (!in) + return in; + return llvm::formatv("IR.DenseAttribute{{{}}", *in).str(); + }) + .Case("APFloat", + [&]() -> std::optional { + if (!attr) + return std::nullopt; + auto type = attr->getDef().getValueAsOptionalDef("valueType"); + if (!type) + return std::nullopt; + return "Float" + type->getName().substr(1).str(); + }) + .Default([&]() { return get(cppToJuliaTypeMap, t); })(); +} + +std::string toPascalCase(std::string s) { + std::string output = ""; + auto nextUp = true; + for (auto c : s) { + if (nextUp) { + output += std::toupper(c); + nextUp = false; + continue; + } + if (c == '_') { + nextUp = true; + continue; + } + output += c; + } + return output; +} + +std::string toSnakeCase(std::string s) { + std::string output = ""; + output += llvm::toLower(s[0]); + auto nextUp = true; + for (auto c : s.substr(1)) { + if (llvm::isUpper(c)) { + output += '_'; + output += llvm::toLower(c); + } else + output += c; + } + return output; +} + +llvm::StringMap blacklisted_struct = { + {"StableHLO_ConvDimensionNumbers", "WIP"}, +}; + +// structure creation can fail if one of a field cannot be translated +std::optional emitStruct(llvm::Record def, std::string dialect) { + auto tableGenName = def.getName().str(); + if (auto cached = get(attributeCache, tableGenName)) + return *cached; + auto assembly = def.getValueAsOptionalString("assemblyFormat"); + + auto customAssembly = def.getValueAsBit("hasCustomAssemblyFormat"); + if (customAssembly) { + if (!assembly) { + llvm::errs() << "Custom C++ assembly for " << tableGenName << '\n'; + // custom assembly without format is a C++ custom parser/printer => must + // anyway a lot of struct have a C++ parser/printer equivalent + // to `<` struct(params) `>` + if (blacklisted_struct.contains(tableGenName)) { + attributeCache.insert({tableGenName, "Attribute"}); + llvm::errs() << "don\'t emit for this attribute" << '\n'; + return std::nullopt; + } + customAssembly = false; // hack + } else { + customAssembly = *assembly != "`<` struct(params) `>`"; + }; + } + + auto standardStructAssembly = !customAssembly; + auto mnemonic = def.getValueAsString("mnemonic").str(); + auto structName = toPascalCase(mnemonic); + auto params = def.getValueAsDag("parameters"); + auto structDef = "struct " + structName + '\n'; + auto mlirAttributeDef = "IR.Attribute(s::" + structName + + ") = parse(Attribute,\"#" + dialect + "." + mnemonic; + if (standardStructAssembly) + mlirAttributeDef.push_back('<'); + for (auto [arg, name_] : + llvm::zip(params->getArgs(), params->getArgNames())) { + auto name = toSnakeCase(name_->getAsUnquotedString()); + // auto name = standardStructAssembly ? toSnakeCase(name_) : name_; + auto sanitizedName = sanitizeName(name); + std::string cppType; + std::optional juliaType; + if (auto init = dyn_cast(arg)) { // not a cpp type + auto subDef = init->getDef(); + cppType = subDef->getValueAsString("cppType").str(); + auto type = subDef->getType()->getAsString(); + llvm::StringSwitch>(type) + .Case("APFloatParameter", [&]() { juliaType = "Float64"; }) + .Case("StringRefParameter", [&]() { juliaType = "String"; }) + .Case("EnumParameter", + [&]() { juliaType = removeNamespace(toPascalCase(cppType)); }) + .Case("ArrayRefParameter", + [&]() { + auto normalizedCppType = removeNamespace(cppType); + juliaType = cppToJuliaType(normalizedCppType); + }) + .Default([&]() { + llvm::errs() << "unknown pattern : " << type << '\n'; + })(); + } else + cppType = removeNamespace(arg->getAsUnquotedString()); + + if (!juliaType) { + if (auto juliaTypeEntry = cppToJuliaType(cppType)) + juliaType = juliaTypeEntry; + else { + llvm::errs() << cppType << '\n'; + return std::nullopt; + } + } + structDef += '\t' + sanitizedName + "::" + *juliaType + '\n'; + if (standardStructAssembly) + mlirAttributeDef += + llvm::formatv("{0} = $(c(s.{1})), ", name, sanitizedName); + } + structDef += "end"; + if (standardStructAssembly) { + mlirAttributeDef.resize(mlirAttributeDef.length() - 2); // remove , + mlirAttributeDef += ">"; + } else + mlirAttributeDef += assemblyFormatToJulia( + def.getValueAsString("assemblyFormat").str(), [](std::string name) { + return llvm::formatv( + "$(c(s.{}))", + sanitizeName( + name)); // TODO: add this function only for some args. The c + // function is here to deal with "Any[]", we want "[]" + }); + mlirAttributeDef += "\")"; + + if (auto description = def.getValueAsOptionalString("summary")) { + attribs += '\n' + formatDescription(mnemonic, description->str()) + '\n'; + } + attribs += structDef + "\n\n" + mlirAttributeDef + "\n\n"; + attributeCache.insert({tableGenName, structName}); + return structName; +} + bool emitOpTableDefs(const llvm::RecordKeeper &recordKeeper, llvm::raw_ostream &os) { + + llvm::StringMap attrMap; llvm::ArrayRef opdefs = recordKeeper.getAllDerivedDefinitionsIfDefined("Op"); + std::string moduleName; + + if (!DialectName.empty()) { + moduleName = DialectName; + } else { + moduleName = getDialectName(opdefs); + DialectName = moduleName; + } + + llvm::ArrayRef attrdefs = + recordKeeper.getAllDerivedDefinitionsIfDefined("Attr"); const char *moduleTemplate; if (disableModuleWrap) { moduleTemplate = R"(import ...IR: IR, NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API - +using EnumX {0} )"; } else { moduleTemplate = R"(module {0} using ...IR import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX {1} end # {0} )"; } - const char *functiontemplate = R"( + const char *functionTemplate = R"( {3} -function {0}({1}location=Location()) +function {0}({1}location::Location=Location()) {2} end )"; // 0: functionname, 1: functionarguments, 2: functionbody - const char *functionbodytemplate = R"(op_ty_results = IR.Type[{0}] + const char *functionBodyTemplate = R"(op_ty_results = IR.Type[{0}] operands = Value[{1}] owned_regions = Region[{2}] successors = Block[{3}] @@ -209,33 +481,25 @@ end ))"; // 0: results, 1: operands, 2: owned_regions, 3: successors, 4: // attributes, 5: optionals, 6: opname, 7: results expression, 8: // result_inference - - std::string modulecontents = ""; - - std::string modulename; - if (!DialectName.empty()) { - modulename = DialectName; - } else { - modulename = getDialectName(opdefs); - } - + + std::string moduleContents = ""; for (const auto *def : opdefs) { mlir::tblgen::Operator op(*def); - std::string operandarguments = ""; - std::string operandcontainer = ""; + std::string operandArguments = ""; + std::string operandContainer = ""; std::string optionals = ""; auto opname = op.getOperationName(); - auto functionname = opname.substr(op.getDialectName().str().length() + + + auto functionName = opname.substr(op.getDialectName().str().length() + 1); // get rid of "dialect." prefix. - functionname = sanitizeName(functionname, modulename); + functionName = sanitizeName(functionName, moduleName); std::string description = ""; - if (op.hasDescription()) { - description = "\"\"\"\n`" + functionname + "`\n" + formatDescription(op) + - "\n\"\"\""; - } + if (op.hasDescription()) + description = formatDescription(functionName, op.getDescription().str()); + bool inferrable = canInferType(op); bool alreadykeyword = @@ -243,18 +507,18 @@ end // is used to insert a single semicolon (;) instead of a comma // (,) as separator between positional and keyword arguments. for (int i = 0; i < op.getNumOperands(); i++) { - const auto &named_operand = op.getOperand(i); + const auto &namedOperand = op.getOperand(i); std::string defaultvalue = ""; - std::string operandname = named_operand.name.str(); - if (operandname.empty()) { - operandname = "operand_" + std::to_string(i); + std::string operandName = namedOperand.name.str(); + if (operandName.empty()) { + operandName = "operand_" + std::to_string(i); } - operandname = sanitizeName(operandname); + operandName = sanitizeName(operandName); std::string type = "Value"; - bool optional = named_operand.isOptional(); - bool variadic = named_operand.isVariadic(); + bool optional = namedOperand.isOptional(); + bool variadic = namedOperand.isVariadic(); if (variadic) { type = "Vector{" + type + "}"; @@ -264,7 +528,7 @@ end if (optional) { optionals += llvm::formatv(R"(!isnothing({0}) && push!(operands, {0}{1}) )", - operandname, (variadic ? "..." : "")); + operandName, (variadic ? "..." : "")); type = "Union{Nothing, " + type + "}"; defaultvalue = "=nothing"; @@ -273,15 +537,15 @@ end separator = "; "; } } else { - operandcontainer += operandname + (variadic ? "..." : "") + ", "; + operandContainer += operandName + (variadic ? "..." : "") + ", "; separator = (!alreadykeyword && i == op.getNumOperands() - 1) ? "; " : ", "; } - operandarguments += operandname + defaultvalue + "::" + type + separator; + operandArguments += operandName + "::" + type + defaultvalue + separator; } - if (operandarguments == "") { - operandarguments = "; "; + if (operandArguments == "") { + operandArguments = "; "; } if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { @@ -306,23 +570,24 @@ end operandsegmentsizes); } - std::string resultarguments = ""; - std::string resultcontainer = ""; + std::string resultArguments = ""; + std::string resultContainer = ""; for (int i = 0; i < op.getNumResults(); i++) { - const auto &named_result = op.getResult(i); + const auto &namedResult = op.getResult(i); std::string defaultvalue = ""; - std::string resultname = named_result.name.str(); + std::string resultname = namedResult.name.str(); if (resultname.empty()) { - resultname = "result_" + std::to_string(i); + resultname = + op.getNumResults() == 1 ? "result" : "result_" + std::to_string(i); } resultname = sanitizeName(resultname); std::string type = "IR.Type"; - bool optional = named_result.isOptional() || inferrable; - bool variadic = named_result.isVariadic(); + bool optional = namedResult.isOptional() || inferrable; + bool variadic = namedResult.isVariadic(); if (variadic) { - type = "Vector{" + type + "}"; + type = "Base.AbstractVecOrTuple{" + type + "}"; } if (optional) { @@ -333,108 +598,173 @@ end type = "Union{Nothing, " + type + "}"; defaultvalue = "=nothing"; } else { - resultcontainer += resultname + (variadic ? "..." : "") + ", "; + resultContainer += resultname + (variadic ? "..." : "") + ", "; } - resultarguments += resultname + defaultvalue + "::" + type + ", "; + resultArguments += resultname + "::" + type + defaultvalue + ", "; } std::string resultsexpression = - (inferrable ? "(length(op_ty_results) == 0 ? nothing : op_ty_results)" + (inferrable ? "(isempty(op_ty_results) ? nothing : op_ty_results)" : "op_ty_results"); - std::string resultinference = - (inferrable ? "(length(op_ty_results) == 0 ? true : false)" : "false"); + std::string resultInference = + (inferrable ? "isempty(op_ty_results)" : "false"); - std::string attributearguments = ""; - std::string attributecontainer = ""; + std::string attributeArguments = ""; + std::string attributeContainer = ""; for (int i = 0; i < op.getNumAttributes(); i++) { - const auto &named_attr = op.getAttribute(i); - + const auto &namedAttr = op.getAttribute(i); + auto attr = namedAttr.attr; // Derived attributes are never materialized and don't have to be // specified. - if (named_attr.attr.isDerivedAttr()) + if (attr.isDerivedAttr()) continue; - std::string defaultvalue = ""; - std::string attributename = named_attr.name.str(); - assert(!attributename.empty() && + std::string defaultValue = ""; + std::string attributeName = namedAttr.name.str(); + + assert(!attributeName.empty() && "expected NamedAttribute to have a name"); - std::string sanitizedname = sanitizeName(attributename); - bool optional = - named_attr.attr.isOptional() || named_attr.attr.hasDefaultValue(); + auto optional = attr.isOptional() || attr.hasDefaultValue(); + + std::string VarName = sanitizeName(attributeName); + std::string pushedExpression = VarName; + std::string varType = "<:Any"; + + attr = optional ? attr.getBaseAttr() : attr; + std::function closure_ = + [&closure_, &varType, &moduleName, &os](Attribute attr) -> void { + auto def = attr.getDef(); + // enum + if (attr.isSubClassOf("EnumAttr") || + attr.isSubClassOf("EnumAttrInfo")) { + + varType = emitEnum(def, moduleName); + return; + } + + // struct + if (attr.isSubClassOf("AttrDef")) { + auto structDef = emitStruct(def, moduleName); + if (structDef) + varType = *structDef; + return; + } + + if (attr.isSubClassOf("TypedArrayAttrBase")) { + auto e = attr.getDef().getValueAsDef("elementAttr"); + Attribute ArrayAttr(e); + closure_(ArrayAttr); + varType = llvm::formatv("IR.DenseAttribute{{{}}", varType); + return; + } + + // simple Attr -> Julia Type + if (auto attr_entry = cppToJuliaType(attr.getAttrDefName().str())) { + varType = *attr_entry; + return; + } + + // simple Attr using simple layout -> Julia Type + { + auto fullCppType = attr.getDef() + .getValue("returnType") + ->getValue() + ->getAsUnquotedString(); + auto cppType = removeNamespace(fullCppType); + cppType.erase(std::remove(cppType.begin(), cppType.end(), ' '), + cppType.end()); + + if (auto juliaType = cppToJuliaType(cppType, attr)) { + varType = *juliaType; + return; + } + // os << '#' << attr.getAttrDefName() << '\n'; + } + }; + closure_(attr); + + auto isAny = varType == "<:Any"; if (optional) { optionals += llvm::formatv( R"(!isnothing({0}) && push!(attributes, namedattribute("{0}", {1})) )", - attributename, sanitizedname); - defaultvalue = "=nothing"; + attributeName, pushedExpression); + defaultValue = "=nothing"; + varType = "Union{" + varType + ", Nothing}"; } else { - attributecontainer += "namedattribute(\"" + attributename + "\", " + - sanitizedname + "), "; + attributeContainer += "namedattribute(\"" + attributeName + "\", " + + pushedExpression + "), "; } - attributearguments += sanitizedname + defaultvalue + ", "; + std::string typeConstraint = " "; + if (!isAny) + typeConstraint = "::" + varType; + + attributeArguments += VarName + typeConstraint + defaultValue + ", "; } - std::string regionarguments = ""; - std::string regioncontainer = ""; + std::string regionArguments = ""; + std::string regionContainer = ""; for (size_t i = 0; i < op.getNumRegions(); i++) { - const auto &named_region = op.getRegion(i); + const auto &namedRegion = op.getRegion(i); std::string defaultvalue = ""; - std::string regionname = named_region.name.str(); - if (regionname.empty()) { - regionname = "region_" + std::to_string(i); + std::string regionName = namedRegion.name.str(); + if (regionName.empty()) { + regionName = "region_" + std::to_string(i); } - regionname = sanitizeName(regionname); + regionName = sanitizeName(regionName); std::string type = "Region"; - bool variadic = named_region.isVariadic(); + bool variadic = namedRegion.isVariadic(); if (variadic) { type = "Vector{" + type + "}"; } - regioncontainer += regionname + (variadic ? "..." : "") + ", "; - regionarguments += regionname + defaultvalue + "::" + type + ", "; + regionContainer += regionName + (variadic ? "..." : "") + ", "; + regionArguments += regionName + "::" + type + defaultvalue + ", "; } - std::string successorarguments = ""; - std::string successorcontainer = ""; + std::string successorArguments = ""; + std::string successorContainer = ""; for (size_t i = 0; i < op.getNumSuccessors(); i++) { - const auto &named_successor = op.getSuccessor(i); - std::string defaultvalue = ""; - std::string successorname = named_successor.name.str(); - if (successorname.empty()) { - successorname = "successor_" + std::to_string(i); + const auto &namedSuccessor = op.getSuccessor(i); + std::string defaultValue = ""; + std::string successorName = namedSuccessor.name.str(); + if (successorName.empty()) { + successorName = "successor_" + std::to_string(i); } - successorname = sanitizeName(successorname); + successorName = sanitizeName(successorName); std::string type = "Block"; - bool variadic = named_successor.isVariadic(); + bool variadic = namedSuccessor.isVariadic(); if (variadic) { type = "Vector{" + type + "}"; } - successorcontainer += successorname + (variadic ? "..." : "") + ", "; - successorarguments += successorname + defaultvalue + "::" + type + ", "; + successorContainer += successorName + (variadic ? "..." : "") + ", "; + successorArguments += successorName + "::" + type + defaultValue + ", "; } - std::string arguments = operandarguments + resultarguments + - attributearguments + regionarguments + - successorarguments; - std::string functionbody = - llvm::formatv(functionbodytemplate, resultcontainer, operandcontainer, - regioncontainer, successorcontainer, attributecontainer, - optionals, opname, resultsexpression, resultinference); + std::string arguments = operandArguments + resultArguments + + attributeArguments + regionArguments + + successorArguments; + std::string functionBody = + llvm::formatv(functionBodyTemplate, resultContainer, operandContainer, + regionContainer, successorContainer, attributeContainer, + optionals, opname, resultsexpression, resultInference); - modulecontents += llvm::formatv(functiontemplate, functionname, arguments, - functionbody, description); + moduleContents += llvm::formatv(functionTemplate, functionName, arguments, + functionBody, description); } + moduleContents = attribs + moduleContents; + if (disableModuleWrap) { - os << llvm::formatv(moduleTemplate, modulecontents); + os << llvm::formatv(moduleTemplate, moduleContents); } else { - os << llvm::formatv(moduleTemplate, modulename, modulecontents); + os << llvm::formatv(moduleTemplate, moduleName, moduleContents); } return false; diff --git a/docs/src/api/affine.md b/docs/src/api/affine.md index 64137c0f4..8a7afef58 100644 --- a/docs/src/api/affine.md +++ b/docs/src/api/affine.md @@ -10,3 +10,7 @@ details. ```@autodocs Modules = [Reactant.MLIR.Dialects.affine] ``` + +```@docs +Reactant.MLIR.Dialects.affine.AtomicRMWKind +``` diff --git a/docs/src/api/arith.md b/docs/src/api/arith.md index 1d9465a01..67f697a36 100644 --- a/docs/src/api/arith.md +++ b/docs/src/api/arith.md @@ -10,3 +10,11 @@ more details. ```@autodocs Modules = [Reactant.MLIR.Dialects.arith] ``` + +```@docs +Reactant.MLIR.Dialects.arith.CmpFPredicate +Reactant.MLIR.Dialects.arith.CmpIPredicate +Reactant.MLIR.Dialects.arith.FastMathFlags +Reactant.MLIR.Dialects.arith.IntegerOverflowFlags +Reactant.MLIR.Dialects.arith.RoundingMode +``` diff --git a/docs/src/api/chlo.md b/docs/src/api/chlo.md index fa96f4bae..dead09026 100644 --- a/docs/src/api/chlo.md +++ b/docs/src/api/chlo.md @@ -10,3 +10,9 @@ for more details. ```@autodocs Modules = [Reactant.MLIR.Dialects.chlo] ``` + +```@docs +Reactant.MLIR.Dialects.chlo.ComparisonDirection +Reactant.MLIR.Dialects.chlo.ComparisonType +Reactant.MLIR.Dialects.chlo.Precision +``` diff --git a/docs/src/api/enzyme.md b/docs/src/api/enzyme.md index 5ebca6993..8c505eef8 100644 --- a/docs/src/api/enzyme.md +++ b/docs/src/api/enzyme.md @@ -7,3 +7,7 @@ CollapsedDocStrings = true ```@autodocs Modules = [Reactant.MLIR.Dialects.enzyme] ``` + +```@docs +Reactant.MLIR.Dialects.enzyme.Activity +``` diff --git a/docs/src/api/gpu.md b/docs/src/api/gpu.md index 9cdf91aac..12f0572af 100644 --- a/docs/src/api/gpu.md +++ b/docs/src/api/gpu.md @@ -10,3 +10,13 @@ more details. ```@autodocs Modules = [Reactant.MLIR.Dialects.gpu] ``` + +```@docs +Reactant.MLIR.Dialects.gpu.AllReduceOperation +Reactant.MLIR.Dialects.gpu.Dimension +Reactant.MLIR.Dialects.gpu.MMAElementwiseOp +Reactant.MLIR.Dialects.gpu.Prune2To4SpMatFlag +Reactant.MLIR.Dialects.gpu.ShuffleMode +Reactant.MLIR.Dialects.gpu.SpGEMMWorkEstimationOrComputeKind +Reactant.MLIR.Dialects.gpu.TransposeMode +``` diff --git a/docs/src/api/llvm.md b/docs/src/api/llvm.md index 48a715429..6f3dc4089 100644 --- a/docs/src/api/llvm.md +++ b/docs/src/api/llvm.md @@ -10,3 +10,15 @@ more details. ```@autodocs Modules = [Reactant.MLIR.Dialects.llvm] ``` + +```@docs +Reactant.MLIR.Dialects.llvm.AsmDialect +Reactant.MLIR.Dialects.llvm.AtomicBinOp +Reactant.MLIR.Dialects.llvm.AtomicOrdering +Reactant.MLIR.Dialects.llvm.Comdat +Reactant.MLIR.Dialects.llvm.FCmpPredicate +Reactant.MLIR.Dialects.llvm.FastmathFlags +Reactant.MLIR.Dialects.llvm.ICmpPredicate +Reactant.MLIR.Dialects.llvm.UnnamedAddr +Reactant.MLIR.Dialects.llvm.Visibility +``` diff --git a/docs/src/api/mpi.md b/docs/src/api/mpi.md index 5b0570714..674a4e746 100644 --- a/docs/src/api/mpi.md +++ b/docs/src/api/mpi.md @@ -10,3 +10,8 @@ more details. ```@autodocs Modules = [Reactant.MLIR.Dialects.mpi] ``` + +```@docs +Reactant.MLIR.Dialects.mpi.MPI_ErrorClassEnum +Reactant.MLIR.Dialects.mpi.MPI_OpClassEnum +``` diff --git a/docs/src/api/nvvm.md b/docs/src/api/nvvm.md index 28169dc7a..ea800b1e4 100644 --- a/docs/src/api/nvvm.md +++ b/docs/src/api/nvvm.md @@ -10,3 +10,26 @@ more details. ```@autodocs Modules = [Reactant.MLIR.Dialects.nvvm] ``` + +```@docs +Reactant.MLIR.Dialects.nvvm.FPRoundingMode +Reactant.MLIR.Dialects.nvvm.LoadCacheModifierKind +Reactant.MLIR.Dialects.nvvm.MMAB1Op +Reactant.MLIR.Dialects.nvvm.MMAFrag +Reactant.MLIR.Dialects.nvvm.MMAIntOverflow +Reactant.MLIR.Dialects.nvvm.MMALayout +Reactant.MLIR.Dialects.nvvm.MMATypes +Reactant.MLIR.Dialects.nvvm.MemScopeKind +Reactant.MLIR.Dialects.nvvm.ProxyKind +Reactant.MLIR.Dialects.nvvm.ReduxKind +Reactant.MLIR.Dialects.nvvm.SaturationMode +Reactant.MLIR.Dialects.nvvm.SetMaxRegisterAction +Reactant.MLIR.Dialects.nvvm.SharedSpace +Reactant.MLIR.Dialects.nvvm.ShflKind +Reactant.MLIR.Dialects.nvvm.TMAReduxKind +Reactant.MLIR.Dialects.nvvm.TMAStoreMode +Reactant.MLIR.Dialects.nvvm.Tcgen05GroupKind +Reactant.MLIR.Dialects.nvvm.WGMMAScaleIn +Reactant.MLIR.Dialects.nvvm.WGMMAScaleOut +Reactant.MLIR.Dialects.nvvm.WGMMATypes +``` diff --git a/docs/src/api/shardy.md b/docs/src/api/shardy.md index 8e0192c5e..fc951f005 100644 --- a/docs/src/api/shardy.md +++ b/docs/src/api/shardy.md @@ -9,3 +9,7 @@ Refer to the [official documentation](https://openxla.org/shardy) for more detai ```@autodocs Modules = [Reactant.MLIR.Dialects.sdy] ``` + +```@docs +Reactant.MLIR.Dialects.sdy.PropagationDirection +``` diff --git a/docs/src/api/stablehlo.md b/docs/src/api/stablehlo.md index 61ebf1d45..7dbc75518 100644 --- a/docs/src/api/stablehlo.md +++ b/docs/src/api/stablehlo.md @@ -9,3 +9,14 @@ Refer to the [official documentation](https://openxla.org/stablehlo) for more de ```@autodocs Modules = [Reactant.MLIR.Dialects.stablehlo] ``` + +```@docs +Reactant.MLIR.Dialects.stablehlo.ComparisonDirection +Reactant.MLIR.Dialects.stablehlo.ComparisonType +Reactant.MLIR.Dialects.stablehlo.CustomCallApiVersion +Reactant.MLIR.Dialects.stablehlo.FftType +Reactant.MLIR.Dialects.stablehlo.Precision +Reactant.MLIR.Dialects.stablehlo.RngAlgorithm +Reactant.MLIR.Dialects.stablehlo.RngDistribution +Reactant.MLIR.Dialects.stablehlo.Transpose +``` diff --git a/docs/src/api/tpu.md b/docs/src/api/tpu.md index 9494cd965..727fc5125 100644 --- a/docs/src/api/tpu.md +++ b/docs/src/api/tpu.md @@ -10,3 +10,11 @@ more details. ```@autodocs Modules = [Reactant.MLIR.Dialects.tpu] ``` + +```@docs +Reactant.MLIR.Dialects.tpu.ContractPrecision +Reactant.MLIR.Dialects.tpu.CoreType +Reactant.MLIR.Dialects.tpu.PackFormat +Reactant.MLIR.Dialects.tpu.ReductionKind +Reactant.MLIR.Dialects.tpu.RoundingMode +``` diff --git a/docs/src/api/triton.md b/docs/src/api/triton.md index fdfb9654a..fa3697f90 100644 --- a/docs/src/api/triton.md +++ b/docs/src/api/triton.md @@ -10,3 +10,17 @@ more details. ```@autodocs Modules = [Reactant.MLIR.Dialects.tt] ``` + +```@docs +Reactant.MLIR.Dialects.tt.CacheModifier +Reactant.MLIR.Dialects.tt.EvictionPolicy +Reactant.MLIR.Dialects.tt.InputPrecision +Reactant.MLIR.Dialects.tt.MemSemantic +Reactant.MLIR.Dialects.tt.MemSyncScope +Reactant.MLIR.Dialects.tt.PaddingOption +Reactant.MLIR.Dialects.tt.ProgramIDDim +Reactant.MLIR.Dialects.tt.PropagateNan +Reactant.MLIR.Dialects.tt.RMWOp +Reactant.MLIR.Dialects.tt.RoundingMode +Reactant.MLIR.Dialects.tt.ScaleDotElemType +``` diff --git a/ext/ReactantAbstractFFTsExt.jl b/ext/ReactantAbstractFFTsExt.jl index 52f504f4a..0578dcd95 100644 --- a/ext/ReactantAbstractFFTsExt.jl +++ b/ext/ReactantAbstractFFTsExt.jl @@ -1,5 +1,5 @@ module ReactantAbstractFFTsExt - +using Reactant.MLIR.Dialects: stablehlo using AbstractFFTs: AbstractFFTs using Reactant: Reactant, MLIR, Ops, TracedRArray @@ -31,58 +31,62 @@ function compute_correct_pdims(x::AbstractArray, dims) end end -for op in (:rfft, :fft, :ifft) - mode = uppercase(string(op)) - @eval function AbstractFFTs.$(op)(x::TracedRArray, dims) +for op in (stablehlo.FftType.RFFT, stablehlo.FftType.FFT, stablehlo.FftType.IFFT) + name = Symbol(lowercase(string(op))) + @eval function AbstractFFTs.$(name)(x::TracedRArray, dims) @assert maximum(dims) ≤ ndims(x) "dims out of range" if dims isa Integer if dims != 1 pdims = compute_correct_pdims(x, dims) return permutedims( - AbstractFFTs.$(op)(permutedims(x, pdims), 1), invperm(pdims) + AbstractFFTs.$(name)(permutedims(x, pdims), 1), invperm(pdims) ) end - return generalized_fft(x, $(mode), nothing, length(dims)) + return generalized_fft(x, $(op), nothing, length(dims)) end if !check_contiguous_innermost_dims(dims, ndims(x)) pdims = compute_correct_pdims(x, dims) return permutedims( - AbstractFFTs.$(op)(permutedims(x, pdims), 1:length(dims)), invperm(pdims) + AbstractFFTs.$(name)(permutedims(x, pdims), 1:length(dims)), invperm(pdims) ) end - return generalized_fft(x, $(mode), nothing, length(dims)) + return generalized_fft(x, $(op), nothing, length(dims)) end end -for op in (:irfft,) - mode = uppercase(string(op)) - @eval function AbstractFFTs.$(op)(x::TracedRArray, d::Int, dims) +for op in (stablehlo.FftType.IRFFT,) + name = Symbol(lowercase(string(op))) + @eval function AbstractFFTs.$(name)(x::TracedRArray, d::Int, dims) @assert maximum(dims) ≤ ndims(x) "dims out of range" if dims isa Integer if dims != 1 pdims = compute_correct_pdims(x, dims) return permutedims( - AbstractFFTs.$(op)(permutedims(x, pdims), d, 1), invperm(pdims) + AbstractFFTs.$(name)(permutedims(x, pdims), d, 1), invperm(pdims) ) end - return generalized_fft(x, $(mode), d, length(dims)) + return generalized_fft(x, $(op), d, length(dims)) end if !check_contiguous_innermost_dims(dims, ndims(x)) pdims = compute_correct_pdims(x, dims) return permutedims( - AbstractFFTs.$(op)(permutedims(x, pdims), d, 1:length(dims)), invperm(pdims) + AbstractFFTs.$(name)(permutedims(x, pdims), d, 1:length(dims)), + invperm(pdims), ) end - return generalized_fft(x, $(mode), d, length(dims)) + return generalized_fft(x, $(op), d, length(dims)) end end -function generalized_fft(x::TracedRArray{T,N}, mode::String, d, first_n::Int) where {T,N} +function generalized_fft( + x::TracedRArray{T,N}, mode::stablehlo.FftType.T, d, first_n::Int +) where {T,N} if d === nothing - @assert mode ∈ ("RFFT", "FFT", "IFFT") + @assert mode ∈ + (stablehlo.FftType.RFFT, stablehlo.FftType.FFT, stablehlo.FftType.IFFT) fft_length = [size(x, i) for i in 1:first_n] else - @assert mode == "IRFFT" + @assert mode == stablehlo.FftType.IRFFT fft_length = [i == 1 ? d : size(x, i) for i in 1:first_n] end diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index bd8e854fa..bf0786e4e 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -833,7 +833,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( wrapfunc = MLIR.IR.block!(MLIR.IR.body(mod)) do return MLIR.Dialects.llvm.func(; sym_name, - sym_visibility=MLIR.IR.Attribute("private"), + sym_visibility="private", function_type=wrapftype, body=MLIR.IR.Region(), CConv, @@ -889,10 +889,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( 1, ) alloc = MLIR.IR.result( - MLIR.Dialects.llvm.alloca( - c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr - ), - 1, + MLIR.Dialects.llvm.alloca(c1; elem_type=argty, res=llvmptr), 1 ) push!(allocs, (alloc, argty)) @@ -953,7 +950,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( MLIR.IR.Value[]; res=llvmptr, elem_type=i8, - rawConstantIndices=MLIR.IR.Attribute([Int32(offset)]), + rawConstantIndices=[Int32(offset)], ), 1, ) @@ -976,13 +973,11 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), - op_bundle_sizes=MLIR.IR.Attribute(Int32[]), + op_bundle_sizes=Int32[], ) MLIR.Dialects.llvm.return_(nothing) end - output_operand_aliases = MLIR.IR.Attribute(aliases) - blk_operands = MLIR.IR.Value[] for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z, shmem) @@ -997,9 +992,9 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( call = MLIR.Dialects.enzymexla.kernel_call( blk_operands..., mlir_args; - result_0=restys, + result=restys, fn=MLIR.IR.FlatSymbolRefAttribute(sym_name), - output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases), + output_operand_aliases=aliases, ) argidx = 1 diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index ee00463e2..fee42f88a 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -3,7 +3,7 @@ module ReactantNNlibExt using NNlib using GPUArraysCore: @allowscalar using Reactant: Reactant, Ops, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber - +using Reactant.MLIR.Dialects: stablehlo using Reactant.TracedUtils: TracedUtils, materialize_traced_array, get_mlir_data, set_mlir_data! @@ -94,7 +94,7 @@ function NNlib.conv!( Int64(output_batch_dim - 1), Int64(output_feature_dim - 1), length(output_spatial_dims), Int64[i - 1 for i in output_spatial_dims], - ) + )#TODO:deal with this using a custom parser in julia code generation #! format: on padding = Reactant.MLIR.IR.DenseElementsAttribute( @@ -110,11 +110,11 @@ function NNlib.conv!( conv = Reactant.MLIR.Dialects.stablehlo.convolution( get_mlir_data(x), get_mlir_data(weight); - result_0=result_type, + result=result_type, window_strides=collect(stride), padding, dimension_numbers, - lhs_dilation=1, + lhs_dilation=[1 for _ in dilation], rhs_dilation=collect(dilation), feature_group_count, batch_group_count=1, @@ -176,13 +176,14 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N} end attr = fill(Reactant.MLIR.IR.Attribute(init), unranked) + init_value = Reactant.MLIR.IR.result( Reactant.MLIR.Dialects.stablehlo.constant(; value=attr) ) reduction = Reactant.MLIR.Dialects.stablehlo.reduce_window( [get_mlir_data(x)], [init_value]; - result_0=[result_type], + result=[result_type], window_dimensions, window_strides, window_dilations, @@ -415,7 +416,7 @@ function NNlib.∇conv_filter!( conv = MLIR.Dialects.stablehlo.convolution( get_mlir_data(x), get_mlir_data(dy); - result_0=result_type, + result=result_type, window_strides=collect(dilation), padding, dimension_numbers, @@ -532,8 +533,8 @@ function NNlib.∇conv_data!( conv = MLIR.Dialects.stablehlo.convolution( get_mlir_data(dy), get_mlir_data(w); - result_0=result_type, - window_strides=1, + result=result_type, + window_strides=[1 for _ in dilation], padding, lhs_dilation=collect(stride), rhs_dilation=collect(dilation), diff --git a/ext/ReactantRandom123Ext.jl b/ext/ReactantRandom123Ext.jl index d701fdc7e..9c576b318 100644 --- a/ext/ReactantRandom123Ext.jl +++ b/ext/ReactantRandom123Ext.jl @@ -2,10 +2,11 @@ module ReactantRandom123Ext using Random123: Threefry4x, Threefry2x, Philox4x, Philox2x using Reactant: TracedRandom +using Reactant.MLIR.Dialects: stablehlo -TracedRandom.rng_algorithm(::Threefry4x) = "THREE_FRY" -TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY" -TracedRandom.rng_algorithm(::Philox4x) = "PHILOX" -TracedRandom.rng_algorithm(::Philox2x) = "PHILOX" +TracedRandom.rng_algorithm(::Threefry4x) = stablehlo.RngAlgorithm.THREE_FRY +TracedRandom.rng_algorithm(::Threefry2x) = stablehlo.RngAlgorithm.THREE_FRY +TracedRandom.rng_algorithm(::Philox4x) = stablehlo.RngAlgorithm.PHILOX +TracedRandom.rng_algorithm(::Philox2x) = stablehlo.RngAlgorithm.PHILOX end diff --git a/src/Interpreter.jl b/src/Interpreter.jl index f3267ed35..61646904a 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -369,20 +369,14 @@ function overload_autodiff( end end - function act_attr(val) - val = @ccall MLIR.API.mlir_c.enzymeActivityAttrGet( - MLIR.IR.context()::MLIR.API.MlirContext, val::Int32 - )::MLIR.API.MlirAttribute - return MLIR.IR.Attribute(val) - end fname = TracedUtils.get_attribute_by_name(func2, "sym_name") fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)( [TracedUtils.transpose_val(v) for v in ad_inputs]; outputs=outtys, fn=fname, - activity=MLIR.IR.Attribute([act_attr(a) for a in activity]), - ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]), + activity=[MLIR.Dialects.enzyme.Activity.T(a) for a in activity], + ret_activity=[MLIR.Dialects.enzyme.Activity.T(a) for a in ret_activity], ) residx = 1 diff --git a/src/Ops.jl b/src/Ops.jl index a873de4ca..9fc74b1b1 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -109,32 +109,17 @@ function fill(number::TracedRNumber{T}, shape::Vector{Int}; location) where {T} return Base.fill(number, Tuple(shape)) end -for (T, mlir_func) in ( - (Bool, :mlirDenseElementsAttrBoolSplatGet), - (UInt8, :mlirDenseElementsAttrUInt8SplatGet), - (Int8, :mlirDenseElementsAttrInt8SplatGet), - (UInt32, :mlirDenseElementsAttrUInt32SplatGet), - (Int32, :mlirDenseElementsAttrInt32SplatGet), - (UInt64, :mlirDenseElementsAttrUInt64SplatGet), - (Int64, :mlirDenseElementsAttrInt64SplatGet), - (Float32, :mlirDenseElementsAttrFloatSplatGet), - (Float64, :mlirDenseElementsAttrDoubleSplatGet), +@noinline function fill( + number::Union{Bool,UInt8,Int8,UInt32,Int32,UInt64,Int64,Float32,Float64}, + shape::Vector{Int}; + location=mlir_stacktrace("fill", @__FILE__, @__LINE__), ) - @eval begin - @noinline function fill( - number::$T, - shape::Vector{Int}; - location=mlir_stacktrace("fill", @__FILE__, @__LINE__), - ) - tt = MLIR.IR.TensorType(shape, MLIR.IR.Type($T); location=location) - - splatattr = MLIR.API.$mlir_func(tt, number) - cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location) - cst = MLIR.IR.result(cst_op) - ta = TracedRArray{$T,length(shape)}((), cst, shape) - return ta - end - end + T = typeof(number) + tt = MLIR.IR.TensorType(shape, MLIR.IR.Type(T); location=location) + cst_op = stablehlo.constant(; output=tt, value=Base.fill(number, tt), location=location) + cst = MLIR.IR.result(cst_op) + ta = TracedRArray{T,length(shape)}((), cst, shape) + return ta end _fill_element_attr(x) = MLIR.IR.Attribute(x) @@ -148,7 +133,9 @@ end element::T, shape::Vector{Int}; location=mlir_stacktrace("fill", @__FILE__, @__LINE__) ) where {T} tt = MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) - splatattr = MLIR.API.mlirDenseElementsAttrSplatGet(tt, _fill_element_attr(element)) + splatattr = MLIR.IR.DenseElementsAttribute( + MLIR.API.mlirDenseElementsAttrSplatGet(tt, _fill_element_attr(element)) + ) cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location) cst = MLIR.IR.result(cst_op) ta = TracedRArray{T,length(shape)}((), cst, shape) @@ -364,7 +351,7 @@ end # HLO reshape semantics collapse the opposite way res1 = transpose(x, Int64[N:-1:1...]) restype = mlir_type(TracedRArray{T,length(dims)}, collect(Base.reverse(dims))) - res = MLIR.IR.result(stablehlo.reshape(res1.mlir_data; result_0=restype, location)) + res = MLIR.IR.result(stablehlo.reshape(res1.mlir_data; result=restype, location)) result = TracedRArray{T,length(dims)}((), res, collect(Base.reverse(dims))) # NOTE this last `transpose` is required for consistency with Julia's column-major order # do not remove, as it will be optimized away by the compiler @@ -376,10 +363,10 @@ end dim; location=mlir_stacktrace("get_dimension_size", @__FILE__, @__LINE__), ) where {T,N} - dimension = MLIR.IR.Attribute(dim - 1) + dimension = dim - 1 res = MLIR.IR.result( stablehlo.get_dimension_size( - x.mlir_data; result_0=mlir_type(TracedRArray{Int32,0}, ()), dimension, location + x.mlir_data; result=mlir_type(TracedRArray{Int32,0}, ()), dimension, location ), ) return TracedRNumber{Int32}((), res) @@ -391,7 +378,7 @@ end dim::Int; location=mlir_stacktrace("set_dimension_size", @__FILE__, @__LINE__), ) where {T,N} - dimension = MLIR.IR.Attribute(dim - 1) + dimension = dim - 1 res = MLIR.IR.result( stablehlo.set_dimension_size( x.mlir_data, @@ -412,7 +399,6 @@ end rsize = permute!(collect(size(x)), permutation) permutation = permutation .- 1 result = mlir_type(TracedRArray{T,N}, rsize) - permutation = MLIR.IR.DenseArrayAttribute(permutation) res = MLIR.IR.result(stablehlo.transpose(x.mlir_data; result, permutation, location)) return TracedRArray{T,N}((), res, rsize) end @@ -431,9 +417,9 @@ end stablehlo.pad( x.mlir_data, padding_value.mlir_data; - edge_padding_low=MLIR.IR.DenseArrayAttribute(low), - edge_padding_high=MLIR.IR.DenseArrayAttribute(high), - interior_padding=MLIR.IR.DenseArrayAttribute(interior), + edge_padding_low=low, + edge_padding_high=high, + interior_padding=interior, location, ), ) @@ -455,10 +441,10 @@ end res = MLIR.IR.result( stablehlo.slice( x.mlir_data; - result_0=mlir_type(TracedRArray{T,N}, rsize), - start_indices=MLIR.IR.DenseArrayAttribute(start_indices), - limit_indices=MLIR.IR.DenseArrayAttribute(limit_indices), - strides=MLIR.IR.DenseArrayAttribute(strides), + result=mlir_type(TracedRArray{T,N}, rsize), + start_indices, + limit_indices, + strides, location, ), ) @@ -555,7 +541,7 @@ end ) where {T,U} res = MLIR.IR.result( stablehlo.bitcast_convert( - x.mlir_data; result_0=mlir_type(TracedRArray{U,0}, ()), location + x.mlir_data; result=mlir_type(TracedRArray{U,0}, ()), location ), ) return TracedRNumber{U}((), res) @@ -563,24 +549,24 @@ end @noinline function fft( x::TracedRArray{T,N}; - type::String, + type::stablehlo.FftType.T, length, location=mlir_stacktrace("fft", @__FILE__, @__LINE__), ) where {T,N} @assert 1 <= Base.length(length) <= 3 "fft only supports up to rank 3" - if type ∈ ("FFT", "IFFT") + if type ∈ (stablehlo.FftType.FFT, stablehlo.FftType.IFFT) @assert T <: Complex Tout = T rsize = size(x) - elseif type == "RFFT" + elseif type == stablehlo.FftType.RFFT @assert T <: Real Tout = Complex{T} rsize = let rsize = collect(size(x)) rsize[end] = rsize[end] == 0 ? 0 : rsize[end] ÷ 2 + 1 Tuple(rsize) end - elseif type == "IRFFT" + elseif type == stablehlo.FftType.IRFFT @assert T <: Complex Tout = Base.real(T) rsize = let rsize = collect(size(x)) @@ -594,9 +580,9 @@ end res = MLIR.IR.result( stablehlo.fft( x.mlir_data; - result_0=mlir_type(TracedRArray{Tout,N}, rsize), - fft_type=MLIR.API.stablehloFftTypeAttrGet(MLIR.IR.context(), type), - fft_length=MLIR.IR.DenseArrayAttribute(length), + result=mlir_type(TracedRArray{Tout,N}, rsize), + fft_type=type, + fft_length=length, location, ), ) @@ -608,7 +594,6 @@ end lower::Bool=false, location=mlir_stacktrace("cholesky", @__FILE__, @__LINE__), ) where {T,N} - lower = MLIR.IR.Attribute(lower) res = MLIR.IR.result( stablehlo.cholesky( x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), lower, location @@ -767,29 +752,13 @@ end lhs_contracting_dimensions = lhs_contracting_dimensions .- 1 rhs_contracting_dimensions = rhs_contracting_dimensions .- 1 - dot_dimension_numbers = GC.@preserve lhs_contracting_dimensions rhs_contracting_dimensions lhs_batching_dimensions rhs_batching_dimensions begin - MLIR.IR.Attribute( - MLIR.API.stablehloDotDimensionNumbersGet( - ctx, - length(lhs_batching_dimensions), - lhs_batching_dimensions, - length(rhs_batching_dimensions), - rhs_batching_dimensions, - length(lhs_contracting_dimensions), - lhs_contracting_dimensions, - length(rhs_contracting_dimensions), - rhs_contracting_dimensions, - ), - ) - end - - if !isnothing(precision_config) - precision_config = MLIR.IR.Attribute([ - MLIR.API.stablehloPrecisionAttrGet(ctx, precision_config[1]), - MLIR.API.stablehloPrecisionAttrGet(ctx, precision_config[2]), - ]) - end - + dot_dimension_numbers = stablehlo.Dot( + lhs_batching_dimensions, + rhs_batching_dimensions, + lhs_contracting_dimensions, + rhs_contracting_dimensions, + ) + algorithm = nothing # all or nothing: if one is set, all must be set # TODO maybe be more flexible, by setting some defaults? if any( @@ -814,29 +783,22 @@ end ) lhs_precision_type, rhs_precision_type = precision_type lhs_component_count, rhs_component_count = component_count - algorithm = GC.@preserve begin - MLIR.IR.Attribute( - MLIR.API.stablehloDotAlgorithmGet( - ctx, - lhs_precision_type, - rhs_precision_type, - accumulation_type, - lhs_component_count, - rhs_component_count, - num_primitive_operations, - allow_imprecise_accumulation, - ), - ) - end - else - algorithm = nothing + algorithm = stablehlo.DotAlgorithm( + lhs_precision_type, + rhs_precision_type, + accumulation_type, + lhs_component_count, + rhs_component_count, + num_primitive_operations, + allow_imprecise_accumulation, + ) end res = MLIR.IR.result( stablehlo.dot_general( lhs.mlir_data, rhs.mlir_data; - result_0=mlir_type(TracedRArray{T,length(ressize)}, ressize), + result=mlir_type(TracedRArray{T,length(ressize)}, ressize), dot_dimension_numbers, precision_config, algorithm, @@ -865,15 +827,11 @@ end end rsize = Tuple(sizes[i] for i in ic) - result_0 = mlir_type(TracedRArray{T,length(ic)}, rsize) + result = mlir_type(TracedRArray{T,length(ic)}, rsize) res = MLIR.IR.result( stablehlo.einsum( - lhs.mlir_data, - rhs.mlir_data; - result_0, - einsum_config=MLIR.IR.Attribute(equation), - location, + lhs.mlir_data, rhs.mlir_data; result, einsum_config=equation, location ), ) return TracedRArray{T,length(rsize)}((), res, rsize) @@ -889,11 +847,11 @@ end # ia, ic = split(equation, "->") # sizes = Dict(c => d for (c, d) in zip(ia, size(x))) # rsize = Tuple(sizes[i] for i in ic) -# result_0 = mlir_type(TracedRArray{T,length(ic)}, rsize) +# result = mlir_type(TracedRArray{T,length(ic)}, rsize) # res = MLIR.IR.result( # stablehlo.unary_einsum( -# x.mlir_data; result_0, einsum_config=MLIR.IR.Attribute(equation), location +# x.mlir_data; result, einsum_config=MLIR.IR.Attribute(equation), location # ), # ) # if length(rsize) == 0 @@ -1000,12 +958,10 @@ end else MLIR.IR.Attribute(is_host_transfer) end - result_0 = map(results) do (typ, shape) + result = map(results) do (typ, shape) MLIR.IR.TensorType(shape, mlir_type(typ)) end - op = stablehlo.recv( - token.mlir_data; result_0, channel_handle, is_host_transfer, location - ) + op = stablehlo.recv(token.mlir_data; result, channel_handle, is_host_transfer, location) return tuple( map(enumerate(results)) do (i, (typ, shape)) typ = MLIR.IR.TensorType(shape, mlir_type(typ)) @@ -1032,8 +988,8 @@ function broadcast_in_dim( res = MLIR.IR.result( stablehlo.broadcast_in_dim( x.mlir_data; - result_0=MLIR.IR.TensorType(result_size, MLIR.IR.Type(T)), - broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims .- 1), + result=MLIR.IR.TensorType(result_size, MLIR.IR.Type(T)), + broadcast_dimensions=dims .- 1, location, ), ) @@ -1051,8 +1007,8 @@ function broadcast_in_dim( res = MLIR.IR.result( stablehlo.broadcast_in_dim( x.mlir_data; - result_0=MLIR.IR.TensorType(result_size, MLIR.IR.Type(T)), - broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims .- 1), + result=MLIR.IR.TensorType(result_size, MLIR.IR.Type(T)), + broadcast_dimensions=dims .- 1, location, ), ) @@ -1102,12 +1058,11 @@ end MLIR.API.mlirRegionTakeBody(comparator, MLIR.IR.region(func, 1)) MLIR.IR.rmfromparent!(func) - dimension = MLIR.IR.Attribute(dimension - 1) - is_stable = MLIR.IR.Attribute(is_stable) + dimension = dimension - 1 op = stablehlo.sort( [x.mlir_data for x in xs]; - result_0=[mlir_type(typeof(x), size(x)) for x in xs], + result=[mlir_type(typeof(x), size(x)) for x in xs], dimension, is_stable, comparator, @@ -1161,7 +1116,7 @@ end ) N = length(shape) output = mlir_type(TracedRArray{T,N}, shape) - iota_dimension = MLIR.IR.Attribute(iota_dimension - 1) + iota_dimension = iota_dimension - 1 res = MLIR.IR.result(stablehlo.iota(; output, iota_dimension, location)) return TracedRArray{T,N}((), res, shape) end @@ -1175,7 +1130,7 @@ end stablehlo.reverse( x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), - dimensions=MLIR.IR.DenseArrayAttribute(collect(dimensions .- 1)), + dimensions=collect(dimensions .- 1), location, ), ) @@ -1188,7 +1143,7 @@ end ::Type{T}, seed::TracedRArray{UInt64,1}, shape; - algorithm::String="DEFAULT", + algorithm::stablehlo.RngAlgorithm.T=stablehlo.RngAlgorithm.DEFAULT, location=mlir_stacktrace("rand", @__FILE__, @__LINE__), ) @@ -1210,21 +1165,19 @@ distribution between 0 and 1. Returns a NamedTuple with the following fields: ::Type{T}, seed::TracedRArray{UInt64,1}, shape; - algorithm::String="DEFAULT", + algorithm::stablehlo.RngAlgorithm.T=stablehlo.RngAlgorithm.DEFAULT, location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__), ) where {T<:Integer} - @assert algorithm in ("DEFAULT", "PHILOX", "THREE_FRY") - if algorithm == "PHILOX" + if algorithm == stablehlo.RngAlgorithm.PHILOX @assert length(seed) ∈ (2, 3) - elseif algorithm == "THREE_FRY" + elseif algorithm == stablehlo.RngAlgorithm.THREE_FRY @assert length(seed) == 2 end output = MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) output_state = MLIR.IR.TensorType(size(seed), MLIR.IR.Type(UInt64)) - rng_algorithm = MLIR.API.stablehloRngAlgorithmAttrGet(MLIR.IR.context(), algorithm) op = stablehlo.rng_bit_generator( - seed.mlir_data; output, output_state, rng_algorithm, location + seed.mlir_data; output, output_state, rng_algorithm=algorithm, location ) return (; output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), size(seed)), @@ -1236,7 +1189,7 @@ end ::Type{T}, seed::TracedRArray{UInt64,1}, shape; - algorithm::String="DEFAULT", + algorithm::stablehlo.RngAlgorithm.T=stablehlo.RngAlgorithm.DEFAULT, location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__), ) where {T<:AbstractFloat} nbits = sizeof(T) * 8 @@ -1254,7 +1207,7 @@ end ::Type{T}, seed::TracedRArray{UInt64,1}, shape; - algorithm::String="DEFAULT", + algorithm::stablehlo.RngAlgorithm.T=stablehlo.RngAlgorithm.DEFAULT, location=mlir_stacktrace("rand", @__FILE__, @__LINE__), ) @@ -1277,7 +1230,7 @@ fields: ::Type{T}, seed::TracedRArray{UInt64,1}, shape; - algorithm::String="DEFAULT", + algorithm::stablehlo.RngAlgorithm.T=stablehlo.RngAlgorithm.DEFAULT, location=mlir_stacktrace("rand", @__FILE__, @__LINE__), ) where {T} res = rng_bit_generator(T, seed, shape; algorithm, location) @@ -1297,7 +1250,7 @@ end ::Type{T}, seed::TracedRArray{UInt64,1}, shape; - algorithm::String="DEFAULT", + algorithm::stablehlo.RngAlgorithm.T=stablehlo.RngAlgorithm.DEFAULT, location=mlir_stacktrace("rand", @__FILE__, @__LINE__), ) @@ -1319,7 +1272,7 @@ distribution with rate 1. Returns a NamedTuple with the following fields: ::Type{T}, seed::TracedRArray{UInt64,1}, shape; - algorithm::String="DEFAULT", + algorithm::stablehlo.RngAlgorithm.T=stablehlo.RngAlgorithm.DEFAULT, location=mlir_stacktrace("rand", @__FILE__, @__LINE__), ) where {T} res = rng_bit_generator(T, seed, shape; algorithm, location) @@ -1375,22 +1328,15 @@ end @noinline function compare( lhs::AT, rhs::AT; - comparison_direction::String, + comparison_direction::stablehlo.ComparisonDirection.T, compare_type=nothing, location=mlir_stacktrace("compare", @__FILE__, @__LINE__), ) where {AT<:Union{TracedRArray,TracedRNumber}} - @assert comparison_direction in ("EQ", "NE", "GE", "GT", "LE", "LT") @assert size(lhs) == size(rhs) res = MLIR.IR.result( stablehlo.compare( - lhs.mlir_data, - rhs.mlir_data; - comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( - MLIR.IR.context(), comparison_direction - ), - compare_type, - location, + lhs.mlir_data, rhs.mlir_data; comparison_direction, compare_type, location ), 1, ) @@ -1533,7 +1479,7 @@ julia> Reactant.@jit( operands = [a.mlir_data for a in args] call = MLIR.Dialects.func.call( operands; - result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)], + result=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)], callee=MLIR.IR.FlatSymbolRefAttribute(name_to_call), location, ) @@ -1588,13 +1534,12 @@ instead. scatter_dims_to_operand_dims = collect(Int64, 0:(N - 1)) index_vector_dim = Int64(1) - scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet( - MLIR.IR.context(), - length(update_window_dims), update_window_dims, - length(inserted_window_dims), inserted_window_dims, - length(input_batching_dims), input_batching_dims, - length(scatter_indices_batching_dims), scatter_indices_batching_dims, - length(scatter_dims_to_operand_dims), scatter_dims_to_operand_dims, + scatter_dimension_numbers = stablehlo.Scatter( + update_window_dims, + inserted_window_dims, + input_batching_dims, + scatter_indices_batching_dims, + scatter_dims_to_operand_dims, index_vector_dim, ) #! format: on @@ -1606,7 +1551,7 @@ instead. [dest.mlir_data], scatter_indices.mlir_data, [updates.mlir_data]; - result_0=[mlir_type(TracedRArray{T,N}, size(dest))], + result=[mlir_type(TracedRArray{T,N}, size(dest))], update_computation, scatter_dimension_numbers, ), @@ -1636,14 +1581,13 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead. start_index_map = collect(Int64, 0:(N - 1)) index_vector_dim = Int64(1) - dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( - MLIR.IR.context(), - Int64(length(offset_dims)), offset_dims, - Int64(length(collapsed_slice_dims)), collapsed_slice_dims, - Int64(length(operand_batching_dims)), operand_batching_dims, - Int64(length(start_indices_batching_dims)), start_indices_batching_dims, - Int64(length(start_index_map)), start_index_map, - Int64(index_vector_dim), + dimension_numbers = stablehlo.Gather( + offset_dims, + collapsed_slice_dims, + operand_batching_dims, + start_indices_batching_dims, + start_index_map, + index_vector_dim ) #! format: on @@ -1717,7 +1661,7 @@ end while_op = MLIR.Dialects.stablehlo.while_( MLIR.IR.Value[Reactant.TracedUtils.get_mlir_data(arg) for arg in linear_args]; - result_0=input_types, + result=input_types, cond=cond_reg, body=body_reg, ) @@ -1766,7 +1710,7 @@ end ] input_types = [mlir_type(arg) for arg in tb_linear_args] - sym_visibility = MLIR.IR.Attribute("private") + sym_visibility = "private" # compile the true branch without any returns first true_fn_mod = MLIR.IR.mmodule() @@ -2045,7 +1989,7 @@ end MLIR.IR.rmfromparent!(false_fn_compiled) if_compiled = MLIR.Dialects.stablehlo.if_( - cond.mlir_data; true_branch=tb_region, false_branch=fb_region, result_0=result_types + cond.mlir_data; true_branch=tb_region, false_branch=fb_region, result=result_types ) corrected_traced_results = fmap(traced_false_results, traced_true_results) do fr, tr @@ -2113,7 +2057,7 @@ end call_op = MLIR.Dialects.func.call( mlir_caller_args; - result_0=mlir_result_types, + result=mlir_result_types, callee=MLIR.IR.FlatSymbolRefAttribute(f_name), ) diff --git a/src/Overlay.jl b/src/Overlay.jl index 5d9b85c83..e437a467e 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -38,7 +38,8 @@ end @reactant_overlay @noinline function TracedRandom.default_rng() return TracedRNG( - TracedUtils.promote_to(TracedRArray{UInt64,1}, TracedRandom.make_seed()), "DEFAULT" + TracedUtils.promote_to(TracedRArray{UInt64,1}, TracedRandom.make_seed()), + stablehlo.RngAlgorithm.DEFAULT, ) end diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index c2e2e0cb9..d049c1c4d 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -73,7 +73,7 @@ function Base.getindex( ) res2 = MLIR.IR.result( MLIR.Dialects.stablehlo.reshape( - res1; result_0=MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(res1))) + res1; result=MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(res1))) ), 1, ) @@ -542,9 +542,7 @@ function Base.mapreduce( body = MLIR.IR.Region() push!(body, fnbody) - red = MLIR.Dialects.stablehlo.reduce( - inp, init; result_0=TT, dimensions=MLIR.IR.DenseArrayAttribute(rdims), body - ) + red = MLIR.Dialects.stablehlo.reduce(inp, init; result=TT, dimensions=rdims, body) red = MLIR.IR.result(red, 1) redT = eltype(MLIR.IR.julia_type(MLIR.IR.type(red))) @@ -770,7 +768,7 @@ function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} # TODO maybe we should do some conversion? MLIR.Dialects.stablehlo.concatenate( collect(TracedUtils.get_mlir_data.(X)); - result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)), + result=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)), dimension=dims - 1, # stablehlo expects this to be zero-indexed ), 1, diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index e7410a0b4..6d608cd75 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -1,7 +1,14 @@ module TracedRNumberOverrides using ..Reactant: - Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype + Reactant, + TracedRNumber, + TracedRArray, + TracedUtils, + Ops, + MLIR, + unwrapped_eltype, + MLIR.Dialects.stablehlo.ComparisonDirection using ReactantCore ReactantCore.is_traced(::TracedRNumber) = true @@ -149,13 +156,13 @@ function Base.:/( end for (jlop, hloop, hlocomp) in ( - (:(Base.:(==)), :compare, "EQ"), - (:(Base.:(!=)), :compare, "NE"), - (:(Base.:(>=)), :compare, "GE"), - (:(Base.:(>)), :compare, "GT"), - (:(Base.:(<=)), :compare, "LE"), - (:(Base.:(<)), :compare, "LT"), - (:(Base.isless), :compare, "LT"), + (:(Base.:(==)), :compare, ComparisonDirection.EQ), + (:(Base.:(!=)), :compare, ComparisonDirection.NE), + (:(Base.:(>=)), :compare, ComparisonDirection.GE), + (:(Base.:(>)), :compare, ComparisonDirection.GT), + (:(Base.:(<=)), :compare, ComparisonDirection.LE), + (:(Base.:(<)), :compare, ComparisonDirection.LT), + (:(Base.isless), :compare, ComparisonDirection.LT), ) @eval begin function $(jlop)( diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index b22e61dc3..6acd9561a 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -125,8 +125,8 @@ end function transpose_val(val) val_size = size(MLIR.IR.type(val)) val_size == () && return val - attr = MLIR.IR.DenseArrayAttribute(Int64[reverse(0:(length(val_size) - 1))...]) - return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) + permutation = Int64[reverse(0:(length(val_size) - 1))...] + return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation), 1) end mutable struct CompiledMlirFnResult{ @@ -210,10 +210,7 @@ function make_mlir_fn( [Ops.mlir_type(arg) for arg in linear_args] end - sym_visibility = nothing - if !concretein - sym_visibility = MLIR.IR.Attribute("private") - end + sym_visibility = concretein ? nothing : "private" ctx = MLIR.IR.context() mod = MLIR.IR.mmodule() @@ -561,10 +558,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} end res = MLIR.Dialects.enzyme.batch( - batch_inputs; - outputs=out_tys2, - fn=fname, - batch_shape=MLIR.IR.DenseArrayAttribute([Int64(i) for i in OutShape]), + batch_inputs; outputs=out_tys2, fn=fname, batch_shape=[Int64(i) for i in OutShape] ) residx = 1 diff --git a/src/Types.jl b/src/Types.jl index 5d73d2171..fe2ab2358 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -1,3 +1,5 @@ +using Reactant.MLIR.Dialects: stablehlo + abstract type RNumber{T<:ReactantPrimitive} <: Number end abstract type RArray{T,N} <: AbstractArray{T,N} end @@ -59,7 +61,7 @@ const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} ## TracedRNG mutable struct TracedRNG <: Random.AbstractRNG seed::TracedRArray{UInt64,1} - const algorithm::String + const algorithm::stablehlo.RngAlgorithm.T end # Concrete Types @@ -169,5 +171,5 @@ end ## ConcreteRNG mutable struct ConcreteRNG{D,S} <: Random.AbstractRNG seed::ConcreteRArray{UInt64,1,D,S} - const algorithm::String + const algorithm::stablehlo.RngAlgorithm.T end diff --git a/src/mlir/Dialects.jl b/src/mlir/Dialects.jl index 87c63b199..fa40297d4 100644 --- a/src/mlir/Dialects.jl +++ b/src/mlir/Dialects.jl @@ -1,12 +1,12 @@ module Dialects -import ..IR: Attribute, NamedAttribute, context +import ..IR: Attribute, AbstractAttribute, NamedAttribute, context import ..API +import ....Reactant using Reactant_jll - namedattribute(name, val) = namedattribute(name, Attribute(val)) -namedattribute(name, val::Attribute) = NamedAttribute(name, val) +namedattribute(name, val::API.MlirAttribute) = NamedAttribute(name, Attribute(val)) function namedattribute(name, val::NamedAttribute) @assert true # TODO(jm): check whether name of attribute is correct, getting the name might need to be added to IR.jl? return val @@ -16,6 +16,9 @@ function operandsegmentsizes(segments) return namedattribute("operand_segment_sizes", Attribute(Int32.(segments))) end +c(a::AbstractArray) = isempty(a) ? "[]" : a +c(x) = x + for file in readdir(joinpath(@__DIR__, "Dialects")) endswith(file, ".jl") || continue include(joinpath(@__DIR__, "Dialects", file)) diff --git a/src/mlir/Dialects/Affine.jl b/src/mlir/Dialects/Affine.jl index 9ce90aa90..747209604 100755 --- a/src/mlir/Dialects/Affine.jl +++ b/src/mlir/Dialects/Affine.jl @@ -10,8 +10,18 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX + +""" +`AtomicRMWKind` +allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 +""" +@enumx AtomicRMWKind addf = 0 addi = 1 assign = 2 maximumf = 3 maxs = 4 maxu = 5 minimumf = + 6 mins = 7 minu = 8 mulf = 9 muli = 10 ori = 11 andi = 12 maxnumf = 13 minnumf = 14 + +IR.Attribute(e::AtomicRMWKind.T) = Int(e) """ `apply` @@ -37,16 +47,16 @@ have ‘index’ type. """ function apply( mapOperands::Vector{Value}; - result_0=nothing::Union{Nothing,IR.Type}, + result::Union{Nothing,IR.Type}=nothing, map, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[mapOperands...,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("map", map),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "affine.apply", @@ -55,8 +65,8 @@ function apply( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -113,9 +123,9 @@ undefined behavior. function delinearize_index( linear_index::Value, dynamic_basis::Vector{Value}; - multi_index::Vector{IR.Type}, - static_basis, - location=Location(), + multi_index::Base.AbstractVecOrTuple{IR.Type}, + static_basis::IR.DenseAttribute{Int64}, + location::Location=Location(), ) op_ty_results = IR.Type[multi_index...,] operands = Value[linear_index, dynamic_basis...] @@ -246,12 +256,12 @@ function for_( lowerBoundOperands::Vector{Value}, upperBoundOperands::Vector{Value}, inits::Vector{Value}; - results::Vector{IR.Type}, + results::Base.AbstractVecOrTuple{IR.Type}, lowerBoundMap, upperBoundMap, step, region::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[lowerBoundOperands..., upperBoundOperands..., inits...] @@ -353,11 +363,11 @@ func.func @pad_edges(%I : memref<10x10xf32>) -> (memref<12x12xf32) { """ function if_( operand_0::Vector{Value}; - results::Vector{IR.Type}, + results::Base.AbstractVecOrTuple{IR.Type}, condition, thenRegion::Region, elseRegion::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[operand_0...,] @@ -429,9 +439,9 @@ In the above example, `%linear_index` conceptually holds the following: function linearize_index( multi_index::Vector{Value}, dynamic_basis::Vector{Value}; - linear_index=nothing::Union{Nothing,IR.Type}, - static_basis, - location=Location(), + linear_index::Union{Nothing,IR.Type}=nothing, + static_basis::IR.DenseAttribute{Int64}, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[multi_index..., dynamic_basis...] @@ -448,8 +458,8 @@ function linearize_index( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -482,7 +492,11 @@ Example 2: Uses `symbol` keyword for symbols `%n` and `%m`. ``` """ function load( - memref::Value, indices::Vector{Value}; result::IR.Type, map, location=Location() + memref::Value, + indices::Vector{Value}; + result::IR.Type, + map, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[memref, indices...] @@ -516,16 +530,16 @@ affine map. """ function max( operands::Vector{Value}; - result_0=nothing::Union{Nothing,IR.Type}, + result::Union{Nothing,IR.Type}=nothing, map, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operands...,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("map", map),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "affine.max", @@ -534,8 +548,8 @@ function max( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -563,16 +577,16 @@ input operands and result must all have \'index\' type. """ function min( operands::Vector{Value}; - result_0=nothing::Union{Nothing,IR.Type}, + result::Union{Nothing,IR.Type}=nothing, map, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operands...,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("map", map),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "affine.min", @@ -581,8 +595,8 @@ function min( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -653,15 +667,15 @@ affine.parallel (%ii, %jj) = (0, 0) to (%N, %M) step (32, 32) { """ function parallel( mapOperands::Vector{Value}; - results::Vector{IR.Type}, - reductions, + results::Base.AbstractVecOrTuple{IR.Type}, + reductions::IR.DenseAttribute{AtomicRMWKind.T}, lowerBoundsMap, - lowerBoundsGroups, + lowerBoundsGroups::IR.AbstractDenseElementsAttribute{Int64}, upperBoundsMap, - upperBoundsGroups, - steps, + upperBoundsGroups::IR.AbstractDenseElementsAttribute{Int64}, + steps::IR.DenseAttribute{Int64}, region::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[mapOperands...,] @@ -709,11 +723,11 @@ instruction cache. function prefetch( memref::Value, indices::Vector{Value}; - isWrite, - localityHint, - isDataCache, + isWrite::Bool, + localityHint::Int32, + isDataCache::Bool, map, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[memref, indices...] @@ -767,7 +781,7 @@ affine.store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32> ``` """ function store( - value::Value, memref::Value, indices::Vector{Value}; map, location=Location() + value::Value, memref::Value, indices::Vector{Value}; map, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[value, memref, indices...] @@ -827,7 +841,11 @@ TODOs: (see [vector.transfer_read](../Vector/#vectortransfer_read-mlirvectortransferreadop)). """ function vector_load( - memref::Value, indices::Vector{Value}; result::IR.Type, map, location=Location() + memref::Value, + indices::Vector{Value}; + result::IR.Type, + map, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[memref, indices...] @@ -889,7 +907,7 @@ TODOs: (see [vector.transfer_write](../Vector/#vectortransfer_write-mlirvectortransferwriteop)). """ function vector_store( - value::Value, memref::Value, indices::Vector{Value}; map, location=Location() + value::Value, memref::Value, indices::Vector{Value}; map, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[value, memref, indices...] @@ -922,7 +940,7 @@ left out in the custom syntax and the builders will insert one implicitly. Otherwise, it has to be present in the syntax to indicate which values are yielded. """ -function yield(operands::Vector{Value}; location=Location()) +function yield(operands::Vector{Value}; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[operands...,] owned_regions = Region[] diff --git a/src/mlir/Dialects/Arith.jl b/src/mlir/Dialects/Arith.jl index 01289a12f..ad4ab5a8a 100755 --- a/src/mlir/Dialects/Arith.jl +++ b/src/mlir/Dialects/Arith.jl @@ -10,8 +10,60 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX + +""" +`FastMathFlags` +Floating point fast math flags +""" +@enumx FastMathFlags none reassoc nnan ninf nsz arcp contract afn fast +const FastMathFlagsStorage = [ + "none", "reassoc", "nnan", "ninf", "nsz", "arcp", "contract", "afn", "fast" +] + +function IR.Attribute(e::FastMathFlags.T) + return parse(Attribute, "#arith>") +end + +""" +`IntegerOverflowFlags` +Integer overflow arith flags +""" +@enumx IntegerOverflowFlags none nsw nuw +const IntegerOverflowFlagsStorage = ["none", "nsw", "nuw"] + +function IR.Attribute(e::IntegerOverflowFlags.T) + return parse(Attribute, "#arith>") +end + +""" +`CmpFPredicate` +allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +""" +@enumx CmpFPredicate AlwaysFalse = 0 OEQ = 1 OGT = 2 OGE = 3 OLT = 4 OLE = 5 ONE = 6 ORD = 7 UEQ = + 8 UGT = 9 UGE = 10 ULT = 11 ULE = 12 UNE = 13 UNO = 14 AlwaysTrue = 15 + +IR.Attribute(e::CmpFPredicate.T) = Int(e) + +""" +`CmpIPredicate` +allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 +""" +@enumx CmpIPredicate eq = 0 ne = 1 slt = 2 sle = 3 sgt = 4 sge = 5 ult = 6 ule = 7 ugt = 8 uge = + 9 + +IR.Attribute(e::CmpIPredicate.T) = Int(e) + +""" +`RoundingMode` +Floating point rounding mode +""" +@enumx RoundingMode to_nearest_even = 0 downward = 1 upward = 2 toward_zero = 3 to_nearest_away = + 4 + +IR.Attribute(e::RoundingMode.T) = Int(e) """ `addf` @@ -40,9 +92,9 @@ math, contraction, rounding mode, and other controls. function addf( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - fastmath=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + fastmath::Union{FastMathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -59,8 +111,8 @@ function addf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -101,9 +153,9 @@ This op supports `nuw`/`nsw` overflow flags which stands stand for function addi( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - overflowFlags=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + overflowFlags::Union{IntegerOverflowFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -121,8 +173,8 @@ function addi( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -148,7 +200,7 @@ indicates no overflow. ``` """ function addui_extended( - lhs::Value, rhs::Value; sum::IR.Type, overflow::IR.Type, location=Location() + lhs::Value, rhs::Value; sum::IR.Type, overflow::IR.Type, location::Location=Location() ) op_ty_results = IR.Type[sum, overflow] operands = Value[lhs, rhs] @@ -190,7 +242,10 @@ has no standard attributes. ``` """ function andi( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -206,8 +261,8 @@ function andi( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -229,7 +284,7 @@ endianness for the source and target types (e.g. float is big-endian and integer is little-endian) a proper lowering would add operations to swap the order of words in addition to the bitcast. """ -function bitcast(in::Value; out::IR.Type, location=Location()) +function bitcast(in::Value; out::IR.Type, location::Location=Location()) op_ty_results = IR.Type[out,] operands = Value[in,] owned_regions = Region[] @@ -266,7 +321,10 @@ signed division overflow. ``` """ function ceildivsi( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -282,8 +340,8 @@ function ceildivsi( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -306,7 +364,10 @@ zero. ``` """ function ceildivui( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -322,8 +383,8 @@ function ceildivui( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -357,10 +418,10 @@ attribute by the parser. function cmpf( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - predicate, - fastmath=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + predicate::CmpFPredicate.T, + fastmath::Union{FastMathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -377,8 +438,8 @@ function cmpf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -450,9 +511,9 @@ complement or large positives function cmpi( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - predicate, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + predicate::CmpIPredicate.T, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -468,8 +529,8 @@ function cmpi( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -490,7 +551,9 @@ forms simple integer and floating point constants. %1 = \"arith.constant\"() {value = 42 : i32} : () -> i32 ``` """ -function constant(; result=nothing::Union{Nothing,IR.Type}, value, location=Location()) +function constant(; + result::Union{Nothing,IR.Type}=nothing, value, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -505,17 +568,17 @@ function constant(; result=nothing::Union{Nothing,IR.Type}, value, location=Loca owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function divf( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - fastmath=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + fastmath::Union{FastMathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -532,8 +595,8 @@ function divf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -562,7 +625,10 @@ signed division overflow. ``` """ function divsi( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -578,8 +644,8 @@ function divsi( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -608,7 +674,10 @@ zero. ``` """ function divui( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -624,8 +693,8 @@ function divui( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -636,7 +705,12 @@ Cast a floating-point value to a larger floating-point-typed value. The destination type must to be strictly wider than the source type. When operating on vectors, casts elementwise. """ -function extf(in::Value; out::IR.Type, fastmath=nothing, location=Location()) +function extf( + in::Value; + out::IR.Type, + fastmath::Union{FastMathFlags.T,Nothing}=nothing, + location::Location=Location(), +) op_ty_results = IR.Type[out,] operands = Value[in,] owned_regions = Region[] @@ -676,7 +750,7 @@ of the most-significant bit of the input. %5 = arith.extsi %0 : vector<2 x i32> to vector<2 x i64> ``` """ -function extsi(in::Value; out::IR.Type, location=Location()) +function extsi(in::Value; out::IR.Type, location::Location=Location()) op_ty_results = IR.Type[out,] operands = Value[in,] owned_regions = Region[] @@ -714,7 +788,7 @@ The top-most (N - M) bits of the output are filled with zeros. %5 = arith.extui %0 : vector<2 x i32> to vector<2 x i64> ``` """ -function extui(in::Value; out::IR.Type, location=Location()) +function extui(in::Value; out::IR.Type, location::Location=Location()) op_ty_results = IR.Type[out,] operands = Value[in,] owned_regions = Region[] @@ -740,7 +814,7 @@ Cast from a value interpreted as floating-point to the nearest (rounding towards zero) signed integer value. When operating on vectors, casts elementwise. """ -function fptosi(in::Value; out::IR.Type, location=Location()) +function fptosi(in::Value; out::IR.Type, location::Location=Location()) op_ty_results = IR.Type[out,] operands = Value[in,] owned_regions = Region[] @@ -766,7 +840,7 @@ Cast from a value interpreted as floating-point to the nearest (rounding towards zero) unsigned integer value. When operating on vectors, casts elementwise. """ -function fptoui(in::Value; out::IR.Type, location=Location()) +function fptoui(in::Value; out::IR.Type, location::Location=Location()) op_ty_results = IR.Type[out,] operands = Value[in,] owned_regions = Region[] @@ -804,7 +878,10 @@ signed division overflow. ``` """ function floordivsi( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -820,8 +897,8 @@ function floordivsi( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -833,7 +910,7 @@ vectors. Index is an integer of platform-specific bit width. If casting to a wider integer, the value is sign-extended. If casting to a narrower integer, the value is truncated. """ -function index_cast(in::Value; out::IR.Type, location=Location()) +function index_cast(in::Value; out::IR.Type, location::Location=Location()) op_ty_results = IR.Type[out,] operands = Value[in,] owned_regions = Region[] @@ -860,7 +937,7 @@ vectors. Index is an integer of platform-specific bit width. If casting to a wider integer, the value is zero-extended. If casting to a narrower integer, the value is truncated. """ -function index_castui(in::Value; out::IR.Type, location=Location()) +function index_castui(in::Value; out::IR.Type, location::Location=Location()) op_ty_results = IR.Type[out,] operands = Value[in,] owned_regions = Region[] @@ -896,9 +973,9 @@ If one of the arguments is NaN, then the result is the other argument. function maxnumf( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - fastmath=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + fastmath::Union{FastMathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -915,13 +992,16 @@ function maxnumf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function maxsi( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -937,13 +1017,16 @@ function maxsi( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function maxui( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -959,8 +1042,8 @@ function maxui( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -980,9 +1063,9 @@ If one of the arguments is NaN, then the result is also NaN. function maximumf( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - fastmath=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + fastmath::Union{FastMathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -999,8 +1082,8 @@ function maximumf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1021,9 +1104,9 @@ If one of the arguments is NaN, then the result is the other argument. function minnumf( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - fastmath=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + fastmath::Union{FastMathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1040,13 +1123,16 @@ function minnumf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function minsi( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1062,13 +1148,16 @@ function minsi( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function minui( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1084,8 +1173,8 @@ function minui( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1105,9 +1194,9 @@ If one of the arguments is NaN, then the result is also NaN. function minimumf( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - fastmath=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + fastmath::Union{FastMathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1124,8 +1213,8 @@ function minimumf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1156,9 +1245,9 @@ math, contraction, rounding mode, and other controls. function mulf( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - fastmath=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + fastmath::Union{FastMathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1175,8 +1264,8 @@ function mulf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1217,9 +1306,9 @@ This op supports `nuw`/`nsw` overflow flags which stands stand for function muli( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - overflowFlags=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + overflowFlags::Union{IntegerOverflowFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1237,8 +1326,8 @@ function muli( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1266,9 +1355,9 @@ the same operands. function mulsi_extended( lhs::Value, rhs::Value; - low=nothing::Union{Nothing,IR.Type}, - high=nothing::Union{Nothing,IR.Type}, - location=Location(), + low::Union{Nothing,IR.Type}=nothing, + high::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1285,8 +1374,8 @@ function mulsi_extended( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1314,9 +1403,9 @@ the same operands. function mului_extended( lhs::Value, rhs::Value; - low=nothing::Union{Nothing,IR.Type}, - high=nothing::Union{Nothing,IR.Type}, - location=Location(), + low::Union{Nothing,IR.Type}=nothing, + high::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1333,8 +1422,8 @@ function mului_extended( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1361,9 +1450,9 @@ It has no standard attributes. """ function negf( operand::Value; - result=nothing::Union{Nothing,IR.Type}, - fastmath=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + fastmath::Union{FastMathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -1380,8 +1469,8 @@ function negf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1407,7 +1496,10 @@ standard attributes. ``` """ function ori( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1423,8 +1515,8 @@ function ori( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1437,9 +1529,9 @@ The remainder has the same sign as the dividend (lhs operand). function remf( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - fastmath=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + fastmath::Union{FastMathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1456,8 +1548,8 @@ function remf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1485,7 +1577,10 @@ zero. ``` """ function remsi( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1501,8 +1596,8 @@ function remsi( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1530,7 +1625,10 @@ zero. ``` """ function remui( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1546,8 +1644,8 @@ function remui( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1559,7 +1657,7 @@ floating-point value. If the value cannot be exactly represented, it is rounded using the default rounding mode. When operating on vectors, casts elementwise. """ -function sitofp(in::Value; out::IR.Type, location=Location()) +function sitofp(in::Value; out::IR.Type, location::Location=Location()) op_ty_results = IR.Type[out,] operands = Value[in,] owned_regions = Region[] @@ -1604,9 +1702,9 @@ This op supports `nuw`/`nsw` overflow flags which stands stand for function shli( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - overflowFlags=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + overflowFlags::Union{IntegerOverflowFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1624,8 +1722,8 @@ function shli( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1651,7 +1749,10 @@ returns poison. ``` """ function shrsi( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1667,8 +1768,8 @@ function shrsi( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1690,7 +1791,10 @@ bitwidth of the first operand, then the operation returns poison. ``` """ function shrui( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1706,8 +1810,8 @@ function shrui( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1738,9 +1842,9 @@ math, contraction, rounding mode, and other controls. function subf( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - fastmath=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + fastmath::Union{FastMathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1757,8 +1861,8 @@ function subf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1799,9 +1903,9 @@ This op supports `nuw`/`nsw` overflow flags which stands stand for function subi( lhs::Value, rhs::Value; - result=nothing::Union{Nothing,IR.Type}, - overflowFlags=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + overflowFlags::Union{IntegerOverflowFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1819,8 +1923,8 @@ function subi( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1834,7 +1938,11 @@ provided rounding mode or the default one if no rounding mode is provided. When operating on vectors, casts elementwise. """ function truncf( - in::Value; out::IR.Type, roundingmode=nothing, fastmath=nothing, location=Location() + in::Value; + out::IR.Type, + roundingmode::Union{RoundingMode.T,Nothing}=nothing, + fastmath::Union{FastMathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[out,] operands = Value[in,] @@ -1875,7 +1983,7 @@ The top-most (N - M) bits of the input are discarded. %5 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16> ``` """ -function trunci(in::Value; out::IR.Type, location=Location()) +function trunci(in::Value; out::IR.Type, location::Location=Location()) op_ty_results = IR.Type[out,] operands = Value[in,] owned_regions = Region[] @@ -1902,7 +2010,7 @@ floating-point value. If the value cannot be exactly represented, it is rounded using the default rounding mode. When operating on vectors, casts elementwise. """ -function uitofp(in::Value; out::IR.Type, location=Location()) +function uitofp(in::Value; out::IR.Type, location::Location=Location()) op_ty_results = IR.Type[out,] operands = Value[in,] owned_regions = Region[] @@ -1943,7 +2051,10 @@ has no standard attributes. ``` """ function xori( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1959,8 +2070,8 @@ function xori( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2005,8 +2116,8 @@ function select( condition::Value, true_value::Value, false_value::Value; - result=nothing::Union{Nothing,IR.Type}, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[condition, true_value, false_value] @@ -2022,8 +2133,8 @@ function select( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end diff --git a/src/mlir/Dialects/Builtin.jl b/src/mlir/Dialects/Builtin.jl index df4b46607..b772c57bd 100755 --- a/src/mlir/Dialects/Builtin.jl +++ b/src/mlir/Dialects/Builtin.jl @@ -10,8 +10,9 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX """ `module_` @@ -33,7 +34,10 @@ module { ``` """ function module_(; - sym_name=nothing, sym_visibility=nothing, bodyRegion::Region, location=Location() + sym_name::Union{String,Nothing}=nothing, + sym_visibility::Union{String,Nothing}=nothing, + bodyRegion::Region, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -91,7 +95,9 @@ operands of arity 0-N. ``` """ function unrealized_conversion_cast( - inputs::Vector{Value}; outputs::Vector{IR.Type}, location=Location() + inputs::Vector{Value}; + outputs::Base.AbstractVecOrTuple{IR.Type}, + location::Location=Location(), ) op_ty_results = IR.Type[outputs...,] operands = Value[inputs...,] diff --git a/src/mlir/Dialects/CHLO.jl b/src/mlir/Dialects/CHLO.jl index 7696a6556..6d9f97f20 100755 --- a/src/mlir/Dialects/CHLO.jl +++ b/src/mlir/Dialects/CHLO.jl @@ -10,8 +10,64 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX + +""" +`ComparisonDirection` +Which comparison operation to perform. +""" +@enumx ComparisonDirection EQ NE GE GT LE LT +const ComparisonDirectionStorage = ["EQ", "NE", "GE", "GT", "LE", "LT"] + +function IR.Attribute(e::ComparisonDirection.T) + return parse( + Attribute, "#chlo" + ) +end + +""" +`ComparisonType` +Which comparison type to use. +""" +@enumx ComparisonType NOTYPE FLOAT TOTALORDER SIGNED UNSIGNED +const ComparisonTypeStorage = ["NOTYPE", "FLOAT", "TOTALORDER", "SIGNED", "UNSIGNED"] + +function IR.Attribute(e::ComparisonType.T) + return parse(Attribute, "#chlo") +end + +""" +`ragged_dot` +Attribute that models the dimension information for ragged dot. +""" +struct RaggedDot + lhs_batching_dimensions::IR.DenseAttribute{Int64} + rhs_batching_dimensions::IR.DenseAttribute{Int64} + lhs_contracting_dimensions::IR.DenseAttribute{Int64} + rhs_contracting_dimensions::IR.DenseAttribute{Int64} + lhs_ragged_dimensions::IR.DenseAttribute{Int64} + rhs_group_dimensions::IR.DenseAttribute{Int64} +end + +function IR.Attribute(s::RaggedDot) + return parse( + Attribute, + "#chlo.ragged_dot", + ) +end + +""" +`Precision` +XLA precision for an operand. Has backend specific meaning. +""" +@enumx Precision DEFAULT HIGH HIGHEST +const PrecisionStorage = ["DEFAULT", "HIGH", "HIGHEST"] + +function IR.Attribute(e::Precision.T) + return parse(Attribute, "#chlo") +end """ `acos` @@ -23,7 +79,9 @@ Returns `Acos(operand)` element-wise. = pi if x == -1 \$\$ """ -function acos(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function acos( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -38,8 +96,8 @@ function acos(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -53,7 +111,9 @@ Returns `Acosh(operand)` element-wise. \\acosh(x) = nan if x < -1 \$\$ """ -function acosh(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function acosh( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -68,8 +128,8 @@ function acosh(operand::Value; result=nothing::Union{Nothing,IR.Type}, location= owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -92,7 +152,7 @@ should never be constructed directly by frameworks or consumed by backends. """ function _asin_acos_kernel( - operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -108,8 +168,8 @@ function _asin_acos_kernel( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -122,7 +182,9 @@ Returns `Asin(operand)` element-wise. \\asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) \$\$ """ -function asin(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function asin( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -137,8 +199,8 @@ function asin(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -151,7 +213,9 @@ Returns `Asinh(operand)` element-wise. \\asinh(x) = log(x + sqrt(x^2 + 1)) \$\$ """ -function asinh(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function asinh( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -166,8 +230,8 @@ function asinh(operand::Value; result=nothing::Union{Nothing,IR.Type}, location= owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -180,7 +244,9 @@ Returns `Atan(operand)` element-wise. \\atan(x) = \\atan2(x, 1) \$\$ """ -function atan(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function atan( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -195,8 +261,8 @@ function atan(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -210,7 +276,9 @@ Returns `Atanh(operand)` element-wise. = nan otherwise \$\$ """ -function atanh(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function atanh( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -225,8 +293,8 @@ function atanh(operand::Value; result=nothing::Union{Nothing,IR.Type}, location= owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -236,7 +304,7 @@ end Returns `bessel_i1e(operand)` element-wise. """ function bessel_i1e( - operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -252,8 +320,8 @@ function bessel_i1e( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -268,16 +336,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_add( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -288,8 +356,8 @@ function broadcast_add( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -304,16 +372,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_and( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -324,8 +392,8 @@ function broadcast_and( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -340,16 +408,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_atan2( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -360,8 +428,8 @@ function broadcast_atan2( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -379,11 +447,11 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_opera function broadcast_compare( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - comparison_direction, - compare_type=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + comparison_direction::ComparisonDirection.T, + compare_type::Union{ComparisonType.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -392,7 +460,7 @@ function broadcast_compare( attributes = NamedAttribute[namedattribute( "comparison_direction", comparison_direction ),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) !isnothing(compare_type) && @@ -405,8 +473,8 @@ function broadcast_compare( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -419,16 +487,16 @@ a complex value. function broadcast_complex( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -439,8 +507,8 @@ function broadcast_complex( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -455,16 +523,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_divide( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -475,8 +543,8 @@ function broadcast_divide( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -491,16 +559,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_maximum( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -511,8 +579,8 @@ function broadcast_maximum( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -527,16 +595,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_minimum( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -547,8 +615,8 @@ function broadcast_minimum( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -563,16 +631,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_multiply( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -583,8 +651,8 @@ function broadcast_multiply( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -599,16 +667,16 @@ Equivalent to the C++ std::nextafter function. function broadcast_next_after( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -619,8 +687,8 @@ function broadcast_next_after( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -635,16 +703,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_or( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -655,8 +723,8 @@ function broadcast_or( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -668,16 +736,16 @@ Returns `Polygamma(operand, operand)` element-wise. function broadcast_polygamma( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -688,8 +756,8 @@ function broadcast_polygamma( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -704,16 +772,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_power( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -724,8 +792,8 @@ function broadcast_power( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -740,16 +808,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_remainder( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -760,8 +828,8 @@ function broadcast_remainder( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -777,15 +845,15 @@ function broadcast_select( pred::Value, on_true::Value, on_false::Value; - result_0=nothing::Union{Nothing,IR.Type}, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[pred, on_true, on_false] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "chlo.broadcast_select", @@ -794,8 +862,8 @@ function broadcast_select( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -810,16 +878,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_shift_left( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -830,8 +898,8 @@ function broadcast_shift_left( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -846,16 +914,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_shift_right_arithmetic( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -866,8 +934,8 @@ function broadcast_shift_right_arithmetic( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -882,16 +950,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_shift_right_logical( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -902,8 +970,8 @@ function broadcast_shift_right_logical( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -918,16 +986,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_subtract( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -938,8 +1006,8 @@ function broadcast_subtract( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -954,16 +1022,16 @@ https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmeti function broadcast_xor( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -974,8 +1042,8 @@ function broadcast_xor( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -991,16 +1059,16 @@ Returns `Zeta(operand, operand)` element-wise. function broadcast_zeta( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_dimensions=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) @@ -1011,8 +1079,8 @@ function broadcast_zeta( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1025,7 +1093,9 @@ Returns `Conj(operand)` element-wise. \\conj(x) = (\\real(x), \\neg(\\imag(x))) \$\$ """ -function conj(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function conj( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -1040,8 +1110,8 @@ function conj(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1051,14 +1121,17 @@ end Returns a splat constant of the same shape as the operand. """ function constant_like( - operand::Value; result_0=nothing::Union{Nothing,IR.Type}, value, location=Location() + operand::Value; + result::Union{Nothing,IR.Type}=nothing, + value, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("value", value),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "chlo.constant_like", @@ -1067,8 +1140,8 @@ function constant_like( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1077,7 +1150,11 @@ end Represents a constant value. """ -function constant(; output=nothing::Union{Nothing,IR.Type}, value, location=Location()) +function constant(; + output::Union{Nothing,IR.Type}=nothing, + value::IR.AbstractDenseElementsAttribute, + location::Location=Location(), +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1092,8 +1169,8 @@ function constant(; output=nothing::Union{Nothing,IR.Type}, value, location=Loca owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1106,7 +1183,9 @@ Returns `Cosh(operand)` element-wise. \\cosh(x) = (e^x + e^-x) / 2 \$\$ """ -function cosh(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function cosh( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -1121,8 +1200,8 @@ function cosh(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1132,7 +1211,7 @@ end Returns `Digamma(operand)` element-wise. """ function digamma( - operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -1148,8 +1227,8 @@ function digamma( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1159,7 +1238,7 @@ end Returns `ErfInv(operand)` element-wise. """ function erf_inv( - operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -1175,8 +1254,8 @@ function erf_inv( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1188,7 +1267,9 @@ Computes the Gauss error function of `x` element-wise. erf(x) = erf_impl(x) if |x| < 1 = 1 - erfc_impl(x) otherwise """ -function erf(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function erf( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -1203,8 +1284,8 @@ function erf(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Lo owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1216,7 +1297,9 @@ Computes an approximation of the error function complement (1 - erf(x)). erfc(x) = erfc_impl(x) if |x| > 1 = 1 - erf_impl(x) otherwise """ -function erfc(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function erfc( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -1231,8 +1314,8 @@ function erfc(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1241,7 +1324,9 @@ end Returns if a value is +/-inf element-wise. """ -function is_inf(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function is_inf( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -1256,8 +1341,8 @@ function is_inf(operand::Value; result=nothing::Union{Nothing,IR.Type}, location owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1267,7 +1352,7 @@ end Returns if a value is -inf element-wise. """ function is_neg_inf( - operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -1283,8 +1368,8 @@ function is_neg_inf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1294,7 +1379,7 @@ end Returns if a value is +inf element-wise. """ function is_pos_inf( - operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -1310,8 +1395,8 @@ function is_pos_inf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1320,7 +1405,9 @@ end Returns `Lgamma(operand)` element-wise. """ -function lgamma(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function lgamma( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -1335,8 +1422,8 @@ function lgamma(operand::Value; result=nothing::Union{Nothing,IR.Type}, location owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1349,7 +1436,10 @@ element-wise. It can also return a subnormal number. Equivalent to the C++ std::nextafter function. """ function next_after( - x::Value, y::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + x::Value, + y::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[x, y] @@ -1365,8 +1455,8 @@ function next_after( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1376,7 +1466,10 @@ end Returns `Polygamma(operand, operand)` element-wise. """ function polygamma( - n::Value, x::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + n::Value, + x::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[n, x] @@ -1392,8 +1485,8 @@ function polygamma( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1423,10 +1516,10 @@ function ragged_dot( lhs::Value, rhs::Value, group_sizes::Value; - result=nothing::Union{Nothing,IR.Type}, - ragged_dot_dimension_numbers, - precision_config=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + ragged_dot_dimension_numbers::RaggedDot, + precision_config::Union{IR.DenseAttribute{Precision.T},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs, group_sizes] @@ -1446,8 +1539,8 @@ function ragged_dot( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1461,7 +1554,9 @@ Returns `Sinh(operand)` element-wise. = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. \$\$ """ -function sinh(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function sinh( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -1476,8 +1571,8 @@ function sinh(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1491,7 +1586,9 @@ Returns `Square(operand)` element-wise. = x * x otherwise \$\$ """ -function square(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function square( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -1506,8 +1603,8 @@ function square(operand::Value; result=nothing::Union{Nothing,IR.Type}, location owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1520,7 +1617,9 @@ Returns `Tan(operand)` element-wise. \\tan(x) = \\sin(x) / \\cos(x) \$\$ """ -function tan(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function tan( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -1535,8 +1634,8 @@ function tan(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Lo owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1556,10 +1655,10 @@ If two elements are equal, the lower-index element appears first. """ function top_k( operand::Value; - values=nothing::Union{Nothing,IR.Type}, - indices=nothing::Union{Nothing,IR.Type}, - k, - location=Location(), + values::Union{Nothing,IR.Type}=nothing, + indices::Union{Nothing,IR.Type}=nothing, + k::Int64, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -1576,8 +1675,8 @@ function top_k( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1591,7 +1690,10 @@ Returns `Zeta(operand, operand)` element-wise. \$\$ """ function zeta( - x::Value, q::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + x::Value, + q::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[x, q] @@ -1607,8 +1709,8 @@ function zeta( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index f922304da..5893f29ec 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -10,15 +10,34 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX + +""" +`Activity` +Possible activity states for variables +""" +@enumx Activity enzyme_active enzyme_dup enzyme_const enzyme_dupnoneed enzyme_activenoneed enzyme_constnoneed +const ActivityStorage = [ + "enzyme_active", + "enzyme_dup", + "enzyme_const", + "enzyme_dupnoneed", + "enzyme_activenoneed", + "enzyme_constnoneed", +] + +function IR.Attribute(e::Activity.T) + return parse(Attribute, "#enzyme") +end """ `addTo` TODO """ -function addTo(values::Vector{Value}; location=Location()) +function addTo(values::Vector{Value}; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[values...,] owned_regions = Region[] @@ -39,12 +58,12 @@ end function autodiff( inputs::Vector{Value}; - outputs::Vector{IR.Type}, - fn, - activity, - ret_activity, - width=nothing, - location=Location(), + outputs::Base.AbstractVecOrTuple{IR.Type}, + fn::IR.FlatSymbolRefAttribute, + activity::IR.DenseAttribute{Activity.T}, + ret_activity::IR.DenseAttribute{Activity.T}, + width::Union{Int64,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[outputs...,] operands = Value[inputs...,] @@ -70,7 +89,11 @@ function autodiff( end function batch( - inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location() + inputs::Vector{Value}; + outputs::Base.AbstractVecOrTuple{IR.Type}, + fn::IR.FlatSymbolRefAttribute, + batch_shape::IR.DenseAttribute{Int64}, + location::Location=Location(), ) op_ty_results = IR.Type[outputs...,] operands = Value[inputs...,] @@ -100,7 +123,12 @@ For scalar operands, ranked tensor is created. NOTE: Only works for scalar and *ranked* tensor operands for now. """ -function broadcast(input::Value; output::IR.Type, shape, location=Location()) +function broadcast( + input::Value; + output::IR.Type, + shape::IR.DenseAttribute{Int64}, + location::Location=Location(), +) op_ty_results = IR.Type[output,] operands = Value[input,] owned_regions = Region[] @@ -121,12 +149,12 @@ end function fwddiff( inputs::Vector{Value}; - outputs::Vector{IR.Type}, - fn, - activity, - ret_activity, - width=nothing, - location=Location(), + outputs::Base.AbstractVecOrTuple{IR.Type}, + fn::IR.FlatSymbolRefAttribute, + activity::IR.DenseAttribute{Activity.T}, + ret_activity::IR.DenseAttribute{Activity.T}, + width::Union{Int64,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[outputs...,] operands = Value[inputs...,] @@ -154,13 +182,13 @@ end function genericAdjoint( inputs::Vector{Value}, outputs::Vector{Value}; - result_tensors::Vector{IR.Type}, - indexing_maps, - iterator_types, - doc=nothing, - library_call=nothing, + result_tensors::Base.AbstractVecOrTuple{IR.Type}, + indexing_maps::IR.DenseAttribute{<:Any}, + iterator_types::Vector{<:IR.AbstractAttribute}, + doc::Union{String,Nothing}=nothing, + library_call::Union{String,Nothing}=nothing, region::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result_tensors...,] operands = Value[inputs..., outputs...] @@ -187,8 +215,8 @@ function genericAdjoint( ) end -function get(gradient::Value; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0,] +function get(gradient::Value; result::IR.Type, location::Location=Location()) + op_ty_results = IR.Type[result,] operands = Value[gradient,] owned_regions = Region[] successors = Block[] @@ -206,8 +234,8 @@ function get(gradient::Value; result_0::IR.Type, location=Location()) ) end -function init(; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0,] +function init(; result::IR.Type, location::Location=Location()) + op_ty_results = IR.Type[result,] operands = Value[] owned_regions = Region[] successors = Block[] @@ -225,7 +253,7 @@ function init(; result_0::IR.Type, location=Location()) ) end -function placeholder(; output::IR.Type, location=Location()) +function placeholder(; output::IR.Type, location::Location=Location()) op_ty_results = IR.Type[output,] operands = Value[] owned_regions = Region[] @@ -244,7 +272,7 @@ function placeholder(; output::IR.Type, location=Location()) ) end -function pop(cache::Value; output::IR.Type, location=Location()) +function pop(cache::Value; output::IR.Type, location::Location=Location()) op_ty_results = IR.Type[output,] operands = Value[cache,] owned_regions = Region[] @@ -263,7 +291,7 @@ function pop(cache::Value; output::IR.Type, location=Location()) ) end -function push(cache::Value, value::Value; location=Location()) +function push(cache::Value, value::Value; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[cache, value] owned_regions = Region[] @@ -282,7 +310,7 @@ function push(cache::Value, value::Value; location=Location()) ) end -function set(gradient::Value, value::Value; location=Location()) +function set(gradient::Value, value::Value; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[gradient, value] owned_regions = Region[] diff --git a/src/mlir/Dialects/EnzymeXLA.jl b/src/mlir/Dialects/EnzymeXLA.jl old mode 100644 new mode 100755 index a7b126148..b1eeca65c --- a/src/mlir/Dialects/EnzymeXLA.jl +++ b/src/mlir/Dialects/EnzymeXLA.jl @@ -10,11 +10,15 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX function scope( - operands::Vector{Value}; results::Vector{IR.Type}, region::Region, location=Location() + operands::Vector{Value}; + results::Base.AbstractVecOrTuple{IR.Type}, + region::Region, + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[operands...,] @@ -34,7 +38,7 @@ function scope( ) end -function get_stream(; result::IR.Type, location=Location()) +function get_stream(; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[] owned_regions = Region[] @@ -55,15 +59,15 @@ end function jit_call( inputs::Vector{Value}; - result_0::Vector{IR.Type}, - fn, - backend_config=nothing, - operand_layouts=nothing, - result_layouts=nothing, - output_operand_aliases=nothing, - location=Location(), + result::Base.AbstractVecOrTuple{IR.Type}, + fn::IR.FlatSymbolRefAttribute, + backend_config::Union{String,Nothing}=nothing, + operand_layouts::Union{IR.AbstractAttribute,Nothing}=nothing, + result_layouts::Union{IR.AbstractAttribute,Nothing}=nothing, + output_operand_aliases::Union{Vector{<:IR.AbstractAttribute},Nothing}=nothing, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[inputs...,] owned_regions = Region[] successors = Block[] @@ -98,15 +102,15 @@ function kernel_call( blockz::Value, shmem::Value, inputs::Vector{Value}; - result_0::Vector{IR.Type}, - fn, - backend_config=nothing, - operand_layouts=nothing, - result_layouts=nothing, - output_operand_aliases=nothing, - location=Location(), + result::Base.AbstractVecOrTuple{IR.Type}, + fn::IR.FlatSymbolRefAttribute, + backend_config::Union{String,Nothing}=nothing, + operand_layouts::Union{IR.AbstractAttribute,Nothing}=nothing, + result_layouts::Union{IR.AbstractAttribute,Nothing}=nothing, + output_operand_aliases::Union{Vector{<:IR.AbstractAttribute},Nothing}=nothing, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, shmem, inputs...] owned_regions = Region[] successors = Block[] @@ -132,7 +136,7 @@ function kernel_call( ) end -function memref2pointer(source::Value; result::IR.Type, location=Location()) +function memref2pointer(source::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[source,] owned_regions = Region[] @@ -151,7 +155,7 @@ function memref2pointer(source::Value; result::IR.Type, location=Location()) ) end -function pointer2memref(source::Value; result::IR.Type, location=Location()) +function pointer2memref(source::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[source,] owned_regions = Region[] diff --git a/src/mlir/Dialects/Func.jl b/src/mlir/Dialects/Func.jl index dcadfd219..c90960d93 100755 --- a/src/mlir/Dialects/Func.jl +++ b/src/mlir/Dialects/Func.jl @@ -10,8 +10,9 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX """ `call_indirect` @@ -33,10 +34,10 @@ Function values can be created with the function call_indirect( callee::Value, callee_operands::Vector{Value}; - results::Vector{IR.Type}, - arg_attrs=nothing, - res_attrs=nothing, - location=Location(), + results::Base.AbstractVecOrTuple{IR.Type}, + arg_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + res_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[callee, callee_operands...] @@ -74,14 +75,14 @@ symbol reference attribute named \"callee\". """ function call( operands::Vector{Value}; - result_0::Vector{IR.Type}, - callee, - arg_attrs=nothing, - res_attrs=nothing, - no_inline=nothing, - location=Location(), + result::Base.AbstractVecOrTuple{IR.Type}, + callee::IR.FlatSymbolRefAttribute, + arg_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + res_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + no_inline::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[operands...,] owned_regions = Region[] successors = Block[] @@ -123,8 +124,10 @@ the compiler is multithreaded, and disallowing SSA values to directly reference a function simplifies this ([rationale](../Rationale/Rationale.md#multithreading-the-compiler)). """ -function constant(; result_0::IR.Type, value, location=Location()) - op_ty_results = IR.Type[result_0,] +function constant(; + result::IR.Type, value::IR.FlatSymbolRefAttribute, location::Location=Location() +) + op_ty_results = IR.Type[result,] operands = Value[] owned_regions = Region[] successors = Block[] @@ -182,14 +185,14 @@ func.func private @example_fn_attr() attributes {dialectName.attrName = false} ``` """ function func_(; - sym_name, - function_type, - sym_visibility=nothing, - arg_attrs=nothing, - res_attrs=nothing, - no_inline=nothing, + sym_name::String, + function_type::IR.Type, + sym_visibility::Union{String,Nothing}=nothing, + arg_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + res_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + no_inline::Union{Bool,Nothing}=nothing, body::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -233,7 +236,7 @@ func.func @foo() -> (i32, f8) { } ``` """ -function return_(operands::Vector{Value}; location=Location()) +function return_(operands::Vector{Value}; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[operands...,] owned_regions = Region[] diff --git a/src/mlir/Dialects/Gpu.jl b/src/mlir/Dialects/Gpu.jl index 6a7a8615c..83e23177b 100755 --- a/src/mlir/Dialects/Gpu.jl +++ b/src/mlir/Dialects/Gpu.jl @@ -10,8 +10,118 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX + +""" +`AllReduceOperation` +built-in reduction operations supported by gpu.allreduce. +""" +@enumx AllReduceOperation ADD MUL MINUI MINSI MINNUMF MAXUI MAXSI MAXNUMF AND OR XOR MINIMUMF MAXIMUMF +const AllReduceOperationStorage = [ + "add", + "mul", + "minui", + "minsi", + "minnumf", + "maxui", + "maxsi", + "maxnumf", + "and", + "or", + "xor", + "minimumf", + "maximumf", +] + +function IR.Attribute(e::AllReduceOperation.T) + return parse(Attribute, "#gpu") +end + +""" +`Dimension` +a dimension, either \'x\', \'y\', or \'z\' +""" +@enumx Dimension x y z +const DimensionStorage = ["x", "y", "z"] + +IR.Attribute(e::Dimension.T) = parse(Attribute, "#gpu") + +""" +`Prune2To4SpMatFlag` +pruning strategy for 2:4 sparse matrix +""" +@enumx Prune2To4SpMatFlag NONE PRUNE_ONLY PRUNE_AND_CHECK +const Prune2To4SpMatFlagStorage = ["NONE", "PRUNE_ONLY", "PRUNE_AND_CHECK"] + +function IR.Attribute(e::Prune2To4SpMatFlag.T) + return parse( + Attribute, "#gpu" + ) +end + +""" +`TransposeMode` +transpose mode of sparse matrix supported by sparse tensor ops +""" +@enumx TransposeMode NON_TRANSPOSE TRANSPOSE CONJUGATE_TRANSPOSE +const TransposeModeStorage = ["NON_TRANSPOSE", "TRANSPOSE", "CONJUGATE_TRANSPOSE"] + +function IR.Attribute(e::TransposeMode.T) + return parse(Attribute, "#gpu") +end + +""" +`ShuffleMode` +Indexing modes supported by gpu.shuffle. +""" +@enumx ShuffleMode XOR UP DOWN IDX +const ShuffleModeStorage = ["xor", "up", "down", "idx"] + +function IR.Attribute(e::ShuffleMode.T) + return parse(Attribute, "#gpu") +end + +""" +`SpGEMMWorkEstimationOrComputeKind` +choose whether spgemm_work_estimation_or_compute does work estimation or compute +""" +@enumx SpGEMMWorkEstimationOrComputeKind WORK_ESTIMATION COMPUTE +const SpGEMMWorkEstimationOrComputeKindStorage = ["WORK_ESTIMATION", "COMPUTE"] + +function IR.Attribute(e::SpGEMMWorkEstimationOrComputeKind.T) + return parse( + Attribute, + "#gpu", + ) +end + +""" +`MMAElementwiseOp` +elementwise operation to apply to mma matrix +""" +@enumx MMAElementwiseOp ADDF MULF SUBF MAXF MINF DIVF ADDI MULI SUBI DIVS DIVU NEGATEF NEGATES EXTF +const MMAElementwiseOpStorage = [ + "addf", + "mulf", + "subf", + "maxf", + "minf", + "divf", + "addi", + "muli", + "subi", + "divs", + "divu", + "negatef", + "negates", + "extf", +] + +function IR.Attribute(e::MMAElementwiseOp.T) + return parse(Attribute, "#gpu") +end """ `all_reduce` @@ -43,11 +153,11 @@ need to execute this op in convergence. """ function all_reduce( value::Value; - result=nothing::Union{Nothing,IR.Type}, - op=nothing, - uniform=nothing, + result::Union{Nothing,IR.Type}=nothing, + op::Union{AllReduceOperation.T,Nothing}=nothing, + uniform::Union{Bool,Nothing}=nothing, body::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[value,] @@ -65,8 +175,8 @@ function all_reduce( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -97,9 +207,9 @@ function alloc( dynamicSizes::Vector{Value}, symbolOperands::Vector{Value}; memref::IR.Type, - asyncToken=nothing::Union{Nothing,IR.Type}, - hostShared=nothing, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + hostShared::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[memref,] operands = Value[asyncDependencies..., dynamicSizes..., symbolOperands...] @@ -146,7 +256,7 @@ in-between these accesses. Either none or all work items of a workgroup need to execute this op in convergence. """ -function barrier(; location=Location()) +function barrier(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -191,7 +301,12 @@ Examples: gpu.binary @myobject <#gpu.select_object<#rocdl.target>> [#gpu.object<...>, #gpu.object<#rocdl.target, ...>] ``` """ -function binary(; sym_name, offloadingHandler=nothing, objects, location=Location()) +function binary(; + sym_name::String, + offloadingHandler::Union{IR.AbstractAttribute,Nothing}=nothing, + objects::Vector{<:IR.AbstractAttribute}, + location::Location=Location(), +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -240,17 +355,17 @@ exceeds `upper_bound` cause undefined behavior. There is an implicit upper bound of `kMaxDim` (currently uint32_t::max). """ function block_dim(; - result_0=nothing::Union{Nothing,IR.Type}, - dimension, + result::Union{Nothing,IR.Type}=nothing, + dimension::Dimension.T, upper_bound=nothing, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("dimension", dimension),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) return create_operation( @@ -260,8 +375,8 @@ function block_dim(; owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -285,17 +400,17 @@ takes priority over bounds inferrable from context. There is an implicit upper bound of `kMaxDim` (currently uint32_t::max). """ function block_id(; - result_0=nothing::Union{Nothing,IR.Type}, - dimension, + result::Union{Nothing,IR.Type}=nothing, + dimension::Dimension.T, upper_bound=nothing, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("dimension", dimension),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) return create_operation( @@ -305,8 +420,8 @@ function block_id(; owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -328,17 +443,17 @@ is greater than `upper_bound` causes undefined behavior. There is an implicit upper bound of `kMaxClusterDim` (currently 8). """ function cluster_block_id(; - result_0=nothing::Union{Nothing,IR.Type}, - dimension, + result::Union{Nothing,IR.Type}=nothing, + dimension::Dimension.T, upper_bound=nothing, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("dimension", dimension),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) return create_operation( @@ -348,8 +463,8 @@ function cluster_block_id(; owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -372,17 +487,17 @@ causes undefined behavior. There is an implicit upper bound of `kMaxClusterDim` (currently 8). """ function cluster_dim_blocks(; - result_0=nothing::Union{Nothing,IR.Type}, - dimension, + result::Union{Nothing,IR.Type}=nothing, + dimension::Dimension.T, upper_bound=nothing, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("dimension", dimension),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) return create_operation( @@ -392,8 +507,8 @@ function cluster_dim_blocks(; owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -416,17 +531,17 @@ undefined behavior. There is an implicit upper bound of `kMaxDim` (currently uint32_t::max). """ function cluster_dim(; - result_0=nothing::Union{Nothing,IR.Type}, - dimension, + result::Union{Nothing,IR.Type}=nothing, + dimension::Dimension.T, upper_bound=nothing, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("dimension", dimension),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) return create_operation( @@ -436,8 +551,8 @@ function cluster_dim(; owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -460,17 +575,17 @@ greater than `upper_bound` causes undefined behavior. There is an implicit upper bound of `kMaxDim` (currently uint32_t::max). """ function cluster_id(; - result_0=nothing::Union{Nothing,IR.Type}, - dimension, + result::Union{Nothing,IR.Type}=nothing, + dimension::Dimension.T, upper_bound=nothing, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("dimension", dimension),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) return create_operation( @@ -480,8 +595,8 @@ function cluster_id(; owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -510,9 +625,9 @@ function create_2to4_spmat( cols::Value, memref::Value; spMat::IR.Type, - asyncToken=nothing::Union{Nothing,IR.Type}, - pruneFlag=nothing, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + pruneFlag::Union{Prune2To4SpMatFlag.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[spMat,] operands = Value[asyncDependencies..., rows, cols, memref] @@ -571,8 +686,8 @@ function create_bsr( bColIdxs::Value, values::Value; spmat::IR.Type, - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[spmat,] operands = Value[ @@ -633,8 +748,8 @@ function create_coo_aos( idxs::Value, values::Value; spmat::IR.Type, - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[spmat,] operands = Value[asyncDependencies..., rows, cols, nnz, idxs, values] @@ -684,8 +799,8 @@ function create_coo( colIdxs::Value, values::Value; spmat::IR.Type, - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[spmat,] operands = Value[asyncDependencies..., rows, cols, nnz, rowIdxs, colIdxs, values] @@ -738,8 +853,8 @@ function create_csc( rowIdxs::Value, values::Value; spmat::IR.Type, - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[spmat,] operands = Value[asyncDependencies..., rows, cols, nnz, colPos, rowIdxs, values] @@ -792,8 +907,8 @@ function create_csr( colIdxs::Value, values::Value; spmat::IR.Type, - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[spmat,] operands = Value[asyncDependencies..., rows, cols, nnz, rowPos, colIdxs, values] @@ -837,8 +952,8 @@ function create_dn_tensor( memref::Value, dims::Vector{Value}; dnTensor::IR.Type, - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[dnTensor,] operands = Value[asyncDependencies..., memref, dims...] @@ -883,8 +998,8 @@ that case, it returns a !gpu.async.token. function dealloc( asyncDependencies::Vector{Value}, memref::Value; - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[asyncDependencies..., memref] @@ -925,8 +1040,8 @@ that case, it returns a !gpu.async.token in addition to the environment. function destroy_dn_tensor( asyncDependencies::Vector{Value}, dnTensor::Value; - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[asyncDependencies..., dnTensor] @@ -967,8 +1082,8 @@ that case, it returns a !gpu.async.token in addition to the environment. function destroy_sp_mat( asyncDependencies::Vector{Value}, spmat::Value; - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[asyncDependencies..., spmat] @@ -1007,7 +1122,7 @@ Examples: to memref<32x64xf32, #gpu.address_space> ``` """ -function dynamic_shared_memory(; resultMemref::IR.Type, location=Location()) +function dynamic_shared_memory(; resultMemref::IR.Type, location::Location=Location()) op_ty_results = IR.Type[resultMemref,] operands = Value[] owned_regions = Region[] @@ -1096,15 +1211,15 @@ Note the non-default memory spaces used in memref types in memory attribution. """ function func(; - function_type, - arg_attrs=nothing, - res_attrs=nothing, - workgroup_attrib_attrs=nothing, - private_attrib_attrs=nothing, - known_block_size=nothing, - known_grid_size=nothing, + function_type::IR.Type, + arg_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + res_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + workgroup_attrib_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + private_attrib_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + known_block_size::Union{IR.DenseAttribute{Int32},Nothing}=nothing, + known_grid_size::Union{IR.DenseAttribute{Int32},Nothing}=nothing, body::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -1174,11 +1289,11 @@ gpu.module @symbol_name2 <#gpu.select_object<1>> [ ``` """ function module_(; - sym_name, - targets=nothing, - offloadingHandler=nothing, + sym_name::String, + targets::Union{IR.DenseAttribute{IR.AbstractAttribute},Nothing}=nothing, + offloadingHandler::Union{IR.AbstractAttribute,Nothing}=nothing, bodyRegion::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -1220,17 +1335,17 @@ The `upper_bound` attribute defines an upper bound analogously to the ones on a combination of `known_block_size` and `known_grid_size`-type annotations. """ function global_id(; - result_0=nothing::Union{Nothing,IR.Type}, - dimension, + result::Union{Nothing,IR.Type}=nothing, + dimension::Dimension.T, upper_bound=nothing, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("dimension", dimension),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) return create_operation( @@ -1240,8 +1355,8 @@ function global_id(; owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1272,17 +1387,17 @@ exceed `upper_bound` cause undefined behavior. There is an implicit upper bound of `kMaxDim` (currently uint32_t::max). """ function grid_dim(; - result_0=nothing::Union{Nothing,IR.Type}, - dimension, + result::Union{Nothing,IR.Type}=nothing, + dimension::Dimension.T, upper_bound=nothing, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("dimension", dimension),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) return create_operation( @@ -1292,8 +1407,8 @@ function grid_dim(; owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1309,7 +1424,7 @@ Writes from the host are guaranteed to be visible to device kernels that are launched afterwards. Writes from the device are guaranteed to be visible on the host after synchronizing with the device kernel completion. """ -function host_register(value::Value; location=Location()) +function host_register(value::Value; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[value,] owned_regions = Region[] @@ -1336,7 +1451,7 @@ This op unmaps the provided host buffer from the device address space. This operation may not be supported in every environment, there is not yet a way to check at runtime whether this feature is supported. """ -function host_unregister(value::Value; location=Location()) +function host_unregister(value::Value; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[value,] owned_regions = Region[] @@ -1371,7 +1486,9 @@ the lane id is still assumed to be non-negative and less than the target-independent `kMaxSubgroupSize` (currently 128). """ function lane_id(; - result=nothing::Union{Nothing,IR.Type}, upper_bound=nothing, location=Location() + result::Union{Nothing,IR.Type}=nothing, + upper_bound=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -1388,8 +1505,8 @@ function lane_id(; owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1498,15 +1615,15 @@ function launch_func( blockSizeX::Value, blockSizeY::Value, blockSizeZ::Value, - clusterSizeX=nothing::Union{Nothing,Value}; - clusterSizeY=nothing::Union{Nothing,Value}, - clusterSizeZ=nothing::Union{Nothing,Value}, - dynamicSharedMemorySize=nothing::Union{Nothing,Value}, + clusterSizeX::Union{Nothing,Value}=nothing; + clusterSizeY::Union{Nothing,Value}=nothing, + clusterSizeZ::Union{Nothing,Value}=nothing, + dynamicSharedMemorySize::Union{Nothing,Value}=nothing, kernelOperands::Vector{Value}, - asyncObject=nothing::Union{Nothing,Value}, - asyncToken=nothing::Union{Nothing,IR.Type}, + asyncObject::Union{Nothing,Value}=nothing, + asyncToken::Union{Nothing,IR.Type}=nothing, kernel, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[ @@ -1681,15 +1798,15 @@ function launch( blockSizeX::Value, blockSizeY::Value, blockSizeZ::Value, - clusterSizeX=nothing::Union{Nothing,Value}; - clusterSizeY=nothing::Union{Nothing,Value}, - clusterSizeZ=nothing::Union{Nothing,Value}, - dynamicSharedMemorySize=nothing::Union{Nothing,Value}, - asyncToken=nothing::Union{Nothing,IR.Type}, + clusterSizeX::Union{Nothing,Value}=nothing; + clusterSizeY::Union{Nothing,Value}=nothing, + clusterSizeZ::Union{Nothing,Value}=nothing, + dynamicSharedMemorySize::Union{Nothing,Value}=nothing, + asyncToken::Union{Nothing,IR.Type}=nothing, kernelFunc=nothing, kernelModule=nothing, body::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[ @@ -1770,8 +1887,8 @@ function memcpy( asyncDependencies::Vector{Value}, dst::Value, src::Value; - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[asyncDependencies..., dst, src] @@ -1814,8 +1931,8 @@ function memset( asyncDependencies::Vector{Value}, dst::Value, value::Value; - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[asyncDependencies..., dst, value] @@ -1852,7 +1969,9 @@ per workgroup cause undefined behavior. There is a default upper bound of `kMaxDim` (currently uint32_t::max). """ function num_subgroups(; - result=nothing::Union{Nothing,IR.Type}, upper_bound=nothing, location=Location() + result::Union{Nothing,IR.Type}=nothing, + upper_bound=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -1869,8 +1988,8 @@ function num_subgroups(; owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1883,7 +2002,7 @@ scalar arguments that should be printed. The format string is a C-style printf string, subject to any restrictions imposed by one\'s target platform. """ -function printf(args::Vector{Value}; format, location=Location()) +function printf(args::Vector{Value}; format::String, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[args...,] owned_regions = Region[] @@ -1909,7 +2028,7 @@ A terminator operation for regions that appear in the body of `gpu.func` functions. The operands to the `gpu.return` are the result values returned by an invocation of the `gpu.func`. """ -function return_(operands::Vector{Value}; location=Location()) +function return_(operands::Vector{Value}; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[operands...,] owned_regions = Region[] @@ -1956,11 +2075,11 @@ function sddmm_buffer_size( dnmatB::Value, spmatC::Value; bufferSz::IR.Type, - asyncToken=nothing::Union{Nothing,IR.Type}, - modeA=nothing, - modeB=nothing, - computeType, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + modeA::Union{TransposeMode.T,Nothing}=nothing, + modeB::Union{TransposeMode.T,Nothing}=nothing, + computeType::IR.Type, + location::Location=Location(), ) op_ty_results = IR.Type[bufferSz,] operands = Value[asyncDependencies..., dnmatA, dnmatB, spmatC] @@ -2011,11 +2130,11 @@ function sddmm( dnmatB::Value, spmatC::Value, buffer::Value; - asyncToken=nothing::Union{Nothing,IR.Type}, - modeA=nothing, - modeB=nothing, - computeType, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + modeA::Union{TransposeMode.T,Nothing}=nothing, + modeB::Union{TransposeMode.T,Nothing}=nothing, + computeType::IR.Type, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[asyncDependencies..., dnmatA, dnmatB, spmatC, buffer] @@ -2062,8 +2181,8 @@ function set_csr_pointers( positions::Value, coordinates::Value, values::Value; - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[asyncDependencies..., spmat, positions, coordinates, values] @@ -2091,7 +2210,7 @@ Operation that sets the current default GPU, using a zero-based index into the set of GPUs on the system. The default GPU setting may be thread-local. """ -function set_default_device(devIndex::Value; location=Location()) +function set_default_device(devIndex::Value; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[devIndex,] owned_regions = Region[] @@ -2165,10 +2284,10 @@ function shuffle( value::Value, offset::Value, width::Value; - shuffleResult=nothing::Union{Nothing,IR.Type}, - valid=nothing::Union{Nothing,IR.Type}, - mode, - location=Location(), + shuffleResult::Union{Nothing,IR.Type}=nothing, + valid::Union{Nothing,IR.Type}=nothing, + mode::ShuffleMode.T, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[value, offset, width] @@ -2185,8 +2304,8 @@ function shuffle( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2216,11 +2335,11 @@ function spgemm_copy( spmatA::Value, spmatB::Value, spmatC::Value; - asyncToken=nothing::Union{Nothing,IR.Type}, - modeA=nothing, - modeB=nothing, - computeType, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + modeA::Union{TransposeMode.T,Nothing}=nothing, + modeB::Union{TransposeMode.T,Nothing}=nothing, + computeType::IR.Type, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[asyncDependencies..., desc, spmatA, spmatB, spmatC] @@ -2264,8 +2383,8 @@ that case, it returns a `!gpu.async.token` in addition to the environment. function spgemm_create_descr( asyncDependencies::Vector{Value}; desc::IR.Type, - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[desc,] operands = Value[asyncDependencies...,] @@ -2304,8 +2423,8 @@ that case, it returns a `!gpu.async.token` in addition to the environment. function spgemm_destroy_descr( asyncDependencies::Vector{Value}, desc::Value; - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[asyncDependencies..., desc] @@ -2364,12 +2483,12 @@ function spgemm_work_estimation_or_compute( bufferSz::Value, buffer::Value; bufferSzNew::IR.Type, - asyncToken=nothing::Union{Nothing,IR.Type}, - modeA=nothing, - modeB=nothing, - computeType, - kind, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + modeA::Union{TransposeMode.T,Nothing}=nothing, + modeB::Union{TransposeMode.T,Nothing}=nothing, + computeType::IR.Type, + kind::SpGEMMWorkEstimationOrComputeKind.T, + location::Location=Location(), ) op_ty_results = IR.Type[bufferSzNew,] operands = Value[asyncDependencies..., desc, spmatA, spmatB, spmatC, bufferSz, buffer] @@ -2421,12 +2540,12 @@ function spmm_buffer_size( spmatA::Value, dnmatB::Value, dnmatC::Value; - bufferSzs::Vector{IR.Type}, - asyncToken=nothing::Union{Nothing,IR.Type}, - modeA=nothing, - modeB=nothing, - computeType, - location=Location(), + bufferSzs::Base.AbstractVecOrTuple{IR.Type}, + asyncToken::Union{Nothing,IR.Type}=nothing, + modeA::Union{TransposeMode.T,Nothing}=nothing, + modeB::Union{TransposeMode.T,Nothing}=nothing, + computeType::IR.Type, + location::Location=Location(), ) op_ty_results = IR.Type[bufferSzs...,] operands = Value[asyncDependencies..., spmatA, dnmatB, dnmatC] @@ -2477,11 +2596,11 @@ function spmm( dnmatB::Value, dnmatC::Value, buffers::Vector{Value}; - asyncToken=nothing::Union{Nothing,IR.Type}, - modeA=nothing, - modeB=nothing, - computeType, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + modeA::Union{TransposeMode.T,Nothing}=nothing, + modeB::Union{TransposeMode.T,Nothing}=nothing, + computeType::IR.Type, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[asyncDependencies..., spmatA, dnmatB, dnmatC, buffers...] @@ -2536,10 +2655,10 @@ function spmv_buffer_size( dnX::Value, dnY::Value; bufferSz::IR.Type, - asyncToken=nothing::Union{Nothing,IR.Type}, - modeA=nothing, - computeType, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + modeA::Union{TransposeMode.T,Nothing}=nothing, + computeType::IR.Type, + location::Location=Location(), ) op_ty_results = IR.Type[bufferSz,] operands = Value[asyncDependencies..., spmatA, dnX, dnY] @@ -2589,10 +2708,10 @@ function spmv( dnX::Value, dnY::Value, buffer::Value; - asyncToken=nothing::Union{Nothing,IR.Type}, - modeA=nothing, - computeType, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + modeA::Union{TransposeMode.T,Nothing}=nothing, + computeType::IR.Type, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[asyncDependencies..., spmatA, dnX, dnY, buffer] @@ -2636,8 +2755,8 @@ function spmat_get_size( rows::IR.Type, cols::IR.Type, nnz::IR.Type, - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[rows, cols, nnz] operands = Value[asyncDependencies..., spmat] @@ -2675,7 +2794,9 @@ cause undefined behavior. There is an implicit upper bound of `kMaxDim` (currently uint32_t::max). """ function subgroup_id(; - result=nothing::Union{Nothing,IR.Type}, upper_bound=nothing, location=Location() + result::Union{Nothing,IR.Type}=nothing, + upper_bound=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -2692,8 +2813,8 @@ function subgroup_id(; owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2732,10 +2853,10 @@ function subgroup_mma_compute( opA::Value, opB::Value, opC::Value; - res=nothing::Union{Nothing,IR.Type}, - a_transpose=nothing, - b_transpose=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + a_transpose::Union{Bool,Nothing}=nothing, + b_transpose::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[opA, opB, opC] @@ -2753,8 +2874,8 @@ function subgroup_mma_compute( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2781,7 +2902,9 @@ This op is meant to be used along with `gpu.subgroup_mma_compute`. !gpu.mma_matrix<16x16xf32, \"COp\"> ``` """ -function subgroup_mma_constant_matrix(value::Value; res::IR.Type, location=Location()) +function subgroup_mma_constant_matrix( + value::Value; res::IR.Type, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[value,] owned_regions = Region[] @@ -2821,7 +2944,10 @@ This op is meant to be used along with `gpu.subgroup_mma_compute`. ``` """ function subgroup_mma_elementwise( - args::Vector{Value}; res::IR.Type, opType, location=Location() + args::Vector{Value}; + res::IR.Type, + opType::MMAElementwiseOp.T, + location::Location=Location(), ) op_ty_results = IR.Type[res,] operands = Value[args...,] @@ -2874,8 +3000,8 @@ function subgroup_mma_load_matrix( indices::Vector{Value}; res::IR.Type, leadDimension, - transpose=nothing, - location=Location(), + transpose::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[res,] operands = Value[srcMemref, indices...] @@ -2924,8 +3050,8 @@ function subgroup_mma_store_matrix( dstMemref::Value, indices::Vector{Value}; leadDimension, - transpose=nothing, - location=Location(), + transpose::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[src, dstMemref, indices...] @@ -2982,12 +3108,12 @@ The reduction operation must be one of: """ function subgroup_reduce( value::Value; - result=nothing::Union{Nothing,IR.Type}, - op, - uniform=nothing, - cluster_size=nothing, - cluster_stride=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + op::AllReduceOperation.T, + uniform::Union{Bool,Nothing}=nothing, + cluster_size::Union{Int32,Nothing}=nothing, + cluster_stride::Union{Int32,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[value,] @@ -3008,8 +3134,8 @@ function subgroup_reduce( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3030,7 +3156,9 @@ similar machinery assume the default bound of `kMaxSubgroupSize`, currently 128. """ function subgroup_size(; - result=nothing::Union{Nothing,IR.Type}, upper_bound=nothing, location=Location() + result::Union{Nothing,IR.Type}=nothing, + upper_bound=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -3047,8 +3175,8 @@ function subgroup_size(; owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3059,7 +3187,7 @@ A terminator operation for regions that appear in the body of `gpu.launch` operation. These regions are not expected to return any value so the terminator takes no operands. """ -function terminator(; location=Location()) +function terminator(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -3097,17 +3225,17 @@ than or equal to that bound cause undefined behavior. There is an implicit upper bound of `kMaxDim` (currently uint32_t::max). """ function thread_id(; - result_0=nothing::Union{Nothing,IR.Type}, - dimension, + result::Union{Nothing,IR.Type}=nothing, + dimension::Dimension.T, upper_bound=nothing, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("dimension", dimension),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(upper_bound) && push!(attributes, namedattribute("upper_bound", upper_bound)) return create_operation( @@ -3117,8 +3245,8 @@ function thread_id(; owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3157,8 +3285,8 @@ gpu.wait [%t0, %t1] """ function wait( asyncDependencies::Vector{Value}; - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), + asyncToken::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[asyncDependencies...,] @@ -3284,10 +3412,10 @@ some_synchronization_primitive function warp_execute_on_lane_0( laneid::Value, args::Vector{Value}; - results::Vector{IR.Type}, - warp_size, + results::Base.AbstractVecOrTuple{IR.Type}, + warp_size::Int64, warpRegion::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[laneid, args...] @@ -3319,7 +3447,7 @@ in gpu ops. It returns values to the immediately enclosing gpu op. gpu.yield %f0, %f1 : f32, f32 ``` """ -function yield(values::Vector{Value}; location=Location()) +function yield(values::Vector{Value}; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[values...,] owned_regions = Region[] diff --git a/src/mlir/Dialects/Llvm.jl b/src/mlir/Dialects/Llvm.jl index 38ff9f89f..ee5a0bad9 100755 --- a/src/mlir/Dialects/Llvm.jl +++ b/src/mlir/Dialects/Llvm.jl @@ -10,15 +10,98 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX + +""" +`UnnamedAddr` +LLVM GlobalValue UnnamedAddr +""" +@enumx UnnamedAddr None = 0 Local = 1 Global = 2 + +IR.Attribute(e::UnnamedAddr.T) = Int(e) + +""" +`Visibility` +LLVM GlobalValue Visibility +""" +@enumx Visibility Default = 0 Hidden = 1 Protected = 2 + +IR.Attribute(e::Visibility.T) = Int(e) + +""" +`AtomicOrdering` +Atomic ordering for LLVM\'s memory model +""" +@enumx AtomicOrdering not_atomic = 0 unordered = 1 monotonic = 2 acquire = 4 release = 5 acq_rel = + 6 seq_cst = 7 + +IR.Attribute(e::AtomicOrdering.T) = Int(e) + +""" +`AtomicBinOp` +llvm.atomicrmw binary operations +""" +@enumx AtomicBinOp xchg = 0 add = 1 sub = 2 _and = 3 nand = 4 _or = 5 _xor = 6 max = 7 min = + 8 umax = 9 umin = 10 fadd = 11 fsub = 12 fmax = 13 fmin = 14 uinc_wrap = 15 udec_wrap = + 16 usub_cond = 17 usub_sat = 18 + +IR.Attribute(e::AtomicBinOp.T) = Int(e) + +""" +`FastmathFlags` +LLVM fastmath flags +""" +@enumx FastmathFlags none nnan ninf nsz arcp contract afn reassoc fast +const FastmathFlagsStorage = [ + "none", "nnan", "ninf", "nsz", "arcp", "contract", "afn", "reassoc", "fast" +] + +function IR.Attribute(e::FastmathFlags.T) + return parse(Attribute, "#llvm>") +end + +""" +`Comdat` +LLVM Comdat Types +""" +@enumx Comdat Any = 0 ExactMatch = 1 Largest = 2 NoDeduplicate = 3 SameSize = 4 + +IR.Attribute(e::Comdat.T) = Int(e) + +""" +`FCmpPredicate` +llvm.fcmp comparison predicate +""" +@enumx FCmpPredicate _false = 0 oeq = 1 ogt = 2 oge = 3 olt = 4 ole = 5 one = 6 ord = 7 ueq = + 8 ugt = 9 uge = 10 ult = 11 ule = 12 une = 13 uno = 14 _true = 15 + +IR.Attribute(e::FCmpPredicate.T) = Int(e) + +""" +`ICmpPredicate` +lvm.icmp comparison predicate +""" +@enumx ICmpPredicate eq = 0 ne = 1 slt = 2 sle = 3 sgt = 4 sge = 5 ult = 6 ule = 7 ugt = 8 uge = + 9 + +IR.Attribute(e::ICmpPredicate.T) = Int(e) + +""" +`AsmDialect` +ATT (0) or Intel (1) asm dialect +""" +@enumx AsmDialect AD_ATT = 0 AD_Intel = 1 + +IR.Attribute(e::AsmDialect.T) = Int(e) function ashr( lhs::Value, rhs::Value; - res=nothing::Union{Nothing,IR.Type}, - isExact=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + isExact::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -35,13 +118,16 @@ function ashr( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function add( - lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + res::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -57,12 +143,12 @@ function add( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function addrspacecast(arg::Value; res::IR.Type, location=Location()) +function addrspacecast(arg::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -119,7 +205,9 @@ llvm.mlir.alias @const_alias : i32 { } ``` """ -function mlir_addressof(; res::IR.Type, global_name, location=Location()) +function mlir_addressof(; + res::IR.Type, global_name::IR.FlatSymbolRefAttribute, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -175,15 +263,15 @@ llvm.mlir.alias linkonce_odr hidden @glob ``` """ function mlir_alias(; - alias_type, - sym_name, + alias_type::IR.Type, + sym_name::String, linkage, - dso_local=nothing, - thread_local_=nothing, - unnamed_addr=nothing, - visibility_=nothing, + dso_local::Union{Bool,Nothing}=nothing, + thread_local_::Union{Bool,Nothing}=nothing, + unnamed_addr::Union{UnnamedAddr.T,Nothing}=nothing, + visibility_::Union{Visibility.T,Nothing}=nothing, initializer::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -216,10 +304,10 @@ end function alloca( arraySize::Value; res::IR.Type, - alignment=nothing, - elem_type, - inalloca=nothing, - location=Location(), + alignment::Union{Int64,Nothing}=nothing, + elem_type::IR.Type, + inalloca::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[res,] operands = Value[arraySize,] @@ -242,7 +330,10 @@ function alloca( end function and( - lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + res::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -258,8 +349,8 @@ function and( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -267,18 +358,18 @@ function cmpxchg( ptr::Value, cmp::Value, val::Value; - res=nothing::Union{Nothing,IR.Type}, - success_ordering, - failure_ordering, - syncscope=nothing, - alignment=nothing, - weak=nothing, - volatile_=nothing, - access_groups=nothing, - alias_scopes=nothing, - noalias_scopes=nothing, - tbaa=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + success_ordering::AtomicOrdering.T, + failure_ordering::AtomicOrdering.T, + syncscope::Union{String,Nothing}=nothing, + alignment::Union{Int64,Nothing}=nothing, + weak::Union{Bool,Nothing}=nothing, + volatile_::Union{Bool,Nothing}=nothing, + access_groups::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + alias_scopes::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + noalias_scopes::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + tbaa::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[ptr, cmp, val] @@ -308,25 +399,25 @@ function cmpxchg( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function atomicrmw( ptr::Value, val::Value; - res=nothing::Union{Nothing,IR.Type}, - bin_op, - ordering, - syncscope=nothing, - alignment=nothing, - volatile_=nothing, - access_groups=nothing, - alias_scopes=nothing, - noalias_scopes=nothing, - tbaa=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + bin_op::AtomicBinOp.T, + ordering::AtomicOrdering.T, + syncscope::Union{String,Nothing}=nothing, + alignment::Union{Int64,Nothing}=nothing, + volatile_::Union{Bool,Nothing}=nothing, + access_groups::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + alias_scopes::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + noalias_scopes::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + tbaa::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[ptr, val] @@ -354,12 +445,12 @@ function atomicrmw( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function bitcast(arg::Value; res::IR.Type, location=Location()) +function bitcast(arg::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -379,7 +470,10 @@ function bitcast(arg::Value; res::IR.Type, location=Location()) end function br( - destOperands::Vector{Value}; loop_annotation=nothing, dest::Block, location=Location() + destOperands::Vector{Value}; + loop_annotation=nothing, + dest::Block, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[destOperands...,] @@ -410,12 +504,12 @@ the MLIR function type of this op to determine which intrinsic to call. function call_intrinsic( args::Vector{Value}, op_bundle_operands::Vector{Value}; - results=nothing::Union{Nothing,IR.Type}, - intrin, - fastmathFlags=nothing, - op_bundle_sizes, - op_bundle_tags=nothing, - location=Location(), + results::Union{Nothing,IR.Type}=nothing, + intrin::String, + fastmathFlags::Union{FastmathFlags.T,Nothing}=nothing, + op_bundle_sizes::IR.DenseAttribute{Int32}, + op_bundle_tags::Union{Vector{<:IR.AbstractAttribute},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[args..., op_bundle_operands...] @@ -485,26 +579,26 @@ llvm.call %1(%0) vararg(!llvm.func) : !llvm.ptr, (i32) -> () function call( callee_operands::Vector{Value}, op_bundle_operands::Vector{Value}; - result=nothing::Union{Nothing,IR.Type}, + result::Union{Nothing,IR.Type}=nothing, var_callee_type=nothing, - callee=nothing, - fastmathFlags=nothing, - branch_weights=nothing, + callee::Union{IR.FlatSymbolRefAttribute,Nothing}=nothing, + fastmathFlags::Union{FastmathFlags.T,Nothing}=nothing, + branch_weights::Union{IR.DenseAttribute{Int32},Nothing}=nothing, CConv=nothing, TailCallKind=nothing, memory_effects=nothing, - convergent=nothing, - no_unwind=nothing, - will_return=nothing, - op_bundle_sizes, - op_bundle_tags=nothing, - arg_attrs=nothing, - res_attrs=nothing, - access_groups=nothing, - alias_scopes=nothing, - noalias_scopes=nothing, - tbaa=nothing, - location=Location(), + convergent::Union{Bool,Nothing}=nothing, + no_unwind::Union{Bool,Nothing}=nothing, + will_return::Union{Bool,Nothing}=nothing, + op_bundle_sizes::IR.DenseAttribute{Int32}, + op_bundle_tags::Union{Vector{<:IR.AbstractAttribute},Nothing}=nothing, + arg_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + res_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + access_groups::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + alias_scopes::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + noalias_scopes::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + tbaa::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[callee_operands..., op_bundle_operands...] @@ -568,7 +662,7 @@ llvm.comdat @__llvm_comdat { llvm.mlir.global internal constant @has_any_comdat(1 : i64) comdat(@__llvm_comdat::@any) : i64 ``` """ -function comdat(; sym_name, body::Region, location=Location()) +function comdat(; sym_name::String, body::Region, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[body,] @@ -600,7 +694,9 @@ llvm.comdat @__llvm_comdat { llvm.mlir.global internal constant @has_any_comdat(1 : i64) comdat(@__llvm_comdat::@any) : i64 ``` """ -function comdat_selector(; sym_name, comdat, location=Location()) +function comdat_selector(; + sym_name::String, comdat::Comdat.T, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -625,11 +721,11 @@ function cond_br( condition::Value, trueDestOperands::Vector{Value}, falseDestOperands::Vector{Value}; - branch_weights=nothing, + branch_weights::Union{IR.DenseAttribute{Int32},Nothing}=nothing, loop_annotation=nothing, trueDest::Block, falseDest::Block, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[condition, trueDestOperands..., falseDestOperands...] @@ -703,7 +799,9 @@ Examples: %3 = llvm.mlir.constant(dense<1.0> : vector<4xf32>) : vector<4xf32> ``` """ -function mlir_constant(; res::IR.Type, value, location=Location()) +function mlir_constant(; + res::IR.Type, value::IR.AbstractAttribute, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -723,7 +821,10 @@ function mlir_constant(; res::IR.Type, value, location=Location()) end function extractelement( - vector::Value, position::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() + vector::Value, + position::Value; + res::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[vector, position] @@ -739,12 +840,17 @@ function extractelement( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function extractvalue(container::Value; res::IR.Type, position, location=Location()) +function extractvalue( + container::Value; + res::IR.Type, + position::IR.DenseAttribute{Int64}, + location::Location=Location(), +) op_ty_results = IR.Type[res,] operands = Value[container,] owned_regions = Region[] @@ -766,9 +872,9 @@ end function fadd( lhs::Value, rhs::Value; - res=nothing::Union{Nothing,IR.Type}, - fastmathFlags=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + fastmathFlags::Union{FastmathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -786,18 +892,18 @@ function fadd( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function fcmp( lhs::Value, rhs::Value; - res=nothing::Union{Nothing,IR.Type}, - predicate, - fastmathFlags=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + predicate::FCmpPredicate.T, + fastmathFlags::Union{FastmathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -815,17 +921,17 @@ function fcmp( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function fdiv( lhs::Value, rhs::Value; - res=nothing::Union{Nothing,IR.Type}, - fastmathFlags=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + fastmathFlags::Union{FastmathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -843,17 +949,17 @@ function fdiv( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function fmul( lhs::Value, rhs::Value; - res=nothing::Union{Nothing,IR.Type}, - fastmathFlags=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + fastmathFlags::Union{FastmathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -871,16 +977,16 @@ function fmul( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function fneg( operand::Value; - res=nothing::Union{Nothing,IR.Type}, - fastmathFlags=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + fastmathFlags::Union{FastmathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -898,12 +1004,12 @@ function fneg( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function fpext(arg::Value; res::IR.Type, location=Location()) +function fpext(arg::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -922,7 +1028,7 @@ function fpext(arg::Value; res::IR.Type, location=Location()) ) end -function fptosi(arg::Value; res::IR.Type, location=Location()) +function fptosi(arg::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -941,7 +1047,7 @@ function fptosi(arg::Value; res::IR.Type, location=Location()) ) end -function fptoui(arg::Value; res::IR.Type, location=Location()) +function fptoui(arg::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -960,7 +1066,7 @@ function fptoui(arg::Value; res::IR.Type, location=Location()) ) end -function fptrunc(arg::Value; res::IR.Type, location=Location()) +function fptrunc(arg::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -982,9 +1088,9 @@ end function frem( lhs::Value, rhs::Value; - res=nothing::Union{Nothing,IR.Type}, - fastmathFlags=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + fastmathFlags::Union{FastmathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1002,17 +1108,17 @@ function frem( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function fsub( lhs::Value, rhs::Value; - res=nothing::Union{Nothing,IR.Type}, - fastmathFlags=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + fastmathFlags::Union{FastmathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1030,12 +1136,16 @@ function fsub( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function fence(; ordering, syncscope=nothing, location=Location()) +function fence(; + ordering::AtomicOrdering.T, + syncscope::Union{String,Nothing}=nothing, + location::Location=Location(), +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1055,7 +1165,9 @@ function fence(; ordering, syncscope=nothing, location=Location()) ) end -function freeze(val::Value; res=nothing::Union{Nothing,IR.Type}, location=Location()) +function freeze( + val::Value; res::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[val,] owned_regions = Region[] @@ -1070,8 +1182,8 @@ function freeze(val::Value; res=nothing::Union{Nothing,IR.Type}, location=Locati owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1106,10 +1218,10 @@ function getelementptr( base::Value, dynamicIndices::Vector{Value}; res::IR.Type, - rawConstantIndices, - elem_type, - inbounds=nothing, - location=Location(), + rawConstantIndices::IR.DenseAttribute{Int32}, + elem_type::IR.Type, + inbounds::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[res,] operands = Value[base, dynamicIndices...] @@ -1155,7 +1267,11 @@ llvm.func @ctor() { } ``` """ -function mlir_global_ctors(; ctors, priorities, location=Location()) +function mlir_global_ctors(; + ctors::IR.DenseAttribute{IR.FlatSymbolRefAttribute}, + priorities::IR.DenseAttribute{Int32}, + location::Location=Location(), +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1195,7 +1311,11 @@ llvm.func @dtor() { llvm.mlir.global_dtors {@dtor} ``` """ -function mlir_global_dtors(; dtors, priorities, location=Location()) +function mlir_global_dtors(; + dtors::IR.DenseAttribute{IR.FlatSymbolRefAttribute}, + priorities::IR.DenseAttribute{Int32}, + location::Location=Location(), +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1314,23 +1434,23 @@ llvm.mlir.global private constant @y(dense<1.0> : tensor<8xf32>) { alignment = 3 ``` """ function mlir_global(; - global_type, - constant=nothing, - sym_name, + global_type::IR.Type, + constant::Union{Bool,Nothing}=nothing, + sym_name::String, linkage, - dso_local=nothing, - thread_local_=nothing, - externally_initialized=nothing, - value=nothing, - alignment=nothing, - addr_space=nothing, - unnamed_addr=nothing, - section=nothing, + dso_local::Union{Bool,Nothing}=nothing, + thread_local_::Union{Bool,Nothing}=nothing, + externally_initialized::Union{Bool,Nothing}=nothing, + value::Union{IR.AbstractAttribute,Nothing}=nothing, + alignment::Union{Int64,Nothing}=nothing, + addr_space::Union{Int32,Nothing}=nothing, + unnamed_addr::Union{UnnamedAddr.T,Nothing}=nothing, + section::Union{String,Nothing}=nothing, comdat=nothing, - dbg_exprs=nothing, - visibility_=nothing, + dbg_exprs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + visibility_::Union{Visibility.T,Nothing}=nothing, initializer::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -1372,9 +1492,9 @@ end function icmp( lhs::Value, rhs::Value; - res=nothing::Union{Nothing,IR.Type}, - predicate, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + predicate::ICmpPredicate.T, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1390,8 +1510,8 @@ function icmp( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1407,14 +1527,14 @@ considered undefined behavior at this time. """ function inline_asm( operands::Vector{Value}; - res=nothing::Union{Nothing,IR.Type}, - asm_string, - constraints, - has_side_effects=nothing, - is_align_stack=nothing, - asm_dialect=nothing, - operand_attrs=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + asm_string::String, + constraints::String, + has_side_effects::Union{Bool,Nothing}=nothing, + is_align_stack::Union{Bool,Nothing}=nothing, + asm_dialect::Union{AsmDialect.T,Nothing}=nothing, + operand_attrs::Union{Vector{<:IR.AbstractAttribute},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operands...,] @@ -1448,8 +1568,8 @@ function insertelement( vector::Value, value::Value, position::Value; - res=nothing::Union{Nothing,IR.Type}, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[vector, value, position] @@ -1465,17 +1585,17 @@ function insertelement( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function insertvalue( container::Value, value::Value; - res=nothing::Union{Nothing,IR.Type}, - position, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + position::IR.DenseAttribute{Int64}, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[container, value] @@ -1491,12 +1611,12 @@ function insertvalue( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function inttoptr(arg::Value; res::IR.Type, location=Location()) +function inttoptr(arg::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -1520,18 +1640,18 @@ function invoke( normalDestOperands::Vector{Value}, unwindDestOperands::Vector{Value}, op_bundle_operands::Vector{Value}; - result=nothing::Union{Nothing,IR.Type}, + result::Union{Nothing,IR.Type}=nothing, var_callee_type=nothing, - callee=nothing, - arg_attrs=nothing, - res_attrs=nothing, - branch_weights=nothing, + callee::Union{IR.FlatSymbolRefAttribute,Nothing}=nothing, + arg_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + res_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + branch_weights::Union{IR.DenseAttribute{Int32},Nothing}=nothing, CConv=nothing, - op_bundle_sizes, - op_bundle_tags=nothing, + op_bundle_sizes::IR.DenseAttribute{Int32}, + op_bundle_tags::Union{Vector{<:IR.AbstractAttribute},Nothing}=nothing, normalDest::Block, unwindDest::Block, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[ @@ -1606,57 +1726,57 @@ llvm.func internal @internal_func() { ``` """ function func(; - sym_name, - sym_visibility=nothing, + sym_name::String, + sym_visibility::Union{String,Nothing}=nothing, function_type, linkage=nothing, - dso_local=nothing, + dso_local::Union{Bool,Nothing}=nothing, CConv=nothing, comdat=nothing, - convergent=nothing, - personality=nothing, - garbageCollector=nothing, - passthrough=nothing, - arg_attrs=nothing, - res_attrs=nothing, - function_entry_count=nothing, + convergent::Union{Bool,Nothing}=nothing, + personality::Union{IR.FlatSymbolRefAttribute,Nothing}=nothing, + garbageCollector::Union{String,Nothing}=nothing, + passthrough::Union{Vector{<:IR.AbstractAttribute},Nothing}=nothing, + arg_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + res_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + function_entry_count::Union{Int64,Nothing}=nothing, memory_effects=nothing, - visibility_=nothing, - arm_streaming=nothing, - arm_locally_streaming=nothing, - arm_streaming_compatible=nothing, - arm_new_za=nothing, - arm_in_za=nothing, - arm_out_za=nothing, - arm_inout_za=nothing, - arm_preserves_za=nothing, - section=nothing, - unnamed_addr=nothing, - alignment=nothing, + visibility_::Union{Visibility.T,Nothing}=nothing, + arm_streaming::Union{Bool,Nothing}=nothing, + arm_locally_streaming::Union{Bool,Nothing}=nothing, + arm_streaming_compatible::Union{Bool,Nothing}=nothing, + arm_new_za::Union{Bool,Nothing}=nothing, + arm_in_za::Union{Bool,Nothing}=nothing, + arm_out_za::Union{Bool,Nothing}=nothing, + arm_inout_za::Union{Bool,Nothing}=nothing, + arm_preserves_za::Union{Bool,Nothing}=nothing, + section::Union{String,Nothing}=nothing, + unnamed_addr::Union{UnnamedAddr.T,Nothing}=nothing, + alignment::Union{Int64,Nothing}=nothing, vscale_range=nothing, frame_pointer=nothing, - target_cpu=nothing, - tune_cpu=nothing, + target_cpu::Union{String,Nothing}=nothing, + tune_cpu::Union{String,Nothing}=nothing, target_features=nothing, - unsafe_fp_math=nothing, - no_infs_fp_math=nothing, - no_nans_fp_math=nothing, - approx_func_fp_math=nothing, - no_signed_zeros_fp_math=nothing, - denormal_fp_math=nothing, - denormal_fp_math_f32=nothing, - fp_contract=nothing, - no_inline=nothing, - always_inline=nothing, - no_unwind=nothing, - will_return=nothing, - optimize_none=nothing, + unsafe_fp_math::Union{Bool,Nothing}=nothing, + no_infs_fp_math::Union{Bool,Nothing}=nothing, + no_nans_fp_math::Union{Bool,Nothing}=nothing, + approx_func_fp_math::Union{Bool,Nothing}=nothing, + no_signed_zeros_fp_math::Union{Bool,Nothing}=nothing, + denormal_fp_math::Union{String,Nothing}=nothing, + denormal_fp_math_f32::Union{String,Nothing}=nothing, + fp_contract::Union{String,Nothing}=nothing, + no_inline::Union{Bool,Nothing}=nothing, + always_inline::Union{Bool,Nothing}=nothing, + no_unwind::Union{Bool,Nothing}=nothing, + will_return::Union{Bool,Nothing}=nothing, + optimize_none::Union{Bool,Nothing}=nothing, vec_type_hint=nothing, - work_group_size_hint=nothing, - reqd_work_group_size=nothing, - intel_reqd_sub_group_size=nothing, + work_group_size_hint::Union{IR.DenseAttribute{Int32},Nothing}=nothing, + reqd_work_group_size::Union{IR.DenseAttribute{Int32},Nothing}=nothing, + intel_reqd_sub_group_size::Union{Int32,Nothing}=nothing, body::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -1758,9 +1878,9 @@ end function lshr( lhs::Value, rhs::Value; - res=nothing::Union{Nothing,IR.Type}, - isExact=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + isExact::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1777,13 +1897,16 @@ function lshr( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function landingpad( - operand_0::Vector{Value}; res::IR.Type, cleanup=nothing, location=Location() + operand_0::Vector{Value}; + res::IR.Type, + cleanup::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[res,] operands = Value[operand_0...,] @@ -1820,7 +1943,7 @@ llvm.linker_options [\"/DEFAULTLIB:\", \"libcmt\"] llvm.linker_options [\"-l\", \"clang_rt.builtins-aarch64\"] ``` """ -function linker_options(; options, location=Location()) +function linker_options(; options::IR.DenseAttribute{String}, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1868,18 +1991,18 @@ https://llvm.org/docs/LangRef.html#load-instruction function load( addr::Value; res::IR.Type, - alignment=nothing, - volatile_=nothing, - nontemporal=nothing, - invariant=nothing, - invariantGroup=nothing, - ordering=nothing, - syncscope=nothing, - access_groups=nothing, - alias_scopes=nothing, - noalias_scopes=nothing, - tbaa=nothing, - location=Location(), + alignment::Union{Int64,Nothing}=nothing, + volatile_::Union{Bool,Nothing}=nothing, + nontemporal::Union{Bool,Nothing}=nothing, + invariant::Union{Bool,Nothing}=nothing, + invariantGroup::Union{Bool,Nothing}=nothing, + ordering::Union{AtomicOrdering.T,Nothing}=nothing, + syncscope::Union{String,Nothing}=nothing, + access_groups::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + alias_scopes::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + noalias_scopes::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + tbaa::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[res,] operands = Value[addr,] @@ -1915,7 +2038,10 @@ function load( end function mul( - lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + res::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1931,8 +2057,8 @@ function mul( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1950,7 +2076,7 @@ Examples: %0 = llvm.mlir.none : !llvm.token ``` """ -function mlir_none(; res=nothing::Union{Nothing,IR.Type}, location=Location()) +function mlir_none(; res::Union{Nothing,IR.Type}=nothing, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1965,17 +2091,17 @@ function mlir_none(; res=nothing::Union{Nothing,IR.Type}, location=Location()) owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function or( lhs::Value, rhs::Value; - res=nothing::Union{Nothing,IR.Type}, - isDisjoint=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + isDisjoint::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1992,8 +2118,8 @@ function or( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2013,7 +2139,7 @@ IR dialect type. %0 = llvm.mlir.poison : !llvm.struct<(i32, f32)> ``` """ -function mlir_poison(; res::IR.Type, location=Location()) +function mlir_poison(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2032,7 +2158,7 @@ function mlir_poison(; res::IR.Type, location=Location()) ) end -function ptrtoint(arg::Value; res::IR.Type, location=Location()) +function ptrtoint(arg::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -2051,7 +2177,7 @@ function ptrtoint(arg::Value; res::IR.Type, location=Location()) ) end -function resume(value::Value; location=Location()) +function resume(value::Value; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[value,] owned_regions = Region[] @@ -2070,7 +2196,7 @@ function resume(value::Value; location=Location()) ) end -function return_(arg=nothing::Union{Nothing,Value}; location=Location()) +function return_(arg::Union{Nothing,Value}=nothing; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -2093,9 +2219,9 @@ end function sdiv( lhs::Value, rhs::Value; - res=nothing::Union{Nothing,IR.Type}, - isExact=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + isExact::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -2112,12 +2238,12 @@ function sdiv( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function sext(arg::Value; res::IR.Type, location=Location()) +function sext(arg::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -2136,7 +2262,7 @@ function sext(arg::Value; res::IR.Type, location=Location()) ) end -function sitofp(arg::Value; res::IR.Type, location=Location()) +function sitofp(arg::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -2156,7 +2282,10 @@ function sitofp(arg::Value; res::IR.Type, location=Location()) end function srem( - lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + res::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -2172,8 +2301,8 @@ function srem( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2181,9 +2310,9 @@ function select( condition::Value, trueValue::Value, falseValue::Value; - res=nothing::Union{Nothing,IR.Type}, - fastmathFlags=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + fastmathFlags::Union{FastmathFlags.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[condition, trueValue, falseValue] @@ -2201,13 +2330,16 @@ function select( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function shl( - lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + res::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -2223,12 +2355,18 @@ function shl( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function shufflevector(v1::Value, v2::Value; res::IR.Type, mask, location=Location()) +function shufflevector( + v1::Value, + v2::Value; + res::IR.Type, + mask::IR.DenseAttribute{Int32}, + location::Location=Location(), +) op_ty_results = IR.Type[res,] operands = Value[v1, v2] owned_regions = Region[] @@ -2276,17 +2414,17 @@ https://llvm.org/docs/LangRef.html#store-instruction function store( value::Value, addr::Value; - alignment=nothing, - volatile_=nothing, - nontemporal=nothing, - invariantGroup=nothing, - ordering=nothing, - syncscope=nothing, - access_groups=nothing, - alias_scopes=nothing, - noalias_scopes=nothing, - tbaa=nothing, - location=Location(), + alignment::Union{Int64,Nothing}=nothing, + volatile_::Union{Bool,Nothing}=nothing, + nontemporal::Union{Bool,Nothing}=nothing, + invariantGroup::Union{Bool,Nothing}=nothing, + ordering::Union{AtomicOrdering.T,Nothing}=nothing, + syncscope::Union{String,Nothing}=nothing, + access_groups::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + alias_scopes::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + noalias_scopes::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + tbaa::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[value, addr] @@ -2321,7 +2459,10 @@ function store( end function sub( - lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + res::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -2337,8 +2478,8 @@ function sub( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2346,12 +2487,12 @@ function switch( value::Value, defaultOperands::Vector{Value}, caseOperands::Vector{Value}; - case_values=nothing, - case_operand_segments, - branch_weights=nothing, + case_values::Union{IR.AbstractDenseElementsAttribute{Int64},Nothing}=nothing, + case_operand_segments::IR.DenseAttribute{Int32}, + branch_weights::Union{IR.DenseAttribute{Int32},Nothing}=nothing, defaultDestination::Block, caseDestinations::Vector{Block}, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[value, defaultOperands..., caseOperands...] @@ -2379,7 +2520,7 @@ function switch( ) end -function trunc(arg::Value; res::IR.Type, location=Location()) +function trunc(arg::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -2401,9 +2542,9 @@ end function udiv( lhs::Value, rhs::Value; - res=nothing::Union{Nothing,IR.Type}, - isExact=nothing, - location=Location(), + res::Union{Nothing,IR.Type}=nothing, + isExact::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -2420,12 +2561,17 @@ function udiv( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function uitofp(arg::Value; res::IR.Type, nonNeg=nothing, location=Location()) +function uitofp( + arg::Value; + res::IR.Type, + nonNeg::Union{Bool,Nothing}=nothing, + location::Location=Location(), +) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -2446,7 +2592,10 @@ function uitofp(arg::Value; res::IR.Type, nonNeg=nothing, location=Location()) end function urem( - lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + res::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -2462,8 +2611,8 @@ function urem( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2482,7 +2631,7 @@ IR dialect type. %0 = llvm.mlir.undef : !llvm.struct<(i32, f32)> ``` """ -function mlir_undef(; res::IR.Type, location=Location()) +function mlir_undef(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2501,7 +2650,7 @@ function mlir_undef(; res::IR.Type, location=Location()) ) end -function unreachable(; location=Location()) +function unreachable(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -2520,7 +2669,7 @@ function unreachable(; location=Location()) ) end -function va_arg(arg::Value; res::IR.Type, location=Location()) +function va_arg(arg::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -2540,7 +2689,10 @@ function va_arg(arg::Value; res::IR.Type, location=Location()) end function xor( - lhs::Value, rhs::Value; res=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + res::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -2556,12 +2708,17 @@ function xor( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function zext(arg::Value; res::IR.Type, nonNeg=nothing, location=Location()) +function zext( + arg::Value; + res::IR.Type, + nonNeg::Union{Bool,Nothing}=nothing, + location::Location=Location(), +) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -2597,7 +2754,7 @@ value of the specified LLVM IR dialect type. %0 = llvm.mlir.zero : !llvm.struct<(i32, f32)> ``` """ -function mlir_zero(; res::IR.Type, location=Location()) +function mlir_zero(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] diff --git a/src/mlir/Dialects/MPI.jl b/src/mlir/Dialects/MPI.jl old mode 100644 new mode 100755 index 4bb1eb16c..2fddc33b2 --- a/src/mlir/Dialects/MPI.jl +++ b/src/mlir/Dialects/MPI.jl @@ -10,8 +10,110 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX + +""" +`MPI_OpClassEnum` +MPI operation class +""" +@enumx MPI_OpClassEnum MPI_OP_NULL MPI_MAX MPI_MIN MPI_SUM MPI_PROD MPI_LAND MPI_BAND MPI_LOR MPI_BOR MPI_LXOR MPI_BXOR MPI_MINLOC MPI_MAXLOC MPI_REPLACE +const MPI_OpClassEnumStorage = [ + "MPI_OP_NULL", + "MPI_MAX", + "MPI_MIN", + "MPI_SUM", + "MPI_PROD", + "MPI_LAND", + "MPI_BAND", + "MPI_LOR", + "MPI_BOR", + "MPI_LXOR", + "MPI_BXOR", + "MPI_MINLOC", + "MPI_MAXLOC", + "MPI_REPLACE", +] + +function IR.Attribute(e::MPI_OpClassEnum.T) + return parse(Attribute, "#mpi>") +end + +""" +`MPI_ErrorClassEnum` +MPI error class name +""" +@enumx MPI_ErrorClassEnum MPI_SUCCESS MPI_ERR_ACCESS MPI_ERR_AMODE MPI_ERR_ARG MPI_ERR_ASSERT MPI_ERR_BAD_FILE MPI_ERR_BASE MPI_ERR_BUFFER MPI_ERR_COMM MPI_ERR_CONVERSION MPI_ERR_COUNT MPI_ERR_DIMS MPI_ERR_DISP MPI_ERR_DUP_DATAREP MPI_ERR_ERRHANDLER MPI_ERR_FILE MPI_ERR_FILE_EXISTS MPI_ERR_FILE_IN_USE MPI_ERR_GROUP MPI_ERR_INFO MPI_ERR_INFO_KEY MPI_ERR_INFO_NOKEY MPI_ERR_INFO_VALUE MPI_ERR_IN_STATUS MPI_ERR_INTERN MPI_ERR_IO MPI_ERR_KEYVAL MPI_ERR_LOCKTYPE MPI_ERR_NAME MPI_ERR_NO_MEM MPI_ERR_NO_SPACE MPI_ERR_NO_SUCH_FILE MPI_ERR_NOT_SAME MPI_ERR_OP MPI_ERR_OTHER MPI_ERR_PENDING MPI_ERR_PORT MPI_ERR_PROC_ABORTED MPI_ERR_QUOTA MPI_ERR_RANK MPI_ERR_READ_ONLY MPI_ERR_REQUEST MPI_ERR_RMA_ATTACH MPI_ERR_RMA_CONFLICT MPI_ERR_RMA_FLAVOR MPI_ERR_RMA_RANGE MPI_ERR_RMA_SHARED MPI_ERR_RMA_SYNC MPI_ERR_ROOT MPI_ERR_SERVICE MPI_ERR_SESSION MPI_ERR_SIZE MPI_ERR_SPAWN MPI_ERR_TAG MPI_ERR_TOPOLOGY MPI_ERR_TRUNCATE MPI_ERR_TYPE MPI_ERR_UNKNOWN MPI_ERR_UNSUPPORTED_DATAREP MPI_ERR_UNSUPPORTED_OPERATION MPI_ERR_VALUE_TOO_LARGE MPI_ERR_WIN MPI_ERR_LASTCODE +const MPI_ErrorClassEnumStorage = [ + "MPI_SUCCESS", + "MPI_ERR_ACCESS", + "MPI_ERR_AMODE", + "MPI_ERR_ARG", + "MPI_ERR_ASSERT", + "MPI_ERR_BAD_FILE", + "MPI_ERR_BASE", + "MPI_ERR_BUFFER", + "MPI_ERR_COMM", + "MPI_ERR_CONVERSION", + "MPI_ERR_COUNT", + "MPI_ERR_DIMS", + "MPI_ERR_DISP", + "MPI_ERR_DUP_DATAREP", + "MPI_ERR_ERRHANDLER", + "MPI_ERR_FILE", + "MPI_ERR_FILE_EXISTS", + "MPI_ERR_FILE_IN_USE", + "MPI_ERR_GROUP", + "MPI_ERR_INFO", + "MPI_ERR_INFO_KEY", + "MPI_ERR_INFO_NOKEY", + "MPI_ERR_INFO_VALUE", + "MPI_ERR_IN_STATUS", + "MPI_ERR_INTERN", + "MPI_ERR_IO", + "MPI_ERR_KEYVAL", + "MPI_ERR_LOCKTYPE", + "MPI_ERR_NAME", + "MPI_ERR_NO_MEM", + "MPI_ERR_NO_SPACE", + "MPI_ERR_NO_SUCH_FILE", + "MPI_ERR_NOT_SAME", + "MPI_ERR_OP", + "MPI_ERR_OTHER", + "MPI_ERR_PENDING", + "MPI_ERR_PORT", + "MPI_ERR_PROC_ABORTED", + "MPI_ERR_QUOTA", + "MPI_ERR_RANK", + "MPI_ERR_READ_ONLY", + "MPI_ERR_REQUEST", + "MPI_ERR_RMA_ATTACH", + "MPI_ERR_RMA_CONFLICT", + "MPI_ERR_RMA_FLAVOR", + "MPI_ERR_RMA_RANGE", + "MPI_ERR_RMA_SHARED", + "MPI_ERR_RMA_SYNC", + "MPI_ERR_ROOT", + "MPI_ERR_SERVICE", + "MPI_ERR_SESSION", + "MPI_ERR_SIZE", + "MPI_ERR_SPAWN", + "MPI_ERR_TAG", + "MPI_ERR_TOPOLOGY", + "MPI_ERR_TRUNCATE", + "MPI_ERR_TYPE", + "MPI_ERR_UNKNOWN", + "MPI_ERR_UNSUPPORTED_DATAREP", + "MPI_ERR_UNSUPPORTED_OPERATION", + "MPI_ERR_VALUE_TOO_LARGE", + "MPI_ERR_WIN", + "MPI_ERR_LASTCODE", +] + +function IR.Attribute(e::MPI_ErrorClassEnum.T) + return parse(Attribute, "#mpi>") +end """ `allreduce` @@ -32,9 +134,9 @@ to check for errors. function allreduce( sendbuf::Value, recvbuf::Value; - retval=nothing::Union{Nothing,IR.Type}, - op, - location=Location(), + retval::Union{Nothing,IR.Type}=nothing, + op::MPI_OpClassEnum.T, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[sendbuf, recvbuf] @@ -66,7 +168,7 @@ Communicators other than `MPI_COMM_WORLD` are not supported for now. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. """ -function barrier(; retval=nothing::Union{Nothing,IR.Type}, location=Location()) +function barrier(; retval::Union{Nothing,IR.Type}=nothing, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -95,7 +197,7 @@ This operation can optionally return an `!mpi.retval` value that can be used to check for errors. """ function comm_rank(; - retval=nothing::Union{Nothing,IR.Type}, rank::IR.Type, location=Location() + retval::Union{Nothing,IR.Type}=nothing, rank::IR.Type, location::Location=Location() ) op_ty_results = IR.Type[rank,] operands = Value[] @@ -125,7 +227,7 @@ This operation can optionally return an `!mpi.retval` value that can be used to check for errors. """ function comm_size(; - retval=nothing::Union{Nothing,IR.Type}, size::IR.Type, location=Location() + retval::Union{Nothing,IR.Type}=nothing, size::IR.Type, location::Location=Location() ) op_ty_results = IR.Type[size,] operands = Value[] @@ -152,7 +254,7 @@ end `MPI_Error_class` maps return values from MPI calls to a set of well-known MPI error classes. """ -function error_class(val::Value; errclass::IR.Type, location=Location()) +function error_class(val::Value; errclass::IR.Type, location::Location=Location()) op_ty_results = IR.Type[errclass,] operands = Value[val,] owned_regions = Region[] @@ -181,7 +283,7 @@ Notably, MPI_Init cannot be called again in the same program. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. """ -function finalize(; retval=nothing::Union{Nothing,IR.Type}, location=Location()) +function finalize(; retval::Union{Nothing,IR.Type}=nothing, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -218,9 +320,9 @@ function irecv( ref::Value, tag::Value, rank::Value; - retval=nothing::Union{Nothing,IR.Type}, + retval::Union{Nothing,IR.Type}=nothing, req::IR.Type, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[req,] operands = Value[ref, tag, rank] @@ -258,9 +360,9 @@ function isend( ref::Value, tag::Value, rank::Value; - retval=nothing::Union{Nothing,IR.Type}, + retval::Union{Nothing,IR.Type}=nothing, req::IR.Type, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[req,] operands = Value[ref, tag, rank] @@ -292,7 +394,7 @@ Passing &argc, &argv is not supported currently. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. """ -function init(; retval=nothing::Union{Nothing,IR.Type}, location=Location()) +function init(; retval::Union{Nothing,IR.Type}=nothing, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -331,8 +433,8 @@ function recv( ref::Value, tag::Value, rank::Value; - retval=nothing::Union{Nothing,IR.Type}, - location=Location(), + retval::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[ref, tag, rank] @@ -359,7 +461,9 @@ end This operation compares MPI status codes to known error class constants such as `MPI_SUCCESS`, or `MPI_ERR_COMM`. """ -function retval_check(val::Value; res::IR.Type, errclass, location=Location()) +function retval_check( + val::Value; res::IR.Type, errclass::MPI_ErrorClassEnum.T, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[val,] owned_regions = Region[] @@ -394,8 +498,8 @@ function send( ref::Value, tag::Value, rank::Value; - retval=nothing::Union{Nothing,IR.Type}, - location=Location(), + retval::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[ref, tag, rank] @@ -427,7 +531,9 @@ is not yet ported to MLIR. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. """ -function wait(req::Value; retval=nothing::Union{Nothing,IR.Type}, location=Location()) +function wait( + req::Value; retval::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[req,] owned_regions = Region[] diff --git a/src/mlir/Dialects/Nvvm.jl b/src/mlir/Dialects/Nvvm.jl index 374ed1d02..a4fee2a1c 100755 --- a/src/mlir/Dialects/Nvvm.jl +++ b/src/mlir/Dialects/Nvvm.jl @@ -10,10 +10,239 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX -function barrier0(; location=Location()) +""" +`TMAReduxKind` +NVVM TMA redux kind +""" +@enumx TMAReduxKind ADD MAX MIN INC DEC AND OR XOR +const TMAReduxKindStorage = ["add", "max", "min", "inc", "dec", "and", "or", "xor"] + +function IR.Attribute(e::TMAReduxKind.T) + return parse(Attribute, "#nvvm>") +end + +""" +`TMAStoreMode` +NVVM TMA Store Mode +""" +@enumx TMAStoreMode TILE IM2COL +const TMAStoreModeStorage = ["tile", "im2col"] + +function IR.Attribute(e::TMAStoreMode.T) + return parse(Attribute, "#nvvm>") +end + +""" +`LoadCacheModifierKind` +NVVM load cache modifier kind +""" +@enumx LoadCacheModifierKind CA CG CS LU CV +const LoadCacheModifierKindStorage = ["ca", "cg", "cs", "lu", "cv"] + +function IR.Attribute(e::LoadCacheModifierKind.T) + return parse( + Attribute, "#nvvm" + ) +end + +""" +`FPRoundingMode` +NVVM FPRoundingMode kind +""" +@enumx FPRoundingMode NONE RN RM RP RZ RNA +const FPRoundingModeStorage = ["none", "rn", "rm", "rp", "rz", "rna"] + +function IR.Attribute(e::FPRoundingMode.T) + return parse(Attribute, "#nvvm>") +end + +""" +`SaturationMode` +NVVM SaturationMode kind +""" +@enumx SaturationMode NONE SATFINITE +const SaturationModeStorage = ["none", "satfinite"] + +function IR.Attribute(e::SaturationMode.T) + return parse(Attribute, "#nvvm>") +end + +""" +`MemScopeKind` +NVVM Memory Scope kind +""" +@enumx MemScopeKind CTA CLUSTER GPU SYS +const MemScopeKindStorage = ["cta", "cluster", "gpu", "sys"] + +function IR.Attribute(e::MemScopeKind.T) + return parse(Attribute, "#nvvm>") +end + +""" +`ProxyKind` +Proxy kind +""" +@enumx ProxyKind alias async async_global async_shared TENSORMAP GENERIC +const ProxyKindStorage = [ + "alias", "async", "async.global", "async.shared", "tensormap", "generic" +] + +function IR.Attribute(e::ProxyKind.T) + return parse(Attribute, "#nvvm>") +end + +""" +`SharedSpace` +Shared memory space +""" +@enumx SharedSpace shared_cta shared_cluster +const SharedSpaceStorage = ["cta", "cluster"] + +function IR.Attribute(e::SharedSpace.T) + return parse(Attribute, "#nvvm>") +end + +""" +`MMALayout` +NVVM MMA layout +""" +@enumx MMALayout row col +const MMALayoutStorage = ["row", "col"] + +function IR.Attribute(e::MMALayout.T) + return parse(Attribute, "#nvvm>") +end + +""" +`MMAB1Op` +MMA binary operations +""" +@enumx MMAB1Op none xor_popc and_popc +const MMAB1OpStorage = ["none", "xor_popc", "and_popc"] + +function IR.Attribute(e::MMAB1Op.T) + return parse(Attribute, "#nvvm>") +end + +""" +`MMAIntOverflow` +MMA overflow options +""" +@enumx MMAIntOverflow satfinite wrapped +const MMAIntOverflowStorage = ["satfinite", "wrapped"] + +function IR.Attribute(e::MMAIntOverflow.T) + return parse(Attribute, "#nvvm>") +end + +""" +`MMATypes` +NVVM MMA types +""" +@enumx MMATypes f16 f32 tf32 bf16 s8 u8 s32 s4 u4 b1 f64 +const MMATypesStorage = [ + "f16", "f32", "tf32", "bf16", "s8", "u8", "s32", "s4", "u4", "b1", "f64" +] + +function IR.Attribute(e::MMATypes.T) + return parse(Attribute, "#nvvm>") +end + +""" +`ReduxKind` +NVVM redux kind +""" +@enumx ReduxKind ADD AND MAX MIN OR UMAX UMIN XOR +const ReduxKindStorage = ["add", "and", "max", "min", "or", "umax", "umin", "xor"] + +function IR.Attribute(e::ReduxKind.T) + return parse(Attribute, "#nvvm") +end + +""" +`SetMaxRegisterAction` +NVVM set max register action +""" +@enumx SetMaxRegisterAction decrease increase +const SetMaxRegisterActionStorage = ["decrease", "increase"] + +function IR.Attribute(e::SetMaxRegisterAction.T) + return parse(Attribute, "#nvvm") +end + +""" +`ShflKind` +NVVM shuffle kind +""" +@enumx ShflKind bfly up down idx +const ShflKindStorage = ["bfly", "up", "down", "idx"] + +function IR.Attribute(e::ShflKind.T) + return parse(Attribute, "#nvvm") +end + +""" +`Tcgen05GroupKind` +NVVM Tcgen05 group kind +""" +@enumx Tcgen05GroupKind CTA_1 CTA_2 +const Tcgen05GroupKindStorage = ["cta_1", "cta_2"] + +function IR.Attribute(e::Tcgen05GroupKind.T) + return parse(Attribute, "#nvvm>") +end + +""" +`MMAFrag` +NVVM MMA frag type +""" +@enumx MMAFrag a b c +const MMAFragStorage = ["a", "b", "c"] + +function IR.Attribute(e::MMAFrag.T) + return parse(Attribute, "#nvvm>") +end + +""" +`WGMMATypes` +NVVM WGMMA types +""" +@enumx WGMMATypes f16 tf32 u8 s8 b1 bf16 e4m3 e5m2 f32 s32 +const WGMMATypesStorage = [ + "f16", "tf32", "u8", "s8", "b1", "bf16", "e4m3", "e5m2", "f32", "s32" +] + +function IR.Attribute(e::WGMMATypes.T) + return parse(Attribute, "#nvvm>") +end + +""" +`WGMMAScaleOut` +WGMMA input predicate +""" +@enumx WGMMAScaleOut zero one +const WGMMAScaleOutStorage = ["zero", "one"] + +function IR.Attribute(e::WGMMAScaleOut.T) + return parse(Attribute, "#nvvm>") +end + +""" +`WGMMAScaleIn` +WGMMA overflow options +""" +@enumx WGMMAScaleIn one neg +const WGMMAScaleInStorage = ["one", "neg"] + +function IR.Attribute(e::WGMMAScaleIn.T) + return parse(Attribute, "#nvvm>") +end + +function barrier0(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -44,7 +273,9 @@ The default barrier id is 0 that is similar to `nvvm.barrier` Op. When [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar) """ function barrier_arrive( - barrierId=nothing::Union{Nothing,Value}; numberOfThreads::Value, location=Location() + barrierId::Union{Nothing,Value}=nothing; + numberOfThreads::Value, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[numberOfThreads,] @@ -66,9 +297,9 @@ function barrier_arrive( end function barrier( - barrierId=nothing::Union{Nothing,Value}; - numberOfThreads=nothing::Union{Nothing,Value}, - location=Location(), + barrierId::Union{Nothing,Value}=nothing; + numberOfThreads::Union{Nothing,Value}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -99,7 +330,7 @@ function barrier( ) end -function read_ptx_sreg_ntid_x(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_ntid_x(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -119,7 +350,7 @@ function read_ptx_sreg_ntid_x(; res::IR.Type, range=nothing, location=Location() ) end -function read_ptx_sreg_ntid_y(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_ntid_y(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -139,7 +370,7 @@ function read_ptx_sreg_ntid_y(; res::IR.Type, range=nothing, location=Location() ) end -function read_ptx_sreg_ntid_z(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_ntid_z(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -159,7 +390,7 @@ function read_ptx_sreg_ntid_z(; res::IR.Type, range=nothing, location=Location() ) end -function read_ptx_sreg_ctaid_x(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_ctaid_x(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -179,7 +410,7 @@ function read_ptx_sreg_ctaid_x(; res::IR.Type, range=nothing, location=Location( ) end -function read_ptx_sreg_ctaid_y(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_ctaid_y(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -199,7 +430,7 @@ function read_ptx_sreg_ctaid_y(; res::IR.Type, range=nothing, location=Location( ) end -function read_ptx_sreg_ctaid_z(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_ctaid_z(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -219,7 +450,9 @@ function read_ptx_sreg_ctaid_z(; res::IR.Type, range=nothing, location=Location( ) end -function read_ptx_sreg_cluster_ctaid_x(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_cluster_ctaid_x(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -239,7 +472,9 @@ function read_ptx_sreg_cluster_ctaid_x(; res::IR.Type, range=nothing, location=L ) end -function read_ptx_sreg_cluster_ctaid_y(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_cluster_ctaid_y(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -259,7 +494,9 @@ function read_ptx_sreg_cluster_ctaid_y(; res::IR.Type, range=nothing, location=L ) end -function read_ptx_sreg_cluster_ctaid_z(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_cluster_ctaid_z(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -285,7 +522,7 @@ end Breakpoint suspends execution of the program for debugging. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-brkpt) """ -function breakpoint(; location=Location()) +function breakpoint(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -304,7 +541,7 @@ function breakpoint(; location=Location()) ) end -function read_ptx_sreg_clock64(; res::IR.Type, location=Location()) +function read_ptx_sreg_clock64(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -323,7 +560,7 @@ function read_ptx_sreg_clock64(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_clock(; res::IR.Type, location=Location()) +function read_ptx_sreg_clock(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -353,7 +590,9 @@ The `aligned` attribute, when provided, generates the .aligned version of the PT [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-barrier-cluster) """ -function cluster_arrive(; aligned=nothing, location=Location()) +function cluster_arrive(; + aligned::Union{Bool,Nothing}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -387,7 +626,9 @@ ordering and visibility guarantees provided for the memory accesses performed pr [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-barrier-cluster) """ -function cluster_arrive_relaxed(; aligned=nothing, location=Location()) +function cluster_arrive_relaxed(; + aligned::Union{Bool,Nothing}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -407,7 +648,9 @@ function cluster_arrive_relaxed(; aligned=nothing, location=Location()) ) end -function read_ptx_sreg_cluster_nctarank(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_cluster_nctarank(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -427,7 +670,9 @@ function read_ptx_sreg_cluster_nctarank(; res::IR.Type, range=nothing, location= ) end -function read_ptx_sreg_cluster_nctaid_x(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_cluster_nctaid_x(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -447,7 +692,9 @@ function read_ptx_sreg_cluster_nctaid_x(; res::IR.Type, range=nothing, location= ) end -function read_ptx_sreg_cluster_nctaid_y(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_cluster_nctaid_y(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -467,7 +714,9 @@ function read_ptx_sreg_cluster_nctaid_y(; res::IR.Type, range=nothing, location= ) end -function read_ptx_sreg_cluster_nctaid_z(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_cluster_nctaid_z(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -487,7 +736,9 @@ function read_ptx_sreg_cluster_nctaid_z(; res::IR.Type, range=nothing, location= ) end -function read_ptx_sreg_nclusterid_x(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_nclusterid_x(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -507,7 +758,9 @@ function read_ptx_sreg_nclusterid_x(; res::IR.Type, range=nothing, location=Loca ) end -function read_ptx_sreg_nclusterid_y(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_nclusterid_y(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -527,7 +780,9 @@ function read_ptx_sreg_nclusterid_y(; res::IR.Type, range=nothing, location=Loca ) end -function read_ptx_sreg_nclusterid_z(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_nclusterid_z(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -547,7 +802,9 @@ function read_ptx_sreg_nclusterid_z(; res::IR.Type, range=nothing, location=Loca ) end -function read_ptx_sreg_cluster_ctarank(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_cluster_ctarank(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -567,7 +824,9 @@ function read_ptx_sreg_cluster_ctarank(; res::IR.Type, range=nothing, location=L ) end -function read_ptx_sreg_clusterid_x(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_clusterid_x(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -587,7 +846,9 @@ function read_ptx_sreg_clusterid_x(; res::IR.Type, range=nothing, location=Locat ) end -function read_ptx_sreg_clusterid_y(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_clusterid_y(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -607,7 +868,9 @@ function read_ptx_sreg_clusterid_y(; res::IR.Type, range=nothing, location=Locat ) end -function read_ptx_sreg_clusterid_z(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_clusterid_z(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -636,7 +899,7 @@ generates the .aligned version of the PTX instruction. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-barrier-cluster) """ -function cluster_wait(; aligned=nothing, location=Location()) +function cluster_wait(; aligned::Union{Bool,Nothing}=nothing, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -664,7 +927,7 @@ instructions into a cp.async.bulk-group. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group) """ -function cp_async_bulk_commit_group(; location=Location()) +function cp_async_bulk_commit_group(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -705,9 +968,9 @@ function cp_async_bulk_shared_cluster_global( srcMem::Value, mbar::Value, size::Value, - multicastMask=nothing::Union{Nothing,Value}; - l2CacheHint=nothing::Union{Nothing,Value}, - location=Location(), + multicastMask::Union{Nothing,Value}=nothing; + l2CacheHint::Union{Nothing,Value}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[dstMem, srcMem, mbar, size] @@ -757,8 +1020,8 @@ function cp_async_bulk_global_shared_cta( dstMem::Value, srcMem::Value, size::Value, - l2CacheHint=nothing::Union{Nothing,Value}; - location=Location(), + l2CacheHint::Union{Nothing,Value}=nothing; + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[dstMem, srcMem, size] @@ -788,7 +1051,7 @@ cluster memory. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk) """ function cp_async_bulk_shared_cluster_shared_cta( - dstMem::Value, srcMem::Value, mbar::Value, size::Value; location=Location() + dstMem::Value, srcMem::Value, mbar::Value, size::Value; location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[dstMem, srcMem, mbar, size] @@ -840,10 +1103,10 @@ function cp_async_bulk_tensor_shared_cluster_global( coordinates::Vector{Value}, mbar::Value, im2colOffsets::Vector{Value}, - multicastMask=nothing::Union{Nothing,Value}; - l2CacheHint=nothing::Union{Nothing,Value}, - predicate=nothing::Union{Nothing,Value}, - location=Location(), + multicastMask::Union{Nothing,Value}=nothing; + l2CacheHint::Union{Nothing,Value}=nothing, + predicate::Union{Nothing,Value}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[dstMem, tmaDescriptor, coordinates..., mbar, im2colOffsets...] @@ -909,8 +1172,8 @@ function cp_async_bulk_tensor_prefetch( tmaDescriptor::Value, coordinates::Vector{Value}, im2colOffsets::Vector{Value}, - l2CacheHint=nothing::Union{Nothing,Value}; - location=Location(), + l2CacheHint::Union{Nothing,Value}=nothing; + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[tmaDescriptor, coordinates..., im2colOffsets...] @@ -957,10 +1220,10 @@ function cp_async_bulk_tensor_reduce( tmaDescriptor::Value, srcMem::Value, coordinates::Vector{Value}, - l2CacheHint=nothing::Union{Nothing,Value}; - redKind, - mode=nothing, - location=Location(), + l2CacheHint::Union{Nothing,Value}=nothing; + redKind::TMAReduxKind.T, + mode::Union{TMAStoreMode.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[tmaDescriptor, srcMem, coordinates...] @@ -990,8 +1253,8 @@ function cp_async_bulk_tensor_global_shared_cta( tmaDescriptor::Value, srcMem::Value, coordinates::Vector{Value}, - predicate=nothing::Union{Nothing,Value}; - location=Location(), + predicate::Union{Nothing,Value}=nothing; + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[tmaDescriptor, srcMem, coordinates...] @@ -1031,7 +1294,9 @@ from their source locations. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group) """ -function cp_async_bulk_wait_group(; group, read=nothing, location=Location()) +function cp_async_bulk_wait_group(; + group::Int32, read::Union{Bool,Nothing}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1051,7 +1316,7 @@ function cp_async_bulk_wait_group(; group, read=nothing, location=Location()) ) end -function cp_async_commit_group(; location=Location()) +function cp_async_commit_group(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1081,7 +1346,9 @@ mbarrier\'s state is updated. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive) """ -function cp_async_mbarrier_arrive(addr::Value; noinc=nothing, location=Location()) +function cp_async_mbarrier_arrive( + addr::Value; noinc::Union{Bool,Nothing}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[addr,] owned_regions = Region[] @@ -1112,7 +1379,9 @@ is updated. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive) """ -function cp_async_mbarrier_arrive_shared(addr::Value; noinc=nothing, location=Location()) +function cp_async_mbarrier_arrive_shared( + addr::Value; noinc::Union{Bool,Nothing}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[addr,] owned_regions = Region[] @@ -1135,10 +1404,10 @@ end function cp_async_shared_global( dst::Value, src::Value, - cpSize=nothing::Union{Nothing,Value}; - size, - modifier, - location=Location(), + cpSize::Union{Nothing,Value}=nothing; + size::Int32, + modifier::LoadCacheModifierKind.T, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[dst, src] @@ -1161,7 +1430,7 @@ function cp_async_shared_global( ) end -function cp_async_wait_group(; n, location=Location()) +function cp_async_wait_group(; n::Int32, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1192,7 +1461,12 @@ the rounding and saturation modes respectively. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) """ function cvt_float_to_tf32( - src::Value; res::IR.Type, rnd=nothing, sat=nothing, relu=nothing, location=Location() + src::Value; + res::IR.Type, + rnd::Union{FPRoundingMode.T,Nothing}=nothing, + sat::Union{SaturationMode.T,Nothing}=nothing, + relu::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[res,] operands = Value[src,] @@ -1226,7 +1500,7 @@ leader thread, and `False` for all other threads. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-elect-sync) """ -function elect_sync(; pred::IR.Type, location=Location()) +function elect_sync(; pred::IR.Type, location::Location=Location()) op_ty_results = IR.Type[pred,] operands = Value[] owned_regions = Region[] @@ -1245,7 +1519,7 @@ function elect_sync(; pred::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg0(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg0(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1264,7 +1538,7 @@ function read_ptx_sreg_envreg0(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg1(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg1(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1283,7 +1557,7 @@ function read_ptx_sreg_envreg1(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg2(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg2(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1302,7 +1576,7 @@ function read_ptx_sreg_envreg2(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg3(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg3(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1321,7 +1595,7 @@ function read_ptx_sreg_envreg3(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg4(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg4(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1340,7 +1614,7 @@ function read_ptx_sreg_envreg4(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg5(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg5(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1359,7 +1633,7 @@ function read_ptx_sreg_envreg5(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg6(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg6(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1378,7 +1652,7 @@ function read_ptx_sreg_envreg6(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg7(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg7(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1397,7 +1671,7 @@ function read_ptx_sreg_envreg7(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg8(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg8(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1416,7 +1690,7 @@ function read_ptx_sreg_envreg8(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg9(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg9(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1435,7 +1709,7 @@ function read_ptx_sreg_envreg9(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg10(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg10(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1454,7 +1728,7 @@ function read_ptx_sreg_envreg10(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg11(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg11(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1473,7 +1747,7 @@ function read_ptx_sreg_envreg11(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg12(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg12(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1492,7 +1766,7 @@ function read_ptx_sreg_envreg12(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg13(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg13(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1511,7 +1785,7 @@ function read_ptx_sreg_envreg13(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg14(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg14(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1530,7 +1804,7 @@ function read_ptx_sreg_envreg14(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg15(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg15(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1549,7 +1823,7 @@ function read_ptx_sreg_envreg15(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg16(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg16(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1568,7 +1842,7 @@ function read_ptx_sreg_envreg16(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg17(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg17(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1587,7 +1861,7 @@ function read_ptx_sreg_envreg17(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg18(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg18(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1606,7 +1880,7 @@ function read_ptx_sreg_envreg18(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg19(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg19(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1625,7 +1899,7 @@ function read_ptx_sreg_envreg19(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg20(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg20(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1644,7 +1918,7 @@ function read_ptx_sreg_envreg20(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg21(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg21(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1663,7 +1937,7 @@ function read_ptx_sreg_envreg21(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg22(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg22(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1682,7 +1956,7 @@ function read_ptx_sreg_envreg22(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg23(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg23(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1701,7 +1975,7 @@ function read_ptx_sreg_envreg23(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg24(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg24(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1720,7 +1994,7 @@ function read_ptx_sreg_envreg24(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg25(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg25(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1739,7 +2013,7 @@ function read_ptx_sreg_envreg25(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg26(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg26(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1758,7 +2032,7 @@ function read_ptx_sreg_envreg26(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg27(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg27(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1777,7 +2051,7 @@ function read_ptx_sreg_envreg27(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg28(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg28(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1796,7 +2070,7 @@ function read_ptx_sreg_envreg28(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg29(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg29(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1815,7 +2089,7 @@ function read_ptx_sreg_envreg29(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg30(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg30(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1834,7 +2108,7 @@ function read_ptx_sreg_envreg30(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_envreg31(; res::IR.Type, location=Location()) +function read_ptx_sreg_envreg31(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -1859,7 +2133,7 @@ end Ends execution of a thread. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-exit) """ -function exit(; location=Location()) +function exit(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1885,7 +2159,7 @@ Fence operation that applies on the prior nvvm.mbarrier.init [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) """ -function fence_mbarrier_init(; location=Location()) +function fence_mbarrier_init(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1921,7 +2195,12 @@ fall within the `.global` state space. Otherwise, the behavior is undefined [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) """ function fence_proxy_acquire( - addr::Value, size::Value; scope, fromProxy=nothing, toProxy=nothing, location=Location() + addr::Value, + size::Value; + scope::MemScopeKind.T, + fromProxy::Union{ProxyKind.T,Nothing}=nothing, + toProxy::Union{ProxyKind.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[addr, size] @@ -1951,7 +2230,11 @@ that may happen through different proxies. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) """ -function fence_proxy(; kind, space=nothing, location=Location()) +function fence_proxy(; + kind::ProxyKind.T, + space::Union{SharedSpace.T,Nothing}=nothing, + location::Location=Location(), +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1983,7 +2266,10 @@ sequence that contains the fence.proxy.acquire proxy fence operation [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) """ function fence_proxy_release(; - scope, fromProxy=nothing, toProxy=nothing, location=Location() + scope::MemScopeKind.T, + fromProxy::Union{ProxyKind.T,Nothing}=nothing, + toProxy::Union{ProxyKind.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -2005,7 +2291,7 @@ function fence_proxy_release(; ) end -function fence_sc_cluster(; location=Location()) +function fence_sc_cluster(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -2024,7 +2310,7 @@ function fence_sc_cluster(; location=Location()) ) end -function read_ptx_sreg_globaltimer(; res::IR.Type, location=Location()) +function read_ptx_sreg_globaltimer(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2043,7 +2329,9 @@ function read_ptx_sreg_globaltimer(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_nctaid_x(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_nctaid_x(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2063,7 +2351,9 @@ function read_ptx_sreg_nctaid_x(; res::IR.Type, range=nothing, location=Location ) end -function read_ptx_sreg_nctaid_y(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_nctaid_y(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2083,7 +2373,9 @@ function read_ptx_sreg_nctaid_y(; res::IR.Type, range=nothing, location=Location ) end -function read_ptx_sreg_nctaid_z(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_nctaid_z(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2103,7 +2395,7 @@ function read_ptx_sreg_nctaid_z(; res::IR.Type, range=nothing, location=Location ) end -function read_ptx_sreg_gridid(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_gridid(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2133,7 +2425,7 @@ issue the same instruction or have completed. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol) """ -function griddepcontrol_launch_dependents(; location=Location()) +function griddepcontrol_launch_dependents(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -2162,7 +2454,7 @@ are performed and made visible to the current grid. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol) """ -function griddepcontrol_wait(; location=Location()) +function griddepcontrol_wait(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -2181,7 +2473,7 @@ function griddepcontrol_wait(; location=Location()) ) end -function read_ptx_sreg_laneid(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_laneid(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2201,7 +2493,7 @@ function read_ptx_sreg_laneid(; res::IR.Type, range=nothing, location=Location() ) end -function read_ptx_sreg_lanemask_eq(; res::IR.Type, location=Location()) +function read_ptx_sreg_lanemask_eq(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2220,7 +2512,7 @@ function read_ptx_sreg_lanemask_eq(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_lanemask_ge(; res::IR.Type, location=Location()) +function read_ptx_sreg_lanemask_ge(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2239,7 +2531,7 @@ function read_ptx_sreg_lanemask_ge(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_lanemask_gt(; res::IR.Type, location=Location()) +function read_ptx_sreg_lanemask_gt(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2258,7 +2550,7 @@ function read_ptx_sreg_lanemask_gt(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_lanemask_le(; res::IR.Type, location=Location()) +function read_ptx_sreg_lanemask_le(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2277,7 +2569,7 @@ function read_ptx_sreg_lanemask_le(; res::IR.Type, location=Location()) ) end -function read_ptx_sreg_lanemask_lt(; res::IR.Type, location=Location()) +function read_ptx_sreg_lanemask_lt(; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2296,7 +2588,9 @@ function read_ptx_sreg_lanemask_lt(; res::IR.Type, location=Location()) ) end -function ldmatrix(ptr::Value; res::IR.Type, num, layout, location=Location()) +function ldmatrix( + ptr::Value; res::IR.Type, num::Int32, layout::MMALayout.T, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[ptr,] owned_regions = Region[] @@ -2320,8 +2614,8 @@ end function mbarrier_arrive_expect_tx( addr::Value, txcount::Value, - predicate=nothing::Union{Nothing,Value}; - location=Location(), + predicate::Union{Nothing,Value}=nothing; + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[addr, txcount] @@ -2345,8 +2639,8 @@ end function mbarrier_arrive_expect_tx_shared( addr::Value, txcount::Value, - predicate=nothing::Union{Nothing,Value}; - location=Location(), + predicate::Union{Nothing,Value}=nothing; + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[addr, txcount] @@ -2368,7 +2662,7 @@ function mbarrier_arrive_expect_tx_shared( end function mbarrier_arrive_nocomplete( - addr::Value, count::Value; res::IR.Type, location=Location() + addr::Value, count::Value; res::IR.Type, location::Location=Location() ) op_ty_results = IR.Type[res,] operands = Value[addr, count] @@ -2389,7 +2683,7 @@ function mbarrier_arrive_nocomplete( end function mbarrier_arrive_nocomplete_shared( - addr::Value, count::Value; res::IR.Type, location=Location() + addr::Value, count::Value; res::IR.Type, location::Location=Location() ) op_ty_results = IR.Type[res,] operands = Value[addr, count] @@ -2409,7 +2703,7 @@ function mbarrier_arrive_nocomplete_shared( ) end -function mbarrier_arrive(addr::Value; res::IR.Type, location=Location()) +function mbarrier_arrive(addr::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[addr,] owned_regions = Region[] @@ -2428,7 +2722,7 @@ function mbarrier_arrive(addr::Value; res::IR.Type, location=Location()) ) end -function mbarrier_arrive_shared(addr::Value; res::IR.Type, location=Location()) +function mbarrier_arrive_shared(addr::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[addr,] owned_regions = Region[] @@ -2448,7 +2742,10 @@ function mbarrier_arrive_shared(addr::Value; res::IR.Type, location=Location()) end function mbarrier_init( - addr::Value, count::Value, predicate=nothing::Union{Nothing,Value}; location=Location() + addr::Value, + count::Value, + predicate::Union{Nothing,Value}=nothing; + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[addr, count] @@ -2470,7 +2767,10 @@ function mbarrier_init( end function mbarrier_init_shared( - addr::Value, count::Value, predicate=nothing::Union{Nothing,Value}; location=Location() + addr::Value, + count::Value, + predicate::Union{Nothing,Value}=nothing; + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[addr, count] @@ -2491,7 +2791,7 @@ function mbarrier_init_shared( ) end -function mbarrier_inval(addr::Value; location=Location()) +function mbarrier_inval(addr::Value; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[addr,] owned_regions = Region[] @@ -2510,7 +2810,7 @@ function mbarrier_inval(addr::Value; location=Location()) ) end -function mbarrier_inval_shared(addr::Value; location=Location()) +function mbarrier_inval_shared(addr::Value; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[addr,] owned_regions = Region[] @@ -2529,7 +2829,9 @@ function mbarrier_inval_shared(addr::Value; location=Location()) ) end -function mbarrier_test_wait(addr::Value, state::Value; res::IR.Type, location=Location()) +function mbarrier_test_wait( + addr::Value, state::Value; res::IR.Type, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[addr, state] owned_regions = Region[] @@ -2549,7 +2851,7 @@ function mbarrier_test_wait(addr::Value, state::Value; res::IR.Type, location=Lo end function mbarrier_test_wait_shared( - addr::Value, state::Value; res::IR.Type, location=Location() + addr::Value, state::Value; res::IR.Type, location::Location=Location() ) op_ty_results = IR.Type[res,] operands = Value[addr, state] @@ -2570,7 +2872,7 @@ function mbarrier_test_wait_shared( end function mbarrier_try_wait_parity( - addr::Value, phase::Value, ticks::Value; location=Location() + addr::Value, phase::Value, ticks::Value; location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[addr, phase, ticks] @@ -2591,7 +2893,7 @@ function mbarrier_try_wait_parity( end function mbarrier_try_wait_parity_shared( - addr::Value, phase::Value, ticks::Value; location=Location() + addr::Value, phase::Value, ticks::Value; location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[addr, phase, ticks] @@ -2611,7 +2913,7 @@ function mbarrier_try_wait_parity_shared( ) end -function mapa(a::Value, b::Value; res::IR.Type, location=Location()) +function mapa(a::Value, b::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[a, b] owned_regions = Region[] @@ -2704,13 +3006,13 @@ function mma_sync( operandC::Vector{Value}; res::IR.Type, shape, - b1Op=nothing, - intOverflowBehavior=nothing, - layoutA, - layoutB, - multiplicandAPtxType=nothing, - multiplicandBPtxType=nothing, - location=Location(), + b1Op::Union{MMAB1Op.T,Nothing}=nothing, + intOverflowBehavior::Union{MMAIntOverflow.T,Nothing}=nothing, + layoutA::MMALayout.T, + layoutB::MMALayout.T, + multiplicandAPtxType::Union{MMATypes.T,Nothing}=nothing, + multiplicandBPtxType::Union{MMATypes.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[res,] operands = Value[operandA..., operandB..., operandC...] @@ -2746,7 +3048,9 @@ function mma_sync( end function prefetch_tensormap( - tmaDescriptor::Value, predicate=nothing::Union{Nothing,Value}; location=Location() + tmaDescriptor::Value, + predicate::Union{Nothing,Value}=nothing; + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[tmaDescriptor,] @@ -2767,7 +3071,7 @@ function prefetch_tensormap( ) end -function rcp_approx_ftz_f(arg::Value; res::IR.Type, location=Location()) +function rcp_approx_ftz_f(arg::Value; res::IR.Type, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[arg,] owned_regions = Region[] @@ -2787,7 +3091,11 @@ function rcp_approx_ftz_f(arg::Value; res::IR.Type, location=Location()) end function redux_sync( - val::Value, mask_and_clamp::Value; res::IR.Type, kind, location=Location() + val::Value, + mask_and_clamp::Value; + res::IR.Type, + kind::ReduxKind.T, + location::Location=Location(), ) op_ty_results = IR.Type[res,] operands = Value[val, mask_and_clamp] @@ -2807,7 +3115,9 @@ function redux_sync( ) end -function setmaxregister(; regCount, action, location=Location()) +function setmaxregister(; + regCount::Int32, action::SetMaxRegisterAction.T, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -2848,9 +3158,9 @@ function shfl_sync( offset::Value, mask_and_clamp::Value; res::IR.Type, - kind, - return_value_and_is_valid=nothing, - location=Location(), + kind::ShflKind.T, + return_value_and_is_valid::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[res,] operands = Value[thread_mask, val, offset, mask_and_clamp] @@ -2874,7 +3184,7 @@ function shfl_sync( ) end -function read_ptx_sreg_nsmid(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_nsmid(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2894,7 +3204,7 @@ function read_ptx_sreg_nsmid(; res::IR.Type, range=nothing, location=Location()) ) end -function read_ptx_sreg_smid(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_smid(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -2922,7 +3232,9 @@ location indicated by the address operand \$ptr in shared memory. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) """ -function stmatrix(ptr::Value, sources::Vector{Value}; layout, location=Location()) +function stmatrix( + ptr::Value, sources::Vector{Value}; layout::MMALayout.T, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[ptr, sources...] owned_regions = Region[] @@ -2941,7 +3253,7 @@ function stmatrix(ptr::Value, sources::Vector{Value}; layout, location=Location( ) end -function bar_warp_sync(mask::Value; location=Location()) +function bar_warp_sync(mask::Value; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[mask,] owned_regions = Region[] @@ -2970,7 +3282,12 @@ number of columns to be allocated and it must be a power-of-two. [For more information, refer to the PTX ISA] (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-memory-alloc-manage-instructions) """ -function tcgen05_alloc(addr::Value, nCols::Value; group=nothing, location=Location()) +function tcgen05_alloc( + addr::Value, + nCols::Value; + group::Union{Tcgen05GroupKind.T,Nothing}=nothing, + location::Location=Location(), +) op_ty_results = IR.Type[] operands = Value[addr, nCols] owned_regions = Region[] @@ -3000,7 +3317,12 @@ of columns to be de-allocated, and it must be a power-of-two. [For more information, refer to the PTX ISA] (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-memory-alloc-manage-instructions) """ -function tcgen05_dealloc(taddr::Value, nCols::Value; group=nothing, location=Location()) +function tcgen05_dealloc( + taddr::Value, + nCols::Value; + group::Union{Tcgen05GroupKind.T,Nothing}=nothing, + location::Location=Location(), +) op_ty_results = IR.Type[] operands = Value[taddr, nCols] owned_regions = Region[] @@ -3030,7 +3352,9 @@ after any of its constituent threads execute `tcgen05.relinquish_alloc_permit`. [For more information, refer to the PTX ISA] (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-memory-alloc-manage-instructions) """ -function tcgen05_relinquish_alloc_permit(; group=nothing, location=Location()) +function tcgen05_relinquish_alloc_permit(; + group::Union{Tcgen05GroupKind.T,Nothing}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -3050,7 +3374,7 @@ function tcgen05_relinquish_alloc_permit(; group=nothing, location=Location()) ) end -function read_ptx_sreg_tid_x(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_tid_x(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -3070,7 +3394,7 @@ function read_ptx_sreg_tid_x(; res::IR.Type, range=nothing, location=Location()) ) end -function read_ptx_sreg_tid_y(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_tid_y(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -3090,7 +3414,7 @@ function read_ptx_sreg_tid_y(; res::IR.Type, range=nothing, location=Location()) ) end -function read_ptx_sreg_tid_z(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_tid_z(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -3110,7 +3434,9 @@ function read_ptx_sreg_tid_z(; res::IR.Type, range=nothing, location=Location()) ) end -function vote_ballot_sync(mask::Value, pred::Value; res::IR.Type, location=Location()) +function vote_ballot_sync( + mask::Value, pred::Value; res::IR.Type, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[mask, pred] owned_regions = Region[] @@ -3133,13 +3459,13 @@ function wmma_load( ptr::Value, stride::Value; res::IR.Type, - m, - n, - k, - layout, - eltype, - frag, - location=Location(), + m::Int32, + n::Int32, + k::Int32, + layout::MMALayout.T, + eltype::MMATypes.T, + frag::MMAFrag.T, + location::Location=Location(), ) op_ty_results = IR.Type[res,] operands = Value[ptr, stride] @@ -3169,14 +3495,14 @@ end function wmma_mma( args::Vector{Value}; res::IR.Type, - m, - n, - k, - layoutA, - layoutB, - eltypeA, - eltypeB, - location=Location(), + m::Int32, + n::Int32, + k::Int32, + layoutA::MMALayout.T, + layoutB::MMALayout.T, + eltypeA::MMATypes.T, + eltypeB::MMATypes.T, + location::Location=Location(), ) op_ty_results = IR.Type[res,] operands = Value[args...,] @@ -3208,12 +3534,12 @@ function wmma_store( ptr::Value, args::Vector{Value}, stride::Value; - m, - n, - k, - layout, - eltype, - location=Location(), + m::Int32, + n::Int32, + k::Int32, + layout::MMALayout.T, + eltype::MMATypes.T, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[ptr, args..., stride] @@ -3239,7 +3565,7 @@ function wmma_store( ) end -function read_ptx_sreg_nwarpid(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_nwarpid(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -3259,7 +3585,7 @@ function read_ptx_sreg_nwarpid(; res::IR.Type, range=nothing, location=Location( ) end -function read_ptx_sreg_warpid(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_warpid(; res::IR.Type, range=nothing, location::Location=Location()) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -3279,7 +3605,9 @@ function read_ptx_sreg_warpid(; res::IR.Type, range=nothing, location=Location() ) end -function read_ptx_sreg_warpsize(; res::IR.Type, range=nothing, location=Location()) +function read_ptx_sreg_warpsize(; + res::IR.Type, range=nothing, location::Location=Location() +) op_ty_results = IR.Type[res,] operands = Value[] owned_regions = Region[] @@ -3307,7 +3635,7 @@ multiplication and other operations. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence) """ -function wgmma_fence_aligned(; location=Location()) +function wgmma_fence_aligned(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -3333,7 +3661,7 @@ Commits all prior uncommitted warpgroup level matrix multiplication operations. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group) """ -function wgmma_commit_group_sync_aligned(; location=Location()) +function wgmma_commit_group_sync_aligned(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -3417,16 +3745,16 @@ function wgmma_mma_async( descriptorB::Value; results::IR.Type, shape, - typeA, - typeB, - typeD, - scaleD, - scaleA, - scaleB, - layoutA, - layoutB, - satfinite=nothing, - location=Location(), + typeA::WGMMATypes.T, + typeB::WGMMATypes.T, + typeD::WGMMATypes.T, + scaleD::WGMMAScaleOut.T, + scaleA::WGMMAScaleIn.T, + scaleB::WGMMAScaleIn.T, + layoutA::MMALayout.T, + layoutB::MMALayout.T, + satfinite::Union{MMAIntOverflow.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[results,] operands = Value[inouts, descriptorA, descriptorB] @@ -3464,7 +3792,7 @@ Signal the completion of a preceding warpgroup operation. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-wait-group) """ -function wgmma_wait_group_sync_aligned(; group, location=Location()) +function wgmma_wait_group_sync_aligned(; group::Int64, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] diff --git a/src/mlir/Dialects/Shardy.jl b/src/mlir/Dialects/Shardy.jl old mode 100644 new mode 100755 index baaaa6c1b..ddd59c3e1 --- a/src/mlir/Dialects/Shardy.jl +++ b/src/mlir/Dialects/Shardy.jl @@ -10,8 +10,17 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX + +""" +`PropagationDirection` +propagation direction enum +""" +@enumx PropagationDirection NONE = 0 FORWARD = 1 BACKWARD = 2 BOTH = 3 + +IR.Attribute(e::PropagationDirection.T) = Int(e) """ `all_gather` @@ -43,10 +52,10 @@ inferred sharding. """ function all_gather( tensor::Value; - result=nothing::Union{Nothing,IR.Type}, + result::Union{Nothing,IR.Type}=nothing, gathering_axes, out_sharding, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[tensor,] @@ -65,8 +74,8 @@ function all_gather( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -84,10 +93,10 @@ affect the order of the corresponding replica groups. """ function all_reduce( tensor::Value; - result=nothing::Union{Nothing,IR.Type}, + result::Union{Nothing,IR.Type}=nothing, reduction_axes, out_sharding, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[tensor,] @@ -106,8 +115,8 @@ function all_reduce( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -142,10 +151,10 @@ inferred sharding. """ function all_slice( tensor::Value; - result=nothing::Union{Nothing,IR.Type}, + result::Union{Nothing,IR.Type}=nothing, slicing_axes, out_sharding, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[tensor,] @@ -164,8 +173,8 @@ function all_slice( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -205,12 +214,12 @@ this inferred sharding. """ function all_to_all( tensor::Value; - result=nothing::Union{Nothing,IR.Type}, - src_dim, - tgt_dim, + result::Union{Nothing,IR.Type}=nothing, + src_dim::Int64, + tgt_dim::Int64, axes, out_sharding, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[tensor,] @@ -231,8 +240,8 @@ function all_to_all( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -266,7 +275,10 @@ sdy.mesh @mesh = <[\"a\"=2, \"b\"=2, \"c\"=4, \"d\"=2, \"e\"=2, \"f\"=2]> must match that of the corresponding operand dimension sharding. """ function collective_permute( - tensor::Value; result=nothing::Union{Nothing,IR.Type}, out_sharding, location=Location() + tensor::Value; + result::Union{Nothing,IR.Type}=nothing, + out_sharding, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[tensor,] @@ -282,8 +294,8 @@ function collective_permute( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -306,7 +318,11 @@ is done between constants (or constant expressions). %output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> ``` """ -function constant(; output=nothing::Union{Nothing,IR.Type}, value, location=Location()) +function constant(; + output::Union{Nothing,IR.Type}=nothing, + value::IR.AbstractDenseElementsAttribute, + location::Location=Location(), +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -321,8 +337,8 @@ function constant(; output=nothing::Union{Nothing,IR.Type}, value, location=Loca owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -384,9 +400,9 @@ responsible for providing this information. """ function data_flow_edge( input::Value; - result=nothing::Union{Nothing,IR.Type}, + result::Union{Nothing,IR.Type}=nothing, sharding=nothing, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[input,] @@ -403,8 +419,8 @@ function data_flow_edge( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -428,12 +444,12 @@ the body on any free axes - those not in the manual_axes list. """ function manual_computation( tensors::Vector{Value}; - results::Vector{IR.Type}, + results::Base.AbstractVecOrTuple{IR.Type}, in_shardings, out_shardings, manual_axes, body::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[tensors...,] @@ -465,7 +481,7 @@ of devices (except for meshes with a single device_id). The mesh is a `Symbol` operation that appears in the module\'s `SymbolTable` and can be referenced by its `name`. """ -function mesh(; sym_name, mesh, location=Location()) +function mesh(; sym_name::String, mesh, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -510,14 +526,14 @@ the same as the type of the operands and results type of the op. """ function named_computation( operands::Vector{Value}; - result_0::Vector{IR.Type}, - name, + result::Base.AbstractVecOrTuple{IR.Type}, + name::String, in_shardings=nothing, out_shardings=nothing, body::Region, - location=Location(), + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[operands...,] owned_regions = Region[body,] successors = Block[] @@ -556,9 +572,9 @@ of the barrier op and its operand. """ function propagation_barrier( input::Value; - result=nothing::Union{Nothing,IR.Type}, - allowed_direction, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + allowed_direction::PropagationDirection.T, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[input,] @@ -574,8 +590,8 @@ function propagation_barrier( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -598,7 +614,10 @@ lifespan is: // reshard ops. """ function reshard( - input::Value; result=nothing::Union{Nothing,IR.Type}, sharding, location=Location() + input::Value; + result::Union{Nothing,IR.Type}=nothing, + sharding, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[input,] @@ -614,12 +633,12 @@ function reshard( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function return_(results::Vector{Value}; location=Location()) +function return_(results::Vector{Value}; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[results...,] owned_regions = Region[] @@ -657,7 +676,10 @@ This op can either: uses then the behavior is the same as the no uses case). """ function sharding_constraint( - input::Value; result=nothing::Union{Nothing,IR.Type}, sharding, location=Location() + input::Value; + result::Union{Nothing,IR.Type}=nothing, + sharding, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[input,] @@ -673,8 +695,8 @@ function sharding_constraint( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -689,7 +711,7 @@ argument group ID and returns no result, but instead modifies the internal sharding group representation to add the input tensor to the group with the given ID. """ -function sharding_group(input::Value; group_id, location=Location()) +function sharding_group(input::Value; group_id::Int64, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[input,] owned_regions = Region[] @@ -703,8 +725,8 @@ function sharding_group(input::Value; group_id, location=Location()) owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end diff --git a/src/mlir/Dialects/StableHLO.jl b/src/mlir/Dialects/StableHLO.jl index f7d45ef92..2ddedc795 100755 --- a/src/mlir/Dialects/StableHLO.jl +++ b/src/mlir/Dialects/StableHLO.jl @@ -10,8 +10,213 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX + +""" +`channel_handle` +two 64-bit integers \'handle\' and \'type\' +""" +struct ChannelHandle + handle::Int64 + type::Int64 +end + +function IR.Attribute(s::ChannelHandle) + return parse( + Attribute, "#stablehlo.channel_handle" + ) +end + +""" +`ComparisonDirection` +Which comparison operation to perform. +""" +@enumx ComparisonDirection EQ NE GE GT LE LT +const ComparisonDirectionStorage = ["EQ", "NE", "GE", "GT", "LE", "LT"] + +function IR.Attribute(e::ComparisonDirection.T) + return parse( + Attribute, + "#stablehlo", + ) +end + +""" +`ComparisonType` +Which comparison type to use. +""" +@enumx ComparisonType NOTYPE FLOAT TOTALORDER SIGNED UNSIGNED +const ComparisonTypeStorage = ["NOTYPE", "FLOAT", "TOTALORDER", "SIGNED", "UNSIGNED"] + +function IR.Attribute(e::ComparisonType.T) + return parse( + Attribute, "#stablehlo" + ) +end + +""" +`Precision` +XLA precision for an operand. Has backend specific meaning. +""" +@enumx Precision DEFAULT HIGH HIGHEST +const PrecisionStorage = ["DEFAULT", "HIGH", "HIGHEST"] + +function IR.Attribute(e::Precision.T) + return parse(Attribute, "#stablehlo") +end + +""" +`CustomCallApiVersion` +Custom call API version +""" +@enumx CustomCallApiVersion API_VERSION_UNSPECIFIED = 0 API_VERSION_ORIGINAL = 1 API_VERSION_STATUS_RETURNING = + 2 API_VERSION_STATUS_RETURNING_UNIFIED = 3 API_VERSION_TYPED_FFI = 4 + +IR.Attribute(e::CustomCallApiVersion.T) = Int(e) + +""" +`output_operand_alias` +Attribute that models the alias relationship of output and operand of a CustomCall op +""" +struct OutputOperandAlias + output_tuple_indices::IR.DenseAttribute{Int64} + operand_index::Int64 + operand_tuple_indices::IR.DenseAttribute{Int64} +end + +function IR.Attribute(s::OutputOperandAlias) + return parse( + Attribute, + "#stablehlo.output_operand_alias", + ) +end + +""" +`dot` +Attribute that models the dimension information for dot. +""" +struct Dot + lhs_batching_dimensions::IR.DenseAttribute{Int64} + rhs_batching_dimensions::IR.DenseAttribute{Int64} + lhs_contracting_dimensions::IR.DenseAttribute{Int64} + rhs_contracting_dimensions::IR.DenseAttribute{Int64} +end + +function IR.Attribute(s::Dot) + return parse( + Attribute, + "#stablehlo.dot", + ) +end + +""" +`dot_algorithm` +Attribute that models the algorithm constraints to use for computing dot. +""" +struct DotAlgorithm + lhs_precision_type::IR.Type + rhs_precision_type::IR.Type + accumulation_type::IR.Type + lhs_component_count::Int64 + rhs_component_count::Int64 + num_primitive_operations::Int64 + allow_imprecise_accumulation::Bool +end + +function IR.Attribute(s::DotAlgorithm) + return parse( + Attribute, + "#stablehlo.dot_algorithm", + ) +end + +""" +`gather` +Attribute that models the dimension information for gather +""" +struct Gather + offset_dims::IR.DenseAttribute{Int64} + collapsed_slice_dims::IR.DenseAttribute{Int64} + operand_batching_dims::IR.DenseAttribute{Int64} + start_indices_batching_dims::IR.DenseAttribute{Int64} + start_index_map::IR.DenseAttribute{Int64} + index_vector_dim::Int64 +end + +function IR.Attribute(s::Gather) + return parse( + Attribute, + "#stablehlo.gather", + ) +end + +""" +`FftType` +XLA fast fourier transform type. +""" +@enumx FftType FFT IFFT RFFT IRFFT +const FftTypeStorage = ["FFT", "IFFT", "RFFT", "IRFFT"] + +function IR.Attribute(e::FftType.T) + return parse(Attribute, "#stablehlo") +end + +""" +`RngAlgorithm` +XLA PRNG algorithm to be used. +""" +@enumx RngAlgorithm DEFAULT THREE_FRY PHILOX +const RngAlgorithmStorage = ["DEFAULT", "THREE_FRY", "PHILOX"] + +function IR.Attribute(e::RngAlgorithm.T) + return parse(Attribute, "#stablehlo") +end + +""" +`RngDistribution` +XLA PRNG distribution to be used. +""" +@enumx RngDistribution UNIFORM NORMAL +const RngDistributionStorage = ["UNIFORM", "NORMAL"] + +function IR.Attribute(e::RngDistribution.T) + return parse( + Attribute, "#stablehlo" + ) +end + +""" +`scatter` +Attribute that models the dimension information for scatter +""" +struct Scatter + update_window_dims::IR.DenseAttribute{Int64} + inserted_window_dims::IR.DenseAttribute{Int64} + input_batching_dims::IR.DenseAttribute{Int64} + scatter_indices_batching_dims::IR.DenseAttribute{Int64} + scatter_dims_to_operand_dims::IR.DenseAttribute{Int64} + index_vector_dim::Int64 +end + +function IR.Attribute(s::Scatter) + return parse( + Attribute, + "#stablehlo.scatter", + ) +end + +""" +`Transpose` +Transpose options +""" +@enumx Transpose TRANSPOSE_INVALID NO_TRANSPOSE TRANSPOSE ADJOINT +const TransposeStorage = ["TRANSPOSE_INVALID", "NO_TRANSPOSE", "TRANSPOSE", "ADJOINT"] + +function IR.Attribute(e::Transpose.T) + return parse(Attribute, "#stablehlo") +end """ `abs` @@ -27,7 +232,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#abs %result = stablehlo.abs %operand : tensor<3xi32> ``` """ -function abs(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function abs( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -42,8 +249,8 @@ function abs(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Lo owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -62,7 +269,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#add ``` """ function add( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -78,8 +288,8 @@ function add( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -98,7 +308,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#after_all ``` """ function after_all( - inputs::Vector{Value}; result=nothing::Union{Nothing,IR.Type}, location=Location() + inputs::Vector{Value}; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[inputs...,] @@ -114,8 +326,8 @@ function after_all( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -140,14 +352,14 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_gather """ function all_gather( operands::Vector{Value}; - result_0::Vector{IR.Type}, - all_gather_dim, - replica_groups, - channel_handle=nothing, - use_global_device_ids=nothing, - location=Location(), + result::Base.AbstractVecOrTuple{IR.Type}, + all_gather_dim::Int64, + replica_groups::IR.AbstractDenseElementsAttribute{Int64}, + channel_handle::Union{ChannelHandle,Nothing}=nothing, + use_global_device_ids::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[operands...,] owned_regions = Region[] successors = Block[] @@ -196,14 +408,14 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_reduce """ function all_reduce( operands::Vector{Value}; - result_0::Vector{IR.Type}, - replica_groups, - channel_handle=nothing, - use_global_device_ids=nothing, + result::Base.AbstractVecOrTuple{IR.Type}, + replica_groups::IR.AbstractDenseElementsAttribute{Int64}, + channel_handle::Union{ChannelHandle,Nothing}=nothing, + use_global_device_ids::Union{Bool,Nothing}=nothing, computation::Region, - location=Location(), + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[operands...,] owned_regions = Region[computation,] successors = Block[] @@ -248,13 +460,13 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_to_all """ function all_to_all( operands::Vector{Value}; - result_0=nothing::Union{Nothing,Vector{IR.Type}}, - split_dimension, - concat_dimension, - split_count, - replica_groups, - channel_handle=nothing, - location=Location(), + result::Union{Nothing,Base.AbstractVecOrTuple{IR.Type}}=nothing, + split_dimension::Int64, + concat_dimension::Int64, + split_count::Int64, + replica_groups::IR.AbstractDenseElementsAttribute{Int64}, + channel_handle::Union{ChannelHandle,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operands...,] @@ -266,7 +478,7 @@ function all_to_all( namedattribute("split_count", split_count), namedattribute("replica_groups", replica_groups), ] - !isnothing(result_0) && push!(op_ty_results, result_0...) + !isnothing(result) && push!(op_ty_results, result...) !isnothing(channel_handle) && push!(attributes, namedattribute("channel_handle", channel_handle)) @@ -277,8 +489,8 @@ function all_to_all( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -297,7 +509,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#and ``` """ function and( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -313,8 +528,8 @@ function and( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -333,7 +548,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#atan2 ``` """ function atan2( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -349,8 +567,8 @@ function atan2( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -380,12 +598,12 @@ function batch_norm_grad( mean::Value, variance::Value, grad_output::Value; - grad_operand=nothing::Union{Nothing,IR.Type}, - grad_scale=nothing::Union{Nothing,IR.Type}, - grad_offset=nothing::Union{Nothing,IR.Type}, - epsilon, - feature_index, - location=Location(), + grad_operand::Union{Nothing,IR.Type}=nothing, + grad_scale::Union{Nothing,IR.Type}=nothing, + grad_offset::Union{Nothing,IR.Type}=nothing, + epsilon::Float32, + feature_index::Int64, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand, scale, mean, variance, grad_output] @@ -405,8 +623,8 @@ function batch_norm_grad( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -433,10 +651,10 @@ function batch_norm_inference( offset::Value, mean::Value, variance::Value; - result=nothing::Union{Nothing,IR.Type}, - epsilon, - feature_index, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + epsilon::Float32, + feature_index::Int64, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand, scale, offset, mean, variance] @@ -454,8 +672,8 @@ function batch_norm_inference( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -482,12 +700,12 @@ function batch_norm_training( operand::Value, scale::Value, offset::Value; - output=nothing::Union{Nothing,IR.Type}, - batch_mean=nothing::Union{Nothing,IR.Type}, - batch_var=nothing::Union{Nothing,IR.Type}, - epsilon, - feature_index, - location=Location(), + output::Union{Nothing,IR.Type}=nothing, + batch_mean::Union{Nothing,IR.Type}=nothing, + batch_var::Union{Nothing,IR.Type}=nothing, + epsilon::Float32, + feature_index::Int64, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand, scale, offset] @@ -507,8 +725,8 @@ function batch_norm_training( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -527,8 +745,8 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#bitcast_convert %result = stablehlo.bitcast_convert %operand : (tensor) -> tensor<4xf16> ``` """ -function bitcast_convert(operand::Value; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0,] +function bitcast_convert(operand::Value; result::IR.Type, location::Location=Location()) + op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] successors = Block[] @@ -561,9 +779,12 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim ``` """ function broadcast_in_dim( - operand::Value; result_0::IR.Type, broadcast_dimensions, location=Location() + operand::Value; + result::IR.Type, + broadcast_dimensions::IR.DenseAttribute{Int64}, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] successors = Block[] @@ -599,16 +820,16 @@ https://www.tensorflow.org/xla/operation_semantics#broadcast """ function broadcast( operand::Value; - result_0=nothing::Union{Nothing,IR.Type}, - broadcast_sizes, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + broadcast_sizes::IR.DenseAttribute{Int64}, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("broadcast_sizes", broadcast_sizes),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "stablehlo.broadcast", @@ -617,8 +838,8 @@ function broadcast( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -641,9 +862,12 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case ``` """ function case( - index::Value; result_0::Vector{IR.Type}, branches::Vector{Region}, location=Location() + index::Value; + result::Base.AbstractVecOrTuple{IR.Type}, + branches::Vector{Region}, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[index,] owned_regions = Region[branches...,] successors = Block[] @@ -675,7 +899,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cbrt %result = stablehlo.cbrt %operand : tensor<4xf64> ``` """ -function cbrt(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function cbrt( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -690,8 +916,8 @@ function cbrt(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -708,7 +934,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#ceil %result = stablehlo.ceil %operand : tensor<5xf32> ``` """ -function ceil(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function ceil( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -723,8 +951,8 @@ function ceil(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -742,7 +970,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cholesky ``` """ function cholesky( - a::Value; result=nothing::Union{Nothing,IR.Type}, lower=nothing, location=Location() + a::Value; + result::Union{Nothing,IR.Type}=nothing, + lower::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[a,] @@ -759,8 +990,8 @@ function cholesky( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -782,8 +1013,8 @@ function clamp( min::Value, operand::Value, max::Value; - result=nothing::Union{Nothing,IR.Type}, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[min, operand, max] @@ -799,8 +1030,8 @@ function clamp( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -819,7 +1050,7 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#count_leading_zeros ``` """ function count_leading_zeros( - operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -835,8 +1066,8 @@ function count_leading_zeros( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -860,17 +1091,17 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective_broadcast """ function collective_broadcast( operand::Value; - result_0=nothing::Union{Nothing,IR.Type}, - replica_groups, - channel_handle=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + replica_groups::IR.AbstractDenseElementsAttribute{Int64}, + channel_handle::Union{ChannelHandle,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("replica_groups", replica_groups),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(channel_handle) && push!(attributes, namedattribute("channel_handle", channel_handle)) @@ -881,8 +1112,8 @@ function collective_broadcast( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -906,17 +1137,17 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective_permute """ function collective_permute( operand::Value; - result_0=nothing::Union{Nothing,IR.Type}, - source_target_pairs, - channel_handle=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + source_target_pairs::IR.AbstractDenseElementsAttribute{Int64}, + channel_handle::Union{ChannelHandle,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("source_target_pairs", source_target_pairs),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(channel_handle) && push!(attributes, namedattribute("channel_handle", channel_handle)) @@ -927,8 +1158,8 @@ function collective_permute( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -949,10 +1180,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#compare function compare( lhs::Value, rhs::Value; - result_0=nothing::Union{Nothing,IR.Type}, - comparison_direction, - compare_type=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + comparison_direction::ComparisonDirection.T, + compare_type::Union{ComparisonType.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -961,7 +1192,7 @@ function compare( attributes = NamedAttribute[namedattribute( "comparison_direction", comparison_direction ),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(compare_type) && push!(attributes, namedattribute("compare_type", compare_type)) @@ -972,8 +1203,8 @@ function compare( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -990,7 +1221,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#complex ``` """ function complex( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1006,8 +1240,8 @@ function complex( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1040,14 +1274,14 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#composite """ function composite( inputs::Vector{Value}; - result_0::Vector{IR.Type}, - name, + result::Base.AbstractVecOrTuple{IR.Type}, + name::String, composite_attributes=nothing, - decomposition, - version=nothing, - location=Location(), + decomposition::IR.FlatSymbolRefAttribute, + version::Union{Int32,Nothing}=nothing, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[inputs...,] owned_regions = Region[] successors = Block[] @@ -1087,16 +1321,16 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#concatenate """ function concatenate( inputs::Vector{Value}; - result_0=nothing::Union{Nothing,IR.Type}, - dimension, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + dimension::Int64, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[inputs...,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("dimension", dimension),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "stablehlo.concatenate", @@ -1105,8 +1339,8 @@ function concatenate( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1123,7 +1357,11 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant %output = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> ``` """ -function constant(; output=nothing::Union{Nothing,IR.Type}, value, location=Location()) +function constant(; + output::Union{Nothing,IR.Type}=nothing, + value::IR.AbstractDenseElementsAttribute, + location::Location=Location(), +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1138,8 +1376,8 @@ function constant(; output=nothing::Union{Nothing,IR.Type}, value, location=Loca owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1157,7 +1395,7 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convert %result = stablehlo.convert %operand : (tensor<3xi64>) -> tensor<3xcomplex> ``` """ -function convert(operand::Value; result::IR.Type, location=Location()) +function convert(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -1206,19 +1444,19 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution function convolution( lhs::Value, rhs::Value; - result_0::IR.Type, - window_strides=nothing, - padding=nothing, - lhs_dilation=nothing, - rhs_dilation=nothing, - window_reversal=nothing, + result::IR.Type, + window_strides::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + padding::Union{IR.AbstractDenseElementsAttribute{Int64},Nothing}=nothing, + lhs_dilation::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + rhs_dilation::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + window_reversal::Union{IR.DenseAttribute{Bool},Nothing}=nothing, dimension_numbers, - feature_group_count, - batch_group_count, - precision_config=nothing, - location=Location(), + feature_group_count::Int64, + batch_group_count::Int64, + precision_config::Union{IR.DenseAttribute{Precision.T},Nothing}=nothing, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] @@ -1265,7 +1503,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cosine %result = stablehlo.cosine %operand : tensor<2xf32> ``` """ -function cosine(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function cosine( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -1280,8 +1520,8 @@ function cosine(operand::Value; result=nothing::Union{Nothing,IR.Type}, location owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1299,7 +1539,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#after_all %output = stablehlo.create_token : !stablehlo.token ``` """ -function create_token(; output=nothing::Union{Nothing,IR.Type}, location=Location()) +function create_token(; + output::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1314,8 +1556,8 @@ function create_token(; output=nothing::Union{Nothing,IR.Type}, location=Locatio owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1339,16 +1581,16 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_reduce """ function cross_replica_sum( operand::Value; - result_0=nothing::Union{Nothing,IR.Type}, - replica_groups, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + replica_groups::IR.AbstractDenseElementsAttribute{Int64}, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("replica_groups", replica_groups),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "stablehlo.cross-replica-sum", @@ -1357,8 +1599,8 @@ function cross_replica_sum( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1387,18 +1629,22 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#custom_call """ function custom_call( inputs::Vector{Value}; - result_0::Vector{IR.Type}, - call_target_name, - has_side_effect=nothing, - backend_config=nothing, - api_version=nothing, - called_computations=nothing, - operand_layouts=nothing, - result_layouts=nothing, - output_operand_aliases=nothing, - location=Location(), + result::Base.AbstractVecOrTuple{IR.Type}, + call_target_name::String, + has_side_effect::Union{Bool,Nothing}=nothing, + backend_config::Union{IR.AbstractAttribute,Nothing}=nothing, + api_version::Union{CustomCallApiVersion.T,Nothing}=nothing, + called_computations::Union{IR.DenseAttribute{IR.FlatSymbolRefAttribute},Nothing}=nothing, + operand_layouts::Union{ + IR.DenseAttribute{IR.AbstractDenseElementsAttribute{Int64}},Nothing + }=nothing, + result_layouts::Union{ + IR.DenseAttribute{IR.AbstractDenseElementsAttribute{Int64}},Nothing + }=nothing, + output_operand_aliases::Union{IR.DenseAttribute{OutputOperandAlias},Nothing}=nothing, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[inputs...,] owned_regions = Region[] successors = Block[] @@ -1444,7 +1690,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#divide ``` """ function divide( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -1460,8 +1709,8 @@ function divide( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1487,13 +1736,13 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general function dot_general( lhs::Value, rhs::Value; - result_0::IR.Type, - dot_dimension_numbers, - precision_config=nothing, - algorithm=nothing, - location=Location(), + result::IR.Type, + dot_dimension_numbers::Dot, + precision_config::Union{IR.DenseAttribute{Precision.T},Nothing}=nothing, + algorithm::Union{DotAlgorithm,Nothing}=nothing, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] @@ -1531,9 +1780,13 @@ https://www.tensorflow.org/xla/operation_semantics#dot ``` """ function dot( - lhs::Value, rhs::Value; result_0::IR.Type, precision_config=nothing, location=Location() + lhs::Value, + rhs::Value; + result::IR.Type, + precision_config::Union{IR.DenseAttribute{Precision.T},Nothing}=nothing, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] @@ -1582,13 +1835,13 @@ See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_broadca function dynamic_broadcast_in_dim( operand::Value, output_dimensions::Value; - result_0::IR.Type, - broadcast_dimensions, - known_expanding_dimensions=nothing, - known_nonexpanding_dimensions=nothing, - location=Location(), + result::IR.Type, + broadcast_dimensions::IR.DenseAttribute{Int64}, + known_expanding_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + known_nonexpanding_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result,] operands = Value[operand, output_dimensions] owned_regions = Region[] successors = Block[] @@ -1642,18 +1895,18 @@ function dynamic_conv( lhs::Value, rhs::Value, padding::Value; - result_0::IR.Type, - window_strides=nothing, - lhs_dilation=nothing, - rhs_dilation=nothing, - window_reversal=nothing, - dimension_numbers, - feature_group_count, - batch_group_count, - precision_config=nothing, - location=Location(), + result::IR.Type, + window_strides::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + lhs_dilation::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + rhs_dilation::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + window_reversal::Union{IR.DenseAttribute{Bool},Nothing}=nothing, + dimension_numbers::Attribute, + feature_group_count::Int64, + batch_group_count::Int64, + precision_config::Union{IR.DenseAttribute{Precision.T},Nothing}=nothing, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result,] operands = Value[lhs, rhs, padding] owned_regions = Region[] successors = Block[] @@ -1709,17 +1962,17 @@ function dynamic_gather( operand::Value, start_indices::Value, slice_sizes::Value; - result_0=nothing::Union{Nothing,IR.Type}, - dimension_numbers, - indices_are_sorted=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + dimension_numbers::Gather, + indices_are_sorted::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand, start_indices, slice_sizes] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("dimension_numbers", dimension_numbers),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(indices_are_sorted) && push!(attributes, namedattribute("indices_are_sorted", indices_are_sorted)) @@ -1730,8 +1983,8 @@ function dynamic_gather( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1752,7 +2005,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_iota ``` """ function dynamic_iota( - output_shape::Value; result::IR.Type, iota_dimension, location=Location() + output_shape::Value; + result::IR.Type, + iota_dimension::Int64, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[output_shape,] @@ -1800,7 +2056,7 @@ function dynamic_pad( edge_padding_high::Value, interior_padding::Value; result::IR.Type, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[ @@ -1839,7 +2095,7 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_reshape ``` """ function dynamic_reshape( - operand::Value, output_shape::Value; result::IR.Type, location=Location() + operand::Value, output_shape::Value; result::IR.Type, location::Location=Location() ) op_ty_results = IR.Type[result,] operands = Value[operand, output_shape] @@ -1877,9 +2133,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_slice function dynamic_slice( operand::Value, start_indices::Vector{Value}; - result=nothing::Union{Nothing,IR.Type}, - slice_sizes, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + slice_sizes::IR.DenseAttribute{Int64}, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand, start_indices...] @@ -1895,8 +2151,8 @@ function dynamic_slice( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1920,8 +2176,8 @@ function dynamic_update_slice( operand::Value, update::Value, start_indices::Vector{Value}; - result=nothing::Union{Nothing,IR.Type}, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand, update, start_indices...] @@ -1937,8 +2193,8 @@ function dynamic_update_slice( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1959,9 +2215,13 @@ https://www.tensorflow.org/api_docs/python/tf/einsum ``` """ function einsum( - lhs::Value, rhs::Value; result_0::IR.Type, einsum_config, location=Location() + lhs::Value, + rhs::Value; + result::IR.Type, + einsum_config::String, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] successors = Block[] @@ -1995,9 +2255,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential """ function exponential( operand::Value; - result=nothing::Union{Nothing,IR.Type}, + result::Union{Nothing,IR.Type}=nothing, result_accuracy=nothing, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -2015,8 +2275,8 @@ function exponential( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2035,7 +2295,7 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential_minus_on ``` """ function exponential_minus_one( - operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -2051,8 +2311,8 @@ function exponential_minus_one( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2072,10 +2332,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#fft """ function fft( operand::Value; - result_0=nothing::Union{Nothing,IR.Type}, - fft_type, - fft_length, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + fft_type::FftType.T, + fft_length::IR.DenseAttribute{Int64}, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -2084,7 +2344,7 @@ function fft( attributes = NamedAttribute[ namedattribute("fft_type", fft_type), namedattribute("fft_length", fft_length) ] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "stablehlo.fft", @@ -2093,8 +2353,8 @@ function fft( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2112,7 +2372,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#floor %result = stablehlo.floor %operand : tensor<2xf32> ``` """ -function floor(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function floor( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -2127,8 +2389,8 @@ function floor(operand::Value; result=nothing::Union{Nothing,IR.Type}, location= owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2159,11 +2421,11 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather function gather( operand::Value, start_indices::Value; - result=nothing::Union{Nothing,IR.Type}, - dimension_numbers, - slice_sizes, - indices_are_sorted=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + dimension_numbers::Gather, + slice_sizes::IR.DenseAttribute{Int64}, + indices_are_sorted::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand, start_indices] @@ -2184,8 +2446,8 @@ function gather( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2203,14 +2465,17 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#get_dimension_size ``` """ function get_dimension_size( - operand::Value; result_0=nothing::Union{Nothing,IR.Type}, dimension, location=Location() + operand::Value; + result::Union{Nothing,IR.Type}=nothing, + dimension::Int64, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("dimension", dimension),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "stablehlo.get_dimension_size", @@ -2219,8 +2484,8 @@ function get_dimension_size( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2239,14 +2504,17 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#get_tuple_element ``` """ function get_tuple_element( - operand::Value; result_0=nothing::Union{Nothing,IR.Type}, index, location=Location() + operand::Value; + result::Union{Nothing,IR.Type}=nothing, + index::Int32, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("index", index),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "stablehlo.get_tuple_element", @@ -2255,8 +2523,8 @@ function get_tuple_element( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2278,12 +2546,12 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#if """ function if_( pred::Value; - result_0::Vector{IR.Type}, + result::Base.AbstractVecOrTuple{IR.Type}, true_branch::Region, false_branch::Region, - location=Location(), + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[pred,] owned_regions = Region[true_branch, false_branch] successors = Block[] @@ -2315,7 +2583,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#imag %result = stablehlo.imag %operand : (tensor<2xcomplex>) -> tensor<2xf32> ``` """ -function imag(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function imag( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -2330,8 +2600,8 @@ function imag(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2351,12 +2621,12 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#infeed """ function infeed( token::Value; - result_0::Vector{IR.Type}, - infeed_config=nothing, - layout=nothing, - location=Location(), + result::Base.AbstractVecOrTuple{IR.Type}, + infeed_config::Union{String,Nothing}=nothing, + layout::Union{Vector{<:IR.AbstractAttribute},Nothing}=nothing, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[token,] owned_regions = Region[] successors = Block[] @@ -2391,7 +2661,7 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#iota %output = stablehlo.iota dim = 0 : tensor<4x5xi32> ``` """ -function iota(; output::IR.Type, iota_dimension, location=Location()) +function iota(; output::IR.Type, iota_dimension::Int64, location::Location=Location()) op_ty_results = IR.Type[output,] operands = Value[] owned_regions = Region[] @@ -2424,7 +2694,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#is_finite %y = stablehlo.is_finite %x : (tensor<7xf64>) -> tensor<7xi1> ``` """ -function is_finite(x::Value; y=nothing::Union{Nothing,IR.Type}, location=Location()) +function is_finite( + x::Value; y::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[x,] owned_regions = Region[] @@ -2439,8 +2711,8 @@ function is_finite(x::Value; y=nothing::Union{Nothing,IR.Type}, location=Locatio owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2459,7 +2731,7 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log_plus_one ``` """ function log_plus_one( - operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -2475,8 +2747,8 @@ function log_plus_one( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2494,7 +2766,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log %result = stablehlo.log %operand : tensor<2x2xf64> ``` """ -function log(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function log( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -2509,8 +2783,8 @@ function log(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Lo owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2529,7 +2803,7 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#logistic ``` """ function logistic( - operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -2545,8 +2819,8 @@ function logistic( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2572,12 +2846,12 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#map """ function map( inputs::Vector{Value}; - result_0::IR.Type, - dimensions, + result::IR.Type, + dimensions::IR.DenseAttribute{Int64}, computation::Region, - location=Location(), + location::Location=Location(), ) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result,] operands = Value[inputs...,] owned_regions = Region[computation,] successors = Block[] @@ -2610,7 +2884,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#maximum ``` """ function maximum( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -2626,8 +2903,8 @@ function maximum( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2646,7 +2923,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#minimum ``` """ function minimum( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -2662,8 +2942,8 @@ function minimum( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2682,7 +2962,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#multiply ``` """ function multiply( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -2698,8 +2981,8 @@ function multiply( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2717,7 +3000,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#negate %result = stablehlo.negate %operand : tensor<2x3xi32> ``` """ -function negate(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function negate( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -2732,8 +3017,8 @@ function negate(operand::Value; result=nothing::Union{Nothing,IR.Type}, location owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2751,7 +3036,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#not %result = stablehlo.not %operand : tensor<5x3x1xi1> ``` """ -function not(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function not( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -2766,8 +3053,8 @@ function not(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Lo owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2789,8 +3076,8 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#optimization_barrier """ function optimization_barrier( operand::Vector{Value}; - result=nothing::Union{Nothing,Vector{IR.Type}}, - location=Location(), + result::Union{Nothing,Base.AbstractVecOrTuple{IR.Type}}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand...,] @@ -2806,8 +3093,8 @@ function optimization_barrier( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2826,7 +3113,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#or ``` """ function or( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -2842,8 +3132,8 @@ function or( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2864,16 +3154,16 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#outfeed function outfeed( inputs::Vector{Value}, token::Value; - result_0=nothing::Union{Nothing,IR.Type}, - outfeed_config=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + outfeed_config::Union{String,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[inputs..., token] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(outfeed_config) && push!(attributes, namedattribute("outfeed_config", outfeed_config)) @@ -2884,8 +3174,8 @@ function outfeed( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2907,11 +3197,11 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad function pad( operand::Value, padding_value::Value; - result_0=nothing::Union{Nothing,IR.Type}, - edge_padding_low, - edge_padding_high, - interior_padding, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + edge_padding_low::IR.DenseAttribute{Int64}, + edge_padding_high::IR.DenseAttribute{Int64}, + interior_padding::IR.DenseAttribute{Int64}, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand, padding_value] @@ -2922,7 +3212,7 @@ function pad( namedattribute("edge_padding_high", edge_padding_high), namedattribute("interior_padding", interior_padding), ] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "stablehlo.pad", @@ -2931,8 +3221,8 @@ function pad( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2949,13 +3239,15 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#partition_id %result = stablehlo.partition_id : tensor ``` """ -function partition_id(; result_0=nothing::Union{Nothing,IR.Type}, location=Location()) +function partition_id(; + result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "stablehlo.partition_id", @@ -2964,8 +3256,8 @@ function partition_id(; result_0=nothing::Union{Nothing,IR.Type}, location=Locat owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -2983,7 +3275,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#popcnt %result = stablehlo.popcnt %operand : tensor<4xi64> ``` """ -function popcnt(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function popcnt( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -2998,8 +3292,8 @@ function popcnt(operand::Value; result=nothing::Union{Nothing,IR.Type}, location owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3018,7 +3312,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power ``` """ function power( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -3034,8 +3331,8 @@ function power( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3062,7 +3359,7 @@ function real_dynamic_slice( limit_indices::Value, strides::Value; result::IR.Type, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand, start_indices, limit_indices, strides] @@ -3096,7 +3393,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#real %result = stablehlo.real %operand : (tensor<2xcomplex>) -> tensor<2xf32> ``` """ -function real(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function real( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -3111,8 +3410,8 @@ function real(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3134,12 +3433,12 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#recv """ function recv( token::Value; - result_0::Vector{IR.Type}, - channel_handle, - is_host_transfer=nothing, - location=Location(), + result::Base.AbstractVecOrTuple{IR.Type}, + channel_handle::ChannelHandle, + is_host_transfer::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[token,] owned_regions = Region[] successors = Block[] @@ -3182,12 +3481,12 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce function reduce( inputs::Vector{Value}, init_values::Vector{Value}; - result_0::Vector{IR.Type}, - dimensions, + result::Base.AbstractVecOrTuple{IR.Type}, + dimensions::IR.DenseAttribute{Int64}, body::Region, - location=Location(), + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[inputs..., init_values...] owned_regions = Region[body,] successors = Block[] @@ -3222,10 +3521,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_precision """ function reduce_precision( operand::Value; - output=nothing::Union{Nothing,IR.Type}, - exponent_bits, - mantissa_bits, - location=Location(), + output::Union{Nothing,IR.Type}=nothing, + exponent_bits::Int32, + mantissa_bits::Int32, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -3244,8 +3543,8 @@ function reduce_precision( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3275,15 +3574,15 @@ scatters the split parts between the processes to produce the `result`. """ function reduce_scatter( operand::Value; - result_0::IR.Type, - scatter_dimension, - replica_groups, - channel_handle=nothing, - use_global_device_ids=nothing, + result::IR.Type, + scatter_dimension::Int64, + replica_groups::IR.AbstractDenseElementsAttribute{Int64}, + channel_handle::Union{ChannelHandle,Nothing}=nothing, + use_global_device_ids::Union{Bool,Nothing}=nothing, computation::Region, - location=Location(), + location::Location=Location(), ) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[computation,] successors = Block[] @@ -3335,16 +3634,16 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window function reduce_window( inputs::Vector{Value}, init_values::Vector{Value}; - result_0::Vector{IR.Type}, - window_dimensions, - window_strides=nothing, - base_dilations=nothing, - window_dilations=nothing, - padding=nothing, + result::Base.AbstractVecOrTuple{IR.Type}, + window_dimensions::IR.DenseAttribute{Int64}, + window_strides::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + base_dilations::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + window_dilations::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + padding::Union{IR.AbstractDenseElementsAttribute{Int64},Nothing}=nothing, body::Region, - location=Location(), + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[inputs..., init_values...] owned_regions = Region[body,] successors = Block[] @@ -3384,7 +3683,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#remainder ``` """ function remainder( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -3400,8 +3702,8 @@ function remainder( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3418,13 +3720,13 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#replica_id %result = stablehlo.replica_id : tensor ``` """ -function replica_id(; result_0=nothing::Union{Nothing,IR.Type}, location=Location()) +function replica_id(; result::Union{Nothing,IR.Type}=nothing, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "stablehlo.replica_id", @@ -3433,8 +3735,8 @@ function replica_id(; result_0=nothing::Union{Nothing,IR.Type}, location=Locatio owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3451,8 +3753,8 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reshape %result = stablehlo.reshape %operand : (tensor<2xf32>) -> tensor<1x2xf32> ``` """ -function reshape(operand::Value; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0,] +function reshape(operand::Value; result::IR.Type, location::Location=Location()) + op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] successors = Block[] @@ -3470,7 +3772,7 @@ function reshape(operand::Value; result_0::IR.Type, location=Location()) ) end -function return_(results::Vector{Value}; location=Location()) +function return_(results::Vector{Value}; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[results...,] owned_regions = Region[] @@ -3504,7 +3806,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reverse ``` """ function reverse( - operand::Value; result=nothing::Union{Nothing,IR.Type}, dimensions, location=Location() + operand::Value; + result::Union{Nothing,IR.Type}=nothing, + dimensions::IR.DenseAttribute{Int64}, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -3520,8 +3825,8 @@ function reverse( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3544,8 +3849,8 @@ function rng_bit_generator( initial_state::Value; output_state::IR.Type, output::IR.Type, - rng_algorithm, - location=Location(), + rng_algorithm::RngAlgorithm.T, + location::Location=Location(), ) op_ty_results = IR.Type[output_state, output] operands = Value[initial_state,] @@ -3583,9 +3888,9 @@ function rng( a::Value, b::Value, shape::Value; - result=nothing::Union{Nothing,IR.Type}, - rng_distribution, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + rng_distribution::RngDistribution.T, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[a, b, shape] @@ -3601,8 +3906,8 @@ function rng( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3622,7 +3927,7 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_even ``` """ function round_nearest_even( - operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -3638,8 +3943,8 @@ function round_nearest_even( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3658,7 +3963,7 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_afz ``` """ function round_nearest_afz( - operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -3674,8 +3979,8 @@ function round_nearest_afz( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3694,7 +3999,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rsqrt %result = stablehlo.rsqrt %operand : tensor<2x2xf32> ``` """ -function rsqrt(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function rsqrt( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -3709,8 +4016,8 @@ function rsqrt(operand::Value; result=nothing::Union{Nothing,IR.Type}, location= owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3747,14 +4054,14 @@ function scatter( inputs::Vector{Value}, scatter_indices::Value, updates::Vector{Value}; - result_0::Vector{IR.Type}, - scatter_dimension_numbers, - indices_are_sorted=nothing, - unique_indices=nothing, + result::Base.AbstractVecOrTuple{IR.Type}, + scatter_dimension_numbers::Scatter, + indices_are_sorted::Union{Bool,Nothing}=nothing, + unique_indices::Union{Bool,Nothing}=nothing, update_computation::Region, - location=Location(), + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[inputs..., scatter_indices, updates...] owned_regions = Region[update_computation,] successors = Block[] @@ -3809,15 +4116,15 @@ function select_and_scatter( operand::Value, source::Value, init_value::Value; - result_0::IR.Type, - window_dimensions=nothing, - window_strides=nothing, - padding=nothing, + result::IR.Type, + window_dimensions::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + window_strides::Union{IR.DenseAttribute{Int64},Nothing}=nothing, + padding::Union{IR.AbstractDenseElementsAttribute{Int64},Nothing}=nothing, select::Region, scatter::Region, - location=Location(), + location::Location=Location(), ) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result,] operands = Value[operand, source, init_value] owned_regions = Region[select, scatter] successors = Block[] @@ -3858,8 +4165,8 @@ function select( pred::Value, on_true::Value, on_false::Value; - result=nothing::Union{Nothing,IR.Type}, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[pred, on_true, on_false] @@ -3875,8 +4182,8 @@ function select( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3899,17 +4206,17 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#send function send( inputs::Vector{Value}, token::Value; - result_0=nothing::Union{Nothing,IR.Type}, - channel_handle, - is_host_transfer=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + channel_handle::ChannelHandle, + is_host_transfer::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[inputs..., token] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("channel_handle", channel_handle),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) !isnothing(is_host_transfer) && push!(attributes, namedattribute("is_host_transfer", is_host_transfer)) @@ -3920,8 +4227,8 @@ function send( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3942,16 +4249,16 @@ https://www.tensorflow.org/xla/operation_semantics#setdimensionsize function set_dimension_size( operand::Value, size::Value; - result_0=nothing::Union{Nothing,IR.Type}, - dimension, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + dimension::Int64, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand, size] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("dimension", dimension),] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "stablehlo.set_dimension_size", @@ -3960,8 +4267,8 @@ function set_dimension_size( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -3980,7 +4287,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_left ``` """ function shift_left( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -3996,8 +4306,8 @@ function shift_left( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -4016,7 +4326,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_right_arithmet ``` """ function shift_right_arithmetic( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -4032,8 +4345,8 @@ function shift_right_arithmetic( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -4052,7 +4365,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_right_logical ``` """ function shift_right_logical( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -4068,8 +4384,8 @@ function shift_right_logical( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -4087,7 +4403,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign %result = stablehlo.sign %operand : tensor<5xf64> ``` """ -function sign(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function sign( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -4102,8 +4420,8 @@ function sign(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -4121,7 +4439,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sine %result = stablehlo.sine %operand : tensor<2xf32> ``` """ -function sine(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function sine( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -4136,8 +4456,8 @@ function sine(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -4168,11 +4488,11 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice """ function slice( operand::Value; - result_0=nothing::Union{Nothing,IR.Type}, - start_indices, - limit_indices, - strides, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + start_indices::IR.DenseAttribute{Int64}, + limit_indices::IR.DenseAttribute{Int64}, + strides::IR.DenseAttribute{Int64}, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -4183,7 +4503,7 @@ function slice( namedattribute("limit_indices", limit_indices), namedattribute("strides", strides), ] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "stablehlo.slice", @@ -4192,8 +4512,8 @@ function slice( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -4220,13 +4540,13 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sort """ function sort( inputs::Vector{Value}; - result_0::Vector{IR.Type}, - dimension=nothing, - is_stable=nothing, + result::Base.AbstractVecOrTuple{IR.Type}, + dimension::Union{Int64,Nothing}=nothing, + is_stable::Union{Bool,Nothing}=nothing, comparator::Region, - location=Location(), + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[inputs...,] owned_regions = Region[comparator,] successors = Block[] @@ -4260,7 +4580,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sqrt %result = stablehlo.sqrt %operand : tensor<2x2xf32> ``` """ -function sqrt(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function sqrt( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -4275,8 +4597,8 @@ function sqrt(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -4295,7 +4617,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#subtract ``` """ function subtract( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -4311,8 +4636,8 @@ function subtract( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -4330,7 +4655,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tan %result = stablehlo.tan %operand : tensor<2x2xf64> ``` """ -function tan(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function tan( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -4345,8 +4672,8 @@ function tan(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Lo owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -4364,7 +4691,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tanh %result = stablehlo.tanh %operand : tensor<2xf32> ``` """ -function tanh(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function tanh( + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[operand,] owned_regions = Region[] @@ -4379,8 +4708,8 @@ function tanh(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -4407,9 +4736,14 @@ the index. ``` """ function torch_index_select( - operand::Value, index::Value; result_0::IR.Type, dim, batch_dims, location=Location() + operand::Value, + index::Value; + result::IR.Type, + dim::Int64, + batch_dims::Int64, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result,] operands = Value[operand, index] owned_regions = Region[] successors = Block[] @@ -4444,7 +4778,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose ``` """ function transpose( - operand::Value; result=nothing::Union{Nothing,IR.Type}, permutation, location=Location() + operand::Value; + result::Union{Nothing,IR.Type}=nothing, + permutation::IR.DenseAttribute{Int64}, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -4460,8 +4797,8 @@ function transpose( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -4487,12 +4824,12 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#triangular_solve function triangular_solve( a::Value, b::Value; - result_0=nothing::Union{Nothing,IR.Type}, - left_side, - lower, - unit_diagonal, - transpose_a, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + left_side::Bool, + lower::Bool, + unit_diagonal::Bool, + transpose_a::Transpose.T, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[a, b] @@ -4504,7 +4841,7 @@ function triangular_solve( namedattribute("unit_diagonal", unit_diagonal), namedattribute("transpose_a", transpose_a), ] - !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(result) && push!(op_ty_results, result) return create_operation( "stablehlo.triangular_solve", @@ -4513,8 +4850,8 @@ function triangular_solve( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -4532,7 +4869,9 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tuple ``` """ function tuple( - val::Vector{Value}; result=nothing::Union{Nothing,IR.Type}, location=Location() + val::Vector{Value}; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[val...,] @@ -4548,8 +4887,8 @@ function tuple( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -4569,8 +4908,10 @@ https://www.tensorflow.org/api_docs/python/tf/einsum } : (tensor<4x16xf32>) -> tensor<4xf32> ``` """ -function unary_einsum(operand::Value; result_0::IR.Type, einsum_config, location=Location()) - op_ty_results = IR.Type[result_0,] +function unary_einsum( + operand::Value; result::IR.Type, einsum_config::String, location::Location=Location() +) + op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] successors = Block[] @@ -4604,7 +4945,7 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#uniform_dequantize ``` """ function uniform_dequantize( - operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + operand::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[operand,] @@ -4620,8 +4961,8 @@ function uniform_dequantize( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -4640,7 +4981,7 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#uniform_quantize %result = stablehlo.uniform_quantize %operand : (tensor<2xf32>) -> tensor<2x!quant.uniform> ``` """ -function uniform_quantize(operand::Value; result::IR.Type, location=Location()) +function uniform_quantize(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -4683,12 +5024,12 @@ cond { """ function while_( operand::Vector{Value}; - result_0::Vector{IR.Type}, + result::Base.AbstractVecOrTuple{IR.Type}, cond::Region, body::Region, - location=Location(), + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[operand...,] owned_regions = Region[cond, body] successors = Block[] @@ -4721,7 +5062,10 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#xor ``` """ function xor( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -4737,8 +5081,8 @@ function xor( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end diff --git a/src/mlir/Dialects/TPU.jl b/src/mlir/Dialects/TPU.jl old mode 100644 new mode 100755 index 109fb6f98..bdba8bb3f --- a/src/mlir/Dialects/TPU.jl +++ b/src/mlir/Dialects/TPU.jl @@ -10,11 +10,73 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX + +""" +`ReductionKind` +Reduction kind +""" +@enumx ReductionKind SUM MAX MIN +const ReductionKindStorage = ["sum", "max", "min"] + +function IR.Attribute(e::ReductionKind.T) + return parse(Attribute, "#tpu>") +end + +""" +`RoundingMode` +Rounding mode +""" +@enumx RoundingMode kTowardsZero kToNearestEven +const RoundingModeStorage = ["towards_zero", "to_nearest_even"] + +function IR.Attribute(e::RoundingMode.T) + return parse(Attribute, "#tpu>") +end + +""" +`ContractPrecision` +Contraction precision +""" +@enumx ContractPrecision kBF16 kFP32 +const ContractPrecisionStorage = ["bf16", "fp32"] + +function IR.Attribute(e::ContractPrecision.T) + return parse( + Attribute, "#tpu>" + ) +end + +""" +`PackFormat` +Pack format +""" +@enumx PackFormat kCompressed kInterleaved +const PackFormatStorage = ["compressed", "interleaved"] + +function IR.Attribute(e::PackFormat.T) + return parse(Attribute, "#tpu>") +end + +""" +`CoreType` +Core type +""" +@enumx CoreType kTc kScScalarSubcore kScVectorSubcore +const CoreTypeStorage = ["tc", "sc_scalar_subcore", "sc_vector_subcore"] + +function IR.Attribute(e::CoreType.T) + return parse(Attribute, "#tpu>") +end function all_reduce( - input::Value; output=nothing::Union{Nothing,IR.Type}, dim, kind, location=Location() + input::Value; + output::Union{Nothing,IR.Type}=nothing, + dim::Int64, + kind::ReductionKind.T, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[input,] @@ -30,12 +92,12 @@ function all_reduce( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function sem_alloc(; result::IR.Type, location=Location()) +function sem_alloc(; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[] owned_regions = Region[] @@ -54,7 +116,7 @@ function sem_alloc(; result::IR.Type, location=Location()) ) end -function assume_layout(input::Value; result::IR.Type, location=Location()) +function assume_layout(input::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[input,] owned_regions = Region[] @@ -74,7 +136,10 @@ function assume_layout(input::Value; result::IR.Type, location=Location()) end function assume_multiple( - value::Value; result=nothing::Union{Nothing,IR.Type}, multiple, location=Location() + value::Value; + result::Union{Nothing,IR.Type}=nothing, + multiple::Int32, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[value,] @@ -90,12 +155,12 @@ function assume_multiple( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function bitcast(input::Value; output::IR.Type, location=Location()) +function bitcast(input::Value; output::IR.Type, location::Location=Location()) op_ty_results = IR.Type[output,] operands = Value[input,] owned_regions = Region[] @@ -114,7 +179,7 @@ function bitcast(input::Value; output::IR.Type, location=Location()) ) end -function bitcast_vreg(input::Value; output::IR.Type, location=Location()) +function bitcast_vreg(input::Value; output::IR.Type, location::Location=Location()) op_ty_results = IR.Type[output,] operands = Value[input,] owned_regions = Region[] @@ -140,7 +205,9 @@ For each sublane `i`, broadcasts the value in lane `lane + i` along the entire sublane. If `lane + i` is not in [0, lane_count), then the value in sublane `i` is not defined (can be anything). """ -function broadcast_in_sublanes(source::Value; output::IR.Type, lane, location=Location()) +function broadcast_in_sublanes( + source::Value; output::IR.Type, lane::Int32, location::Location=Location() +) op_ty_results = IR.Type[output,] operands = Value[source,] owned_regions = Region[] @@ -160,7 +227,7 @@ function broadcast_in_sublanes(source::Value; output::IR.Type, lane, location=Lo end function concatenate( - sources::Vector{Value}; output::IR.Type, dimension, location=Location() + sources::Vector{Value}; output::IR.Type, dimension::Int32, location::Location=Location() ) op_ty_results = IR.Type[output,] operands = Value[sources...,] @@ -181,7 +248,7 @@ function concatenate( end function create_mask( - low::Vector{Value}, high::Vector{Value}; output::IR.Type, location=Location() + low::Vector{Value}, high::Vector{Value}; output::IR.Type, location::Location=Location() ) op_ty_results = IR.Type[output,] operands = Value[low..., high...] @@ -229,7 +296,9 @@ It is currently only supported: - In TPU v4, for `num_subelems` of 1 and 2. - In TPU v5, for `num_subelems` of 1, 2, and 4. """ -function create_subelement_mask(; output::IR.Type, from, to, location=Location()) +function create_subelement_mask(; + output::IR.Type, from::Int32, to::Int32, location::Location=Location() +) op_ty_results = IR.Type[output,] operands = Value[] owned_regions = Region[] @@ -248,7 +317,7 @@ function create_subelement_mask(; output::IR.Type, from, to, location=Location() ) end -function delay(nanos::Value; location=Location()) +function delay(nanos::Value; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[nanos,] owned_regions = Region[] @@ -267,7 +336,7 @@ function delay(nanos::Value; location=Location()) ) end -function device_id(; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function device_id(; result::Union{Nothing,IR.Type}=nothing, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -282,13 +351,17 @@ function device_id(; result=nothing::Union{Nothing,IR.Type}, location=Location() owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function dynamic_gather( - source::Value, indices::Value; output::IR.Type, dimension, location=Location() + source::Value, + indices::Value; + output::IR.Type, + dimension::Int32, + location::Location=Location(), ) op_ty_results = IR.Type[output,] operands = Value[source, indices] @@ -312,10 +385,10 @@ function dynamic_rotate( value::Value, amount::Value; result::IR.Type, - dimension, - stride=nothing, - stride_dimension=nothing, - location=Location(), + dimension::Int32, + stride::Union{Int32,Nothing}=nothing, + stride_dimension::Union{Int32,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[value, amount] @@ -340,12 +413,12 @@ end function enqueue_dma( source::Value, - source_semaphore=nothing::Union{Nothing,Value}; + source_semaphore::Union{Nothing,Value}=nothing; target::Value, target_semaphore::Value, - device_id=nothing::Union{Nothing,Value}, - core_id=nothing::Union{Nothing,Value}, - location=Location(), + device_id::Union{Nothing,Value}=nothing, + core_id::Union{Nothing,Value}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[source, target, target_semaphore] @@ -383,7 +456,7 @@ function enqueue_dma( ) end -function erase_memref_layout(operand::Value; result::IR.Type, location=Location()) +function erase_memref_layout(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -402,7 +475,12 @@ function erase_memref_layout(operand::Value; result::IR.Type, location=Location( ) end -function fptosi(input::Value; output::IR.Type, rounding_mode, location=Location()) +function fptosi( + input::Value; + output::IR.Type, + rounding_mode::RoundingMode.T, + location::Location=Location(), +) op_ty_results = IR.Type[output,] operands = Value[input,] owned_regions = Region[] @@ -421,7 +499,13 @@ function fptosi(input::Value; output::IR.Type, rounding_mode, location=Location( ) end -function gather(source::Value; output::IR.Type, indices, dimension, location=Location()) +function gather( + source::Value; + output::IR.Type, + indices::IR.DenseAttribute{Int32}, + dimension::Int32, + location::Location=Location(), +) op_ty_results = IR.Type[output,] operands = Value[source,] owned_regions = Region[] @@ -442,7 +526,7 @@ function gather(source::Value; output::IR.Type, indices, dimension, location=Loc ) end -function sem_barrier(; semaphore::IR.Type, location=Location()) +function sem_barrier(; semaphore::IR.Type, location::Location=Location()) op_ty_results = IR.Type[semaphore,] operands = Value[] owned_regions = Region[] @@ -461,7 +545,7 @@ function sem_barrier(; semaphore::IR.Type, location=Location()) ) end -function internal_scratch(; result::IR.Type, location=Location()) +function internal_scratch(; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[] owned_regions = Region[] @@ -480,7 +564,9 @@ function internal_scratch(; result::IR.Type, location=Location()) ) end -function iteration_bound(; result=nothing::Union{Nothing,IR.Type}, dim, location=Location()) +function iteration_bound(; + result::Union{Nothing,IR.Type}=nothing, dim::Int32, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -495,12 +581,14 @@ function iteration_bound(; result=nothing::Union{Nothing,IR.Type}, dim, location owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function iota(; output::IR.Type, dimension=nothing, location=Location()) +function iota(; + output::IR.Type, dimension::Union{Int32,Nothing}=nothing, location::Location=Location() +) op_ty_results = IR.Type[output,] operands = Value[] owned_regions = Region[] @@ -524,9 +612,9 @@ function load( base::Value, indices::Vector{Value}; result::IR.Type, - sublane_mask, - sublane_stride=nothing, - location=Location(), + sublane_mask::IR.DenseAttribute{Bool}, + sublane_stride::Union{Int32,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[base, indices...] @@ -548,7 +636,12 @@ function load( ) end -function log_buffer(input::Value; shape, tag, location=Location()) +function log_buffer( + input::Value; + shape::IR.DenseAttribute{Int64}, + tag::String, + location::Location=Location(), +) op_ty_results = IR.Type[] operands = Value[input,] owned_regions = Region[] @@ -567,7 +660,12 @@ function log_buffer(input::Value; shape, tag, location=Location()) ) end -function log(inputs::Vector{Value}; tag, formatted=nothing, location=Location()) +function log( + inputs::Vector{Value}; + tag::String, + formatted::Union{Bool,Nothing}=nothing, + location::Location=Location(), +) op_ty_results = IR.Type[] operands = Value[inputs...,] owned_regions = Region[] @@ -587,7 +685,7 @@ function log(inputs::Vector{Value}; tag, formatted=nothing, location=Location()) ) end -function mask_cast(input::Value; result::IR.Type, location=Location()) +function mask_cast(input::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[input,] owned_regions = Region[] @@ -611,11 +709,11 @@ function matmul( rhs::Value, acc::Value; result::IR.Type, - transpose_lhs=nothing, - transpose_rhs=nothing, - precision=nothing, + transpose_lhs::Union{Bool,Nothing}=nothing, + transpose_rhs::Union{Bool,Nothing}=nothing, + precision::Union{ContractPrecision.T,Nothing}=nothing, dimension_numbers=nothing, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs, acc] @@ -642,7 +740,7 @@ function matmul( ) end -function memref_bitcast(input::Value; result::IR.Type, location=Location()) +function memref_bitcast(input::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[input,] owned_regions = Region[] @@ -661,7 +759,7 @@ function memref_bitcast(input::Value; result::IR.Type, location=Location()) ) end -function memref_reshape(input::Value; result::IR.Type, location=Location()) +function memref_reshape(input::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[input,] owned_regions = Region[] @@ -685,7 +783,7 @@ function memref_slice( base_idx::Vector{Value}, dynamic_sizes::Vector{Value}; result::IR.Type, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[mem_ref, base_idx..., dynamic_sizes...] @@ -706,7 +804,7 @@ function memref_slice( ) end -function memref_squeeze(input::Value; result::IR.Type, location=Location()) +function memref_squeeze(input::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[input,] owned_regions = Region[] @@ -725,7 +823,7 @@ function memref_squeeze(input::Value; result::IR.Type, location=Location()) ) end -function prng_random_bits(; output::IR.Type, location=Location()) +function prng_random_bits(; output::IR.Type, location::Location=Location()) op_ty_results = IR.Type[output,] operands = Value[] owned_regions = Region[] @@ -744,7 +842,7 @@ function prng_random_bits(; output::IR.Type, location=Location()) ) end -function prng_set_seed_32(seeds::Vector{Value}; location=Location()) +function prng_set_seed_32(seeds::Vector{Value}; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[seeds...,] owned_regions = Region[] @@ -763,7 +861,7 @@ function prng_set_seed_32(seeds::Vector{Value}; location=Location()) ) end -function pack_vmsk(low::Value, high::Value; output::IR.Type, location=Location()) +function pack_vmsk(low::Value, high::Value; output::IR.Type, location::Location=Location()) op_ty_results = IR.Type[output,] operands = Value[low, high] owned_regions = Region[] @@ -783,7 +881,11 @@ function pack_vmsk(low::Value, high::Value; output::IR.Type, location=Location() end function pack_subelements( - sources::Vector{Value}; output::IR.Type, positions, pack_format, location=Location() + sources::Vector{Value}; + output::IR.Type, + positions::IR.DenseAttribute{Int32}, + pack_format::PackFormat.T, + location::Location=Location(), ) op_ty_results = IR.Type[output,] operands = Value[sources...,] @@ -805,7 +907,9 @@ function pack_subelements( ) end -function region(; results::Vector{IR.Type}, region::Region, location=Location()) +function region(; + results::Base.AbstractVecOrTuple{IR.Type}, region::Region, location::Location=Location() +) op_ty_results = IR.Type[results...,] operands = Value[] owned_regions = Region[region,] @@ -824,7 +928,7 @@ function region(; results::Vector{IR.Type}, region::Region, location=Location()) ) end -function reinterpret_cast(input::Value; result::IR.Type, location=Location()) +function reinterpret_cast(input::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[input,] owned_regions = Region[] @@ -843,7 +947,9 @@ function reinterpret_cast(input::Value; result::IR.Type, location=Location()) ) end -function relayout(input::Value; output=nothing::Union{Nothing,IR.Type}, location=Location()) +function relayout( + input::Value; output::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[input,] owned_regions = Region[] @@ -858,12 +964,18 @@ function relayout(input::Value; output=nothing::Union{Nothing,IR.Type}, location owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function repeat(source::Value; output::IR.Type, dimension, times, location=Location()) +function repeat( + source::Value; + output::IR.Type, + dimension::Int32, + times::Int32, + location::Location=Location(), +) op_ty_results = IR.Type[output,] operands = Value[source,] owned_regions = Region[] @@ -884,7 +996,7 @@ function repeat(source::Value; output::IR.Type, dimension, times, location=Locat ) end -function roll_vectors(input::Vector{Value}; output::IR.Type, location=Location()) +function roll_vectors(input::Vector{Value}; output::IR.Type, location::Location=Location()) op_ty_results = IR.Type[output,] operands = Value[input...,] owned_regions = Region[] @@ -905,12 +1017,12 @@ end function rotate( value::Value; - result=nothing::Union{Nothing,IR.Type}, - amount, - dimension, - stride=nothing, - stride_dimension=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + amount::Int32, + dimension::Int32, + stride::Union{Int32,Nothing}=nothing, + stride_dimension::Union{Int32,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[value,] @@ -931,13 +1043,13 @@ function rotate( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function sem_read( - semaphore::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + semaphore::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[semaphore,] @@ -953,18 +1065,18 @@ function sem_read( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function sem_signal( semaphore::Value, amount::Value, - device_id=nothing::Union{Nothing,Value}; - core_id=nothing::Union{Nothing,Value}, - core_type=nothing, - location=Location(), + device_id::Union{Nothing,Value}=nothing; + core_id::Union{Nothing,Value}=nothing, + core_type::Union{CoreType.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[semaphore, amount] @@ -998,7 +1110,7 @@ function sem_signal( ) end -function sem_wait(semaphore::Value, amount::Value; location=Location()) +function sem_wait(semaphore::Value, amount::Value; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[semaphore, amount] owned_regions = Region[] @@ -1021,9 +1133,9 @@ function shuffled_load( base::Value, indices::Vector{Value}; result::IR.Type, - sublane_mask, - sublane_offsets, - location=Location(), + sublane_mask::IR.DenseAttribute{Bool}, + sublane_offsets::IR.DenseAttribute{Int32}, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[base, indices...] @@ -1050,9 +1162,9 @@ function shuffled_store( valueToStore::Value, base::Value, indices::Vector{Value}; - sublane_mask, - sublane_offsets, - location=Location(), + sublane_mask::IR.DenseAttribute{Bool}, + sublane_offsets::IR.DenseAttribute{Int32}, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[valueToStore, base, indices...] @@ -1079,10 +1191,10 @@ function store( valueToStore::Value, base::Value, indices::Vector{Value}, - mask=nothing::Union{Nothing,Value}; - sublane_mask, - sublane_stride=nothing, - location=Location(), + mask::Union{Nothing,Value}=nothing; + sublane_mask::IR.DenseAttribute{Bool}, + sublane_stride::Union{Int32,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[valueToStore, base, indices...] @@ -1109,7 +1221,11 @@ function store( end function strided_load( - base::Value, indices::Vector{Value}; result::IR.Type, strides, location=Location() + base::Value, + indices::Vector{Value}; + result::IR.Type, + strides::IR.DenseAttribute{Int32}, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[base, indices...] @@ -1130,7 +1246,11 @@ function strided_load( end function strided_store( - valueToStore::Value, base::Value, indices::Vector{Value}; strides, location=Location() + valueToStore::Value, + base::Value, + indices::Vector{Value}; + strides::IR.DenseAttribute{Int32}, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[valueToStore, base, indices...] @@ -1151,7 +1271,11 @@ function strided_store( end function trace(; - results::Vector{IR.Type}, message, level, region::Region, location=Location() + results::Base.AbstractVecOrTuple{IR.Type}, + message::String, + level::Int32, + region::Region, + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[] @@ -1173,7 +1297,7 @@ function trace(; ) end -function trace_start(; message, level, location=Location()) +function trace_start(; message::String, level::Int32, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1194,7 +1318,7 @@ function trace_start(; message, level, location=Location()) ) end -function trace_stop(; location=Location()) +function trace_stop(; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -1214,7 +1338,11 @@ function trace_stop(; location=Location()) end function unpack_subelements( - source::Value; output::IR.Type, index, pack_format, location=Location() + source::Value; + output::IR.Type, + index::Int32, + pack_format::PackFormat.T, + location::Location=Location(), ) op_ty_results = IR.Type[output,] operands = Value[source,] @@ -1236,7 +1364,9 @@ function unpack_subelements( ) end -function unroll_vectors(input::Value; output::Vector{IR.Type}, location=Location()) +function unroll_vectors( + input::Value; output::Base.AbstractVecOrTuple{IR.Type}, location::Location=Location() +) op_ty_results = IR.Type[output...,] operands = Value[input,] owned_regions = Region[] @@ -1259,9 +1389,9 @@ function vector_store( valueToStore::Value, base::Value, indices::Vector{Value}, - mask=nothing::Union{Nothing,Value}; - strides, - location=Location(), + mask::Union{Nothing,Value}=nothing; + strides::IR.DenseAttribute{Int32}, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[valueToStore, base, indices...] @@ -1285,7 +1415,7 @@ function vector_store( ) end -function wait_dma(semaphore::Value, ref::Value; location=Location()) +function wait_dma(semaphore::Value, ref::Value; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[semaphore, ref] owned_regions = Region[] @@ -1304,7 +1434,7 @@ function wait_dma(semaphore::Value, ref::Value; location=Location()) ) end -function weird(input::Value; output::IR.Type, location=Location()) +function weird(input::Value; output::IR.Type, location::Location=Location()) op_ty_results = IR.Type[output,] operands = Value[input,] owned_regions = Region[] @@ -1323,7 +1453,7 @@ function weird(input::Value; output::IR.Type, location=Location()) ) end -function yield(results::Vector{Value}; location=Location()) +function yield(results::Vector{Value}; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[results...,] owned_regions = Region[] diff --git a/src/mlir/Dialects/Triton.jl b/src/mlir/Dialects/Triton.jl index 36122eeee..9542b26a6 100755 --- a/src/mlir/Dialects/Triton.jl +++ b/src/mlir/Dialects/Triton.jl @@ -10,8 +10,98 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX + +""" +`MemSemantic` +allowed 32-bit signless integer cases: 1, 2, 3, 4 +""" +@enumx MemSemantic RELAXED = 1 ACQUIRE = 2 RELEASE = 3 ACQUIRE_RELEASE = 4 + +IR.Attribute(e::MemSemantic.T) = Int(e) + +""" +`MemSyncScope` +allowed 32-bit signless integer cases: 1, 2, 3 +""" +@enumx MemSyncScope GPU = 1 CTA = 2 SYSTEM = 3 + +IR.Attribute(e::MemSyncScope.T) = Int(e) + +""" +`RMWOp` +allowed 32-bit signless integer cases: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 +""" +@enumx RMWOp AND = 1 OR = 2 XOR = 3 ADD = 4 FADD = 5 MAX = 6 MIN = 7 UMAX = 8 UMIN = 9 XCHG = + 10 + +IR.Attribute(e::RMWOp.T) = Int(e) + +""" +`PropagateNan` +allowed 32-bit signless integer cases: 0, 65535 +""" +@enumx PropagateNan NONE = 0 ALL = 65535 + +IR.Attribute(e::PropagateNan.T) = Int(e) + +""" +`InputPrecision` +allowed 32-bit signless integer cases: 0, 1, 2 +""" +@enumx InputPrecision TF32 = 0 TF32x3 = 1 IEEE = 2 + +IR.Attribute(e::InputPrecision.T) = Int(e) + +""" +`ScaleDotElemType` +allowed 32-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6 +""" +@enumx ScaleDotElemType E4M3 = 0 E5M2 = 1 E2M3 = 2 E3M2 = 3 E2M1 = 4 BF16 = 5 FP16 = 6 + +IR.Attribute(e::ScaleDotElemType.T) = Int(e) + +""" +`CacheModifier` +allowed 32-bit signless integer cases: 1, 2, 3, 4, 5, 6, 7 +""" +@enumx CacheModifier NONE = 1 CA = 2 CG = 3 WB = 4 CS = 5 WT = 6 CV = 7 + +IR.Attribute(e::CacheModifier.T) = Int(e) + +""" +`EvictionPolicy` +allowed 32-bit signless integer cases: 1, 2, 3 +""" +@enumx EvictionPolicy NORMAL = 1 EVICT_FIRST = 2 EVICT_LAST = 3 + +IR.Attribute(e::EvictionPolicy.T) = Int(e) + +""" +`RoundingMode` +allowed 32-bit signless integer cases: 0, 1 +""" +@enumx RoundingMode RTZ = 0 RTNE = 1 + +IR.Attribute(e::RoundingMode.T) = Int(e) + +""" +`ProgramIDDim` +allowed 32-bit signless integer cases: 0, 1, 2 +""" +@enumx ProgramIDDim X = 0 Y = 1 Z = 2 + +IR.Attribute(e::ProgramIDDim.T) = Int(e) + +""" +`PaddingOption` +allowed 32-bit signless integer cases: 1, 2 +""" +@enumx PaddingOption PAD_ZERO = 1 PAD_NAN = 2 + +IR.Attribute(e::PaddingOption.T) = Int(e) """ `call` @@ -28,9 +118,12 @@ symbol reference attribute named \"callee\". ``` """ function call( - operands::Vector{Value}; result_0::Vector{IR.Type}, callee, location=Location() + operands::Vector{Value}; + result::Base.AbstractVecOrTuple{IR.Type}, + callee::IR.FlatSymbolRefAttribute, + location::Location=Location(), ) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result...,] operands = Value[operands...,] owned_regions = Region[] successors = Block[] @@ -88,13 +181,13 @@ tt.func @example_fn_attr() attributes {dialectName.attrName = false} ``` """ function func(; - sym_name, - function_type, - sym_visibility=nothing, - arg_attrs=nothing, - res_attrs=nothing, + sym_name::String, + function_type::IR.Type, + sym_visibility::Union{String,Nothing}=nothing, + arg_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, + res_attrs::Union{IR.DenseAttribute{<:Any},Nothing}=nothing, body::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -126,7 +219,9 @@ end This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects. Ideally, we can remove this once the APIs are fully fleshed out. """ -function reinterpret_tensor_descriptor(rawDesc::Value; result::IR.Type, location=Location()) +function reinterpret_tensor_descriptor( + rawDesc::Value; result::IR.Type, location::Location=Location() +) op_ty_results = IR.Type[result,] operands = Value[rawDesc,] owned_regions = Region[] @@ -162,7 +257,7 @@ tt.func @foo() : (i32, f8) { } ``` """ -function return_(srcs::Vector{Value}; location=Location()) +function return_(srcs::Vector{Value}; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[srcs...,] owned_regions = Region[] @@ -181,7 +276,7 @@ function return_(srcs::Vector{Value}; location=Location()) ) end -function addptr(ptr::Value, offset::Value; result::IR.Type, location=Location()) +function addptr(ptr::Value, offset::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[ptr, offset] owned_regions = Region[] @@ -200,7 +295,9 @@ function addptr(ptr::Value, offset::Value; result::IR.Type, location=Location()) ) end -function advance(ptr::Value, offsets::Vector{Value}; result::IR.Type, location=Location()) +function advance( + ptr::Value, offsets::Vector{Value}; result::IR.Type, location::Location=Location() +) op_ty_results = IR.Type[result,] operands = Value[ptr, offsets...] owned_regions = Region[] @@ -225,7 +322,7 @@ end `tt.assert` takes a condition tensor and a message string. If the condition is false, the message is printed, and the program is aborted. """ -function assert(condition::Value; message, location=Location()) +function assert(condition::Value; message::String, location::Location=Location()) op_ty_results = IR.Type[] operands = Value[condition,] owned_regions = Region[] @@ -256,7 +353,13 @@ else store \$old to \$ptr, return \$old """ function atomic_cas( - ptr::Value, cmp::Value, val::Value; result::IR.Type, sem, scope, location=Location() + ptr::Value, + cmp::Value, + val::Value; + result::IR.Type, + sem::MemSemantic.T, + scope::MemSyncScope.T, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[ptr, cmp, val] @@ -286,12 +389,12 @@ return old value at \$ptr function atomic_rmw( ptr::Value, val::Value, - mask=nothing::Union{Nothing,Value}; + mask::Union{Nothing,Value}=nothing; result::IR.Type, - atomic_rmw_op, - sem, - scope, - location=Location(), + atomic_rmw_op::RMWOp.T, + sem::MemSemantic.T, + scope::MemSyncScope.T, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[ptr, val] @@ -316,7 +419,7 @@ function atomic_rmw( ) end -function bitcast(src::Value; result::IR.Type, location=Location()) +function bitcast(src::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[src,] owned_regions = Region[] @@ -342,7 +445,7 @@ For a given tensor, broadcast changes one or more dimensions with size 1 to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot change the size of a non-1 dimension. """ -function broadcast(src::Value; result::IR.Type, location=Location()) +function broadcast(src::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[src,] owned_regions = Region[] @@ -361,7 +464,7 @@ function broadcast(src::Value; result::IR.Type, location=Location()) ) end -function cat(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function cat(lhs::Value, rhs::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] @@ -391,9 +494,9 @@ function clampf( x::Value, min::Value, max::Value; - result=nothing::Union{Nothing,IR.Type}, - propagateNan, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + propagateNan::PropagateNan.T, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[x, min, max] @@ -409,8 +512,8 @@ function clampf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -428,10 +531,10 @@ function dot( a::Value, b::Value, c::Value; - d=nothing::Union{Nothing,IR.Type}, - inputPrecision=nothing, - maxNumImpreciseAcc=nothing, - location=Location(), + d::Union{Nothing,IR.Type}=nothing, + inputPrecision::Union{InputPrecision.T,Nothing}=nothing, + maxNumImpreciseAcc::Union{Int32,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[a, b, c] @@ -451,8 +554,8 @@ function dot( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -466,13 +569,13 @@ function dot_scaled( lhs::Value, rhs::Value, c::Value, - lhs_scale=nothing::Union{Nothing,Value}; - rhs_scale=nothing::Union{Nothing,Value}, + lhs_scale::Union{Nothing,Value}=nothing; + rhs_scale::Union{Nothing,Value}=nothing, d::IR.Type, - lhs_type, - rhs_type, - fastMath, - location=Location(), + lhs_type::ScaleDotElemType.T, + rhs_type::ScaleDotElemType.T, + fastMath::Bool, + location::Location=Location(), ) op_ty_results = IR.Type[d,] operands = Value[lhs, rhs, c] @@ -520,12 +623,12 @@ elems it receives is unspecified. """ function elementwise_inline_asm( args::Vector{Value}; - result::Vector{IR.Type}, - asm_string, - constraints, - pure, - packed_element, - location=Location(), + result::Base.AbstractVecOrTuple{IR.Type}, + asm_string::String, + constraints::String, + pure::Bool, + packed_element::Int32, + location::Location=Location(), ) op_ty_results = IR.Type[result...,] operands = Value[args...,] @@ -551,7 +654,10 @@ function elementwise_inline_asm( end function expand_dims( - src::Value; result=nothing::Union{Nothing,IR.Type}, axis, location=Location() + src::Value; + result::Union{Nothing,IR.Type}=nothing, + axis::Int32, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[src,] @@ -567,8 +673,8 @@ function expand_dims( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -586,7 +692,11 @@ This is an escape hatch and is only there for testing/experimenting. This op will be removed in the future. """ function experimental_descriptor_gather( - desc::Value, x_offsets::Value, y_offset::Value; result::IR.Type, location=Location() + desc::Value, + x_offsets::Value, + y_offset::Value; + result::IR.Type, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[desc, x_offsets, y_offset] @@ -620,9 +730,9 @@ function experimental_descriptor_load( desc::Value, indices::Vector{Value}; result::IR.Type, - cache=nothing, - evict=nothing, - location=Location(), + cache::Union{CacheModifier.T,Nothing}=nothing, + evict::Union{EvictionPolicy.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[desc, indices...] @@ -645,7 +755,11 @@ function experimental_descriptor_load( end function experimental_descriptor_scatter( - desc::Value, x_offsets::Value, y_offset::Value, src::Value; location=Location() + desc::Value, + x_offsets::Value, + y_offset::Value, + src::Value; + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[desc, x_offsets, y_offset, src] @@ -676,7 +790,7 @@ This is an escape hatch and is only there for testing/experimenting. This op will be removed in the future. """ function experimental_descriptor_store( - desc::Value, src::Value, indices::Vector{Value}; location=Location() + desc::Value, src::Value, indices::Vector{Value}; location::Location=Location() ) op_ty_results = IR.Type[] operands = Value[desc, src, indices...] @@ -703,11 +817,11 @@ function experimental_tensormap_create( global_dim::Vector{Value}, global_stride::Vector{Value}, element_stride::Vector{Value}; - elem_type, - interleave_layout, - swizzle_mode, - fill_mode, - location=Location(), + elem_type::Int32, + interleave_layout::Int32, + swizzle_mode::Int32, + fill_mode::Int32, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[ @@ -750,7 +864,9 @@ function experimental_tensormap_create( ) end -function experimental_tensormap_fenceproxy_acquire(desc_ptr::Value; location=Location()) +function experimental_tensormap_fenceproxy_acquire( + desc_ptr::Value; location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[desc_ptr,] owned_regions = Region[] @@ -778,11 +894,11 @@ return \$libpath/\$libname:\$symbol(\$args...) function extern_elementwise( srcs::Vector{Value}; result::IR.Type, - libname, - libpath, - symbol, - pure, - location=Location(), + libname::String, + libpath::String, + symbol::String, + pure::Bool, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[srcs...,] @@ -814,7 +930,12 @@ Floating point casting for custom types (F8), and non-default rounding modes. F8 <-> FP16, BF16, FP32, FP64 """ -function fp_to_fp(src::Value; result::IR.Type, rounding=nothing, location=Location()) +function fp_to_fp( + src::Value; + result::IR.Type, + rounding::Union{RoundingMode.T,Nothing}=nothing, + location::Location=Location(), +) op_ty_results = IR.Type[result,] operands = Value[src,] owned_regions = Region[] @@ -851,10 +972,10 @@ changed. function gather( src::Value, indices::Value; - result=nothing::Union{Nothing,IR.Type}, - axis, - efficient_layout=nothing, - location=Location(), + result::Union{Nothing,IR.Type}=nothing, + axis::Int32, + efficient_layout::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[src, indices] @@ -872,13 +993,15 @@ function gather( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function get_num_programs(; - result=nothing::Union{Nothing,IR.Type}, axis, location=Location() + result::Union{Nothing,IR.Type}=nothing, + axis::ProgramIDDim.T, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -894,12 +1017,16 @@ function get_num_programs(; owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end -function get_program_id(; result=nothing::Union{Nothing,IR.Type}, axis, location=Location()) +function get_program_id(; + result::Union{Nothing,IR.Type}=nothing, + axis::ProgramIDDim.T, + location::Location=Location(), +) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] @@ -914,8 +1041,8 @@ function get_program_id(; result=nothing::Union{Nothing,IR.Type}, axis, location owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -926,7 +1053,7 @@ Return the histogram of the input tensor. The number of bins is equal to the dimension of the output tensor. Each bins has a width of 1 and bins start at 0. """ -function histogram(src::Value; result::IR.Type, location=Location()) +function histogram(src::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[src,] owned_regions = Region[] @@ -945,7 +1072,7 @@ function histogram(src::Value; result::IR.Type, location=Location()) ) end -function int_to_ptr(src::Value; result::IR.Type, location=Location()) +function int_to_ptr(src::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[src,] owned_regions = Region[] @@ -974,7 +1101,10 @@ Because Triton tensors always have a power-of-two number of elements, the two input tensors must have the same shape. """ function join( - lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + lhs::Value, + rhs::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[lhs, rhs] @@ -990,22 +1120,22 @@ function join( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function load( ptr::Value, - mask=nothing::Union{Nothing,Value}; - other=nothing::Union{Nothing,Value}, - result=nothing::Union{Nothing,IR.Type}, - boundaryCheck=nothing, - padding=nothing, - cache=nothing, - evict=nothing, - isVolatile=nothing, - location=Location(), + mask::Union{Nothing,Value}=nothing; + other::Union{Nothing,Value}=nothing, + result::Union{Nothing,IR.Type}=nothing, + boundaryCheck::Union{IR.DenseAttribute{Int32},Nothing}=nothing, + padding::Union{PaddingOption.T,Nothing}=nothing, + cache::Union{CacheModifier.T,Nothing}=nothing, + evict::Union{EvictionPolicy.T,Nothing}=nothing, + isVolatile::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[ptr,] @@ -1039,8 +1169,8 @@ function load( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1051,7 +1181,9 @@ Returns an 1D int32 tensor. Values span from \$start to \$end (exclusive), with step = 1 """ -function make_range(; result::IR.Type, start, end_, location=Location()) +function make_range(; + result::IR.Type, start::Int32, end_::Int32, location::Location=Location() +) op_ty_results = IR.Type[result,] operands = Value[] owned_regions = Region[] @@ -1081,7 +1213,7 @@ function make_tensor_descriptor( shape::Vector{Value}, strides::Vector{Value}; result::IR.Type, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[base, shape..., strides...] @@ -1113,8 +1245,8 @@ function make_tensor_ptr( strides::Vector{Value}, offsets::Vector{Value}; result::IR.Type, - order, - location=Location(), + order::IR.DenseAttribute{Int32}, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[base, shape..., strides..., offsets...] @@ -1140,7 +1272,10 @@ end Most significant N bits of the 2N-bit product of two integers. """ function mulhiui( - x::Value, y::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + x::Value, + y::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[x, y] @@ -1156,8 +1291,8 @@ function mulhiui( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1167,7 +1302,10 @@ end Precise div for floating point types. """ function precise_divf( - x::Value, y::Value; result=nothing::Union{Nothing,IR.Type}, location=Location() + x::Value, + y::Value; + result::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[x, y] @@ -1183,8 +1321,8 @@ function precise_divf( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1193,7 +1331,9 @@ end Precise sqrt for floating point types. """ -function precise_sqrt(x::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) +function precise_sqrt( + x::Value; result::Union{Nothing,IR.Type}=nothing, location::Location=Location() +) op_ty_results = IR.Type[] operands = Value[x,] owned_regions = Region[] @@ -1208,8 +1348,8 @@ function precise_sqrt(x::Value; result=nothing::Union{Nothing,IR.Type}, location owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end @@ -1219,7 +1359,13 @@ end `tt.print` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. format are generated automatically from the arguments. """ -function print(args::Vector{Value}; prefix, hex, isSigned, location=Location()) +function print( + args::Vector{Value}; + prefix::String, + hex::Bool, + isSigned::IR.DenseAttribute{Int32}, + location::Location=Location(), +) op_ty_results = IR.Type[] operands = Value[args...,] owned_regions = Region[] @@ -1242,7 +1388,7 @@ function print(args::Vector{Value}; prefix, hex, isSigned, location=Location()) ) end -function ptr_to_int(src::Value; result::IR.Type, location=Location()) +function ptr_to_int(src::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[src,] owned_regions = Region[] @@ -1263,10 +1409,10 @@ end function reduce( srcs::Vector{Value}; - result::Vector{IR.Type}, - axis, + result::Base.AbstractVecOrTuple{IR.Type}, + axis::Int32, combineOp::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result...,] operands = Value[srcs...,] @@ -1286,7 +1432,7 @@ function reduce( ) end -function reduce_return(result::Vector{Value}; location=Location()) +function reduce_return(result::Vector{Value}; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[result...,] owned_regions = Region[] @@ -1319,9 +1465,9 @@ The compiler is still free to change it for better performance. function reshape( src::Value; result::IR.Type, - allow_reorder=nothing, - efficient_layout=nothing, - location=Location(), + allow_reorder::Union{Bool,Nothing}=nothing, + efficient_layout::Union{Bool,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[src,] @@ -1347,11 +1493,11 @@ end function scan( srcs::Vector{Value}; - result::Vector{IR.Type}, - axis, - reverse, + result::Base.AbstractVecOrTuple{IR.Type}, + axis::Int32, + reverse::Bool, combineOp::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result...,] operands = Value[srcs...,] @@ -1373,7 +1519,7 @@ function scan( ) end -function scan_return(result::Vector{Value}; location=Location()) +function scan_return(result::Vector{Value}; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[result...,] owned_regions = Region[] @@ -1392,7 +1538,7 @@ function scan_return(result::Vector{Value}; location=Location()) ) end -function splat(src::Value; result::IR.Type, location=Location()) +function splat(src::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[src,] owned_regions = Region[] @@ -1422,9 +1568,9 @@ shape 4x8xf32. """ function split( src::Value; - outLHS=nothing::Union{Nothing,IR.Type}, - outRHS=nothing::Union{Nothing,IR.Type}, - location=Location(), + outLHS::Union{Nothing,IR.Type}=nothing, + outRHS::Union{Nothing,IR.Type}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[src,] @@ -1441,19 +1587,19 @@ function split( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end function store( ptr::Value, value::Value, - mask=nothing::Union{Nothing,Value}; - boundaryCheck=nothing, - cache=nothing, - evict=nothing, - location=Location(), + mask::Union{Nothing,Value}=nothing; + boundaryCheck::Union{IR.DenseAttribute{Int32},Nothing}=nothing, + cache::Union{CacheModifier.T,Nothing}=nothing, + evict::Union{EvictionPolicy.T,Nothing}=nothing, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[ptr, value] @@ -1508,7 +1654,10 @@ We do this so that you can chain multiple data-movement ops (e.g. transpose+reshape+concat) without going to shared memory after each one. """ function trans( - src::Value; result=nothing::Union{Nothing,IR.Type}, order, location=Location() + src::Value; + result::Union{Nothing,IR.Type}=nothing, + order::IR.DenseAttribute{Int32}, + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[src,] @@ -1524,8 +1673,8 @@ function trans( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=(isempty(op_ty_results) ? nothing : op_ty_results), + result_inference=isempty(op_ty_results), ) end diff --git a/src/mlir/Dialects/VHLO.jl b/src/mlir/Dialects/VHLO.jl index 1f706fcba..b01a395c2 100755 --- a/src/mlir/Dialects/VHLO.jl +++ b/src/mlir/Dialects/VHLO.jl @@ -10,10 +10,11 @@ import ...IR: create_operation, context, IndexType -import ..Dialects: namedattribute, operandsegmentsizes +import ..Dialects: namedattribute, operandsegmentsizes, c import ...API +using EnumX -function abs_v1(operand::Value; result::IR.Type, location=Location()) +function abs_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -32,7 +33,7 @@ function abs_v1(operand::Value; result::IR.Type, location=Location()) ) end -function add_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function add_v1(lhs::Value, rhs::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] @@ -51,7 +52,7 @@ function add_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) ) end -function after_all_v1(inputs::Vector{Value}; result::IR.Type, location=Location()) +function after_all_v1(inputs::Vector{Value}; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[inputs...,] owned_regions = Region[] @@ -73,11 +74,11 @@ end function all_gather_v1( operand::Value; result::IR.Type, - all_gather_dim, - replica_groups, - channel_id, - use_global_device_ids, - location=Location(), + all_gather_dim::IR.AbstractAttribute, + replica_groups::IR.AbstractAttribute, + channel_id::IR.AbstractAttribute, + use_global_device_ids::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand,] @@ -104,12 +105,12 @@ end function all_gather_v2( operands::Vector{Value}; - results::Vector{IR.Type}, - all_gather_dim, - replica_groups, - channel_id, - use_global_device_ids, - location=Location(), + results::Base.AbstractVecOrTuple{IR.Type}, + all_gather_dim::IR.AbstractAttribute, + replica_groups::IR.AbstractAttribute, + channel_id::IR.AbstractAttribute, + use_global_device_ids::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[operands...,] @@ -137,11 +138,11 @@ end function all_reduce_v1( operand::Value; result::IR.Type, - replica_groups, - channel_id, - use_global_device_ids, + replica_groups::IR.AbstractAttribute, + channel_id::IR.AbstractAttribute, + use_global_device_ids::IR.AbstractAttribute, computation::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand,] @@ -167,12 +168,12 @@ end function all_reduce_v2( operands::Vector{Value}; - results::Vector{IR.Type}, - replica_groups, - channel_id, - use_global_device_ids, + results::Base.AbstractVecOrTuple{IR.Type}, + replica_groups::IR.AbstractAttribute, + channel_id::IR.AbstractAttribute, + use_global_device_ids::IR.AbstractAttribute, computation::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[operands...,] @@ -199,12 +200,12 @@ end function all_to_all_v1( operand::Value; result::IR.Type, - split_dimension, - concat_dimension, - split_count, - replica_groups, - channel_id, - location=Location(), + split_dimension::IR.AbstractAttribute, + concat_dimension::IR.AbstractAttribute, + split_count::IR.AbstractAttribute, + replica_groups::IR.AbstractAttribute, + channel_id::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand,] @@ -232,13 +233,13 @@ end function all_to_all_v2( operands::Vector{Value}; - results::Vector{IR.Type}, - split_dimension, - concat_dimension, - split_count, - replica_groups, - channel_id, - location=Location(), + results::Base.AbstractVecOrTuple{IR.Type}, + split_dimension::IR.AbstractAttribute, + concat_dimension::IR.AbstractAttribute, + split_count::IR.AbstractAttribute, + replica_groups::IR.AbstractAttribute, + channel_id::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[operands...,] @@ -264,7 +265,7 @@ function all_to_all_v2( ) end -function and_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function and_v1(lhs::Value, rhs::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] @@ -283,7 +284,7 @@ function and_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) ) end -function atan2_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function atan2_v1(lhs::Value, rhs::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] @@ -311,9 +312,9 @@ function batch_norm_grad_v1( grad_operand::IR.Type, grad_scale::IR.Type, grad_offset::IR.Type, - epsilon, - feature_index, - location=Location(), + epsilon::IR.AbstractAttribute, + feature_index::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[grad_operand, grad_scale, grad_offset] operands = Value[operand, scale, mean, variance, grad_output] @@ -342,9 +343,9 @@ function batch_norm_inference_v1( mean::Value, variance::Value; result::IR.Type, - epsilon, - feature_index, - location=Location(), + epsilon::IR.AbstractAttribute, + feature_index::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand, scale, offset, mean, variance] @@ -373,9 +374,9 @@ function batch_norm_training_v1( output::IR.Type, batch_mean::IR.Type, batch_var::IR.Type, - epsilon, - feature_index, - location=Location(), + epsilon::IR.AbstractAttribute, + feature_index::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[output, batch_mean, batch_var] operands = Value[operand, scale, offset] @@ -397,7 +398,7 @@ function batch_norm_training_v1( ) end -function bitcast_convert_v1(operand::Value; result::IR.Type, location=Location()) +function bitcast_convert_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -417,7 +418,10 @@ function bitcast_convert_v1(operand::Value; result::IR.Type, location=Location() end function broadcast_in_dim_v1( - operand::Value; result::IR.Type, broadcast_dimensions, location=Location() + operand::Value; + result::IR.Type, + broadcast_dimensions::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand,] @@ -439,7 +443,12 @@ function broadcast_in_dim_v1( ) end -function broadcast_v1(operand::Value; result::IR.Type, broadcast_sizes, location=Location()) +function broadcast_v1( + operand::Value; + result::IR.Type, + broadcast_sizes::IR.AbstractAttribute, + location::Location=Location(), +) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -459,7 +468,10 @@ function broadcast_v1(operand::Value; result::IR.Type, broadcast_sizes, location end function call_v1( - operands::Vector{Value}; results::Vector{IR.Type}, callee, location=Location() + operands::Vector{Value}; + results::Base.AbstractVecOrTuple{IR.Type}, + callee::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[operands...,] @@ -480,7 +492,10 @@ function call_v1( end function case_v1( - index::Value; results::Vector{IR.Type}, branches::Vector{Region}, location=Location() + index::Value; + results::Base.AbstractVecOrTuple{IR.Type}, + branches::Vector{Region}, + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[index,] @@ -500,7 +515,7 @@ function case_v1( ) end -function cbrt_v1(operand::Value; result::IR.Type, location=Location()) +function cbrt_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -519,7 +534,7 @@ function cbrt_v1(operand::Value; result::IR.Type, location=Location()) ) end -function ceil_v1(operand::Value; result::IR.Type, location=Location()) +function ceil_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -538,7 +553,9 @@ function ceil_v1(operand::Value; result::IR.Type, location=Location()) ) end -function cholesky_v1(a::Value; result::IR.Type, lower, location=Location()) +function cholesky_v1( + a::Value; result::IR.Type, lower::IR.AbstractAttribute, location::Location=Location() +) op_ty_results = IR.Type[result,] operands = Value[a,] owned_regions = Region[] @@ -558,7 +575,7 @@ function cholesky_v1(a::Value; result::IR.Type, lower, location=Location()) end function clamp_v1( - min::Value, operand::Value, max::Value; result::IR.Type, location=Location() + min::Value, operand::Value, max::Value; result::IR.Type, location::Location=Location() ) op_ty_results = IR.Type[result,] operands = Value[min, operand, max] @@ -578,7 +595,9 @@ function clamp_v1( ) end -function count_leading_zeros_v1(operand::Value; result::IR.Type, location=Location()) +function count_leading_zeros_v1( + operand::Value; result::IR.Type, location::Location=Location() +) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -598,7 +617,11 @@ function count_leading_zeros_v1(operand::Value; result::IR.Type, location=Locati end function collective_broadcast_v1( - operand::Value; result::IR.Type, replica_groups, channel_id, location=Location() + operand::Value; + result::IR.Type, + replica_groups::IR.AbstractAttribute, + channel_id::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand,] @@ -622,7 +645,11 @@ function collective_broadcast_v1( end function collective_permute_v1( - operand::Value; result::IR.Type, source_target_pairs, channel_id, location=Location() + operand::Value; + result::IR.Type, + source_target_pairs::IR.AbstractAttribute, + channel_id::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand,] @@ -649,9 +676,9 @@ function compare_v1( lhs::Value, rhs::Value; result::IR.Type, - comparison_direction, - compare_type, - location=Location(), + comparison_direction::IR.AbstractAttribute, + compare_type::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] @@ -674,7 +701,7 @@ function compare_v1( ) end -function complex_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function complex_v1(lhs::Value, rhs::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] @@ -695,12 +722,12 @@ end function composite_v1( inputs::Vector{Value}; - results::Vector{IR.Type}, - name, - composite_attributes, - decomposition, - version, - location=Location(), + results::Base.AbstractVecOrTuple{IR.Type}, + name::IR.AbstractAttribute, + composite_attributes::IR.AbstractAttribute, + decomposition::IR.AbstractAttribute, + version::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[inputs...,] @@ -726,7 +753,10 @@ function composite_v1( end function concatenate_v1( - inputs::Vector{Value}; result::IR.Type, dimension, location=Location() + inputs::Vector{Value}; + result::IR.Type, + dimension::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[inputs...,] @@ -746,7 +776,9 @@ function concatenate_v1( ) end -function constant_v1(; output::IR.Type, value, location=Location()) +function constant_v1(; + output::IR.Type, value::IR.AbstractAttribute, location::Location=Location() +) op_ty_results = IR.Type[output,] operands = Value[] owned_regions = Region[] @@ -765,7 +797,7 @@ function constant_v1(; output::IR.Type, value, location=Location()) ) end -function convert_v1(operand::Value; result::IR.Type, location=Location()) +function convert_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -788,24 +820,24 @@ function convolution_v1( lhs::Value, rhs::Value; result::IR.Type, - window_strides, - padding, - lhs_dilation, - rhs_dilation, - window_reversal, - input_batch_dimension, - input_feature_dimension, - input_spatial_dimensions, - kernel_input_feature_dimension, - kernel_output_feature_dimension, - kernel_spatial_dimensions, - output_batch_dimension, - output_feature_dimension, - output_spatial_dimensions, - feature_group_count, - batch_group_count, - precision_config, - location=Location(), + window_strides::IR.AbstractAttribute, + padding::IR.AbstractAttribute, + lhs_dilation::IR.AbstractAttribute, + rhs_dilation::IR.AbstractAttribute, + window_reversal::IR.AbstractAttribute, + input_batch_dimension::IR.AbstractAttribute, + input_feature_dimension::IR.AbstractAttribute, + input_spatial_dimensions::IR.AbstractAttribute, + kernel_input_feature_dimension::IR.AbstractAttribute, + kernel_output_feature_dimension::IR.AbstractAttribute, + kernel_spatial_dimensions::IR.AbstractAttribute, + output_batch_dimension::IR.AbstractAttribute, + output_feature_dimension::IR.AbstractAttribute, + output_spatial_dimensions::IR.AbstractAttribute, + feature_group_count::IR.AbstractAttribute, + batch_group_count::IR.AbstractAttribute, + precision_config::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] @@ -843,7 +875,7 @@ function convolution_v1( ) end -function cosine_v1(operand::Value; result::IR.Type, location=Location()) +function cosine_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -862,7 +894,7 @@ function cosine_v1(operand::Value; result::IR.Type, location=Location()) ) end -function create_token_v1(; output::IR.Type, location=Location()) +function create_token_v1(; output::IR.Type, location::Location=Location()) op_ty_results = IR.Type[output,] operands = Value[] owned_regions = Region[] @@ -882,7 +914,10 @@ function create_token_v1(; output::IR.Type, location=Location()) end function cross_replica_sum_v1( - operand::Value; result::IR.Type, replica_groups, location=Location() + operand::Value; + result::IR.Type, + replica_groups::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand,] @@ -904,16 +939,16 @@ end function custom_call_v1( inputs::Vector{Value}; - results::Vector{IR.Type}, - call_target_name, - has_side_effect, - backend_config, - api_version, - called_computations, - operand_layouts, - result_layouts, - output_operand_aliases, - location=Location(), + results::Base.AbstractVecOrTuple{IR.Type}, + call_target_name::IR.AbstractAttribute, + has_side_effect::IR.AbstractAttribute, + backend_config::IR.AbstractAttribute, + api_version::IR.AbstractAttribute, + called_computations::IR.AbstractAttribute, + operand_layouts::IR.AbstractAttribute, + result_layouts::IR.AbstractAttribute, + output_operand_aliases::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[inputs...,] @@ -942,7 +977,7 @@ function custom_call_v1( ) end -function divide_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function divide_v1(lhs::Value, rhs::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] @@ -965,12 +1000,12 @@ function dot_general_v1( lhs::Value, rhs::Value; result::IR.Type, - lhs_batching_dimensions, - rhs_batching_dimensions, - lhs_contracting_dimensions, - rhs_contracting_dimensions, - precision_config, - location=Location(), + lhs_batching_dimensions::IR.AbstractAttribute, + rhs_batching_dimensions::IR.AbstractAttribute, + lhs_contracting_dimensions::IR.AbstractAttribute, + rhs_contracting_dimensions::IR.AbstractAttribute, + precision_config::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] @@ -1000,19 +1035,19 @@ function dot_general_v2( lhs::Value, rhs::Value; result::IR.Type, - lhs_batching_dimensions, - rhs_batching_dimensions, - lhs_contracting_dimensions, - rhs_contracting_dimensions, - precision_config, - lhs_precision_type, - rhs_precision_type, - accumulation_type, - lhs_component_count, - rhs_component_count, - num_primitive_operations, - allow_imprecise_accumulation, - location=Location(), + lhs_batching_dimensions::IR.AbstractAttribute, + rhs_batching_dimensions::IR.AbstractAttribute, + lhs_contracting_dimensions::IR.AbstractAttribute, + rhs_contracting_dimensions::IR.AbstractAttribute, + precision_config::IR.AbstractAttribute, + lhs_precision_type::IR.AbstractAttribute, + rhs_precision_type::IR.AbstractAttribute, + accumulation_type::IR.AbstractAttribute, + lhs_component_count::IR.AbstractAttribute, + rhs_component_count::IR.AbstractAttribute, + num_primitive_operations::IR.AbstractAttribute, + allow_imprecise_accumulation::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] @@ -1046,7 +1081,11 @@ function dot_general_v2( end function dot_v1( - lhs::Value, rhs::Value; result::IR.Type, precision_config, location=Location() + lhs::Value, + rhs::Value; + result::IR.Type, + precision_config::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] @@ -1070,10 +1109,10 @@ function dynamic_broadcast_in_dim_v1( operand::Value, output_dimensions::Value; result::IR.Type, - broadcast_dimensions, - known_expanding_dimensions, - known_nonexpanding_dimensions, - location=Location(), + broadcast_dimensions::IR.AbstractAttribute, + known_expanding_dimensions::IR.AbstractAttribute, + known_nonexpanding_dimensions::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand, output_dimensions] @@ -1102,24 +1141,24 @@ function dynamic_conv_v1( rhs::Value, d_padding::Value; result::IR.Type, - window_strides, - padding, - lhs_dilation, - rhs_dilation, - window_reversal, - input_batch_dimension, - input_feature_dimension, - input_spatial_dimensions, - kernel_input_feature_dimension, - kernel_output_feature_dimension, - kernel_spatial_dimensions, - output_batch_dimension, - output_feature_dimension, - output_spatial_dimensions, - feature_group_count, - batch_group_count, - precision_config, - location=Location(), + window_strides::IR.AbstractAttribute, + padding::IR.AbstractAttribute, + lhs_dilation::IR.AbstractAttribute, + rhs_dilation::IR.AbstractAttribute, + window_reversal::IR.AbstractAttribute, + input_batch_dimension::IR.AbstractAttribute, + input_feature_dimension::IR.AbstractAttribute, + input_spatial_dimensions::IR.AbstractAttribute, + kernel_input_feature_dimension::IR.AbstractAttribute, + kernel_output_feature_dimension::IR.AbstractAttribute, + kernel_spatial_dimensions::IR.AbstractAttribute, + output_batch_dimension::IR.AbstractAttribute, + output_feature_dimension::IR.AbstractAttribute, + output_spatial_dimensions::IR.AbstractAttribute, + feature_group_count::IR.AbstractAttribute, + batch_group_count::IR.AbstractAttribute, + precision_config::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs, d_padding] @@ -1162,23 +1201,23 @@ function dynamic_conv_v2( rhs::Value, padding::Value; result::IR.Type, - window_strides, - lhs_dilation, - rhs_dilation, - window_reversal, - input_batch_dimension, - input_feature_dimension, - input_spatial_dimensions, - kernel_input_feature_dimension, - kernel_output_feature_dimension, - kernel_spatial_dimensions, - output_batch_dimension, - output_feature_dimension, - output_spatial_dimensions, - feature_group_count, - batch_group_count, - precision_config, - location=Location(), + window_strides::IR.AbstractAttribute, + lhs_dilation::IR.AbstractAttribute, + rhs_dilation::IR.AbstractAttribute, + window_reversal::IR.AbstractAttribute, + input_batch_dimension::IR.AbstractAttribute, + input_feature_dimension::IR.AbstractAttribute, + input_spatial_dimensions::IR.AbstractAttribute, + kernel_input_feature_dimension::IR.AbstractAttribute, + kernel_output_feature_dimension::IR.AbstractAttribute, + kernel_spatial_dimensions::IR.AbstractAttribute, + output_batch_dimension::IR.AbstractAttribute, + output_feature_dimension::IR.AbstractAttribute, + output_spatial_dimensions::IR.AbstractAttribute, + feature_group_count::IR.AbstractAttribute, + batch_group_count::IR.AbstractAttribute, + precision_config::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs, padding] @@ -1220,12 +1259,12 @@ function dynamic_gather_v1( start_indices::Value, slice_sizes::Value; result::IR.Type, - offset_dims, - collapsed_slice_dims, - start_index_map, - index_vector_dim, - indices_are_sorted, - location=Location(), + offset_dims::IR.AbstractAttribute, + collapsed_slice_dims::IR.AbstractAttribute, + start_index_map::IR.AbstractAttribute, + index_vector_dim::IR.AbstractAttribute, + indices_are_sorted::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand, start_indices, slice_sizes] @@ -1256,14 +1295,14 @@ function dynamic_gather_v2( start_indices::Value, slice_sizes::Value; result::IR.Type, - offset_dims, - collapsed_slice_dims, - operand_batching_dims, - start_indices_batching_dims, - start_index_map, - index_vector_dim, - indices_are_sorted, - location=Location(), + offset_dims::IR.AbstractAttribute, + collapsed_slice_dims::IR.AbstractAttribute, + operand_batching_dims::IR.AbstractAttribute, + start_indices_batching_dims::IR.AbstractAttribute, + start_index_map::IR.AbstractAttribute, + index_vector_dim::IR.AbstractAttribute, + indices_are_sorted::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand, start_indices, slice_sizes] @@ -1292,7 +1331,10 @@ function dynamic_gather_v2( end function dynamic_iota_v1( - output_shape::Value; result::IR.Type, iota_dimension, location=Location() + output_shape::Value; + result::IR.Type, + iota_dimension::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[output_shape,] @@ -1319,7 +1361,7 @@ function dynamic_pad_v1( edge_padding_high::Value, interior_padding::Value; result::IR.Type, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[ @@ -1342,7 +1384,7 @@ function dynamic_pad_v1( end function dynamic_reshape_v1( - operand::Value, output_shape::Value; result::IR.Type, location=Location() + operand::Value, output_shape::Value; result::IR.Type, location::Location=Location() ) op_ty_results = IR.Type[result,] operands = Value[operand, output_shape] @@ -1366,8 +1408,8 @@ function dynamic_slice_v1( operand::Value, start_indices::Vector{Value}; result::IR.Type, - slice_sizes, - location=Location(), + slice_sizes::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand, start_indices...] @@ -1392,7 +1434,7 @@ function dynamic_update_slice_v1( update::Value, start_indices::Vector{Value}; result::IR.Type, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand, update, start_indices...] @@ -1413,7 +1455,11 @@ function dynamic_update_slice_v1( end function einsum_v1( - lhs::Value, rhs::Value; result::IR.Type, einsum_config, location=Location() + lhs::Value, + rhs::Value; + result::IR.Type, + einsum_config::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] @@ -1433,7 +1479,7 @@ function einsum_v1( ) end -function exponential_v1(operand::Value; result::IR.Type, location=Location()) +function exponential_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -1453,7 +1499,10 @@ function exponential_v1(operand::Value; result::IR.Type, location=Location()) end function exponential_v2( - operand::Value; result::IR.Type, result_accuracy, location=Location() + operand::Value; + result::IR.Type, + result_accuracy::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand,] @@ -1473,7 +1522,9 @@ function exponential_v2( ) end -function exponential_minus_one_v1(operand::Value; result::IR.Type, location=Location()) +function exponential_minus_one_v1( + operand::Value; result::IR.Type, location::Location=Location() +) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -1492,7 +1543,13 @@ function exponential_minus_one_v1(operand::Value; result::IR.Type, location=Loca ) end -function fft_v1(operand::Value; result::IR.Type, fft_type, fft_length, location=Location()) +function fft_v1( + operand::Value; + result::IR.Type, + fft_type::IR.AbstractAttribute, + fft_length::IR.AbstractAttribute, + location::Location=Location(), +) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -1513,7 +1570,7 @@ function fft_v1(operand::Value; result::IR.Type, fft_type, fft_length, location= ) end -function floor_v1(operand::Value; result::IR.Type, location=Location()) +function floor_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -1533,13 +1590,13 @@ function floor_v1(operand::Value; result::IR.Type, location=Location()) end function func_v1(; - sym_name, - function_type, - sym_visibility, - arg_attrs, - res_attrs, + sym_name::IR.AbstractAttribute, + function_type::IR.AbstractAttribute, + sym_visibility::IR.AbstractAttribute, + arg_attrs::IR.AbstractAttribute, + res_attrs::IR.AbstractAttribute, body::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[] operands = Value[] @@ -1569,13 +1626,13 @@ function gather_v1( operand::Value, start_indices::Value; result::IR.Type, - offset_dims, - collapsed_slice_dims, - start_index_map, - index_vector_dim, - slice_sizes, - indices_are_sorted, - location=Location(), + offset_dims::IR.AbstractAttribute, + collapsed_slice_dims::IR.AbstractAttribute, + start_index_map::IR.AbstractAttribute, + index_vector_dim::IR.AbstractAttribute, + slice_sizes::IR.AbstractAttribute, + indices_are_sorted::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand, start_indices] @@ -1606,15 +1663,15 @@ function gather_v2( operand::Value, start_indices::Value; result::IR.Type, - offset_dims, - collapsed_slice_dims, - operand_batching_dims, - start_indices_batching_dims, - start_index_map, - index_vector_dim, - slice_sizes, - indices_are_sorted, - location=Location(), + offset_dims::IR.AbstractAttribute, + collapsed_slice_dims::IR.AbstractAttribute, + operand_batching_dims::IR.AbstractAttribute, + start_indices_batching_dims::IR.AbstractAttribute, + start_index_map::IR.AbstractAttribute, + index_vector_dim::IR.AbstractAttribute, + slice_sizes::IR.AbstractAttribute, + indices_are_sorted::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand, start_indices] @@ -1644,7 +1701,10 @@ function gather_v2( end function get_dimension_size_v1( - operand::Value; result::IR.Type, dimension, location=Location() + operand::Value; + result::IR.Type, + dimension::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand,] @@ -1664,7 +1724,12 @@ function get_dimension_size_v1( ) end -function get_tuple_element_v1(operand::Value; result::IR.Type, index, location=Location()) +function get_tuple_element_v1( + operand::Value; + result::IR.Type, + index::IR.AbstractAttribute, + location::Location=Location(), +) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -1685,10 +1750,10 @@ end function if_v1( pred::Value; - results::Vector{IR.Type}, + results::Base.AbstractVecOrTuple{IR.Type}, true_branch::Region, false_branch::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[pred,] @@ -1708,7 +1773,7 @@ function if_v1( ) end -function imag_v1(operand::Value; result::IR.Type, location=Location()) +function imag_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -1728,7 +1793,11 @@ function imag_v1(operand::Value; result::IR.Type, location=Location()) end function infeed_v1( - token::Value; results::Vector{IR.Type}, infeed_config, layout, location=Location() + token::Value; + results::Base.AbstractVecOrTuple{IR.Type}, + infeed_config::IR.AbstractAttribute, + layout::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[token,] @@ -1750,7 +1819,9 @@ function infeed_v1( ) end -function iota_v1(; output::IR.Type, iota_dimension, location=Location()) +function iota_v1(; + output::IR.Type, iota_dimension::IR.AbstractAttribute, location::Location=Location() +) op_ty_results = IR.Type[output,] operands = Value[] owned_regions = Region[] @@ -1769,7 +1840,7 @@ function iota_v1(; output::IR.Type, iota_dimension, location=Location()) ) end -function is_finite_v1(x::Value; y::IR.Type, location=Location()) +function is_finite_v1(x::Value; y::IR.Type, location::Location=Location()) op_ty_results = IR.Type[y,] operands = Value[x,] owned_regions = Region[] @@ -1788,7 +1859,7 @@ function is_finite_v1(x::Value; y::IR.Type, location=Location()) ) end -function log_plus_one_v1(operand::Value; result::IR.Type, location=Location()) +function log_plus_one_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -1807,7 +1878,7 @@ function log_plus_one_v1(operand::Value; result::IR.Type, location=Location()) ) end -function log_v1(operand::Value; result::IR.Type, location=Location()) +function log_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -1826,7 +1897,7 @@ function log_v1(operand::Value; result::IR.Type, location=Location()) ) end -function logistic_v1(operand::Value; result::IR.Type, location=Location()) +function logistic_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -1848,9 +1919,9 @@ end function map_v1( inputs::Vector{Value}; result::IR.Type, - dimensions, + dimensions::IR.AbstractAttribute, computation::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[inputs...,] @@ -1870,7 +1941,7 @@ function map_v1( ) end -function maximum_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function maximum_v1(lhs::Value, rhs::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] @@ -1889,7 +1960,7 @@ function maximum_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location() ) end -function minimum_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function minimum_v1(lhs::Value, rhs::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] @@ -1908,7 +1979,7 @@ function minimum_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location() ) end -function multiply_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function multiply_v1(lhs::Value, rhs::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] @@ -1927,7 +1998,7 @@ function multiply_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location( ) end -function negate_v1(operand::Value; result::IR.Type, location=Location()) +function negate_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -1946,7 +2017,7 @@ function negate_v1(operand::Value; result::IR.Type, location=Location()) ) end -function not_v1(operand::Value; result::IR.Type, location=Location()) +function not_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -1966,7 +2037,9 @@ function not_v1(operand::Value; result::IR.Type, location=Location()) end function optimization_barrier_v1( - operand::Vector{Value}; result::Vector{IR.Type}, location=Location() + operand::Vector{Value}; + result::Base.AbstractVecOrTuple{IR.Type}, + location::Location=Location(), ) op_ty_results = IR.Type[result...,] operands = Value[operand...,] @@ -1986,7 +2059,7 @@ function optimization_barrier_v1( ) end -function or_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function or_v1(lhs::Value, rhs::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] @@ -2009,8 +2082,8 @@ function outfeed_v1( inputs::Vector{Value}, token::Value; result::IR.Type, - outfeed_config, - location=Location(), + outfeed_config::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[inputs..., token] @@ -2034,10 +2107,10 @@ function pad_v1( operand::Value, padding_value::Value; result::IR.Type, - edge_padding_low, - edge_padding_high, - interior_padding, - location=Location(), + edge_padding_low::IR.AbstractAttribute, + edge_padding_high::IR.AbstractAttribute, + interior_padding::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand, padding_value] @@ -2061,7 +2134,7 @@ function pad_v1( ) end -function partition_id_v1(; result::IR.Type, location=Location()) +function partition_id_v1(; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[] owned_regions = Region[] @@ -2080,7 +2153,7 @@ function partition_id_v1(; result::IR.Type, location=Location()) ) end -function popcnt_v1(operand::Value; result::IR.Type, location=Location()) +function popcnt_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -2099,7 +2172,7 @@ function popcnt_v1(operand::Value; result::IR.Type, location=Location()) ) end -function power_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function power_v1(lhs::Value, rhs::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] @@ -2124,7 +2197,7 @@ function real_dynamic_slice_v1( limit_indices::Value, strides::Value; result::IR.Type, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand, start_indices, limit_indices, strides] @@ -2144,7 +2217,7 @@ function real_dynamic_slice_v1( ) end -function real_v1(operand::Value; result::IR.Type, location=Location()) +function real_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -2165,11 +2238,11 @@ end function recv_v1( token::Value; - results::Vector{IR.Type}, - channel_id, - channel_type, - is_host_transfer, - location=Location(), + results::Base.AbstractVecOrTuple{IR.Type}, + channel_id::IR.AbstractAttribute, + channel_type::IR.AbstractAttribute, + is_host_transfer::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[token,] @@ -2196,10 +2269,10 @@ end function reduce_v1( inputs::Vector{Value}, init_values::Vector{Value}; - results::Vector{IR.Type}, - dimensions, + results::Base.AbstractVecOrTuple{IR.Type}, + dimensions::IR.AbstractAttribute, body::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[inputs..., init_values...] @@ -2220,7 +2293,11 @@ function reduce_v1( end function reduce_precision_v1( - operand::Value; output::IR.Type, exponent_bits, mantissa_bits, location=Location() + operand::Value; + output::IR.Type, + exponent_bits::IR.AbstractAttribute, + mantissa_bits::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[output,] operands = Value[operand,] @@ -2246,12 +2323,12 @@ end function reduce_scatter_v1( operand::Value; result::IR.Type, - scatter_dimension, - replica_groups, - channel_id, - use_global_device_ids, + scatter_dimension::IR.AbstractAttribute, + replica_groups::IR.AbstractAttribute, + channel_id::IR.AbstractAttribute, + use_global_device_ids::IR.AbstractAttribute, computation::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand,] @@ -2279,14 +2356,14 @@ end function reduce_window_v1( inputs::Vector{Value}, init_values::Vector{Value}; - results::Vector{IR.Type}, - window_dimensions, - window_strides, - base_dilations, - window_dilations, - padding, + results::Base.AbstractVecOrTuple{IR.Type}, + window_dimensions::IR.AbstractAttribute, + window_strides::IR.AbstractAttribute, + base_dilations::IR.AbstractAttribute, + window_dilations::IR.AbstractAttribute, + padding::IR.AbstractAttribute, body::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[inputs..., init_values...] @@ -2312,7 +2389,9 @@ function reduce_window_v1( ) end -function remainder_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function remainder_v1( + lhs::Value, rhs::Value; result::IR.Type, location::Location=Location() +) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] @@ -2331,7 +2410,7 @@ function remainder_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location ) end -function replica_id_v1(; result::IR.Type, location=Location()) +function replica_id_v1(; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[] owned_regions = Region[] @@ -2350,7 +2429,7 @@ function replica_id_v1(; result::IR.Type, location=Location()) ) end -function reshape_v1(operand::Value; result::IR.Type, location=Location()) +function reshape_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -2369,7 +2448,7 @@ function reshape_v1(operand::Value; result::IR.Type, location=Location()) ) end -function return_v1(results::Vector{Value}; location=Location()) +function return_v1(results::Vector{Value}; location::Location=Location()) op_ty_results = IR.Type[] operands = Value[results...,] owned_regions = Region[] @@ -2388,7 +2467,12 @@ function return_v1(results::Vector{Value}; location=Location()) ) end -function reverse_v1(operand::Value; result::IR.Type, dimensions, location=Location()) +function reverse_v1( + operand::Value; + result::IR.Type, + dimensions::IR.AbstractAttribute, + location::Location=Location(), +) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -2411,8 +2495,8 @@ function rng_bit_generator_v1( initial_state::Value; output_state::IR.Type, output::IR.Type, - rng_algorithm, - location=Location(), + rng_algorithm::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[output_state, output] operands = Value[initial_state,] @@ -2433,7 +2517,12 @@ function rng_bit_generator_v1( end function rng_v1( - a::Value, b::Value, shape::Value; result::IR.Type, rng_distribution, location=Location() + a::Value, + b::Value, + shape::Value; + result::IR.Type, + rng_distribution::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[a, b, shape] @@ -2453,7 +2542,9 @@ function rng_v1( ) end -function round_nearest_even_v1(operand::Value; result::IR.Type, location=Location()) +function round_nearest_even_v1( + operand::Value; result::IR.Type, location::Location=Location() +) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -2472,7 +2563,9 @@ function round_nearest_even_v1(operand::Value; result::IR.Type, location=Locatio ) end -function round_nearest_afz_v1(operand::Value; result::IR.Type, location=Location()) +function round_nearest_afz_v1( + operand::Value; result::IR.Type, location::Location=Location() +) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -2491,7 +2584,7 @@ function round_nearest_afz_v1(operand::Value; result::IR.Type, location=Location ) end -function rsqrt_v1(operand::Value; result::IR.Type, location=Location()) +function rsqrt_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -2514,15 +2607,15 @@ function scatter_v1( inputs::Vector{Value}, scatter_indices::Value, updates::Vector{Value}; - results::Vector{IR.Type}, - update_window_dims, - inserted_window_dims, - scatter_dims_to_operand_dims, - index_vector_dim, - indices_are_sorted, - unique_indices, + results::Base.AbstractVecOrTuple{IR.Type}, + update_window_dims::IR.AbstractAttribute, + inserted_window_dims::IR.AbstractAttribute, + scatter_dims_to_operand_dims::IR.AbstractAttribute, + index_vector_dim::IR.AbstractAttribute, + indices_are_sorted::IR.AbstractAttribute, + unique_indices::IR.AbstractAttribute, update_computation::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[inputs..., scatter_indices, updates...] @@ -2553,17 +2646,17 @@ function scatter_v2( inputs::Vector{Value}, scatter_indices::Value, updates::Vector{Value}; - results::Vector{IR.Type}, - update_window_dims, - inserted_window_dims, - input_batching_dims, - scatter_indices_batching_dims, - scatter_dims_to_operand_dims, - index_vector_dim, - indices_are_sorted, - unique_indices, + results::Base.AbstractVecOrTuple{IR.Type}, + update_window_dims::IR.AbstractAttribute, + inserted_window_dims::IR.AbstractAttribute, + input_batching_dims::IR.AbstractAttribute, + scatter_indices_batching_dims::IR.AbstractAttribute, + scatter_dims_to_operand_dims::IR.AbstractAttribute, + index_vector_dim::IR.AbstractAttribute, + indices_are_sorted::IR.AbstractAttribute, + unique_indices::IR.AbstractAttribute, update_computation::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[inputs..., scatter_indices, updates...] @@ -2597,12 +2690,12 @@ function select_and_scatter_v1( source::Value, init_value::Value; result::IR.Type, - window_dimensions, - window_strides, - padding, + window_dimensions::IR.AbstractAttribute, + window_strides::IR.AbstractAttribute, + padding::IR.AbstractAttribute, select::Region, scatter::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand, source, init_value] @@ -2627,7 +2720,11 @@ function select_and_scatter_v1( end function select_v1( - pred::Value, on_true::Value, on_false::Value; result::IR.Type, location=Location() + pred::Value, + on_true::Value, + on_false::Value; + result::IR.Type, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[pred, on_true, on_false] @@ -2651,10 +2748,10 @@ function send_v1( inputs::Vector{Value}, token::Value; result::IR.Type, - channel_id, - channel_type, - is_host_transfer, - location=Location(), + channel_id::IR.AbstractAttribute, + channel_type::IR.AbstractAttribute, + is_host_transfer::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[inputs..., token] @@ -2679,7 +2776,11 @@ function send_v1( end function set_dimension_size_v1( - operand::Value, size::Value; result::IR.Type, dimension, location=Location() + operand::Value, + size::Value; + result::IR.Type, + dimension::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand, size] @@ -2699,7 +2800,9 @@ function set_dimension_size_v1( ) end -function shift_left_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function shift_left_v1( + lhs::Value, rhs::Value; result::IR.Type, location::Location=Location() +) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] @@ -2719,7 +2822,7 @@ function shift_left_v1(lhs::Value, rhs::Value; result::IR.Type, location=Locatio end function shift_right_arithmetic_v1( - lhs::Value, rhs::Value; result::IR.Type, location=Location() + lhs::Value, rhs::Value; result::IR.Type, location::Location=Location() ) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] @@ -2740,7 +2843,7 @@ function shift_right_arithmetic_v1( end function shift_right_logical_v1( - lhs::Value, rhs::Value; result::IR.Type, location=Location() + lhs::Value, rhs::Value; result::IR.Type, location::Location=Location() ) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] @@ -2760,7 +2863,7 @@ function shift_right_logical_v1( ) end -function sign_v1(operand::Value; result::IR.Type, location=Location()) +function sign_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -2779,7 +2882,7 @@ function sign_v1(operand::Value; result::IR.Type, location=Location()) ) end -function sine_v1(operand::Value; result::IR.Type, location=Location()) +function sine_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -2801,10 +2904,10 @@ end function slice_v1( operand::Value; result::IR.Type, - start_indices, - limit_indices, - strides, - location=Location(), + start_indices::IR.AbstractAttribute, + limit_indices::IR.AbstractAttribute, + strides::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand,] @@ -2830,11 +2933,11 @@ end function sort_v1( inputs::Vector{Value}; - results::Vector{IR.Type}, - dimension, - is_stable, + results::Base.AbstractVecOrTuple{IR.Type}, + dimension::IR.AbstractAttribute, + is_stable::IR.AbstractAttribute, comparator::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[inputs...,] @@ -2856,7 +2959,7 @@ function sort_v1( ) end -function sqrt_v1(operand::Value; result::IR.Type, location=Location()) +function sqrt_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -2875,7 +2978,7 @@ function sqrt_v1(operand::Value; result::IR.Type, location=Location()) ) end -function subtract_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function subtract_v1(lhs::Value, rhs::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] @@ -2894,7 +2997,7 @@ function subtract_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location( ) end -function tan_v1(operand::Value; result::IR.Type, location=Location()) +function tan_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -2913,7 +3016,7 @@ function tan_v1(operand::Value; result::IR.Type, location=Location()) ) end -function tanh_v1(operand::Value; result::IR.Type, location=Location()) +function tanh_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -2933,7 +3036,12 @@ function tanh_v1(operand::Value; result::IR.Type, location=Location()) end function torch_index_select_v1( - operand::Value, index::Value; result::IR.Type, dim, batch_dims, location=Location() + operand::Value, + index::Value; + result::IR.Type, + dim::IR.AbstractAttribute, + batch_dims::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand, index] @@ -2955,7 +3063,12 @@ function torch_index_select_v1( ) end -function transpose_v1(operand::Value; result::IR.Type, permutation, location=Location()) +function transpose_v1( + operand::Value; + result::IR.Type, + permutation::IR.AbstractAttribute, + location::Location=Location(), +) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -2978,11 +3091,11 @@ function triangular_solve_v1( a::Value, b::Value; result::IR.Type, - left_side, - lower, - unit_diagonal, - transpose_a, - location=Location(), + left_side::IR.AbstractAttribute, + lower::IR.AbstractAttribute, + unit_diagonal::IR.AbstractAttribute, + transpose_a::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[a, b] @@ -3007,7 +3120,7 @@ function triangular_solve_v1( ) end -function tuple_v1(val::Vector{Value}; result::IR.Type, location=Location()) +function tuple_v1(val::Vector{Value}; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[val...,] owned_regions = Region[] @@ -3027,7 +3140,10 @@ function tuple_v1(val::Vector{Value}; result::IR.Type, location=Location()) end function unary_einsum_v1( - operand::Value; result::IR.Type, einsum_config, location=Location() + operand::Value; + result::IR.Type, + einsum_config::IR.AbstractAttribute, + location::Location=Location(), ) op_ty_results = IR.Type[result,] operands = Value[operand,] @@ -3047,7 +3163,9 @@ function unary_einsum_v1( ) end -function uniform_dequantize_v1(operand::Value; result::IR.Type, location=Location()) +function uniform_dequantize_v1( + operand::Value; result::IR.Type, location::Location=Location() +) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -3066,7 +3184,7 @@ function uniform_dequantize_v1(operand::Value; result::IR.Type, location=Locatio ) end -function uniform_quantize_v1(operand::Value; result::IR.Type, location=Location()) +function uniform_quantize_v1(operand::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[operand,] owned_regions = Region[] @@ -3087,10 +3205,10 @@ end function while_v1( operand::Vector{Value}; - results::Vector{IR.Type}, + results::Base.AbstractVecOrTuple{IR.Type}, cond::Region, body::Region, - location=Location(), + location::Location=Location(), ) op_ty_results = IR.Type[results...,] operands = Value[operand...,] @@ -3110,7 +3228,7 @@ function while_v1( ) end -function xor_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) +function xor_v1(lhs::Value, rhs::Value; result::IR.Type, location::Location=Location()) op_ty_results = IR.Type[result,] operands = Value[lhs, rhs] owned_regions = Region[] diff --git a/src/mlir/IR/Attribute.jl b/src/mlir/IR/Attribute.jl index d7aac0083..e7eba1b53 100644 --- a/src/mlir/IR/Attribute.jl +++ b/src/mlir/IR/Attribute.jl @@ -1,7 +1,25 @@ -struct Attribute - attribute::API.MlirAttribute +abstract type AbstractAttribute end + +struct Attribute <: AbstractAttribute + attr::API.MlirAttribute +end + +getattribute(attr::API.MlirAttribute) = getattribute(Attribute(attr)) + +function getattribute(attr::Attribute) + if isdenseelements(attr) + issplat(attr) && return SplatAttribute(attr) + return DenseElementsAttribute(attr) + end + isflatsymbolref(attr) && return FlatSymbolRefAttribute(attr) + isarray(attr) && return [Attribute(API.mlirArrayAttrGetElement(attr, i)) for i in 1:length(attr)] + return attr end +Attribute(f::AbstractAttribute) = f.attr + +Base.convert(::Core.Type{API.MlirAttribute}, attribute::AbstractAttribute) = attribute.attr + """ Attribute() @@ -9,8 +27,6 @@ Returns an empty attribute. """ Attribute() = Attribute(API.mlirAttributeGetNull()) -Base.convert(::Core.Type{API.MlirAttribute}, attribute::Attribute) = attribute.attribute - """ parse(::Core.Type{Attribute}, str; context=context()) @@ -38,7 +54,7 @@ context(attr::Attribute) = Context(API.mlirAttributeGetContext(attr)) Gets the type of this attribute. """ -type(attr::Attribute) = Type(API.mlirAttributeGetType(attr)) +type(attr::AbstractAttribute) = Type(API.mlirAttributeGetType(Attribute(attr))) #TODO: remove Attribute here """ typeid(attribute) @@ -353,8 +369,19 @@ isflatsymbolref(attr::Attribute) = API.mlirAttributeIsAFlatSymbolRef(attr) Creates a flat symbol reference attribute in the given context referencing a symbol identified by the given string. """ -FlatSymbolRefAttribute(symbol::String; context::Context=context()) = - Attribute(API.mlirFlatSymbolRefAttrGet(context, symbol)) +struct FlatSymbolRefAttribute <: AbstractAttribute + attr::API.MlirAttribute + function FlatSymbolRefAttribute(symbol::String; context::Context=context()) + return new(API.mlirFlatSymbolRefAttrGet(context, symbol)) + end + + function FlatSymbolRefAttribute(attr::Attribute) + @assert isflatsymbolref(attr) "attribute $(attr) is not a flat symbol reference attribute" + return new(attr) + end +end + +Base.show(io::IO, f::FlatSymbolRefAttribute) = print(io, "@$(flatsymbol(f.attr))") """ flatsymbol(attr) @@ -420,6 +447,50 @@ isdenseelements(attr::Attribute) = API.mlirAttributeIsADenseElements(attr) isdenseintelements(attr::Attribute) = API.mlirAttributeIsADenseIntElements(attr) isdensefloatelements(attr::Attribute) = API.mlirAttributeIsADenseFPElements(attr) +abstract type AbstractDenseElementsAttribute{T} <: AbstractAttribute end + +DenseAttribute{T} = Union{Vector{T},AbstractDenseElementsAttribute{T}} + +struct DenseElementsAttribute{T} <: AbstractDenseElementsAttribute{T} + attr::API.MlirAttribute + function DenseElementsAttribute{T}(attr::API.MlirAttribute) where {T} + if !API.mlirAttributeIsADenseElements(attr) + throw("$attr is not a dense elements attribute.") + end + return new{T}(attr) + end + + DenseElementsAttribute(a::Attribute) = DenseElementsAttribute(a.attr) + + function DenseElementsAttribute(attr::API.MlirAttribute) + if !API.mlirAttributeIsADenseElements(attr) + throw("$attr is not a dense elements attribute.") + end + e = julia_type(eltype(type(Attribute(attr)))) + return new{e}(attr) + end +end + +struct SplatAttribute{T} <: AbstractDenseElementsAttribute{T} + attr::API.MlirAttribute + function SplatAttribute(attr::API.MlirAttribute) + if !issplat(Attribute(attr)) + throw("$attr is not a splat attribute.") + end + e = julia_type(eltype(type(Attribute(attr)))) + return new{e}(attr) + end + + SplatAttribute(a::Attribute) = SplatAttribute(a.attr) + + function SplatAttribute{T}(attr::API.MlirAttribute) where {T} + if !issplat(Attribute(attr)) + throw("$attr is not a splat attribute.") + end + return new{T}(attr) + end +end + """ DenseElementsAttribute(shapedType, elements) @@ -427,7 +498,9 @@ Creates a dense elements attribute with the given Shaped type and elements in th """ function DenseElementsAttribute(shaped_type::Type, elements::AbstractArray) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return Attribute(API.mlirDenseElementsAttrGet(shaped_type, length(elements), elements)) + return DenseElementsAttribute{shaped_type}( + API.mlirDenseElementsAttrGet(shaped_type, length(elements), elements) + ) end # TODO mlirDenseElementsAttrRawBufferGet @@ -439,52 +512,60 @@ Creates a dense elements attribute with the given Shaped type containing a singl """ function Base.fill(attr::Attribute, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return Attribute(API.mlirDenseElementsAttrSplatGet(shaped_type, attr)) + return SplatAttribute(API.mlirDenseElementsAttrSplatGet(shaped_type, attr)) end function Base.fill(value::Bool, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrBoolSplatGet(shaped_type, value) + return SplatAttribute{Bool}(API.mlirDenseElementsAttrBoolSplatGet(shaped_type, value)) end function Base.fill(value::UInt8, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrUInt8SplatGet(shaped_type, value) + return SplatAttribute{UInt8}(API.mlirDenseElementsAttrUInt8SplatGet(shaped_type, value)) end function Base.fill(value::Int8, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrInt8SplatGet(shaped_type, value) + return SplatAttribute{Int8}(API.mlirDenseElementsAttrInt8SplatGet(shaped_type, value)) end function Base.fill(value::UInt32, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrUInt32SplatGet(shaped_type, value) + return SplatAttribute{UInt32}( + API.mlirDenseElementsAttrUInt32SplatGet(shaped_type, value) + ) end function Base.fill(value::Int32, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrInt32SplatGet(shaped_type, value) + return SplatAttribute{Int32}(API.mlirDenseElementsAttrInt32SplatGet(shaped_type, value)) end function Base.fill(value::UInt64, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrUInt64SplatGet(shaped_type, value) + return SplatAttribute{UInt64}( + API.mlirDenseElementsAttrUInt64SplatGet(shaped_type, value) + ) end function Base.fill(value::Int64, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrInt64SplatGet(shaped_type, value) + return SplatAttribute{Int64}(API.mlirDenseElementsAttrInt64SplatGet(shaped_type, value)) end function Base.fill(value::Float32, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrFloatSplatGet(shaped_type, value) + return SplatAttribute{Float32}( + API.mlirDenseElementsAttrFloatSplatGet(shaped_type, value) + ) end function Base.fill(value::Float64, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrDoubleSplatGet(shaped_type, value) + return SplatAttribute{Float64}( + API.mlirDenseElementsAttrDoubleSplatGet(shaped_type, value) + ) end function Base.fill(::Core.Type{Attribute}, value, shape) @@ -503,7 +584,7 @@ Creates a dense elements attribute with the given shaped type from elements of a """ function DenseElementsAttribute(values::AbstractArray{Bool}) shaped_type = TensorType(size(values), Type(Bool)) - return Attribute( + return DenseElementsAttribute{Bool}( API.mlirDenseElementsAttrBoolGet( shaped_type, length(values), AbstractArray{Cint}(to_row_major(values)) ), @@ -512,21 +593,21 @@ end function DenseElementsAttribute(values::AbstractArray{UInt8}) shaped_type = TensorType(size(values), Type(UInt8)) - return Attribute( + return DenseElementsAttribute{UInt8}( API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), to_row_major(values)) ) end function DenseElementsAttribute(values::AbstractArray{Int8}) shaped_type = TensorType(size(values), Type(Int8)) - return Attribute( + return DenseElementsAttribute{Int8}( API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), to_row_major(values)) ) end function DenseElementsAttribute(values::AbstractArray{UInt16}) shaped_type = TensorType(size(values), Type(UInt16)) - return Attribute( + return DenseElementsAttribute{UInt16}( API.mlirDenseElementsAttrUInt16Get( shaped_type, length(values), to_row_major(values) ), @@ -535,14 +616,14 @@ end function DenseElementsAttribute(values::AbstractArray{Int16}) shaped_type = TensorType(size(values), Type(Int16)) - return Attribute( + return DenseElementsAttribute{Int16}( API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), to_row_major(values)) ) end function DenseElementsAttribute(values::AbstractArray{UInt32}) shaped_type = TensorType(size(values), Type(UInt32)) - return Attribute( + return DenseElementsAttribute{UInt32}( API.mlirDenseElementsAttrUInt32Get( shaped_type, length(values), to_row_major(values) ), @@ -551,14 +632,14 @@ end function DenseElementsAttribute(values::AbstractArray{Int32}) shaped_type = TensorType(size(values), Type(Int32)) - return Attribute( + return DenseElementsAttribute{Int32}( API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), to_row_major(values)) ) end function DenseElementsAttribute(values::AbstractArray{UInt64}) shaped_type = TensorType(size(values), Type(UInt64)) - return Attribute( + return DenseElementsAttribute{UInt64}( API.mlirDenseElementsAttrUInt64Get( shaped_type, length(values), to_row_major(values) ), @@ -567,21 +648,21 @@ end function DenseElementsAttribute(values::AbstractArray{Int64}) shaped_type = TensorType(size(values), Type(Int64)) - return Attribute( + return DenseElementsAttribute{Int64}( API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), to_row_major(values)) ) end function DenseElementsAttribute(values::AbstractArray{Float32}) shaped_type = TensorType(size(values), Type(Float32)) - return Attribute( + return DenseElementsAttribute{Float32}( API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), to_row_major(values)) ) end function DenseElementsAttribute(values::AbstractArray{Float64}) shaped_type = TensorType(size(values), Type(Float64)) - return Attribute( + return DenseElementsAttribute{Float64}( API.mlirDenseElementsAttrDoubleGet( shaped_type, length(values), to_row_major(values) ), @@ -591,7 +672,7 @@ end if isdefined(Core, :BFloat16) function DenseElementsAttribute(values::AbstractArray{Core.BFloat16}) shaped_type = TensorType(size(values), Type(Core.BFloat16)) - return Attribute( + return DenseElementsAttribute{Core.BFloat16}( API.mlirDenseElementsAttrBFloat16Get( shaped_type, length(values), to_row_major(values) ), @@ -601,16 +682,16 @@ end function DenseElementsAttribute(values::AbstractArray{Float16}) shaped_type = TensorType(size(values), Type(Float16)) - return Attribute( + return DenseElementsAttribute{Float16}( API.mlirDenseElementsAttrFloat16Get( shaped_type, length(values), to_row_major(values) ), ) end -function DenseElementsAttribute(values::AbstractArray) +function DenseElementsAttribute(values::AbstractArray{T}) where {T} shaped_type = TensorType(size(values), Type(eltype(values))) - return Attribute( + return DenseElementsAttribute{T}( API.mlirDenseElementsAttrRawBufferGet( shaped_type, length(values) * Base.elsize(values), to_row_major(values) ), @@ -625,7 +706,7 @@ Creates a dense elements attribute with the given shaped type from string elemen function DenseElementsAttribute(values::AbstractArray{String}) # TODO may fail because `Type(String)` is not defined shaped_type = TensorType(size(values), Type(String)) - return Attribute( + return DenseElementsAttribute{String}( API.mlirDenseElementsAttrStringGet( shaped_type, length(values), to_row_major(values) ), @@ -637,12 +718,11 @@ end Creates a dense elements attribute that has the same data as the given dense elements attribute and a different shaped type. The new type must have the same total number of elements. """ -function Base.reshape(attr::Attribute, shape) - @assert isdenseelements(attr) "attribute $(attr) is not a dense elements attribute" +function Base.reshape(attr::DenseElementsAttribute{T}, shape) where {T} @assert length(attr) == prod(shape) "new shape $(shape) has a different number of elements than the original attribute" element_type = eltype(type(attr)) shaped_type = TensorType(shape, element_type) - return Attribute(API.mlirDenseElementsAttrReshape(attr, shaped_type)) + return DenseElementsAttribute{T}(API.mlirDenseElementsAttrReshape(attr, shaped_type)) end """ @@ -745,7 +825,43 @@ function Base.length(attr::Attribute) end end +function Base.getindex(attr::DenseElementsAttribute, i) + @assert i >= 1 + i -= 1 + attr = Attribute(attr) + elem_type = julia_type(eltype(type(attr))) + if elem_type isa Bool + API.mlirDenseElementsAttrGetBoolValue(attr, i) + elseif elem_type isa Int8 + API.mlirDenseElementsAttrGetInt8Value(attr, i) + elseif elem_type isa UInt8 + API.mlirDenseElementsAttrGetUInt8Value(attr, i) + elseif elem_type isa Int16 + API.mlirDenseElementsAttrGetInt16Value(attr, i) + elseif elem_type isa UInt16 + API.mlirDenseElementsAttrGetUInt16Value(attr, i) + elseif elem_type isa Int32 + API.mlirDenseElementsAttrGetInt32Value(attr, i) + elseif elem_type isa UInt32 + API.mlirDenseElementsAttrGetUInt32Value(attr, i) + elseif elem_type isa Int64 + API.mlirDenseElementsAttrGetInt64Value(attr, i) + elseif elem_type isa UInt64 + API.mlirDenseElementsAttrGetUInt64Value(attr, i) + elseif elem_type isa Float32 + API.mlirDenseElementsAttrGetFloatValue(attr, i) + elseif elem_type isa Float64 + API.mlirDenseElementsAttrGetDoubleValue(attr, i) + elseif elem_type isa String # TODO does this case work? + String(API.mlirDenseElementsAttrGetStringValue(attr, i)) + else + throw("unsupported element type $(elem_type)") + end +end + function Base.getindex(attr::Attribute, i) + @assert i >= 1 + i -= 1 if isarray(attr) Attribute(API.mlirArrayAttrGetElement(attr, i)) elseif isdict(attr) @@ -835,8 +951,8 @@ function Base.getindex(attr::Attribute) end end -function Base.show(io::IO, attribute::Attribute) - print(io, "Attribute(#= ") +function Base.show(io::IO, attribute::AbstractAttribute) + print(io, "$(typeof(attribute))(#= ") c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) ref = Ref(io) API.mlirAttributePrint(attribute, c_print_callback, ref) @@ -852,8 +968,8 @@ end Associates an attribute with the name. Takes ownership of neither. """ -function NamedAttribute(name, attribute; context=context(attribute)) - @assert !mlirIsNull(attribute.attribute) +function NamedAttribute(name, attribute::AbstractAttribute; context=context(attribute)) + @assert !mlirIsNull(Attribute(attribute)) name = API.mlirIdentifierGet(context, name) return NamedAttribute(API.mlirNamedAttributeGet(name, attribute)) end @@ -861,3 +977,7 @@ end function Base.convert(::Core.Type{API.MlirAttribute}, named_attribute::NamedAttribute) return named_attribute.named_attribute end + +function DenseArrayAttribute(values::Vector{<:Enum}) + return Attribute([Attribute(value) for value in values]) +end diff --git a/src/mlir/IR/Operation.jl b/src/mlir/IR/Operation.jl index 6f45bbf8e..caf846aa3 100644 --- a/src/mlir/IR/Operation.jl +++ b/src/mlir/IR/Operation.jl @@ -189,7 +189,7 @@ function attr(operation::Operation, name::AbstractString) if mlirIsNull(raw_attr) return nothing end - return Attribute(raw_attr) + return getattribute(raw_attr) end """ diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 29a9a2874..5031e0573 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -10,7 +10,7 @@ using ..Reactant: unwrapped_eltype, Ops, MLIR - +using ..Reactant.MLIR.Dialects: stablehlo using ..TracedUtils: TracedUtils, get_mlir_data, materialize_traced_array, set_mlir_data! using LinearAlgebra @@ -42,7 +42,10 @@ function TracedUtils.materialize_traced_array( return diagm(-1 => x.dl, 0 => x.d, 1 => x.du) end -for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE")) +for (AT, comp) in ( + (:LowerTriangular, stablehlo.ComparisonDirection.GE), + (:UpperTriangular, stablehlo.ComparisonDirection.LE), +) uAT = Symbol(:Unit, AT) @eval begin function TracedUtils.materialize_traced_array( @@ -61,7 +64,9 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE")) m, n = size(x) row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) - nondiag_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="NE") + nondiag_indicator = Ops.compare( + row_idxs, col_idxs; comparison_direction=stablehlo.ComparisonDirection.NE + ) x = materialize_traced_array($(AT)(parent(x))) return Ops.select(nondiag_indicator, x, one.(x)) end @@ -75,12 +80,16 @@ function TracedUtils.materialize_traced_array( row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) if x.uplo == 'L' - indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="GT") + indicator = Ops.compare( + row_idxs, col_idxs; comparison_direction=stablehlo.ComparisonDirection.GT + ) x_lt = Ops.select(indicator, parent(x), zero(parent(x))) x_ltd = materialize_traced_array(LowerTriangular(parent(x))) return Ops.add(x_lt, Ops.transpose(x_ltd, [2, 1])) else - indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="LT") + indicator = Ops.compare( + row_idxs, col_idxs; comparison_direction=stablehlo.ComparisonDirection.LT + ) x_ut = Ops.select(indicator, parent(x), zero(parent(x))) x_utd = materialize_traced_array(UpperTriangular(parent(x))) return Ops.add(Ops.transpose(x_utd, [2, 1]), x_ut) @@ -121,10 +130,18 @@ function TracedUtils.set_mlir_data!( end for (AT, dcomp, ocomp) in ( - (:LowerTriangular, "GE", "LT"), - (:UnitLowerTriangular, "GT", "LE"), - (:UpperTriangular, "LE", "GT"), - (:UnitUpperTriangular, "LT", "GE"), + (:LowerTriangular, stablehlo.ComparisonDirection.GE, stablehlo.ComparisonDirection.LT), + ( + :UnitLowerTriangular, + stablehlo.ComparisonDirection.GT, + stablehlo.ComparisonDirection.LE, + ), + (:UpperTriangular, stablehlo.ComparisonDirection.LE, stablehlo.ComparisonDirection.GT), + ( + :UnitUpperTriangular, + stablehlo.ComparisonDirection.LT, + stablehlo.ComparisonDirection.GE, + ), ) @eval function TracedUtils.set_mlir_data!( x::$(AT){TracedRNumber{T},TracedRArray{T,2}}, data @@ -233,7 +250,9 @@ function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) wh Ops.iota(Int64, [size(X)...]; iota_dimension=2), TracedUtils.broadcast_to_size(k, size(X)), ) - idxs = Ops.compare(iota_1, iota_2; comparison_direction="LE") + idxs = Ops.compare( + iota_1, iota_2; comparison_direction=stablehlo.ComparisonDirection.LE + ) X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data return X end @@ -244,7 +263,9 @@ function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) wh Ops.iota(Int64, [size(X)...]; iota_dimension=2), TracedUtils.broadcast_to_size(k, size(X)), ) - idxs = Ops.compare(iota_1, iota_2; comparison_direction="GE") + idxs = Ops.compare( + iota_1, iota_2; comparison_direction=stablehlo.ComparisonDirection.GE + ) X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data return X end @@ -302,7 +323,9 @@ function LinearAlgebra._diagm( MLIR.IR.result(MLIR.Dialects.stablehlo.concatenate(concat_inputs; dimension=0), 1), (size(scatter_indices, 1),), ) - return Ops.scatter_setindex(Ops.fill(zero(T), (m, n)), scatter_indices, values) + return Ops.scatter_setindex( + Ops.constant(fill(zero(T), (m, n))), scatter_indices, values + ) end # Common Utilities diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index 617f1fac1..1a103b103 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -18,6 +18,7 @@ using ..Reactant: ConcreteRNumber, unwrapped_eltype using Random: Random, AbstractRNG +using Reactant.MLIR.Dialects: stablehlo @noinline make_seed(rng::AbstractRNG=Random.RandomDevice()) = Random.rand!(rng, Vector{UInt64}(undef, 2)) @@ -70,12 +71,13 @@ Base.copy(rng::ConcreteRNG) = ConcreteRNG(copy(rng.seed), rng.algorithm) Base.copy(rng::TracedRNG) = TracedRNG(copy(rng.seed), rng.algorithm) @noinline ConcreteRNG() = ConcreteRNG(ConcreteRArray(make_seed())) -@noinline ConcreteRNG(seed::ConcreteRArray{UInt64,1}) = ConcreteRNG(seed, "DEFAULT") +@noinline ConcreteRNG(seed::ConcreteRArray{UInt64,1}) = + ConcreteRNG(seed, stablehlo.RngAlgorithm.DEFAULT) @noinline default_rng() = ConcreteRNG() @noinline rng_algorithm(rng::TracedRNG) = rng.algorithm -@noinline rng_algorithm(::AbstractRNG) = "DEFAULT" +@noinline rng_algorithm(::AbstractRNG) = stablehlo.RngAlgorithm.DEFAULT @noinline function internal_overload_rand!( rng::TracedRNG, A::AnyTracedRArray{T,N} diff --git a/test/ops.jl b/test/ops.jl index 4cfe1305d..8f0ca48d1 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -1,5 +1,6 @@ using Reactant, Test using Reactant: Ops +using Reactant.MLIR.Dialects: stablehlo using LinearAlgebra using SpecialFunctions: SpecialFunctions @@ -274,8 +275,8 @@ end end @testset "fft" begin - grfft(x) = Ops.fft(x; type="RFFT", length=[4]) - gfft(x) = Ops.fft(x; type="FFT", length=[4]) + grfft(x) = Ops.fft(x; type=stablehlo.FftType.RFFT, length=[4]) + gfft(x) = Ops.fft(x; type=stablehlo.FftType.FFT, length=[4]) x = ConcreteRArray([1.0, 1.0, 1.0, 1.0]) @test ComplexF64[4.0, 0.0, 0.0] ≈ @jit grfft(x)