Skip to content

Commit

Permalink
Blit Node
Browse files Browse the repository at this point in the history
Summary: Introduce a graph node to call vkcmdBlitImage which can convert between dtypes (and also perform scaling, filtering etc. but we don't need them right now).

Differential Revision: D63839654
  • Loading branch information
abhishekchandra authored and facebook-github-bot committed Oct 9, 2024
1 parent 8a32657 commit d82f981
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 0 deletions.
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/api/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ void Context::register_shader_dispatch(
cmd_.dispatch(effective_global_wg);
}

void Context::register_blit(
vkapi::PipelineBarrier& pipeline_barrier,
vkapi::VulkanImage& src,
vkapi::VulkanImage& dst) {
cmd_.insert_barrier(pipeline_barrier);
cmd_.blit(src, dst);
}

void Context::submit_cmd_to_gpu(VkFence fence_handle, const bool final_use) {
if (cmd_) {
cmd_.end();
Expand Down
5 changes: 5 additions & 0 deletions backends/vulkan/runtime/api/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ class Context final {
const vkapi::ShaderInfo&,
const utils::uvec3&);

void register_blit(
vkapi::PipelineBarrier&,
vkapi::VulkanImage& src,
vkapi::VulkanImage& dst);

template <typename... Arguments>
bool submit_compute_job(
const vkapi::ShaderInfo&,
Expand Down
65 changes: 65 additions & 0 deletions backends/vulkan/runtime/graph/ops/BlitNode.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/BlitNode.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

namespace vkcompute {

BlitNode::BlitNode(
ComputeGraph& graph,
ValueRef src,
ValueRef dst,
// const vkapi::ScalarType& dtype,
const ResizeFunction& resize_fn,
const std::vector<ValueRef>& resize_args)
: ExecuteNode(resize_fn, resize_args, {}, "Blit Node"),
src_(src),
dst_(dst) {
(void)graph;
}

void BlitNode::encode(ComputeGraph* graph) {
auto src_tensor = graph->get_tensor(src_);
auto dst_tensor = graph->get_tensor(dst_);
VK_CHECK_COND(
src_tensor->storage_type() != utils::kBuffer &&
dst_tensor->storage_type() != utils::kBuffer,
"BlitNode: Only texture backed tensors are supported.");

api::Context* const context = graph->context();
vkapi::PipelineBarrier pipeline_barrier{};

std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();

// Hack to get timing data for non shader op
std::string kernel_name("Blit_");
kernel_name.reserve(32);
kernel_name += vkapi::to_string(src_tensor->dtype());
kernel_name += "_to_";
kernel_name += vkapi::to_string(dst_tensor->dtype());

context->report_shader_dispatch_start(
kernel_name, utils::uvec3(), utils::uvec3(), node_id_);

context->register_blit(
pipeline_barrier,
src_tensor->image(
pipeline_barrier,
vkapi::PipelineStage::TRANSFER,
vkapi::kRead),
dst_tensor->image(
pipeline_barrier,
vkapi::PipelineStage::TRANSFER,
vkapi::kWrite));

context->report_shader_dispatch_end();
}

} // namespace vkcompute
45 changes: 45 additions & 0 deletions backends/vulkan/runtime/graph/ops/BlitNode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/vulkan/runtime/api/api.h>
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>

#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>

namespace vkcompute {

/*
* Represents a tensor blit execution op in a ML model.
*/
class BlitNode final : public ExecuteNode {
friend class ComputeGraph;

public:
explicit BlitNode(
ComputeGraph& graph,
ValueRef src,
ValueRef dst,
/*const vkapi::ScalarType& dtype,*/
const ResizeFunction& resize_fn = nullptr,
const std::vector<ValueRef>& resize_args = {});

~BlitNode() = default;

void encode(ComputeGraph* graph) override;

protected:
ValueRef src_;
ValueRef dst_;
// const vkapi::ScalarType &dtype_;
};

} // namespace vkcompute
39 changes: 39 additions & 0 deletions backends/vulkan/runtime/vk_api/Command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,45 @@ void CommandBuffer::dispatch(const utils::uvec3& global_workgroup_size) {
state_ = CommandBuffer::State::RECORDING;
}

void CommandBuffer::blit(vkapi::VulkanImage& src, vkapi::VulkanImage& dst) {
VK_CHECK_COND(
state_ == CommandBuffer::State::BARRIERS_INSERTED,
"Vulkan CommandBuffer: called blit() on a command buffer whose state "
"is not BARRIERS_INSERTED.");

auto src_extents = src.extents();
auto dst_extents = dst.extents();

VkImageBlit blit{};
blit.srcOffsets[0] = {0, 0, 0},
blit.srcOffsets[1] =
{static_cast<int32_t>(src_extents.width),
static_cast<int32_t>(src_extents.height),
static_cast<int32_t>(src_extents.depth)},
blit.srcSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT,
blit.srcSubresource.mipLevel = 0, blit.srcSubresource.baseArrayLayer = 0,
blit.srcSubresource.layerCount = 1, blit.dstOffsets[0] = {0, 0, 0},
blit.dstOffsets[1] =
{static_cast<int32_t>(dst_extents.width),
static_cast<int32_t>(dst_extents.height),
static_cast<int32_t>(dst_extents.depth)},
blit.dstSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT,
blit.dstSubresource.mipLevel = 0, blit.dstSubresource.baseArrayLayer = 0,
blit.dstSubresource.layerCount = 1,

vkCmdBlitImage(
handle_,
src.handle(),
src.layout(),
dst.handle(),
dst.layout(),
1,
&blit,
VK_FILTER_NEAREST);

state_ = CommandBuffer::State::RECORDING;
}

void CommandBuffer::write_timestamp(VkQueryPool querypool, const uint32_t idx)
const {
VK_CHECK_COND(
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/runtime/vk_api/Command.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class CommandBuffer final {

void insert_barrier(PipelineBarrier& pipeline_barrier);
void dispatch(const utils::uvec3&);
void blit(vkapi::VulkanImage& src, vkapi::VulkanImage& dst);

void write_timestamp(VkQueryPool, const uint32_t) const;
void reset_querypool(VkQueryPool, const uint32_t, const uint32_t) const;
Expand Down

0 comments on commit d82f981

Please sign in to comment.