Skip to content

Commit

Permalink
add type conversions for width != 1.
Browse files Browse the repository at this point in the history
This still requires changes in the tblgenerated derivative files. For example, createForwardModeTangent in MulFOpFwdDerivative could be altered like this:
```
  LogicalResult createForwardModeTangent(Operation *op0, OpBuilder &builder, MGradientUtils *gutils) const
  {
    auto op = cast<arith::MulFOp>(op0);
    if (gutils->width != 1) {
      auto newop = gutils->getNewFromOriginal(op0);
      for (auto res : newop->getResults()) {
        res.setType(mlir::RankedTensorType::get({gutils->width}, res.getType()));
      }
    }
    gutils->eraseIfUnused(op);
    if (gutils->isConstantInstruction(op))
      return success();
    mlir::Value res = nullptr;
    if (!gutils->isConstantValue(op->getOperand(0)))
    {
      auto dif = gutils->invertPointerM(op->getOperand(0), builder);
      {
        mlir::Value itmp = ({
          // Computing MulFOp
          auto fwdarg_0 = dif;
          dif.dump();
          // TODO: gutils->makeBatched(...)
          auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1));
          builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1);
        });
        itmp.dump();
        if (!res)
          res = itmp;
        else
        {
          auto operandType = cast<AutoDiffTypeInterface>(res.getType());
          res = operandType.createAddOp(builder, op.getLoc(), res, itmp);
        }
      }
    }
    if (!gutils->isConstantValue(op->getOperand(1)))
    {
      auto dif = gutils->invertPointerM(op->getOperand(1), builder);
      {
        mlir::Value itmp = ({
          // Computing MulFOp
          auto fwdarg_0 = dif;
          dif.dump();
          auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(0));
          builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1);
        });
        if (!res)
          res = itmp;
        else
        {
          auto operandType = cast<AutoDiffTypeInterface>(res.getType());
          res = operandType.createAddOp(builder, op.getLoc(), res, itmp);
        }
      }
    }
    assert(res);
    gutils->setDiffe(op->getResult(0), res, builder);
    return success();
  }
```
  • Loading branch information
jumerckx authored and Pangoraw committed Dec 21, 2024
1 parent 6bf5d41 commit 33c933b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ class FloatTypeInterface
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
if (width > 1) {
return RankedTensorType::get({width}, self);
} else {
return self;
}
}

bool isMutable(Type self) const { return false; }
Expand Down
17 changes: 15 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode,
for (auto &&[Ty, returnPrimal, returnShadow, activity] : llvm::zip(
FTy.getResults(), returnPrimals, returnShadows, ReturnActivity)) {
if (returnPrimal) {
RetTypes.push_back(Ty);
if (width != 1) {
RetTypes.push_back(mlir::RankedTensorType::get({width}, Ty));
} else {
RetTypes.push_back(Ty);
}
}
if (returnShadow) {
assert(activity != DIFFE_TYPE::CONSTANT);
Expand All @@ -39,7 +43,11 @@ getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode,
SmallVector<mlir::Type, 4> ArgTypes;

for (auto &&[ITy, act] : llvm::zip(FTy.getInputs(), ArgActivity)) {
ArgTypes.push_back(ITy);
if (width != 1) {
ArgTypes.push_back(mlir::RankedTensorType::get({width}, ITy));
} else {
ArgTypes.push_back(ITy);
}
if (act == DIFFE_TYPE::DUP_ARG || act == DIFFE_TYPE::DUP_NONEED) {
ArgTypes.push_back(getShadowType(ITy, width));
} else if (act == DIFFE_TYPE::OUT_DIFF) {
Expand Down Expand Up @@ -232,6 +240,11 @@ FunctionOpInterface CloneFunctionWithReturns(

{
auto &blk = NewF.getFunctionBody().front();
if (width != 1) {
for (auto &arg : blk.getArguments()) {
arg.setType(mlir::RankedTensorType::get({width}, arg.getType()));
}
}
assert(F.getFunctionBody().front().getNumArguments() == ArgActivity.size());
for (ssize_t i = ArgActivity.size() - 1; i >= 0; i--) {
mlir::Value oval = F.getFunctionBody().front().getArgument(i);
Expand Down

0 comments on commit 33c933b

Please sign in to comment.