Skip to content

Commit

Permalink
tablegen: Add StaticSelect to select based on static condition (#2206)
Browse files Browse the repository at this point in the history
* tablegen: Add StaticIf to select based on static condition

* rename to StaticSelect and implement SelectIfActive and SelectIfComplex
with it

* define for llvm

* put vector mode for LLVM back

* basic use analysis

* StaticSelect use analysis

* fixup

---------

Co-authored-by: William S. Moses <[email protected]>
  • Loading branch information
Pangoraw and wsmoses authored Jan 9, 2025
1 parent ac3be7e commit 495fde3
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 206 deletions.
8 changes: 5 additions & 3 deletions enzyme/Enzyme/InstructionDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ class ConstantFP<string val> : Operation</*primal*/0, /*shadow*/0> {
string value = val;
}

class StaticSelect<string condition_> : Operation</*primal*/0, /*shadow*/0, /*custom*/0> {
string condition = condition_;
}

def SelectIfActive : StaticSelect<"!gutils->isConstantValue(imVal)">;

class Attribute<string name_> {
string name = name_;
Expand Down Expand Up @@ -62,9 +67,6 @@ class Inst<string mnemonic> : Operation</*primal*/1, /*shadow*/0> {
def TypeOf : Operation</*primal*/0, /*shadow*/0> {
}
def VectorSize : Operation</*primal*/0, /*shadow*/0> {
}
def SelectIfActive : Operation</*primal*/0, /*shadow*/0, /*custom*/1> {

}

// Define ops to rewrite.
Expand Down
25 changes: 18 additions & 7 deletions enzyme/Enzyme/MLIR/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,20 @@ def DiffeRet : DiffeRetIndex<[-1]>;
def Shadow : Operation</*primal*/0, /*shadow*/1> {
}

class GlobalExpr<bit uses_primal, bit uses_shadow, string val> : Operation<uses_primal, uses_shadow>{
class GlobalExpr<bit uses_primal, bit uses_shadow, string val> : Operation<uses_primal, uses_shadow> {
string value = val;
}

// Class for a dag operator that generates either a or b
// It can then be used with a two or three arguments.
// The two arguments version is (StaticSelect a, b)
// The three arguments version accepts a name as a first argument
// which is then available in the condition as a `Value` under the
// variable `imVal`.
class StaticSelect<string condition_> : Operation</*usesPrimal*/0, /*usesShadow*/0, /*usesCustom*/0> {
string condition = condition_;
}

class Inst<string mnemonic, string dialect_, string postop_=""> : Operation</*primal*/1, /*shadow*/0> {
string name = mnemonic;
string dialect = dialect_;
Expand All @@ -99,13 +109,14 @@ class Inst<string mnemonic, string dialect_, string postop_=""> : Operation</*p
def Op {
}

def SelectIfActive : Operation</*primal*/0, /*shadow*/0, /*custom*/1> {

}

def SelectIfComplex : Operation</*primal*/1, /*shadow*/0, /*custom*/0> {
def SelectIfActive : StaticSelect<"!gutils->isConstantValue(imVal)">;

}
def SelectIfComplex : StaticSelect<[{
auto ty = imVal.getType();
ty.isa<ComplexType>() ||
ty.isa<TensorType>() &&
ty.cast<TensorType>().getElementType().isa<ComplexType>();
}]>;

class ConstantFP<string val, string dialect_, string op_, string type_=""> : Operation</*primal*/0, /*shadow*/0> {
string value = val;
Expand Down
Loading

0 comments on commit 495fde3

Please sign in to comment.