Skip to content

Commit

Permalink
Implemented TLAS (for now, only for instances)
Browse files Browse the repository at this point in the history
  • Loading branch information
corporateshark committed Sep 8, 2024
1 parent 9cb6e86 commit 6f84591
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 2 deletions.
108 changes: 106 additions & 2 deletions lvk/vulkan/VulkanClasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3307,15 +3307,15 @@ lvk::Holder<lvk::AccelStructHandle> lvk::VulkanContext::createAccelerationStruct
handle = createBLAS(desc, &result);
break;
case AccelStructType_TLAS:
LVK_ASSERT_MSG(false, "Not implemented (yet)");
handle = createTLAS(desc, &result);
break;
default:
LVK_ASSERT_MSG(false, "Invalid acceleration structure type");
Result::setResult(outResult, Result(Result::Code::ArgumentOutOfRange, "Invalid acceleration structure type"));
return {};
}

if (!LVK_VERIFY(result.isOk())) {
if (!LVK_VERIFY(result.isOk() && handle.valid())) {
Result::setResult(outResult, Result(Result::Code::RuntimeError, "Cannot create AccelerationStructure"));
return {};
}
Expand Down Expand Up @@ -3754,6 +3754,110 @@ lvk::AccelStructHandle lvk::VulkanContext::createBLAS(const AccelStructDesc& des
return accelStructuresPool_.create(std::move(accelStruct));
}

lvk::AccelStructHandle lvk::VulkanContext::createTLAS(const AccelStructDesc& desc, Result* outResult) {
LVK_ASSERT(desc.type == AccelStructType_TLAS);
LVK_ASSERT(desc.geometryType == AccelStructGeomType_Instances);
LVK_ASSERT(desc.numVertices == 0);
LVK_ASSERT(desc.instancesBuffer.valid());
LVK_ASSERT(desc.buildRange.primitiveCount);

const VkAccelerationStructureGeometryKHR accelerationStructureGeometry{
.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_KHR,
.geometryType = VK_GEOMETRY_TYPE_INSTANCES_KHR,
.geometry =
{
.instances =
{
.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_INSTANCES_DATA_KHR,
.arrayOfPointers = VK_FALSE,
.data = {.deviceAddress = gpuAddress(desc.instancesBuffer)},
},
},
.flags = VK_GEOMETRY_OPAQUE_BIT_KHR,
};

const VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfo = {
.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR,
.type = VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR,
.flags = buildFlagsToVkBuildAccelerationStructureFlags(desc.buildFlags),
.geometryCount = 1,
.pGeometries = &accelerationStructureGeometry,
};

VkAccelerationStructureBuildSizesInfoKHR accelerationStructureBuildSizesInfo{
.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR,
};
vkGetAccelerationStructureBuildSizesKHR(vkDevice_,
VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR,
&accelerationStructureBuildGeometryInfo,
&desc.buildRange.primitiveCount,
&accelerationStructureBuildSizesInfo);

char debugNameBuffer[256] = {0};
if (desc.debugName) {
snprintf(debugNameBuffer, sizeof(debugNameBuffer) - 1, "Buffer: %s", desc.debugName);
}
lvk::AccelerationStructure accelStruct = {
.buffer = createBuffer(
{
.usage = lvk::BufferUsageBits_AccelStructStorage,
.storage = lvk::StorageType_Device,
.size = accelerationStructureBuildSizesInfo.accelerationStructureSize,
.debugName = debugNameBuffer,
},
outResult),
};

const VkAccelerationStructureCreateInfoKHR ciAccelerationStructure = {
.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR,
.buffer = getVkBuffer(this, accelStruct.buffer),
.size = accelerationStructureBuildSizesInfo.accelerationStructureSize,
.type = VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR,
};
vkCreateAccelerationStructureKHR(vkDevice_, &ciAccelerationStructure, nullptr, &accelStruct.vkHandle);

lvk::Holder<lvk::BufferHandle> scratchBuffer = createBuffer(
{
.usage = lvk::BufferUsageBits_Storage,
.storage = lvk::StorageType_Device,
.size = accelerationStructureBuildSizesInfo.buildScratchSize,
.debugName = "Buffer: TLAS scratch",
},
outResult);

const VkAccelerationStructureBuildGeometryInfoKHR accelerationBuildGeometryInfo = {
.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR,
.type = VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR,
.flags = buildFlagsToVkBuildAccelerationStructureFlags(desc.buildFlags),
.mode = VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR,
.dstAccelerationStructure = accelStruct.vkHandle,
.geometryCount = 1,
.pGeometries = &accelerationStructureGeometry,
.scratchData = {.deviceAddress = gpuAddress(scratchBuffer)},
};

const VkAccelerationStructureBuildRangeInfoKHR accelerationStructureBuildRangeInfo{
.primitiveCount = desc.buildRange.primitiveCount,
.primitiveOffset = desc.buildRange.primitiveOffset,
.firstVertex = desc.buildRange.firstVertex,
.transformOffset = desc.buildRange.transformOffset,
};
const VkAccelerationStructureBuildRangeInfoKHR* accelerationBuildStructureRangeInfos[] = {&accelerationStructureBuildRangeInfo};

lvk::ICommandBuffer& buffer = acquireCommandBuffer();
vkCmdBuildAccelerationStructuresKHR(
lvk::getVkCommandBuffer(buffer), 1, &accelerationBuildGeometryInfo, accelerationBuildStructureRangeInfos);
wait(submit(buffer, {}));

const VkAccelerationStructureDeviceAddressInfoKHR accelerationDeviceAddressInfo = {
.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_DEVICE_ADDRESS_INFO_KHR,
.accelerationStructure = accelStruct.vkHandle,
};
accelStruct.deviceAddress = vkGetAccelerationStructureDeviceAddressKHR(vkDevice_, &accelerationDeviceAddressInfo);

return accelStructuresPool_.create(std::move(accelStruct));
}

static_assert(1 << (sizeof(lvk::Format) * 8) <= LVK_ARRAY_NUM_ELEMENTS(lvk::VulkanContextImpl::ycbcrConversionData_),
"There aren't enough elements in `ycbcrConversionData_` to be accessed by lvk::Format");

Expand Down
1 change: 1 addition & 0 deletions lvk/vulkan/VulkanClasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ class VulkanContext final : public IContext {
lvk::Format yuvFormat = Format_Invalid,
const char* debugName = nullptr);
AccelStructHandle createBLAS(const AccelStructDesc& desc, Result* outResult);
AccelStructHandle createTLAS(const AccelStructDesc& desc, Result* outResult);

bool hasSwapchain() const noexcept {
return swapchain_ != nullptr;
Expand Down

0 comments on commit 6f84591

Please sign in to comment.