Mesa (main): radv: Build accaleration structures using LBVH

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Sun Apr 24 15:23:39 UTC 2022


Module: Mesa
Branch: main
Commit: be57b085be6d840c27a2d69a7df4b1261ad0a0f5
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=be57b085be6d840c27a2d69a7df4b1261ad0a0f5

Author: Konstantin Seurer <konstantin.seurer at gmail.com>
Date:   Wed Mar 30 14:09:36 2022 +0200

radv: Build accaleration structures using LBVH

This sorts the leaf nodes along a morton curve before
creating the internal nodes. For reference:
https://developer.nvidia.com/blog/thinking-parallel-part-iii-tree-construction-gpu/

Ray query cts:
Test run totals:
  Passed:        22418/23426 (95.7%)
  Failed:        0/23426 (0.0%)
  Not supported: 1008/23426 (4.3%)
  Warnings:      0/23426 (0.0%)
  Waived:        0/23426 (0.0%)

Signed-off-by: Konstantin Seurer <konstantin.seurer at gmail.com>
Reviewed-by: Bas Nieuwenhuizen <bas at basnieuwenhuizen.nl>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15648>

---

 src/amd/vulkan/radv_acceleration_structure.c | 608 +++++++++++++++++++++------
 src/amd/vulkan/radv_private.h                |   7 +
 2 files changed, 481 insertions(+), 134 deletions(-)

diff --git a/src/amd/vulkan/radv_acceleration_structure.c b/src/amd/vulkan/radv_acceleration_structure.c
index aa542ae94e0..946604baca5 100644
--- a/src/amd/vulkan/radv_acceleration_structure.c
+++ b/src/amd/vulkan/radv_acceleration_structure.c
@@ -29,12 +29,48 @@
 #include "radv_cs.h"
 #include "radv_meta.h"
 
+#include "radix_sort/radv_radix_sort.h"
+
+/* Min and max bounds of the bvh used to compute morton codes */
+#define SCRATCH_TOTAL_BOUNDS_SIZE (6 * sizeof(float))
+
+enum accel_struct_build {
+   accel_struct_build_unoptimized,
+   accel_struct_build_lbvh,
+};
+
+static enum accel_struct_build
+get_accel_struct_build(const struct radv_physical_device *pdevice,
+                       VkAccelerationStructureBuildTypeKHR buildType)
+{
+   if (buildType != VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
+      return accel_struct_build_unoptimized;
+
+   return (pdevice->rad_info.chip_class < GFX10) ? accel_struct_build_unoptimized
+                                                 : accel_struct_build_lbvh;
+}
+
+static uint32_t
+get_node_id_stride(enum accel_struct_build build_mode)
+{
+   switch (build_mode) {
+   case accel_struct_build_unoptimized:
+      return 4;
+   case accel_struct_build_lbvh:
+      return 8;
+   default:
+      unreachable("Unhandled accel_struct_build!");
+   }
+}
+
 VKAPI_ATTR void VKAPI_CALL
 radv_GetAccelerationStructureBuildSizesKHR(
    VkDevice _device, VkAccelerationStructureBuildTypeKHR buildType,
    const VkAccelerationStructureBuildGeometryInfoKHR *pBuildInfo,
    const uint32_t *pMaxPrimitiveCounts, VkAccelerationStructureBuildSizesInfoKHR *pSizeInfo)
 {
+   RADV_FROM_HANDLE(radv_device, device, _device);
+
    uint64_t triangles = 0, boxes = 0, instances = 0;
 
    STATIC_ASSERT(sizeof(struct radv_bvh_triangle_node) == 64);
@@ -79,9 +115,30 @@ radv_GetAccelerationStructureBuildSizesKHR(
 
    pSizeInfo->accelerationStructureSize = size;
 
-   /* 2x the max number of nodes in a BVH layer (one uint32_t each) */
-   pSizeInfo->updateScratchSize = pSizeInfo->buildScratchSize =
-      MAX2(4096, 2 * (boxes + instances + triangles) * sizeof(uint32_t));
+   /* 2x the max number of nodes in a BVH layer and order information for sorting when using
+    * LBVH (one uint32_t each, two buffers) plus space to store the bounds.
+    * LBVH is only supported for device builds and hardware that supports global atomics.
+    */
+   enum accel_struct_build build_mode = get_accel_struct_build(device->physical_device, buildType);
+   uint32_t node_id_stride = get_node_id_stride(build_mode);
+
+   uint32_t leaf_count = boxes + instances + triangles;
+   VkDeviceSize scratchSize = 2 * leaf_count * node_id_stride;
+
+   if (build_mode == accel_struct_build_lbvh) {
+      radix_sort_vk_memory_requirements_t requirements;
+      radix_sort_vk_get_memory_requirements(device->meta_state.accel_struct_build.radix_sort,
+                                            leaf_count, &requirements);
+
+      /* Make sure we have the space required by the radix sort. */
+      scratchSize = MAX2(scratchSize, requirements.keyvals_size * 2);
+
+      scratchSize += requirements.internal_size + SCRATCH_TOTAL_BOUNDS_SIZE;
+   }
+
+   scratchSize = MAX2(4096, scratchSize);
+   pSizeInfo->updateScratchSize = scratchSize;
+   pSizeInfo->buildScratchSize = scratchSize;
 }
 
 VKAPI_ATTR VkResult VKAPI_CALL
@@ -745,6 +802,19 @@ radv_CopyAccelerationStructureKHR(VkDevice _device, VkDeferredOperationKHR defer
    return VK_SUCCESS;
 }
 
+static nir_builder
+create_accel_build_shader(struct radv_device *device, const char *name)
+{
+   nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "%s", name);
+   b.shader->info.workgroup_size[0] = 64;
+
+   assert(b.shader->info.workgroup_size[1] == 1);
+   assert(b.shader->info.workgroup_size[2] == 1);
+   assert(!b.shader->info.workgroup_size_variable);
+
+   return b;
+}
+
 static nir_ssa_def *
 get_indices(nir_builder *b, nir_ssa_def *addr, nir_ssa_def *type, nir_ssa_def *id)
 {
@@ -935,6 +1005,21 @@ struct build_primitive_constants {
    };
 };
 
+struct bounds_constants {
+   uint64_t node_addr;
+   uint64_t scratch_addr;
+};
+
+struct morton_constants {
+   uint64_t node_addr;
+   uint64_t scratch_addr;
+};
+
+struct fill_constants {
+   uint64_t addr;
+   uint32_t value;
+};
+
 struct build_internal_constants {
    uint64_t node_dst_addr;
    uint64_t scratch_addr;
@@ -972,6 +1057,29 @@ nir_invert_3x3(nir_builder *b, nir_ssa_def *in[3][3], nir_ssa_def *out[3][3])
    }
 }
 
+static nir_ssa_def *
+id_to_node_id_offset(nir_builder *b, nir_ssa_def *global_id,
+                     const struct radv_physical_device *pdevice)
+{
+   uint32_t stride = get_node_id_stride(
+      get_accel_struct_build(pdevice, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR));
+
+   return nir_umul24(b, global_id, nir_imm_int(b, stride));
+}
+
+static nir_ssa_def *
+id_to_morton_offset(nir_builder *b, nir_ssa_def *global_id,
+                    const struct radv_physical_device *pdevice)
+{
+   enum accel_struct_build build_mode =
+      get_accel_struct_build(pdevice, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR);
+   assert(build_mode == accel_struct_build_lbvh);
+
+   uint32_t stride = get_node_id_stride(build_mode);
+
+   return nir_iadd_imm(b, nir_umul24(b, global_id, nir_imm_int(b, stride)), sizeof(uint32_t));
+}
+
 static nir_shader *
 build_leaf_shader(struct radv_device *dev)
 {
@@ -1003,9 +1111,15 @@ build_leaf_shader(struct radv_device *dev)
                nir_umul24(&b, nir_channels(&b, nir_load_workgroup_id(&b, 32), 1),
                           nir_imm_int(&b, b.shader->info.workgroup_size[0])),
                nir_channels(&b, nir_load_local_invocation_id(&b), 1));
-   scratch_addr = nir_iadd(
-      &b, scratch_addr,
-      nir_u2u64(&b, nir_iadd(&b, scratch_offset, nir_umul24(&b, global_id, nir_imm_int(&b, 4)))));
+   nir_ssa_def *scratch_dst_addr =
+      nir_iadd(&b, scratch_addr,
+               nir_u2u64(&b, nir_iadd(&b, scratch_offset,
+                                      id_to_node_id_offset(&b, global_id, dev->physical_device))));
+
+   nir_variable *bounds[2] = {
+      nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "min_bound"),
+      nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "max_bound"),
+   };
 
    nir_push_if(&b, nir_ieq_imm(&b, geom_type, VK_GEOMETRY_TYPE_TRIANGLES_KHR));
    { /* Triangles */
@@ -1053,6 +1167,22 @@ build_leaf_shader(struct radv_device *dev)
          for (unsigned j = 0; j < 3; ++j)
             node_data[i * 3 + j] = nir_fdph(&b, positions[i], nir_load_var(&b, transform[j]));
 
+      nir_ssa_def *min_bound = NULL;
+      nir_ssa_def *max_bound = NULL;
+      for (unsigned i = 0; i < 3; ++i) {
+         nir_ssa_def *position = nir_vec(&b, node_data + i * 3, 3);
+         if (min_bound) {
+            min_bound = nir_fmin(&b, min_bound, position);
+            max_bound = nir_fmax(&b, max_bound, position);
+         } else {
+            min_bound = position;
+            max_bound = position;
+         }
+      }
+
+      nir_store_var(&b, bounds[0], min_bound, 7);
+      nir_store_var(&b, bounds[1], max_bound, 7);
+
       node_data[12] = global_id;
       node_data[13] = geometry_id;
       node_data[15] = nir_imm_int(&b, 9);
@@ -1066,7 +1196,7 @@ build_leaf_shader(struct radv_device *dev)
       }
 
       nir_ssa_def *node_id = nir_ushr_imm(&b, node_offset, 3);
-      nir_build_store_global(&b, node_id, scratch_addr);
+      nir_build_store_global(&b, node_id, scratch_dst_addr);
    }
    nir_push_else(&b, NULL);
    nir_push_if(&b, nir_ieq_imm(&b, geom_type, VK_GEOMETRY_TYPE_AABBS_KHR));
@@ -1077,14 +1207,18 @@ build_leaf_shader(struct radv_device *dev)
       nir_ssa_def *node_offset =
          nir_iadd(&b, node_dst_offset, nir_umul24(&b, global_id, nir_imm_int(&b, 64)));
       nir_ssa_def *aabb_node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
+
       nir_ssa_def *node_id = nir_iadd_imm(&b, nir_ushr_imm(&b, node_offset, 3), 7);
-      nir_build_store_global(&b, node_id, scratch_addr);
+      nir_build_store_global(&b, node_id, scratch_dst_addr);
 
       aabb_addr = nir_iadd(&b, aabb_addr, nir_u2u64(&b, nir_imul(&b, aabb_stride, global_id)));
 
       nir_ssa_def *min_bound = nir_build_load_global(&b, 3, 32, nir_iadd_imm(&b, aabb_addr, 0));
       nir_ssa_def *max_bound = nir_build_load_global(&b, 3, 32, nir_iadd_imm(&b, aabb_addr, 12));
 
+      nir_store_var(&b, bounds[0], min_bound, 7);
+      nir_store_var(&b, bounds[1], max_bound, 7);
+
       nir_ssa_def *values[] = {nir_channel(&b, min_bound, 0),
                                nir_channel(&b, min_bound, 1),
                                nir_channel(&b, min_bound, 2),
@@ -1130,16 +1264,9 @@ build_leaf_shader(struct radv_device *dev)
       nir_ssa_def *node_offset =
          nir_iadd(&b, node_dst_offset, nir_umul24(&b, global_id, nir_imm_int(&b, 128)));
       node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
-      nir_ssa_def *node_id = nir_iadd_imm(&b, nir_ushr_imm(&b, node_offset, 3), 6);
-      nir_build_store_global(&b, node_id, scratch_addr);
-
-      nir_variable *bounds[2] = {
-         nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "min_bound"),
-         nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "max_bound"),
-      };
 
-      nir_store_var(&b, bounds[0], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
-      nir_store_var(&b, bounds[1], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
+      nir_ssa_def *node_id = nir_iadd_imm(&b, nir_ushr_imm(&b, node_offset, 3), 6);
+      nir_build_store_global(&b, node_id, scratch_dst_addr);
 
       nir_ssa_def *header_addr = nir_pack_64_2x32(&b, nir_channels(&b, inst3, 12));
       nir_push_if(&b, nir_ine_imm(&b, header_addr, 0));
@@ -1204,6 +1331,32 @@ build_leaf_shader(struct radv_device *dev)
    nir_pop_if(&b, NULL);
    nir_pop_if(&b, NULL);
 
+   if (get_accel_struct_build(dev->physical_device,
+                              VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) !=
+       accel_struct_build_unoptimized) {
+      nir_ssa_def *min = nir_load_var(&b, bounds[0]);
+      nir_ssa_def *max = nir_load_var(&b, bounds[1]);
+
+      nir_ssa_def *min_reduced = nir_reduce(&b, min, .reduction_op = nir_op_fmin);
+      nir_ssa_def *max_reduced = nir_reduce(&b, max, .reduction_op = nir_op_fmax);
+
+      nir_push_if(&b, nir_elect(&b, 1));
+
+      nir_global_atomic_fmin(&b, 32, nir_isub(&b, scratch_addr, nir_imm_int64(&b, 24)),
+                             nir_channel(&b, min_reduced, 0));
+      nir_global_atomic_fmin(&b, 32, nir_isub(&b, scratch_addr, nir_imm_int64(&b, 20)),
+                             nir_channel(&b, min_reduced, 1));
+      nir_global_atomic_fmin(&b, 32, nir_isub(&b, scratch_addr, nir_imm_int64(&b, 16)),
+                             nir_channel(&b, min_reduced, 2));
+
+      nir_global_atomic_fmax(&b, 32, nir_isub(&b, scratch_addr, nir_imm_int64(&b, 12)),
+                             nir_channel(&b, max_reduced, 0));
+      nir_global_atomic_fmax(&b, 32, nir_isub(&b, scratch_addr, nir_imm_int64(&b, 8)),
+                             nir_channel(&b, max_reduced, 1));
+      nir_global_atomic_fmax(&b, 32, nir_isub(&b, scratch_addr, nir_imm_int64(&b, 4)),
+                             nir_channel(&b, max_reduced, 2));
+   }
+
    return b.shader;
 }
 
@@ -1267,6 +1420,89 @@ determine_bounds(nir_builder *b, nir_ssa_def *node_addr, nir_ssa_def *node_id,
    nir_pop_if(b, NULL);
 }
 
+/* https://developer.nvidia.com/blog/thinking-parallel-part-iii-tree-construction-gpu/ */
+static nir_ssa_def *
+build_morton_component(nir_builder *b, nir_ssa_def *x)
+{
+   x = nir_iand_imm(b, nir_imul_imm(b, x, 0x00000101u), 0x0F00F00Fu);
+   x = nir_iand_imm(b, nir_imul_imm(b, x, 0x00000011u), 0xC30C30C3u);
+   x = nir_iand_imm(b, nir_imul_imm(b, x, 0x00000005u), 0x49249249u);
+   return x;
+}
+
+static nir_shader *
+build_morton_shader(struct radv_device *dev)
+{
+   const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
+
+   nir_builder b = create_accel_build_shader(dev, "accel_build_morton_shader");
+
+   /*
+    * push constants:
+    *   i32 x 2: node address
+    *   i32 x 2: scratch address
+    */
+   nir_ssa_def *pconst0 =
+      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
+
+   nir_ssa_def *node_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b0011));
+   nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b1100));
+
+   nir_ssa_def *global_id =
+      nir_iadd(&b,
+               nir_imul_imm(&b, nir_channel(&b, nir_load_workgroup_id(&b, 32), 0),
+                            b.shader->info.workgroup_size[0]),
+               nir_load_local_invocation_index(&b));
+
+   nir_ssa_def *node_id_addr = nir_iadd(
+      &b, scratch_addr, nir_u2u64(&b, id_to_node_id_offset(&b, global_id, dev->physical_device)));
+   nir_ssa_def *node_id =
+      nir_build_load_global(&b, 1, 32, node_id_addr, .align_mul = 4, .align_offset = 0);
+
+   nir_variable *node_bounds[2] = {
+      nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "min_bound"),
+      nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "max_bound"),
+   };
+
+   determine_bounds(&b, node_addr, node_id, node_bounds);
+
+   nir_ssa_def *node_min = nir_load_var(&b, node_bounds[0]);
+   nir_ssa_def *node_max = nir_load_var(&b, node_bounds[1]);
+   nir_ssa_def *node_pos =
+      nir_fmul(&b, nir_fadd(&b, node_min, node_max), nir_imm_vec3(&b, 0.5, 0.5, 0.5));
+
+   nir_ssa_def *bvh_min =
+      nir_build_load_global(&b, 3, 32, nir_isub(&b, scratch_addr, nir_imm_int64(&b, 24)),
+                            .align_mul = 4, .align_offset = 0);
+   nir_ssa_def *bvh_max =
+      nir_build_load_global(&b, 3, 32, nir_isub(&b, scratch_addr, nir_imm_int64(&b, 12)),
+                            .align_mul = 4, .align_offset = 0);
+   nir_ssa_def *bvh_size = nir_fsub(&b, bvh_max, bvh_min);
+
+   nir_ssa_def *normalized_node_pos = nir_fdiv(&b, nir_fsub(&b, node_pos, bvh_min), bvh_size);
+
+   nir_ssa_def *x_int =
+      nir_f2u32(&b, nir_fmul_imm(&b, nir_channel(&b, normalized_node_pos, 0), 255.0));
+   nir_ssa_def *x_morton = build_morton_component(&b, x_int);
+
+   nir_ssa_def *y_int =
+      nir_f2u32(&b, nir_fmul_imm(&b, nir_channel(&b, normalized_node_pos, 1), 255.0));
+   nir_ssa_def *y_morton = build_morton_component(&b, y_int);
+
+   nir_ssa_def *z_int =
+      nir_f2u32(&b, nir_fmul_imm(&b, nir_channel(&b, normalized_node_pos, 2), 255.0));
+   nir_ssa_def *z_morton = build_morton_component(&b, z_int);
+
+   nir_ssa_def *morton_code = nir_iadd(
+      &b, nir_iadd(&b, nir_ishl_imm(&b, x_morton, 2), nir_ishl_imm(&b, y_morton, 1)), z_morton);
+
+   nir_ssa_def *dst_addr = nir_iadd(
+      &b, scratch_addr, nir_u2u64(&b, id_to_morton_offset(&b, global_id, dev->physical_device)));
+   nir_build_store_global(&b, morton_code, dst_addr, .align_mul = 4);
+
+   return b.shader;
+}
+
 static nir_shader *
 build_internal_shader(struct radv_device *dev)
 {
@@ -1308,12 +1544,22 @@ build_internal_shader(struct radv_device *dev)
 
    nir_ssa_def *node_offset = nir_iadd(&b, node_dst_offset, nir_ishl_imm(&b, global_id, 7));
    nir_ssa_def *node_dst_addr = nir_iadd(&b, node_addr, nir_u2u64(&b, node_offset));
-   nir_ssa_def *src_nodes = nir_build_load_global(
-      &b, 4, 32,
-      nir_iadd(&b, scratch_addr,
-               nir_u2u64(&b, nir_iadd(&b, src_scratch_offset, nir_ishl_imm(&b, global_id, 4)))));
 
-   nir_build_store_global(&b, src_nodes, nir_iadd_imm(&b, node_dst_addr, 0));
+   nir_ssa_def *src_base_addr =
+      nir_iadd(&b, scratch_addr,
+               nir_u2u64(&b, nir_iadd(&b, src_scratch_offset,
+                                      id_to_node_id_offset(&b, src_idx, dev->physical_device))));
+
+   enum accel_struct_build build_mode =
+      get_accel_struct_build(dev->physical_device, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR);
+   uint32_t node_id_stride = get_node_id_stride(build_mode);
+
+   nir_ssa_def *src_nodes[4];
+   for (uint32_t i = 0; i < 4; i++) {
+      src_nodes[i] =
+         nir_build_load_global(&b, 1, 32, nir_iadd_imm(&b, src_base_addr, i * node_id_stride));
+      nir_build_store_global(&b, src_nodes[i], nir_iadd_imm(&b, node_dst_addr, i * 4));
+   }
 
    nir_ssa_def *total_bounds[2] = {
       nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7),
@@ -1329,7 +1575,7 @@ build_internal_shader(struct radv_device *dev)
       nir_store_var(&b, bounds[1], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
 
       nir_push_if(&b, nir_ilt(&b, nir_imm_int(&b, i), src_count));
-      determine_bounds(&b, node_addr, nir_channel(&b, src_nodes, i), bounds);
+      determine_bounds(&b, node_addr, src_nodes[i], bounds);
       nir_pop_if(&b, NULL);
       nir_build_store_global(&b, nir_load_var(&b, bounds[0]),
                              nir_iadd_imm(&b, node_dst_addr, 16 + 24 * i));
@@ -1342,7 +1588,8 @@ build_internal_shader(struct radv_device *dev)
    nir_ssa_def *node_id = nir_iadd_imm(&b, nir_ushr_imm(&b, node_offset, 3), 5);
    nir_ssa_def *dst_scratch_addr =
       nir_iadd(&b, scratch_addr,
-               nir_u2u64(&b, nir_iadd(&b, dst_scratch_offset, nir_ishl_imm(&b, global_id, 2))));
+               nir_u2u64(&b, nir_iadd(&b, dst_scratch_offset,
+                                      id_to_node_id_offset(&b, global_id, dev->physical_device))));
    nir_build_store_global(&b, node_id, dst_scratch_addr);
 
    nir_push_if(&b, fill_header);
@@ -1586,138 +1833,133 @@ radv_device_finish_accel_struct_build_state(struct radv_device *device)
                         &state->alloc);
    radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.leaf_pipeline,
                         &state->alloc);
+   radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.morton_pipeline,
+                        &state->alloc);
    radv_DestroyPipelineLayout(radv_device_to_handle(device),
                               state->accel_struct_build.copy_p_layout, &state->alloc);
    radv_DestroyPipelineLayout(radv_device_to_handle(device),
                               state->accel_struct_build.internal_p_layout, &state->alloc);
    radv_DestroyPipelineLayout(radv_device_to_handle(device),
                               state->accel_struct_build.leaf_p_layout, &state->alloc);
+   radv_DestroyPipelineLayout(radv_device_to_handle(device),
+                              state->accel_struct_build.morton_p_layout, &state->alloc);
+
+   if (state->accel_struct_build.radix_sort)
+      radix_sort_vk_destroy(state->accel_struct_build.radix_sort, radv_device_to_handle(device),
+                            &state->alloc);
 }
 
-VkResult
-radv_device_init_accel_struct_build_state(struct radv_device *device)
+static VkResult
+create_build_pipeline(struct radv_device *device, nir_shader *shader, unsigned push_constant_size,
+                      VkPipeline *pipeline, VkPipelineLayout *layout)
 {
-   VkResult result;
-   nir_shader *leaf_cs = build_leaf_shader(device);
-   nir_shader *internal_cs = build_internal_shader(device);
-   nir_shader *copy_cs = build_copy_shader(device);
-
-   const VkPipelineLayoutCreateInfo leaf_pl_create_info = {
+   const VkPipelineLayoutCreateInfo pl_create_info = {
       .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
       .setLayoutCount = 0,
       .pushConstantRangeCount = 1,
-      .pPushConstantRanges = &(VkPushConstantRange){VK_SHADER_STAGE_COMPUTE_BIT, 0,
-                                                    sizeof(struct build_primitive_constants)},
+      .pPushConstantRanges =
+         &(VkPushConstantRange){VK_SHADER_STAGE_COMPUTE_BIT, 0, push_constant_size},
    };
 
-   result = radv_CreatePipelineLayout(radv_device_to_handle(device), &leaf_pl_create_info,
-                                      &device->meta_state.alloc,
-                                      &device->meta_state.accel_struct_build.leaf_p_layout);
-   if (result != VK_SUCCESS)
-      goto fail;
+   VkResult result = radv_CreatePipelineLayout(radv_device_to_handle(device), &pl_create_info,
+                                               &device->meta_state.alloc, layout);
+   if (result != VK_SUCCESS) {
+      radv_device_finish_accel_struct_build_state(device);
+      ralloc_free(shader);
+      return result;
+   }
 
-   VkPipelineShaderStageCreateInfo leaf_shader_stage = {
+   VkPipelineShaderStageCreateInfo shader_stage = {
       .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
       .stage = VK_SHADER_STAGE_COMPUTE_BIT,
-      .module = vk_shader_module_handle_from_nir(leaf_cs),
+      .module = vk_shader_module_handle_from_nir(shader),
       .pName = "main",
       .pSpecializationInfo = NULL,
    };
 
-   VkComputePipelineCreateInfo leaf_pipeline_info = {
+   VkComputePipelineCreateInfo pipeline_info = {
       .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
-      .stage = leaf_shader_stage,
+      .stage = shader_stage,
       .flags = 0,
-      .layout = device->meta_state.accel_struct_build.leaf_p_layout,
+      .layout = *layout,
    };
 
-   result = radv_CreateComputePipelines(
-      radv_device_to_handle(device), radv_pipeline_cache_to_handle(&device->meta_state.cache), 1,
-      &leaf_pipeline_info, NULL, &device->meta_state.accel_struct_build.leaf_pipeline);
-   if (result != VK_SUCCESS)
-      goto fail;
+   result = radv_CreateComputePipelines(radv_device_to_handle(device),
+                                        radv_pipeline_cache_to_handle(&device->meta_state.cache), 1,
+                                        &pipeline_info, &device->meta_state.alloc, pipeline);
 
-   const VkPipelineLayoutCreateInfo internal_pl_create_info = {
-      .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
-      .setLayoutCount = 0,
-      .pushConstantRangeCount = 1,
-      .pPushConstantRanges = &(VkPushConstantRange){VK_SHADER_STAGE_COMPUTE_BIT, 0,
-                                                    sizeof(struct build_internal_constants)},
-   };
+   if (result != VK_SUCCESS) {
+      radv_device_finish_accel_struct_build_state(device);
+      ralloc_free(shader);
+      return result;
+   }
 
-   result = radv_CreatePipelineLayout(radv_device_to_handle(device), &internal_pl_create_info,
-                                      &device->meta_state.alloc,
-                                      &device->meta_state.accel_struct_build.internal_p_layout);
-   if (result != VK_SUCCESS)
-      goto fail;
+   return VK_SUCCESS;
+}
 
-   VkPipelineShaderStageCreateInfo internal_shader_stage = {
-      .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
-      .stage = VK_SHADER_STAGE_COMPUTE_BIT,
-      .module = vk_shader_module_handle_from_nir(internal_cs),
-      .pName = "main",
-      .pSpecializationInfo = NULL,
-   };
+static void
+radix_sort_fill_buffer(VkCommandBuffer commandBuffer,
+                       radix_sort_vk_buffer_info_t const *buffer_info, VkDeviceSize offset,
+                       VkDeviceSize size, uint32_t data)
+{
+   RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
 
-   VkComputePipelineCreateInfo internal_pipeline_info = {
-      .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
-      .stage = internal_shader_stage,
-      .flags = 0,
-      .layout = device->meta_state.accel_struct_build.internal_p_layout,
-   };
+   assert(size % 4 == 0);
+   assert(size != VK_WHOLE_SIZE);
 
-   result = radv_CreateComputePipelines(
-      radv_device_to_handle(device), radv_pipeline_cache_to_handle(&device->meta_state.cache), 1,
-      &internal_pipeline_info, NULL, &device->meta_state.accel_struct_build.internal_pipeline);
-   if (result != VK_SUCCESS)
-      goto fail;
+   radv_fill_buffer_shader(cmd_buffer, buffer_info->devaddr + buffer_info->offset + offset, size,
+                           data);
+}
 
-   const VkPipelineLayoutCreateInfo copy_pl_create_info = {
-      .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
-      .setLayoutCount = 0,
-      .pushConstantRangeCount = 1,
-      .pPushConstantRanges =
-         &(VkPushConstantRange){VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(struct copy_constants)},
-   };
+VkResult
+radv_device_init_accel_struct_build_state(struct radv_device *device)
+{
+   VkResult result;
+   nir_shader *leaf_cs = build_leaf_shader(device);
+   nir_shader *internal_cs = build_internal_shader(device);
+   nir_shader *copy_cs = build_copy_shader(device);
 
-   result = radv_CreatePipelineLayout(radv_device_to_handle(device), &copy_pl_create_info,
-                                      &device->meta_state.alloc,
-                                      &device->meta_state.accel_struct_build.copy_p_layout);
+   result = create_build_pipeline(device, leaf_cs, sizeof(struct build_primitive_constants),
+                                  &device->meta_state.accel_struct_build.leaf_pipeline,
+                                  &device->meta_state.accel_struct_build.leaf_p_layout);
    if (result != VK_SUCCESS)
-      goto fail;
+      return result;
 
-   VkPipelineShaderStageCreateInfo copy_shader_stage = {
-      .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
-      .stage = VK_SHADER_STAGE_COMPUTE_BIT,
-      .module = vk_shader_module_handle_from_nir(copy_cs),
-      .pName = "main",
-      .pSpecializationInfo = NULL,
-   };
+   result = create_build_pipeline(device, internal_cs, sizeof(struct build_internal_constants),
+                                  &device->meta_state.accel_struct_build.internal_pipeline,
+                                  &device->meta_state.accel_struct_build.internal_p_layout);
+   if (result != VK_SUCCESS)
+      return result;
 
-   VkComputePipelineCreateInfo copy_pipeline_info = {
-      .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
-      .stage = copy_shader_stage,
-      .flags = 0,
-      .layout = device->meta_state.accel_struct_build.copy_p_layout,
-   };
+   result = create_build_pipeline(device, copy_cs, sizeof(struct copy_constants),
+                                  &device->meta_state.accel_struct_build.copy_pipeline,
+                                  &device->meta_state.accel_struct_build.copy_p_layout);
 
-   result = radv_CreateComputePipelines(
-      radv_device_to_handle(device), radv_pipeline_cache_to_handle(&device->meta_state.cache), 1,
-      &copy_pipeline_info, NULL, &device->meta_state.accel_struct_build.copy_pipeline);
    if (result != VK_SUCCESS)
-      goto fail;
+      return result;
 
-   ralloc_free(copy_cs);
-   ralloc_free(internal_cs);
-   ralloc_free(leaf_cs);
+   if (get_accel_struct_build(device->physical_device,
+                              VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ==
+       accel_struct_build_lbvh) {
+      nir_shader *morton_cs = build_morton_shader(device);
 
-   return VK_SUCCESS;
+      result = create_build_pipeline(device, morton_cs, sizeof(struct morton_constants),
+                                     &device->meta_state.accel_struct_build.morton_pipeline,
+                                     &device->meta_state.accel_struct_build.morton_p_layout);
+      if (result != VK_SUCCESS)
+         return result;
+
+      device->meta_state.accel_struct_build.radix_sort =
+         radv_create_radix_sort_u64(radv_device_to_handle(device), &device->meta_state.alloc,
+                                    radv_pipeline_cache_to_handle(&device->meta_state.cache));
+
+      struct radix_sort_vk_sort_devaddr_info *radix_sort_info =
+         &device->meta_state.accel_struct_build.radix_sort_info;
+      radix_sort_info->ext = NULL;
+      radix_sort_info->key_bits = 24;
+      radix_sort_info->fill_buffer = radix_sort_fill_buffer;
+   }
 
-fail:
-   radv_device_finish_accel_struct_build_state(device);
-   ralloc_free(copy_cs);
-   ralloc_free(internal_cs);
-   ralloc_free(leaf_cs);
    return result;
 }
 
@@ -1725,6 +1967,8 @@ struct bvh_state {
    uint32_t node_offset;
    uint32_t node_count;
    uint32_t scratch_offset;
+   uint32_t buffer_1_offset;
+   uint32_t buffer_2_offset;
 
    uint32_t instance_offset;
    uint32_t instance_count;
@@ -1739,12 +1983,35 @@ radv_CmdBuildAccelerationStructuresKHR(
    RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
    struct radv_meta_saved_state saved_state;
 
+   enum radv_cmd_flush_bits flush_bits =
+      RADV_CMD_FLAG_CS_PARTIAL_FLUSH |
+      radv_src_access_flush(cmd_buffer, VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT,
+                            NULL) |
+      radv_dst_access_flush(cmd_buffer, VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT,
+                            NULL);
+
+   enum accel_struct_build build_mode = get_accel_struct_build(
+      cmd_buffer->device->physical_device, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR);
+   uint32_t node_id_stride = get_node_id_stride(build_mode);
+   uint32_t scratch_offset =
+      (build_mode != accel_struct_build_unoptimized) ? SCRATCH_TOTAL_BOUNDS_SIZE : 0;
+
    radv_meta_save(
       &saved_state, cmd_buffer,
       RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
    struct bvh_state *bvh_states = calloc(infoCount, sizeof(struct bvh_state));
 
-   radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
+   if (build_mode != accel_struct_build_unoptimized) {
+      for (uint32_t i = 0; i < infoCount; ++i) {
+         /* Clear the bvh bounds with nan. */
+         radv_fill_buffer_shader(cmd_buffer, pInfos[i].scratchData.deviceAddress, 6 * sizeof(float),
+                                 0x7FC00000);
+      }
+
+      cmd_buffer->state.flush_bits |= flush_bits;
+   }
+
+   radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
                         cmd_buffer->device->meta_state.accel_struct_build.leaf_pipeline);
 
    for (uint32_t i = 0; i < infoCount; ++i) {
@@ -1753,7 +2020,7 @@ radv_CmdBuildAccelerationStructuresKHR(
 
       struct build_primitive_constants prim_consts = {
          .node_dst_addr = radv_accel_struct_get_va(accel_struct),
-         .scratch_addr = pInfos[i].scratchData.deviceAddress,
+         .scratch_addr = pInfos[i].scratchData.deviceAddress + scratch_offset,
          .dst_offset = ALIGN(sizeof(struct radv_accel_struct_header), 64) + 128,
          .dst_scratch_offset = 0,
       };
@@ -1805,20 +2072,91 @@ radv_CmdBuildAccelerationStructuresKHR(
                unreachable("Unknown geometryType");
             }
 
-            radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
-                                  cmd_buffer->device->meta_state.accel_struct_build.leaf_p_layout,
-                                  VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(prim_consts),
-                                  &prim_consts);
+            radv_CmdPushConstants(
+               commandBuffer, cmd_buffer->device->meta_state.accel_struct_build.leaf_p_layout,
+               VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(prim_consts), &prim_consts);
             radv_unaligned_dispatch(cmd_buffer, ppBuildRangeInfos[i][j].primitiveCount, 1, 1);
             prim_consts.dst_offset += prim_size * ppBuildRangeInfos[i][j].primitiveCount;
-            prim_consts.dst_scratch_offset += 4 * ppBuildRangeInfos[i][j].primitiveCount;
+            prim_consts.dst_scratch_offset +=
+               node_id_stride * ppBuildRangeInfos[i][j].primitiveCount;
          }
       }
       bvh_states[i].node_offset = prim_consts.dst_offset;
-      bvh_states[i].node_count = prim_consts.dst_scratch_offset / 4;
+      bvh_states[i].node_count = prim_consts.dst_scratch_offset / node_id_stride;
    }
 
-   radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
+   if (build_mode == accel_struct_build_lbvh) {
+      cmd_buffer->state.flush_bits |= flush_bits;
+
+      radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
+                           cmd_buffer->device->meta_state.accel_struct_build.morton_pipeline);
+
+      for (uint32_t i = 0; i < infoCount; ++i) {
+         RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
+                          pInfos[i].dstAccelerationStructure);
+
+         const struct morton_constants consts = {
+            .node_addr = radv_accel_struct_get_va(accel_struct),
+            .scratch_addr = pInfos[i].scratchData.deviceAddress + SCRATCH_TOTAL_BOUNDS_SIZE,
+         };
+
+         radv_CmdPushConstants(commandBuffer,
+                               cmd_buffer->device->meta_state.accel_struct_build.morton_p_layout,
+                               VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
+         radv_unaligned_dispatch(cmd_buffer, bvh_states[i].node_count, 1, 1);
+      }
+
+      cmd_buffer->state.flush_bits |= flush_bits;
+
+      for (uint32_t i = 0; i < infoCount; ++i) {
+         struct radix_sort_vk_memory_requirements requirements;
+         radix_sort_vk_get_memory_requirements(
+            cmd_buffer->device->meta_state.accel_struct_build.radix_sort, bvh_states[i].node_count,
+            &requirements);
+
+         struct radix_sort_vk_sort_devaddr_info info =
+            cmd_buffer->device->meta_state.accel_struct_build.radix_sort_info;
+         info.count = bvh_states[i].node_count;
+
+         VkDeviceAddress base_addr =
+            pInfos[i].scratchData.deviceAddress + SCRATCH_TOTAL_BOUNDS_SIZE;
+
+         info.keyvals_even.buffer = VK_NULL_HANDLE;
+         info.keyvals_even.offset = 0;
+         info.keyvals_even.devaddr = base_addr;
+
+         info.keyvals_odd = base_addr + requirements.keyvals_size;
+
+         info.internal.buffer = VK_NULL_HANDLE;
+         info.internal.offset = 0;
+         info.internal.devaddr = base_addr + requirements.keyvals_size * 2;
+
+         VkDeviceAddress result_addr;
+         radix_sort_vk_sort_devaddr(cmd_buffer->device->meta_state.accel_struct_build.radix_sort,
+                                    &info, radv_device_to_handle(cmd_buffer->device), commandBuffer,
+                                    &result_addr);
+
+         assert(result_addr == info.keyvals_even.devaddr || result_addr == info.keyvals_odd);
+
+         if (result_addr == info.keyvals_even.devaddr) {
+            bvh_states[i].buffer_1_offset = SCRATCH_TOTAL_BOUNDS_SIZE;
+            bvh_states[i].buffer_2_offset = SCRATCH_TOTAL_BOUNDS_SIZE + requirements.keyvals_size;
+         } else {
+            bvh_states[i].buffer_1_offset = SCRATCH_TOTAL_BOUNDS_SIZE + requirements.keyvals_size;
+            bvh_states[i].buffer_2_offset = SCRATCH_TOTAL_BOUNDS_SIZE;
+         }
+         bvh_states[i].scratch_offset = bvh_states[i].buffer_1_offset;
+      }
+
+      cmd_buffer->state.flush_bits |= flush_bits;
+   } else {
+      for (uint32_t i = 0; i < infoCount; ++i) {
+         bvh_states[i].buffer_1_offset = 0;
+         bvh_states[i].buffer_2_offset = bvh_states[i].node_count * 4;
+      }
+   }
+
+   radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
                         cmd_buffer->device->meta_state.accel_struct_build.internal_pipeline);
    bool progress = true;
    for (unsigned iter = 0; progress; ++iter) {
@@ -1830,18 +2168,20 @@ radv_CmdBuildAccelerationStructuresKHR(
          if (iter && bvh_states[i].node_count == 1)
             continue;
 
-         if (!progress) {
-            cmd_buffer->state.flush_bits |=
-               RADV_CMD_FLAG_CS_PARTIAL_FLUSH |
-               radv_src_access_flush(cmd_buffer, VK_ACCESS_2_SHADER_WRITE_BIT, NULL) |
-               radv_dst_access_flush(cmd_buffer,
-                                     VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT, NULL);
-         }
+         if (!progress)
+            cmd_buffer->state.flush_bits |= flush_bits;
+
          progress = true;
+
          uint32_t dst_node_count = MAX2(1, DIV_ROUND_UP(bvh_states[i].node_count, 4));
          bool final_iter = dst_node_count == 1;
+
          uint32_t src_scratch_offset = bvh_states[i].scratch_offset;
-         uint32_t dst_scratch_offset = src_scratch_offset ? 0 : bvh_states[i].node_count * 4;
+         uint32_t buffer_1_offset = bvh_states[i].buffer_1_offset;
+         uint32_t buffer_2_offset = bvh_states[i].buffer_2_offset;
+         uint32_t dst_scratch_offset =
+            (src_scratch_offset == buffer_1_offset) ? buffer_2_offset : buffer_1_offset;
+
          uint32_t dst_node_offset = bvh_states[i].node_offset;
          if (final_iter)
             dst_node_offset = ALIGN(sizeof(struct radv_accel_struct_header), 64);
@@ -1855,7 +2195,7 @@ radv_CmdBuildAccelerationStructuresKHR(
             .fill_header = bvh_states[i].node_count | (final_iter ? 0x80000000U : 0),
          };
 
-         radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
+         radv_CmdPushConstants(commandBuffer,
                                cmd_buffer->device->meta_state.accel_struct_build.internal_p_layout,
                                VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
          radv_unaligned_dispatch(cmd_buffer, dst_node_count, 1, 1);
diff --git a/src/amd/vulkan/radv_private.h b/src/amd/vulkan/radv_private.h
index b9a91d21150..89c66a0af86 100644
--- a/src/amd/vulkan/radv_private.h
+++ b/src/amd/vulkan/radv_private.h
@@ -82,6 +82,8 @@
 #include "radv_shader_args.h"
 #include "sid.h"
 
+#include "radix_sort/radix_sort_vk_devaddr.h"
+
 /* Pre-declarations needed for WSI entrypoints */
 struct wl_surface;
 struct wl_display;
@@ -661,10 +663,15 @@ struct radv_meta_state {
    struct {
       VkPipelineLayout leaf_p_layout;
       VkPipeline leaf_pipeline;
+      VkPipelineLayout morton_p_layout;
+      VkPipeline morton_pipeline;
       VkPipelineLayout internal_p_layout;
       VkPipeline internal_pipeline;
       VkPipelineLayout copy_p_layout;
       VkPipeline copy_pipeline;
+
+      struct radix_sort_vk *radix_sort;
+      struct radix_sort_vk_sort_devaddr_info radix_sort_info;
    } accel_struct_build;
 
    struct {



More information about the mesa-commit mailing list