Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add type conversions for width != 1.
This still requires changes in the tblgenerated derivative files. For example, createForwardModeTangent in MulFOpFwdDerivative could be altered like this: ``` LogicalResult createForwardModeTangent(Operation *op0, OpBuilder &builder, MGradientUtils *gutils) const { auto op = cast<arith::MulFOp>(op0); if (gutils->width != 1) { auto newop = gutils->getNewFromOriginal(op0); for (auto res : newop->getResults()) { res.setType(mlir::RankedTensorType::get({gutils->width}, res.getType())); } } gutils->eraseIfUnused(op); if (gutils->isConstantInstruction(op)) return success(); mlir::Value res = nullptr; if (!gutils->isConstantValue(op->getOperand(0))) { auto dif = gutils->invertPointerM(op->getOperand(0), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); // TODO: gutils->makeBatched(...) auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1)); builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1); }); itmp.dump(); if (!res) res = itmp; else { auto operandType = cast<AutoDiffTypeInterface>(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } if (!gutils->isConstantValue(op->getOperand(1))) { auto dif = gutils->invertPointerM(op->getOperand(1), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(0)); builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1); }); if (!res) res = itmp; else { auto operandType = cast<AutoDiffTypeInterface>(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } assert(res); gutils->setDiffe(op->getResult(0), res, builder); return success(); } ```
- Loading branch information