Skip to content

Commit

Permalink
[mlir] Async: add support for lowering async value operands to LLVM
Browse files Browse the repository at this point in the history
Depends On D93592

Add support for `async.execute` async value unwrapping operands:

```
%token = async.execute(%async_value as %unwrapped : !async.value<!my.type>) {
  ...
  async.yield
}
```

Reviewed By: csigg

Differential Revision: https://reviews.llvm.org/D93598
  • Loading branch information
ezhulenev committed Dec 25, 2020
1 parent 621ad46 commit 61422c8
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 12 deletions.
23 changes: 13 additions & 10 deletions mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
// Collect all outlined function inputs.
llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
execute.dependencies().end());
assert(execute.operands().empty() && "operands are not supported");
functionInputs.insert(execute.operands().begin(), execute.operands().end());
getUsedValuesDefinedAbove(execute.body(), functionInputs);

// Collect types for the outlined function inputs and outputs.
Expand Down Expand Up @@ -636,15 +636,26 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended,
resume, builder);

size_t numDependencies = execute.dependencies().size();
size_t numOperands = execute.operands().size();

// Await on all dependencies before starting to execute the body region.
builder.setInsertionPointToStart(resume);
for (size_t i = 0; i < execute.dependencies().size(); ++i)
for (size_t i = 0; i < numDependencies; ++i)
builder.create<AwaitOp>(func.getArgument(i));

// Await on all async value operands and unwrap the payload.
SmallVector<Value, 4> unwrappedOperands(numOperands);
for (size_t i = 0; i < numOperands; ++i) {
Value operand = func.getArgument(numDependencies + i);
unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
}

// Map from function inputs defined above the execute op to the function
// arguments.
BlockAndValueMapping valueMapping;
valueMapping.map(functionInputs, func.getArguments());
valueMapping.map(execute.body().getArguments(), unwrappedOperands);

// Clone all operations from the execute operation body into the outlined
// function body.
Expand Down Expand Up @@ -1069,14 +1080,6 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
return WalkResult::interrupt();
}

// We currently do not support execute operations that have async value
// operands or produce async results.
if (!execute.operands().empty()) {
execute.emitOpError(
"can't outline async.execute op with async value operands");
return WalkResult::interrupt();
}

outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));

return WalkResult::advance();
Expand Down
51 changes: 51 additions & 0 deletions mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,54 @@ func @execute_and_return_f32() -> f32 {
// Emplace result token.
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[TOKEN]])

// -----

// RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s

func @async_value_operands() {
// CHECK: %[[RET:.*]]:2 = call @async_execute_fn
%token, %result = async.execute -> !async.value<f32> {
%c0 = constant 123.0 : f32
async.yield %c0 : f32
}

// CHECK: %[[TOKEN:.*]] = call @async_execute_fn_0(%[[RET]]#1)
%token0 = async.execute(%result as %value: !async.value<f32>) {
%0 = addf %value, %value : f32
async.yield
}

// CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]])
async.await %token0 : !async.token

return
}

// Function outlined from the first async.execute operation.
// CHECK-LABEL: func private @async_execute_fn()

// Function outlined from the second async.execute operation.
// CHECK-LABEL: func private @async_execute_fn_0(%arg0: !llvm.ptr<i8>)
// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken()
// CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin

// Suspend coroutine in the beginning.
// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]],
// CHECK: llvm.call @llvm.coro.suspend

// Suspend coroutine second time waiting for the async operand.
// CHECK: llvm.call @llvm.coro.save
// CHECK: call @mlirAsyncRuntimeAwaitValueAndExecute(%arg0, %[[HDL]],
// CHECK: llvm.call @llvm.coro.suspend

// Get the operand value storage, cast to f32 and add the value.
// CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%arg0)
// CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]]
// CHECK: %[[LOADED:.*]] = llvm.load %[[ST_F32]] : !llvm.ptr<float>
// CHECK: %[[CASTED:.*]] = llvm.mlir.cast %[[LOADED]] : !llvm.float to f32
// CHECK: addf %[[CASTED]], %[[CASTED]] : f32

// Emplace result token.
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[TOKEN]])


21 changes: 19 additions & 2 deletions mlir/test/mlir-cpu-runner/async-value.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func @main() {
// ------------------------------------------------------------------------ //
%token2, %result2 = async.execute[%token0] -> !async.value<memref<f32>> {
%5 = alloc() : memref<f32>
%c0 = constant 987.654 : f32
%c0 = constant 0.25 : f32
store %c0, %5[]: memref<f32>
async.yield %5 : memref<f32>
}
Expand All @@ -53,8 +53,25 @@ func @main() {

// CHECK: Unranked Memref
// CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = []
// CHECK-NEXT: [987.654]
// CHECK-NEXT: [0.25]
call @print_memref_f32(%7): (memref<*xf32>) -> ()

// ------------------------------------------------------------------------ //
// Memref passed as async.execute operand.
// ------------------------------------------------------------------------ //
%token3 = async.execute(%result2 as %unwrapped : !async.value<memref<f32>>) {
%8 = load %unwrapped[]: memref<f32>
%9 = addf %8, %8 : f32
store %9, %unwrapped[]: memref<f32>
async.yield
}
async.await %token3 : !async.token

// CHECK: Unranked Memref
// CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = []
// CHECK-NEXT: [0.5]
call @print_memref_f32(%7): (memref<*xf32>) -> ()

dealloc %6 : memref<f32>

return
Expand Down

0 comments on commit 61422c8

Please sign in to comment.