Mesa (main): ac/nir/ngg: Clean up mesh shader output LDS layout.

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Jun 8 09:12:36 UTC 2022


Module: Mesa
Branch: main
Commit: 304a0e948b49161abfd9badbf0783f099258d6f1
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=304a0e948b49161abfd9badbf0783f099258d6f1

Author: Timur Kristóf <timur.kristof at gmail.com>
Date:   Fri May 20 15:27:03 2022 +0200

ac/nir/ngg: Clean up mesh shader output LDS layout.

Signed-off-by: Timur Kristóf <timur.kristof at gmail.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset at gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02 at gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16737>

---

 src/amd/common/ac_nir_lower_ngg.c | 356 ++++++++++++++++++++++----------------
 1 file changed, 210 insertions(+), 146 deletions(-)

diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c
index 8ab79e4580e..d2339967c6a 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -95,17 +95,48 @@ typedef struct
    gs_output_component_info output_component_info[VARYING_SLOT_MAX][4];
 } lower_ngg_gs_state;
 
+/* LDS layout of Mesh Shader workgroup info. */
+enum {
+   /* DW0: number of primitives */
+   lds_ms_num_prims = 0,
+   /* DW1: reserved for future use */
+   lds_ms_dw1_reserved = 4,
+   /* DW2: workgroup index within the current dispatch */
+   lds_ms_wg_index = 8,
+   /* DW3: number of API workgroups in flight */
+   lds_ms_num_api_waves = 12,
+};
+
+/* Potential location for Mesh Shader outputs. */
+typedef enum {
+   ms_out_mode_lds,
+} ms_out_mode;
+
+typedef struct
+{
+   uint64_t mask; /* Mask of output locations */
+   uint32_t addr; /* Base address */
+} ms_out_part;
+
 typedef struct
 {
+   /* Mesh shader LDS layout. For details, see ms_calculate_output_layout. */
+   struct {
+      uint32_t workgroup_info_addr;
+      ms_out_part vtx_attr;
+      ms_out_part prm_attr;
+      uint32_t indices_addr;
+      uint32_t total_size;
+   } lds;
+} ms_out_mem_layout;
+
+typedef struct
+{
+   ms_out_mem_layout layout;
    uint64_t per_vertex_outputs;
    uint64_t per_primitive_outputs;
-   unsigned num_per_vertex_outputs;
-   unsigned num_per_primitive_outputs;
    unsigned vertices_per_prim;
-   unsigned vertex_attr_lds_addr;
-   unsigned prim_attr_lds_addr;
-   unsigned prim_vtx_indices_addr;
-   unsigned numprims_lds_addr;
+
    unsigned wave_size;
    unsigned api_workgroup_size;
    unsigned hw_workgroup_size;
@@ -1977,7 +2008,7 @@ ms_store_prim_indices(nir_builder *b,
    if (!offset_src)
       offset_src = nir_imm_int(b, 0);
 
-   nir_store_shared(b, nir_u2u8(b, val), offset_src, .base = s->prim_vtx_indices_addr + offset_const);
+   nir_store_shared(b, nir_u2u8(b, val), offset_src, .base = s->layout.lds.indices_addr + offset_const);
 }
 
 static nir_ssa_def *
@@ -1989,7 +2020,7 @@ ms_load_prim_indices(nir_builder *b,
    if (!offset_src)
       offset_src = nir_imm_int(b, 0);
 
-   return nir_load_shared(b, 1, 8, offset_src, .base = s->prim_vtx_indices_addr + offset_const);
+   return nir_load_shared(b, 1, 8, offset_src, .base = s->layout.lds.indices_addr + offset_const);
 }
 
 static void
@@ -1998,7 +2029,7 @@ ms_store_num_prims(nir_builder *b,
                    lower_ngg_ms_state *s)
 {
    nir_ssa_def *addr = nir_imm_int(b, 0);
-   nir_store_shared(b, nir_u2u32(b, store_val), addr, .base = s->numprims_lds_addr);
+   nir_store_shared(b, nir_u2u32(b, store_val), addr, .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
 }
 
 static nir_ssa_def *
@@ -2006,10 +2037,9 @@ ms_load_num_prims(nir_builder *b,
                   lower_ngg_ms_state *s)
 {
    nir_ssa_def *addr = nir_imm_int(b, 0);
-   return nir_load_shared(b, 1, 32, addr, .base = s->numprims_lds_addr);
+   return nir_load_shared(b, 1, 32, addr, .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
 }
 
-
 static nir_ssa_def *
 lower_ms_store_output(nir_builder *b,
                       nir_intrinsic_instr *intrin,
@@ -2135,25 +2165,9 @@ update_ms_output_info(nir_intrinsic_instr *intrin,
    }
 }
 
-static void
-ms_store_arrayed_output_intrin(nir_builder *b,
-                               nir_intrinsic_instr *intrin,
-                               unsigned num_arrayed_outputs,
-                               unsigned base_shared_addr)
+static nir_ssa_def *
+regroup_store_val(nir_builder *b, nir_ssa_def *store_val)
 {
-   nir_ssa_def *store_val = intrin->src[0].ssa;
-   nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
-   nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa;
-
-   unsigned driver_location = nir_intrinsic_base(intrin);
-   unsigned component_offset = nir_intrinsic_component(intrin);
-   unsigned write_mask = nir_intrinsic_write_mask(intrin);
-   unsigned bit_size = store_val->bit_size;
-
-   nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_arrayed_outputs);
-   nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16u);
-   nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
-
    /* Vulkan spec 15.1.4-15.1.5:
     *
     * The shader interface consists of output slots with 4x 32-bit components.
@@ -2166,109 +2180,138 @@ ms_store_arrayed_output_intrin(nir_builder *b,
     * 32-bit and then stored contiguously.
     */
 
-   if (bit_size < 32) {
+   if (store_val->bit_size < 32) {
       assert(store_val->num_components <= 4);
       nir_ssa_def *comps[4] = {0};
       for (unsigned c = 0; c < store_val->num_components; ++c)
          comps[c] = nir_u2u32(b, nir_channel(b, store_val, c));
-      store_val = nir_vec(b, comps, store_val->num_components);
-      bit_size = 32;
+      return nir_vec(b, comps, store_val->num_components);
    }
 
-   unsigned const_off = base_shared_addr + component_offset * 4;
+   return store_val;
+}
+
+static nir_ssa_def *
+regroup_load_val(nir_builder *b, nir_ssa_def *load, unsigned dest_bit_size)
+{
+   if (dest_bit_size == load->bit_size)
+      return load;
+
+   /* Small bitsize components are not stored contiguously, take care of that here. */
+   unsigned num_components = load->num_components;
+   assert(num_components <= 4);
+   nir_ssa_def *components[4] = {0};
+   for (unsigned i = 0; i < num_components; ++i)
+      components[i] = nir_u2u(b, nir_channel(b, load, i), dest_bit_size);
+
+   return nir_vec(b, components, num_components);
+}
+
+static const ms_out_part *
+ms_get_out_layout_part(unsigned location,
+                       shader_info *info,
+                       ms_out_mode *out_mode,
+                       lower_ngg_ms_state *s)
+{
+   uint64_t mask = BITFIELD64_BIT(location);
+
+   if (info->per_primitive_outputs & mask) {
+      if (mask & s->layout.lds.prm_attr.mask) {
+         *out_mode = ms_out_mode_lds;
+         return &s->layout.lds.prm_attr;
+      }
+   } else {
+      if (mask & s->layout.lds.vtx_attr.mask) {
+         *out_mode = ms_out_mode_lds;
+         return &s->layout.lds.vtx_attr;
+      }
+   }
+
+   unreachable("Couldn't figure out mesh shader output mode.");
+}
+
+static void
+ms_store_arrayed_output_intrin(nir_builder *b,
+                               nir_intrinsic_instr *intrin,
+                               lower_ngg_ms_state *s)
+{
+   ms_out_mode out_mode;
+   unsigned location = nir_intrinsic_io_semantics(intrin).location;
+   const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
+
+   unsigned driver_location = nir_intrinsic_base(intrin);
+   unsigned component_offset = nir_intrinsic_component(intrin);
+   unsigned write_mask = nir_intrinsic_write_mask(intrin);
+   unsigned num_outputs = util_bitcount64(out->mask);
+   unsigned const_off = out->addr + component_offset * 4;
+
+   nir_ssa_def *store_val = regroup_store_val(b, intrin->src[0].ssa);
+   nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
+   nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_outputs);
+   nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa;
+   nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16u);
+   nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
 
-   nir_store_shared(b, store_val, addr, .base = const_off,
-                    .write_mask = write_mask, .align_mul = 16,
-                    .align_offset = const_off % 16);
+   if (out_mode == ms_out_mode_lds) {
+      nir_store_shared(b, store_val, addr, .base = const_off,
+                     .write_mask = write_mask, .align_mul = 16,
+                     .align_offset = const_off % 16);
+   } else {
+      unreachable("Invalid MS output mode for store");
+   }
 }
 
 static nir_ssa_def *
 ms_load_arrayed_output(nir_builder *b,
                        nir_ssa_def *arr_index,
                        nir_ssa_def *base_offset,
+                       unsigned location,
                        unsigned driver_location,
                        unsigned component_offset,
                        unsigned num_components,
                        unsigned load_bit_size,
-                       unsigned num_arrayed_outputs,
-                       unsigned base_shared_addr)
+                       lower_ngg_ms_state *s)
 {
+   ms_out_mode out_mode;
+   const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
+
    unsigned component_addr_off = component_offset * 4;
+   unsigned num_outputs = util_bitcount64(out->mask);
+   unsigned const_off = out->addr + component_offset * 4;
 
-   nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_arrayed_outputs);
+   nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_outputs);
    nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16);
    nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
 
-   return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16,
-                          .align_offset = component_addr_off % 16,
-                          .base = base_shared_addr + component_addr_off);
+   if (out_mode == ms_out_mode_lds) {
+      return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16,
+                             .align_offset = component_addr_off % 16,
+                             .base = const_off);
+   } else {
+      unreachable("Invalid MS output mode for load");
+   }
 }
 
 static nir_ssa_def *
 ms_load_arrayed_output_intrin(nir_builder *b,
                               nir_intrinsic_instr *intrin,
-                              unsigned num_arrayed_outputs,
-                              unsigned base_shared_addr)
+                              lower_ngg_ms_state *s)
 {
    nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
    nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa;
 
+   unsigned location = nir_intrinsic_io_semantics(intrin).location;
    unsigned driver_location = nir_intrinsic_base(intrin);
    unsigned component_offset = nir_intrinsic_component(intrin);
    unsigned bit_size = intrin->dest.ssa.bit_size;
    unsigned num_components = intrin->dest.ssa.num_components;
    unsigned load_bit_size = MAX2(bit_size, 32);
 
-   nir_ssa_def *loaded =
-      ms_load_arrayed_output(b, arr_index, base_offset, driver_location, component_offset, num_components,
-                             load_bit_size, num_arrayed_outputs, base_shared_addr);
-
-   if (bit_size == load_bit_size)
-      return loaded;
-
-   /* Small bitsize components are not stored contiguously, take care of that here. */
-   assert(num_components <= 4);
-   nir_ssa_def *components[4] = {0};
-   for (unsigned i = 0; i < num_components; ++i)
-      components[i] = nir_u2u(b, nir_channel(b, loaded, i), bit_size);
-
-   return nir_vec(b, components, num_components);
-}
-
-static nir_ssa_def *
-lower_ms_store_per_vertex_output(nir_builder *b,
-                                 nir_intrinsic_instr *intrin,
-                                 lower_ngg_ms_state *s)
-{
-   update_ms_output_info(intrin, s);
-   ms_store_arrayed_output_intrin(b, intrin, s->num_per_vertex_outputs, s->vertex_attr_lds_addr);
-   return NIR_LOWER_INSTR_PROGRESS_REPLACE;
-}
-
-static nir_ssa_def *
-lower_ms_load_per_vertex_output(nir_builder *b,
-                                nir_intrinsic_instr *intrin,
-                                lower_ngg_ms_state *s)
-{
-   return ms_load_arrayed_output_intrin(b, intrin, s->num_per_vertex_outputs, s->vertex_attr_lds_addr);
-}
-
-static nir_ssa_def *
-lower_ms_store_per_primitive_output(nir_builder *b,
-                                    nir_intrinsic_instr *intrin,
-                                    lower_ngg_ms_state *s)
-{
-   update_ms_output_info(intrin, s);
-   ms_store_arrayed_output_intrin(b, intrin, s->num_per_primitive_outputs, s->prim_attr_lds_addr);
-   return NIR_LOWER_INSTR_PROGRESS_REPLACE;
-}
+   nir_ssa_def *load =
+      ms_load_arrayed_output(b, arr_index, base_offset, location, driver_location,
+                             component_offset, num_components, load_bit_size, s);
 
-static nir_ssa_def *
-lower_ms_load_per_primitive_output(nir_builder *b,
-                                   nir_intrinsic_instr *intrin,
-                                   lower_ngg_ms_state *s)
-{
-   return ms_load_arrayed_output_intrin(b, intrin, s->num_per_primitive_outputs, s->prim_attr_lds_addr);
+   return regroup_load_val(b, load, bit_size);
 }
 
 static nir_ssa_def *
@@ -2309,24 +2352,26 @@ lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
 
    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
 
-   if (intrin->intrinsic == nir_intrinsic_store_output)
+   switch (intrin->intrinsic) {
+   case nir_intrinsic_store_output:
       return lower_ms_store_output(b, intrin, s);
-   else if (intrin->intrinsic == nir_intrinsic_load_output)
+   case nir_intrinsic_load_output:
       return lower_ms_load_output(b, intrin, s);
-   else if (intrin->intrinsic == nir_intrinsic_store_per_vertex_output)
-      return lower_ms_store_per_vertex_output(b, intrin, s);
-   else if (intrin->intrinsic == nir_intrinsic_load_per_vertex_output)
-      return lower_ms_load_per_vertex_output(b, intrin, s);
-   else if (intrin->intrinsic == nir_intrinsic_store_per_primitive_output)
-      return lower_ms_store_per_primitive_output(b, intrin, s);
-   else if (intrin->intrinsic == nir_intrinsic_load_per_primitive_output)
-      return lower_ms_load_per_primitive_output(b, intrin, s);
-   else if (intrin->intrinsic == nir_intrinsic_load_workgroup_id)
+   case nir_intrinsic_store_per_vertex_output:
+   case nir_intrinsic_store_per_primitive_output:
+      update_ms_output_info(intrin, s);
+      ms_store_arrayed_output_intrin(b, intrin, s);
+      return NIR_LOWER_INSTR_PROGRESS_REPLACE;
+   case nir_intrinsic_load_per_vertex_output:
+   case nir_intrinsic_load_per_primitive_output:
+      return ms_load_arrayed_output_intrin(b, intrin, s);
+   case nir_intrinsic_load_workgroup_id:
       return lower_ms_load_workgroup_id(b, intrin, s);
-   else if (intrin->intrinsic == nir_intrinsic_scoped_barrier)
+   case nir_intrinsic_scoped_barrier:
       return update_ms_scoped_barrier(b, intrin, s);
-   else
+   default:
       unreachable("Not a lowerable mesh shader intrinsic.");
+   }
 }
 
 static bool
@@ -2356,14 +2401,12 @@ lower_ms_intrinsics(nir_shader *shader, lower_ngg_ms_state *s)
 static void
 ms_emit_arrayed_outputs(nir_builder *b,
                         nir_ssa_def *invocation_index,
-                        uint64_t arrayed_outputs_mask,
-                        unsigned num_arrayed_outputs,
-                        unsigned lds_base_addr,
+                        uint64_t mask,
                         lower_ngg_ms_state *s)
 {
    nir_ssa_def *zero = nir_imm_int(b, 0);
 
-   u_foreach_bit64(slot, arrayed_outputs_mask) {
+   u_foreach_bit64(slot, mask) {
       /* Should not occour here, handled separately. */
       assert(slot != VARYING_SLOT_PRIMITIVE_COUNT && slot != VARYING_SLOT_PRIMITIVE_INDICES);
 
@@ -2376,8 +2419,8 @@ ms_emit_arrayed_outputs(nir_builder *b,
          u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components);
 
          nir_ssa_def *load =
-            ms_load_arrayed_output(b, invocation_index, zero, driver_location, start_comp,
-                                   num_components, 32, num_arrayed_outputs, lds_base_addr);
+            ms_load_arrayed_output(b, invocation_index, zero, slot, driver_location, start_comp,
+                                   num_components, 32, s);
 
          nir_store_output(b, load, nir_imm_int(b, 0), .base = slot, .component = start_comp,
                           .io_semantics = io_sem);
@@ -2416,7 +2459,7 @@ emit_ms_prelude(nir_builder *b, lower_ngg_ms_state *s)
       return;
    }
 
-   unsigned workgroup_index_lds_addr = s->numprims_lds_addr + 8;
+   unsigned workgroup_index_lds_addr = s->layout.lds.workgroup_info_addr + lds_ms_wg_index;
 
    nir_ssa_def *zero = nir_imm_int(b, 0);
    nir_ssa_def *dont_care = nir_ssa_undef(b, 1, 32);
@@ -2524,8 +2567,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
    nir_if *if_has_output_vertex = nir_push_if(b, has_output_vertex);
    {
       /* All per-vertex attributes. */
-      ms_emit_arrayed_outputs(b, invocation_index, s->per_vertex_outputs,
-                              s->num_per_vertex_outputs, s->vertex_attr_lds_addr, s);
+      ms_emit_arrayed_outputs(b, invocation_index, s->per_vertex_outputs, s);
       nir_export_vertex_amd(b);
    }
    nir_pop_if(b, if_has_output_vertex);
@@ -2535,8 +2577,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
    nir_if *if_has_output_primitive = nir_push_if(b, has_output_primitive);
    {
       /* Generic per-primitive attributes. */
-      ms_emit_arrayed_outputs(b, invocation_index, s->per_primitive_outputs,
-                              s->num_per_primitive_outputs, s->prim_attr_lds_addr, s);
+      ms_emit_arrayed_outputs(b, invocation_index, s->per_primitive_outputs, s);
 
       /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
       if (s->insert_layer_output) {
@@ -2549,7 +2590,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
 
       /* Primitive connectivity data: describes which vertices the primitive uses. */
       nir_ssa_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
-      nir_ssa_def *indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->prim_vtx_indices_addr);
+      nir_ssa_def *indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
       nir_ssa_def *indices[3];
       nir_ssa_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u);
 
@@ -2591,7 +2632,7 @@ handle_smaller_ms_api_workgroup(nir_builder *b,
    bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size;
    bool need_additional_barriers = scan_barriers && !can_shrink_barriers;
 
-   unsigned api_waves_in_flight_addr = s->numprims_lds_addr + 12;
+   unsigned api_waves_in_flight_addr = s->layout.lds.workgroup_info_addr + lds_ms_num_api_waves;
    unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size);
 
    /* Scan the shader for workgroup barriers. */
@@ -2707,6 +2748,50 @@ handle_smaller_ms_api_workgroup(nir_builder *b,
    }
 }
 
+static void
+ms_calculate_arrayed_output_layout(ms_out_mem_layout *l,
+                                   unsigned max_vertices,
+                                   unsigned max_primitives)
+{
+   uint32_t lds_vtx_attr_size = util_bitcount64(l->lds.vtx_attr.mask) * max_vertices * 16;
+   uint32_t lds_prm_attr_size = util_bitcount64(l->lds.prm_attr.mask) * max_primitives * 16;
+   l->lds.prm_attr.addr = ALIGN(l->lds.vtx_attr.addr + lds_vtx_attr_size, 16);
+   l->lds.total_size = l->lds.prm_attr.addr + lds_prm_attr_size;
+}
+
+static ms_out_mem_layout
+ms_calculate_output_layout(unsigned api_shared_size,
+                           uint64_t per_vertex_output_mask,
+                           uint64_t per_primitive_output_mask,
+                           unsigned max_vertices,
+                           unsigned max_primitives,
+                           unsigned vertices_per_prim)
+{
+   uint64_t lds_per_vertex_output_mask = per_vertex_output_mask;
+   uint64_t lds_per_primitive_output_mask = per_primitive_output_mask;
+
+   /* Shared memory used by the API shader. */
+   ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } };
+
+   /* Workgroup information, see ms_workgroup_* for the layout. */
+   l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16);
+   l.lds.total_size = l.lds.workgroup_info_addr + 16;
+
+   /* Per-vertex and per-primitive output attributes. */
+   l.lds.vtx_attr.addr = ALIGN(l.lds.total_size, 16);
+   l.lds.vtx_attr.mask = lds_per_vertex_output_mask;
+   l.lds.prm_attr.mask = lds_per_primitive_output_mask;
+   ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
+
+   /* Indices: flat array of 8-bit vertex indices for each primitive. */
+   l.lds.indices_addr = ALIGN(l.lds.total_size, 16);
+   l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim;
+
+   /* NGG is only allowed to address up to 32K of LDS. */
+   assert(l.lds.total_size <= 32 * 1024);
+   return l;
+}
+
 void
 ac_nir_lower_ngg_ms(nir_shader *shader,
                     unsigned wave_size,
@@ -2722,30 +2807,14 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
    uint64_t per_primitive_outputs =
       shader->info.per_primitive_outputs & shader->info.outputs_written & ~special_outputs;
 
-   unsigned num_per_vertex_outputs = util_bitcount64(per_vertex_outputs);
-   unsigned num_per_primitive_outputs = util_bitcount64(per_primitive_outputs);
    unsigned max_vertices = shader->info.mesh.max_vertices_out;
    unsigned max_primitives = shader->info.mesh.max_primitives_out;
 
-   /* LDS area for total number of output primitives and other info.
-    * DW0: number of primitives
-    * DW1: reserved for later use
-    * DW2: workgroup index within the current dispatch
-    * DW3: number of API workgroups in flight
-    */
-   unsigned numprims_lds_addr = ALIGN(shader->info.shared_size, 16);
-   unsigned numprims_lds_size = 16;
-   /* LDS area for vertex attributes */
-   unsigned vertex_attr_lds_addr = ALIGN(numprims_lds_addr + numprims_lds_size, 16);
-   unsigned vertex_attr_lds_size = max_vertices * num_per_vertex_outputs * 16;
-   /* LDS area for primitive attributes */
-   unsigned prim_attr_lds_addr = ALIGN(vertex_attr_lds_addr + vertex_attr_lds_size, 16);
-   unsigned prim_attr_lds_size = max_primitives * num_per_primitive_outputs * 16;
-   /* LDS area for the vertex indices (stored as a flat array) */
-   unsigned prim_vtx_indices_addr = ALIGN(prim_attr_lds_addr + prim_attr_lds_size, 16);
-   unsigned prim_vtx_indices_size = max_primitives * vertices_per_prim;
-
-   shader->info.shared_size = prim_vtx_indices_addr + prim_vtx_indices_size;
+   ms_out_mem_layout layout =
+      ms_calculate_output_layout(shader->info.shared_size, per_vertex_outputs, per_primitive_outputs,
+                                 max_vertices, max_primitives, vertices_per_prim);
+
+   shader->info.shared_size = layout.lds.total_size;
 
    /* The workgroup size that is specified by the API shader may be different
     * from the size of the workgroup that actually runs on the HW, due to the
@@ -2762,16 +2831,11 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
       ALIGN(MAX3(api_workgroup_size, max_primitives, max_vertices), wave_size);
 
    lower_ngg_ms_state state = {
+      .layout = layout,
       .wave_size = wave_size,
       .per_vertex_outputs = per_vertex_outputs,
       .per_primitive_outputs = per_primitive_outputs,
-      .num_per_vertex_outputs = num_per_vertex_outputs,
-      .num_per_primitive_outputs = num_per_primitive_outputs,
       .vertices_per_prim = vertices_per_prim,
-      .vertex_attr_lds_addr = vertex_attr_lds_addr,
-      .prim_attr_lds_addr = prim_attr_lds_addr,
-      .prim_vtx_indices_addr = prim_vtx_indices_addr,
-      .numprims_lds_addr = numprims_lds_addr,
       .api_workgroup_size = api_workgroup_size,
       .hw_workgroup_size = hw_workgroup_size,
       .insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER),



More information about the mesa-commit mailing list