Skip to content

Commit

Permalink
use AutoDiffTypeInterface for batching
Browse files Browse the repository at this point in the history
  • Loading branch information
Pangoraw committed Dec 21, 2024
1 parent 33c933b commit d8efc38
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ using namespace mlir;
using namespace mlir::enzyme;

namespace {

static mlir::Type batchType(mlir::Type type, int64_t width) {
if (width > 1 || ShapedType::isDynamic(width)) {
return RankedTensorType::get({width}, type);
}
return type;
}

class FloatTypeInterface
: public AutoDiffTypeInterface::ExternalModel<FloatTypeInterface,
FloatType> {
Expand All @@ -44,12 +52,8 @@ class FloatTypeInterface
return a;
}

Type getShadowType(Type self, unsigned width) const {
if (width > 1) {
return RankedTensorType::get({width}, self);
} else {
return self;
}
Type getShadowType(Type self, int64_t width) const {
return batchType(self, width);
}

bool isMutable(Type self) const { return false; }
Expand Down Expand Up @@ -108,9 +112,8 @@ class TensorTypeInterface
return added;
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
Type getShadowType(Type self, int64_t width) const {
return batchType(self, width);
}

bool isMutable(Type self) const { return false; }
Expand Down Expand Up @@ -141,9 +144,8 @@ class IntegerTypeInterface
return a;
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
Type getShadowType(Type self, int64_t width) const {
return batchType(self, width);
}

bool isMutable(Type self) const { return false; }
Expand Down Expand Up @@ -175,9 +177,8 @@ class ComplexTypeInterface
return builder.create<complex::ConjOp>(loc, a)->getResult(0);
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
Type getShadowType(Type self, int64_t width) const {
return batchType(self, width);
}

bool isMutable(Type self) const { return false; }
Expand Down
53 changes: 37 additions & 16 deletions enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace {
#include "Implementations/SCFDerivatives.inc"

// TODO: support non constant number of iteration by using unknown dimensions
static std::optional<int64_t> getNumberOfIterations(scf::ForOp forOp) {
static std::optional<int64_t> getConstantNumberOfIterations(scf::ForOp forOp) {
auto lb = forOp.getLowerBound();
auto ub = forOp.getUpperBound();
auto step = forOp.getStep();
Expand All @@ -55,6 +55,14 @@ static std::optional<int64_t> getNumberOfIterations(scf::ForOp forOp) {
return (ubI - lbI) / stepI;
}

static Value getNumberOfIterations(OpBuilder &builder, scf::ForOp forOp) {
Value lb = forOp.getLowerBound(), ub = forOp.getUpperBound(),
step = forOp.getStep();
Value diff = builder.create<arith::SubIOp>(forOp->getLoc(), ub, lb);
Value nSteps = builder.create<arith::DivUIOp>(forOp->getLoc(), diff, step);
return nSteps;
}

struct ForOpEnzymeOpsRemover
: public EnzymeOpsRemoverOpInterface::ExternalModel<ForOpEnzymeOpsRemover,
scf::ForOp> {
Expand Down Expand Up @@ -132,7 +140,7 @@ struct ForOpEnzymeOpsRemover
}
}

auto numIters = getNumberOfIterations(forOp);
auto numIters = getConstantNumberOfIterations(forOp);
Value inductionVariable; // [0, N[ counter

if (matchPattern(forOp.getLowerBound(), m_Zero()) &&
Expand Down Expand Up @@ -186,16 +194,25 @@ struct ForOpEnzymeOpsRemover
}

auto newType =
info.batchType(numIters.value_or(mlir::ShapedType::kDynamic));
ValueRange operands =
numIters.has_value()
? ValueRange{}
: ValueRange{builder
.create<arith::ConstantOp>(
forOp->getLoc(), builder.getIndexAttr(10))
.getResult()};
auto initValue = builder.create<tensor::EmptyOp>(info.initOp->getLoc(),
newType, operands);
info.cachedType()
.cast<AutoDiffTypeInterface>()
.getShadowType(numIters.value_or(mlir::ShapedType::kDynamic))
.cast<ShapedType>();

SmallVector<Value> dynamicDims;

for (auto it : llvm::enumerate(newType.getShape())) {
if (ShapedType::isDynamic(it.value())) {
if (it.index() == 0)
dynamicDims.push_back(getNumberOfIterations(builder, forOp));
else
return failure(); // TODO: find dynamic dims within the body.
}
}

Value initValue = builder.create<tensor::EmptyOp>(info.initOp->getLoc(),
newType, dynamicDims);

// cast<AutoDiffTypeInterface>(newType).createNullValue(
// builder, info.initOp->getLoc());

Expand Down Expand Up @@ -241,9 +258,11 @@ struct ForOpEnzymeOpsRemover
builder.setInsertionPoint(otherForOp);
SmallVector<Value> operands(otherForOp.getInitArgs().begin(),
otherForOp.getInitArgs().end());
operands.push_back(builder.create<arith::ConstantOp>(
otherForOp->getLoc(),
builder.getIndexAttr(numIters.value_or(1) - 1)));
operands.push_back(numIters.has_value()
? builder.create<arith::ConstantOp>(
otherForOp->getLoc(),
builder.getIndexAttr(numIters.value() - 1))
: getNumberOfIterations(builder, forOp));

Block *otherBody = otherForOp.getBody();
Value otherInductionVariable =
Expand Down Expand Up @@ -285,7 +304,9 @@ struct ForOpEnzymeOpsRemover

Value cache = info.initOp.getResult();

auto newType = info.batchType(numIters.value());
auto newType =
info.cachedType().cast<AutoDiffTypeInterface>().getShadowType(
numIters.value());
enzyme::InitOp newInit = ({
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(info.initOp);
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def AutoDiffTypeInterface : TypeInterface<"AutoDiffTypeInterface"> {
}],
/*retTy=*/"::mlir::Type",
/*methodName=*/"getShadowType",
/*args=*/(ins "unsigned":$width)
/*args=*/(ins "int64_t":$width)
>,
InterfaceMethod<
/*desc=*/[{
Expand Down
17 changes: 0 additions & 17 deletions enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,6 @@
#include "RemovalUtils.h"
#include "Interfaces/AutoDiffOpInterface.h"

mlir::Type mlir::enzyme::CacheInfo::batchType() {
return mlir::enzyme::CacheInfo::batchType(mlir::ShapedType::kDynamic);
}

mlir::Type mlir::enzyme::CacheInfo::batchType(int64_t dim) {
auto T = pushedValue().getType();

if (auto TT = dyn_cast<mlir::TensorType>(T)) {
SmallVector<int64_t> shape;
shape.push_back(dim);
shape.append(TT.getShape().begin(), TT.getShape().end());
return TT.clone(shape);
}

return mlir::RankedTensorType::get({dim}, T);
}

mlir::LogicalResult mlir::enzyme::removeOpsWithinBlock(mlir::Block *block) {
bool valid = true;

Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/MLIR/Passes/RemovalUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ struct CacheInfo {
}

Value pushedValue() { return pushOp.getValue(); }

Type batchType(int64_t dim);
Type batchType(); // unknown size
Type cachedType() {
return initOp.getResult().getType().cast<enzyme::CacheType>().getType();
}
};

LogicalResult removeOpsWithinBlock(Block *block);
Expand Down

0 comments on commit d8efc38

Please sign in to comment.