diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index 4d2a854de3b..7cb7f1bf5d9 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -145,6 +145,14 @@ void Context::register_shader_dispatch( cmd_.dispatch(effective_global_wg); } +void Context::register_blit( + vkapi::PipelineBarrier& pipeline_barrier, + vkapi::VulkanImage& src, + vkapi::VulkanImage& dst) { + cmd_.insert_barrier(pipeline_barrier); + cmd_.blit(src, dst); +} + void Context::submit_cmd_to_gpu(VkFence fence_handle, const bool final_use) { if (cmd_) { cmd_.end(); diff --git a/backends/vulkan/runtime/api/Context.h b/backends/vulkan/runtime/api/Context.h index 6681cb7cbfb..2a63fb9599c 100644 --- a/backends/vulkan/runtime/api/Context.h +++ b/backends/vulkan/runtime/api/Context.h @@ -195,6 +195,11 @@ class Context final { const vkapi::ShaderInfo&, const utils::uvec3&); + void register_blit( + vkapi::PipelineBarrier&, + vkapi::VulkanImage& src, + vkapi::VulkanImage& dst); + template bool submit_compute_job( const vkapi::ShaderInfo&, diff --git a/backends/vulkan/runtime/graph/ops/BlitNode.cpp b/backends/vulkan/runtime/graph/ops/BlitNode.cpp new file mode 100644 index 00000000000..328623eaa76 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/BlitNode.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace vkcompute { + +BlitNode::BlitNode( + ComputeGraph& graph, + ValueRef src, + ValueRef dst, + // const vkapi::ScalarType& dtype, + const ResizeFunction& resize_fn, + const std::vector& resize_args) + : ExecuteNode(resize_fn, resize_args, {}, "Blit Node"), + src_(src), + dst_(dst) { + (void)graph; + } + +void BlitNode::encode(ComputeGraph* graph) { + auto src_tensor = graph->get_tensor(src_); + auto dst_tensor = graph->get_tensor(dst_); + VK_CHECK_COND( + src_tensor->storage_type() != utils::kBuffer && + dst_tensor->storage_type() != utils::kBuffer, + "BlitNode: Only texture backed tensors are supported."); + + api::Context* const context = graph->context(); + vkapi::PipelineBarrier pipeline_barrier{}; + + std::unique_lock cmd_lock = context->dispatch_lock(); + + // Hack to get timing data for non shader op + std::string kernel_name("Blit_"); + kernel_name.reserve(32); + kernel_name += vkapi::to_string(src_tensor->dtype()); + kernel_name += "_to_"; + kernel_name += vkapi::to_string(dst_tensor->dtype()); + + context->report_shader_dispatch_start( + kernel_name, utils::uvec3(), utils::uvec3(), node_id_); + + context->register_blit( + pipeline_barrier, + src_tensor->image( + pipeline_barrier, + vkapi::PipelineStage::TRANSFER, + vkapi::kRead), + dst_tensor->image( + pipeline_barrier, + vkapi::PipelineStage::TRANSFER, + vkapi::kWrite)); + + context->report_shader_dispatch_end(); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/BlitNode.h b/backends/vulkan/runtime/graph/ops/BlitNode.h new file mode 100644 index 00000000000..a50f84bea6d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/BlitNode.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +#include + +namespace vkcompute { + +/* + * Represents a tensor blit execution op in a ML model. + */ +class BlitNode final : public ExecuteNode { + friend class ComputeGraph; + + public: + explicit BlitNode( + ComputeGraph& graph, + ValueRef src, + ValueRef dst, + /*const vkapi::ScalarType& dtype,*/ + const ResizeFunction& resize_fn = nullptr, + const std::vector& resize_args = {}); + + ~BlitNode() = default; + + void encode(ComputeGraph* graph) override; + + protected: + ValueRef src_; + ValueRef dst_; + // const vkapi::ScalarType &dtype_; +}; + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/vk_api/Command.cpp b/backends/vulkan/runtime/vk_api/Command.cpp index f971a8f8358..408103cd5d9 100644 --- a/backends/vulkan/runtime/vk_api/Command.cpp +++ b/backends/vulkan/runtime/vk_api/Command.cpp @@ -179,6 +179,45 @@ void CommandBuffer::dispatch(const utils::uvec3& global_workgroup_size) { state_ = CommandBuffer::State::RECORDING; } +void CommandBuffer::blit(vkapi::VulkanImage& src, vkapi::VulkanImage& dst) { + VK_CHECK_COND( + state_ == CommandBuffer::State::BARRIERS_INSERTED, + "Vulkan CommandBuffer: called blit() on a command buffer whose state " + "is not BARRIERS_INSERTED."); + + auto src_extents = src.extents(); + auto dst_extents = dst.extents(); + + VkImageBlit blit{}; + blit.srcOffsets[0] = {0, 0, 0}, + blit.srcOffsets[1] = + {static_cast(src_extents.width), + static_cast(src_extents.height), + static_cast(src_extents.depth)}, + blit.srcSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT, + blit.srcSubresource.mipLevel = 0, blit.srcSubresource.baseArrayLayer = 0, + blit.srcSubresource.layerCount = 1, blit.dstOffsets[0] = {0, 0, 0}, + blit.dstOffsets[1] = + {static_cast(dst_extents.width), + static_cast(dst_extents.height), + static_cast(dst_extents.depth)}, + blit.dstSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT, + blit.dstSubresource.mipLevel = 0, blit.dstSubresource.baseArrayLayer = 0, + blit.dstSubresource.layerCount = 1, + + vkCmdBlitImage( + handle_, + src.handle(), + src.layout(), + dst.handle(), + dst.layout(), + 1, + &blit, + VK_FILTER_NEAREST); + + state_ = CommandBuffer::State::RECORDING; +} + void CommandBuffer::write_timestamp(VkQueryPool querypool, const uint32_t idx) const { VK_CHECK_COND( diff --git a/backends/vulkan/runtime/vk_api/Command.h b/backends/vulkan/runtime/vk_api/Command.h index e78d410aec4..56b9940eb1e 100644 --- a/backends/vulkan/runtime/vk_api/Command.h +++ b/backends/vulkan/runtime/vk_api/Command.h @@ -92,6 +92,7 @@ class CommandBuffer final { void insert_barrier(PipelineBarrier& pipeline_barrier); void dispatch(const utils::uvec3&); + void blit(vkapi::VulkanImage& src, vkapi::VulkanImage& dst); void write_timestamp(VkQueryPool, const uint32_t) const; void reset_querypool(VkQueryPool, const uint32_t, const uint32_t) const;