Mesa (main): ac/nir/ngg: Lower NV mesh shaders to NGG semantics.

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Fri Dec 31 13:38:47 UTC 2021


Module: Mesa
Branch: main
Commit: 7aa42e023a7907c58062bb286c37da3a1f58999d
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=7aa42e023a7907c58062bb286c37da3a1f58999d

Author: Timur Kristóf <timur.kristof at gmail.com>
Date:   Sun Aug 29 10:32:01 2021 +0200

ac/nir/ngg: Lower NV mesh shaders to NGG semantics.

Lower mesh shader outputs to shared memory.

At the end of the shader, read the outputs from shared memory
and export their values as NGG expects.

We allocate separate shared memory (LDS) areas for per-vertex,
per-primitive outputs, primitive indices, primitive count.

Signed-off-by: Timur Kristóf <timur.kristof at gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02 at gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13580>

---

 src/amd/common/ac_nir.h           |   4 +
 src/amd/common/ac_nir_lower_ngg.c | 565 ++++++++++++++++++++++++++++++++++++++
 2 files changed, 569 insertions(+)

diff --git a/src/amd/common/ac_nir.h b/src/amd/common/ac_nir.h
index 16db749baea..097c4be0522 100644
--- a/src/amd/common/ac_nir.h
+++ b/src/amd/common/ac_nir.h
@@ -114,6 +114,10 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
                     unsigned gs_total_out_vtx_bytes,
                     bool provoking_vtx_last);
 
+void
+ac_nir_lower_ngg_ms(nir_shader *shader,
+                    unsigned wave_size);
+
 nir_ssa_def *
 ac_nir_cull_triangle(nir_builder *b,
                      nir_ssa_def *initially_accepted,
diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c
index bdbcb613681..40ed7494c0c 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -94,6 +94,27 @@ typedef struct
    gs_output_component_info output_component_info[VARYING_SLOT_MAX][4];
 } lower_ngg_gs_state;
 
+typedef struct
+{
+   uint64_t per_vertex_outputs;
+   uint64_t per_primitive_outputs;
+   unsigned num_per_vertex_outputs;
+   unsigned num_per_primitive_outputs;
+   unsigned vertices_per_prim;
+   unsigned vertex_attr_lds_addr;
+   unsigned prim_attr_lds_addr;
+   unsigned prim_vtx_indices_addr;
+   unsigned numprims_lds_addr;
+   unsigned wave_size;
+
+   struct {
+      /* Bitmask of components used: 4 bits per slot, 1 bit per component. */
+      uint32_t components_mask;
+      /* Driver location of the output slot, if used. */
+      unsigned driver_location;
+   } output_info[VARYING_SLOT_MAX];
+} lower_ngg_ms_state;
+
 typedef struct {
    nir_variable *pre_cull_position_value_var;
 } remove_culling_shader_outputs_state;
@@ -1924,3 +1945,547 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
    nir_metadata_preserve(impl, nir_metadata_none);
 }
+
+static nir_ssa_def *
+lower_ms_store_output(nir_builder *b,
+                      nir_intrinsic_instr *intrin,
+                      lower_ngg_ms_state *s)
+{
+   nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
+   nir_ssa_def *store_val = intrin->src[0].ssa;
+   unsigned base = nir_intrinsic_base(intrin);
+
+   /* Component makes no sense here. */
+   assert(nir_intrinsic_component(intrin) == 0);
+
+   if (io_sem.location == VARYING_SLOT_PRIMITIVE_COUNT) {
+      /* Total number of primitives output by the mesh shader workgroup.
+       * This can be read and written by any invocation any number of times.
+       */
+
+      /* Base, offset and component make no sense here. */
+      assert(nir_src_is_const(intrin->src[1]) && nir_src_as_uint(intrin->src[1]) == 0);
+      assert(base == 0);
+
+      nir_ssa_def *addr = nir_imm_int(b, 0);
+      nir_build_store_shared(b, nir_u2u32(b, store_val), addr,
+                             .write_mask = 0x1u, .base = s->numprims_lds_addr,
+                             .align_mul = 4u);
+   } else if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) {
+      /* Contrary to the name, these are not primitive indices, but
+       * vertex indices for each vertex of the output primitives.
+       * The Mesh NV API has these stored in a flat array.
+       */
+
+      nir_ssa_def *offset_src = nir_get_io_offset_src(intrin)->ssa;
+      nir_build_store_shared(b, nir_u2u8(b, store_val), offset_src,
+                             .write_mask = 0x1u, .base = s->prim_vtx_indices_addr + base,
+                             .align_mul = 1u);
+   } else {
+      unreachable("Invalid mesh shader output");
+   }
+
+   return NIR_LOWER_INSTR_PROGRESS_REPLACE;
+}
+
+static nir_ssa_def *
+lower_ms_load_output(nir_builder *b,
+                     nir_intrinsic_instr *intrin,
+                     lower_ngg_ms_state *s)
+{
+   nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
+   unsigned base = nir_intrinsic_base(intrin);
+
+   /* Component makes no sense here. */
+   assert(nir_intrinsic_component(intrin) == 0);
+
+   if (io_sem.location == VARYING_SLOT_PRIMITIVE_COUNT) {
+      /* Base, offset and component make no sense here. */
+      assert(nir_src_is_const(intrin->src[1]) && nir_src_as_uint(intrin->src[1]) == 0);
+      assert(base == 0);
+
+      nir_ssa_def *addr = nir_imm_int(b, 0);
+      return nir_build_load_shared(b, 1, 32, addr, .base = s->numprims_lds_addr, .align_mul = 4u);
+   } else if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) {
+      nir_ssa_def *offset_src = nir_get_io_offset_src(intrin)->ssa;
+      nir_ssa_def *index = nir_build_load_shared(b, 1, 8, offset_src,
+                                                 .base = s->prim_vtx_indices_addr + base, .align_mul = 1u);
+      return nir_u2u(b, index, intrin->dest.ssa.bit_size);
+   }
+
+   unreachable("Invalid mesh shader output");
+}
+
+static nir_ssa_def *
+ms_arrayed_output_base_addr(nir_builder *b,
+                            nir_ssa_def *arr_index,
+                            unsigned driver_location,
+                            unsigned num_arrayed_outputs)
+{
+   /* Address offset of the array item (vertex or primitive). */
+   unsigned arr_index_stride = num_arrayed_outputs * 16u;
+   nir_ssa_def *arr_index_off = nir_imul_imm(b, arr_index, arr_index_stride);
+
+   /* IO address offset within the vertex or primitive data. */
+   unsigned io_offset = driver_location * 16u;
+   nir_ssa_def *io_off = nir_imm_int(b, io_offset);
+
+   return nir_iadd_nuw(b, arr_index_off, io_off);
+}
+
+static void
+update_ms_output_info_slot(lower_ngg_ms_state *s,
+                           unsigned slot, unsigned base, unsigned base_off,
+                           uint32_t components_mask)
+{
+   while (components_mask) {
+      unsigned driver_location = base + base_off;
+
+      /* If already set, it must match. */
+      if (s->output_info[slot + base_off].driver_location)
+         assert(s->output_info[slot + base_off].driver_location == driver_location);
+
+      s->output_info[slot + base_off].driver_location = driver_location;
+      s->output_info[slot + base_off].components_mask |= components_mask & 0xF;
+
+      components_mask >>= 4;
+      base_off++;
+   }
+}
+
+static void
+update_ms_output_info(nir_intrinsic_instr *intrin,
+                      lower_ngg_ms_state *s)
+{
+   nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
+   nir_src *base_offset_src = nir_get_io_offset_src(intrin);
+   uint32_t write_mask = nir_intrinsic_write_mask(intrin);
+   unsigned component_offset = nir_intrinsic_component(intrin);
+   unsigned base = nir_intrinsic_base(intrin);
+
+   nir_ssa_def *store_val = intrin->src[0].ssa;
+   write_mask = util_widen_mask(write_mask, DIV_ROUND_UP(store_val->bit_size, 32));
+   uint32_t components_mask = write_mask << component_offset;
+
+   if (nir_src_is_const(*base_offset_src)) {
+      /* Simply mark the components of the current slot as used. */
+      unsigned base_off = nir_src_as_uint(*base_offset_src);
+      update_ms_output_info_slot(s, io_sem.location, base, base_off, components_mask);
+   } else {
+      /* Indirect offset: mark the components of all slots as used. */
+      for (unsigned base_off = 0; base_off < io_sem.num_slots; ++base_off)
+         update_ms_output_info_slot(s, io_sem.location, base, base_off, components_mask);
+   }
+}
+
+static void
+ms_store_arrayed_output_intrin(nir_builder *b,
+                               nir_intrinsic_instr *intrin,
+                               unsigned num_arrayed_outputs,
+                               unsigned base_shared_addr)
+{
+   nir_ssa_def *store_val = intrin->src[0].ssa;
+   nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
+   nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa;
+
+   unsigned driver_location = nir_intrinsic_base(intrin);
+   unsigned component_offset = nir_intrinsic_component(intrin);
+   unsigned write_mask = nir_intrinsic_write_mask(intrin);
+   unsigned bit_size = store_val->bit_size;
+
+   nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_arrayed_outputs);
+   nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16u);
+   nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
+
+   /* Vulkan spec 15.1.4-15.1.5:
+    *
+    * The shader interface consists of output slots with 4x 32-bit components.
+    * Small bitsize components consume the same space as 32-bit components,
+    * but 64-bit ones consume twice as much.
+    *
+    * The same output slot may consist of components of different bit sizes.
+    * Therefore for simplicity we don't store small bitsize components
+    * contiguously, but pad them instead. In practice, they are converted to
+    * 32-bit and then stored contiguously.
+    */
+
+   if (bit_size < 32) {
+      assert(store_val->num_components <= 4);
+      nir_ssa_def *comps[4] = {0};
+      for (unsigned c = 0; c < store_val->num_components; ++c)
+         comps[c] = nir_u2u32(b, nir_channel(b, store_val, c));
+      store_val = nir_vec(b, comps, store_val->num_components);
+      bit_size = 32;
+   }
+
+   unsigned const_off = base_shared_addr + component_offset * 4;
+
+   nir_build_store_shared(b, store_val, addr, .base = const_off,
+                          .write_mask = write_mask, .align_mul = 16,
+                          .align_offset = const_off % 16);
+}
+
+static nir_ssa_def *
+ms_load_arrayed_output(nir_builder *b,
+                       nir_ssa_def *arr_index,
+                       nir_ssa_def *base_offset,
+                       unsigned driver_location,
+                       unsigned component_offset,
+                       unsigned num_components,
+                       unsigned load_bit_size,
+                       unsigned num_arrayed_outputs,
+                       unsigned base_shared_addr)
+{
+   unsigned component_addr_off = component_offset * 4;
+
+   nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_arrayed_outputs);
+   nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16);
+   nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
+
+   return nir_build_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16,
+                                .align_offset = component_addr_off % 16,
+                                .base = base_shared_addr + component_addr_off);
+}
+
+static nir_ssa_def *
+ms_load_arrayed_output_intrin(nir_builder *b,
+                              nir_intrinsic_instr *intrin,
+                              unsigned num_arrayed_outputs,
+                              unsigned base_shared_addr)
+{
+   nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
+   nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa;
+
+   unsigned driver_location = nir_intrinsic_base(intrin);
+   unsigned component_offset = nir_intrinsic_component(intrin);
+   unsigned bit_size = intrin->dest.ssa.bit_size;
+   unsigned num_components = intrin->dest.ssa.num_components;
+   unsigned load_bit_size = MAX2(bit_size, 32);
+
+   nir_ssa_def *loaded =
+      ms_load_arrayed_output(b, arr_index, base_offset, driver_location, component_offset, num_components,
+                             load_bit_size, num_arrayed_outputs, base_shared_addr);
+
+   if (bit_size == load_bit_size)
+      return loaded;
+
+   /* Small bitsize components are not stored contiguously, take care of that here. */
+   assert(num_components <= 4);
+   nir_ssa_def *components[4] = {0};
+   for (unsigned i = 0; i < num_components; ++i)
+      components[i] = nir_u2u(b, nir_channel(b, loaded, i), bit_size);
+
+   return nir_vec(b, components, num_components);
+}
+
+static nir_ssa_def *
+lower_ms_store_per_vertex_output(nir_builder *b,
+                                 nir_intrinsic_instr *intrin,
+                                 lower_ngg_ms_state *s)
+{
+   update_ms_output_info(intrin, s);
+   ms_store_arrayed_output_intrin(b, intrin, s->num_per_vertex_outputs, s->vertex_attr_lds_addr);
+   return NIR_LOWER_INSTR_PROGRESS_REPLACE;
+}
+
+static nir_ssa_def *
+lower_ms_load_per_vertex_output(nir_builder *b,
+                                nir_intrinsic_instr *intrin,
+                                lower_ngg_ms_state *s)
+{
+   return ms_load_arrayed_output_intrin(b, intrin, s->num_per_vertex_outputs, s->vertex_attr_lds_addr);
+}
+
+static nir_ssa_def *
+lower_ms_store_per_primitive_output(nir_builder *b,
+                                    nir_intrinsic_instr *intrin,
+                                    lower_ngg_ms_state *s)
+{
+   update_ms_output_info(intrin, s);
+   ms_store_arrayed_output_intrin(b, intrin, s->num_per_primitive_outputs, s->prim_attr_lds_addr);
+   return NIR_LOWER_INSTR_PROGRESS_REPLACE;
+}
+
+static nir_ssa_def *
+lower_ms_load_per_primitive_output(nir_builder *b,
+                                   nir_intrinsic_instr *intrin,
+                                   lower_ngg_ms_state *s)
+{
+   return ms_load_arrayed_output_intrin(b, intrin, s->num_per_primitive_outputs, s->prim_attr_lds_addr);
+}
+
+static nir_ssa_def *
+update_ms_scoped_barrier(nir_builder *b,
+                         nir_intrinsic_instr *intrin,
+                         lower_ngg_ms_state *s)
+{
+   /* Output loads and stores are lowered to shared memory access,
+    * so we have to update the barriers to also reflect this.
+    */
+   unsigned mem_modes = nir_intrinsic_memory_modes(intrin);
+   if (mem_modes & nir_var_shader_out)
+      mem_modes |= nir_var_mem_shared;
+   else
+      return NULL;
+
+   nir_intrinsic_set_memory_modes(intrin, mem_modes);
+
+   return NIR_LOWER_INSTR_PROGRESS;
+}
+
+static nir_ssa_def *
+lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
+{
+   lower_ngg_ms_state *s = (lower_ngg_ms_state *) state;
+
+   if (instr->type != nir_instr_type_intrinsic)
+      return NULL;
+
+   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+
+   if (intrin->intrinsic == nir_intrinsic_store_output)
+      return lower_ms_store_output(b, intrin, s);
+   else if (intrin->intrinsic == nir_intrinsic_load_output)
+      return lower_ms_load_output(b, intrin, s);
+   else if (intrin->intrinsic == nir_intrinsic_store_per_vertex_output)
+      return lower_ms_store_per_vertex_output(b, intrin, s);
+   else if (intrin->intrinsic == nir_intrinsic_load_per_vertex_output)
+      return lower_ms_load_per_vertex_output(b, intrin, s);
+   else if (intrin->intrinsic == nir_intrinsic_store_per_primitive_output)
+      return lower_ms_store_per_primitive_output(b, intrin, s);
+   else if (intrin->intrinsic == nir_intrinsic_load_per_primitive_output)
+      return lower_ms_load_per_primitive_output(b, intrin, s);
+   else if (intrin->intrinsic == nir_intrinsic_scoped_barrier)
+      return update_ms_scoped_barrier(b, intrin, s);
+   else
+      unreachable("Not a lowerable mesh shader intrinsic.");
+}
+
+static bool
+filter_ms_intrinsic(const nir_instr *instr,
+                    UNUSED const void *st)
+{
+   if (instr->type != nir_instr_type_intrinsic)
+      return false;
+
+   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+   return intrin->intrinsic == nir_intrinsic_store_output ||
+          intrin->intrinsic == nir_intrinsic_load_output ||
+          intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
+          intrin->intrinsic == nir_intrinsic_load_per_vertex_output ||
+          intrin->intrinsic == nir_intrinsic_store_per_primitive_output ||
+          intrin->intrinsic == nir_intrinsic_load_per_primitive_output ||
+          intrin->intrinsic == nir_intrinsic_scoped_barrier;
+}
+
+static void
+lower_ms_intrinsics(nir_shader *shader, lower_ngg_ms_state *s)
+{
+   nir_shader_lower_instructions(shader, filter_ms_intrinsic, lower_ms_intrinsic, s);
+}
+
+static void
+ms_emit_arrayed_outputs(nir_builder *b,
+                        nir_ssa_def *invocation_index,
+                        uint64_t arrayed_outputs_mask,
+                        unsigned num_arrayed_outputs,
+                        unsigned lds_base_addr,
+                        lower_ngg_ms_state *s)
+{
+   nir_ssa_def *zero = nir_imm_int(b, 0);
+
+   u_foreach_bit64(slot, arrayed_outputs_mask) {
+      const nir_io_semantics io_sem = { .location = slot, .num_slots = 1 };
+      const unsigned driver_location = s->output_info[slot].driver_location;
+      unsigned component_mask = s->output_info[slot].components_mask;
+
+      while (component_mask) {
+         int start_comp = 0, num_components = 1;
+         u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components);
+
+         nir_ssa_def *load =
+            ms_load_arrayed_output(b, invocation_index, zero, driver_location, start_comp,
+                                   num_components, 32, num_arrayed_outputs, lds_base_addr);
+
+         nir_build_store_output(b, load, nir_imm_int(b, 0), .write_mask = BITFIELD_MASK(num_components),
+                                .base = slot, .component = start_comp, .io_semantics = io_sem);
+      }
+   }
+}
+
+static void
+emit_ms_finale(nir_shader *shader, lower_ngg_ms_state *s)
+{
+   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
+   assert(impl);
+   nir_block *last_block = nir_impl_last_block(impl);
+   assert(last_block);
+
+   /* We assume there is always a single end block in the shader. */
+
+   nir_builder builder;
+   nir_builder *b = &builder; /* This is to avoid the & */
+   nir_builder_init(b, impl);
+   b->cursor = nir_after_block(last_block);
+
+   nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
+                         .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared);
+
+   /* Limitations of the NV extension:
+    * - Number of primitives can be written and read by any invocation,
+    *   so we have to store/load it to/from LDS to make sure the general case works.
+    * - Number of vertices is not actually known, so we just always use the
+    *   maximum number here.
+    *
+    * TODO: in a possible cross-vendor extension we expect to be able do this smarter:
+    * - Lower SetMeshOutputCounts (not present in NV) directly to alloc_vertices_and_primitives.
+    * - We'll know the exact number of output vertices.
+    * - No longer need to ensure that these variables are readable by any invocation.
+    */
+   nir_ssa_def *loaded_num_prm;
+   nir_ssa_def *zero = nir_imm_int(b, 0);
+   nir_ssa_def *dont_care = nir_ssa_undef(b, 1, 32);
+   nir_if *if_elected = nir_push_if(b, nir_build_elect(b, 1));
+   {
+      loaded_num_prm = nir_build_load_shared(b, 1, 32, zero, .base = s->numprims_lds_addr, .align_mul = 4u);
+   }
+   nir_pop_if(b, if_elected);
+   loaded_num_prm = nir_if_phi(b, loaded_num_prm, dont_care);
+   nir_ssa_def *num_prm = nir_build_read_first_invocation(b, loaded_num_prm);
+   nir_ssa_def *num_vtx = nir_imm_int(b, shader->info.mesh.max_vertices_out);
+
+   /* If the shader doesn't actually create any primitives, don't allocate any output. */
+   num_vtx = nir_bcsel(b, nir_ieq_imm(b, num_prm, 0), nir_imm_int(b, 0), num_vtx);
+
+   /* Emit GS_ALLOC_REQ on Wave 0 to let the HW know the output size. */
+   nir_ssa_def *wave_id = nir_build_load_subgroup_id(b);
+   nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0));
+   {
+      nir_build_alloc_vertices_and_primitives_amd(b, num_vtx, num_prm);
+   }
+   nir_pop_if(b, if_wave_0);
+
+   nir_ssa_def *invocation_index = nir_build_load_local_invocation_index(b);
+
+   /* Load vertex/primitive attributes from shared memory and
+    * emit store_output intrinsics for them.
+    *
+    * Contrary to the semantics of the API mesh shader, these are now
+    * compliant with NGG HW semantics, meaning that these store the
+    * current thread's vertex attributes in a way the HW can export.
+    */
+
+   /* Export vertices. */
+   nir_ssa_def *has_output_vertex = nir_ilt(b, invocation_index, num_vtx);
+   nir_if *if_has_output_vertex = nir_push_if(b, has_output_vertex);
+   {
+      /* All per-vertex attributes. */
+      ms_emit_arrayed_outputs(b, invocation_index, s->per_vertex_outputs,
+                              s->num_per_vertex_outputs, s->vertex_attr_lds_addr, s);
+      nir_build_export_vertex_amd(b);
+   }
+   nir_pop_if(b, if_has_output_vertex);
+
+   /* Export primitives. */
+   nir_ssa_def *has_output_primitive = nir_ilt(b, invocation_index, num_prm);
+   nir_if *if_has_output_primitive = nir_push_if(b, has_output_primitive);
+   {
+      /* Generic per-primitive attributes. */
+      ms_emit_arrayed_outputs(b, invocation_index, s->per_primitive_outputs,
+                              s->num_per_primitive_outputs, s->prim_attr_lds_addr, s);
+
+      /* Primitive connectivity data: describes which vertices the primitive uses. */
+      nir_ssa_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
+      nir_ssa_def *indices_loaded = nir_build_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->prim_vtx_indices_addr, .align_mul = 1u);
+      nir_ssa_def *indices[3];
+      indices[0] = nir_u2u32(b, nir_channel(b, indices_loaded, 0));
+      indices[1] = s->vertices_per_prim > 1 ? nir_u2u32(b, nir_channel(b, indices_loaded, 1)) : NULL;
+      indices[2] = s->vertices_per_prim > 2 ? nir_u2u32(b, nir_channel(b, indices_loaded, 2)) : NULL;
+
+      nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, NULL, false);
+      nir_build_export_primitive_amd(b, prim_exp_arg);
+   }
+   nir_pop_if(b, if_has_output_primitive);
+}
+
+void
+ac_nir_lower_ngg_ms(nir_shader *shader,
+                    unsigned wave_size)
+{
+   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
+   assert(impl);
+
+   unsigned vertices_per_prim = 3;
+   if (shader->info.mesh.primitive_type == GL_POINTS)
+      vertices_per_prim = 1;
+   else if (shader->info.mesh.primitive_type == GL_LINES)
+      vertices_per_prim = 2;
+
+   uint64_t per_vertex_outputs = shader->info.outputs_written & ~shader->info.per_primitive_outputs
+                                 & ~BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT)
+                                 & ~BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES);
+   uint64_t per_primitive_outputs = shader->info.per_primitive_outputs & shader->info.outputs_written;
+   unsigned num_per_vertex_outputs = util_bitcount64(per_vertex_outputs);
+   unsigned num_per_primitive_outputs = util_bitcount64(per_primitive_outputs);
+   unsigned max_vertices = shader->info.mesh.max_vertices_out;
+   unsigned max_primitives = shader->info.mesh.max_primitives_out;
+
+   /* LDS area for vertex attributes */
+   unsigned vertex_attr_lds_addr = ALIGN(shader->info.shared_size, 16);
+   unsigned vertex_attr_lds_size = max_vertices * num_per_vertex_outputs * 16;
+   /* LDS area for primitive attributes */
+   unsigned prim_attr_lds_addr = ALIGN(vertex_attr_lds_addr + vertex_attr_lds_size, 16);
+   unsigned prim_attr_lds_size = max_primitives * num_per_primitive_outputs * 16;
+   /* LDS area for the vertex indices (stored as a flat array) */
+   unsigned prim_vtx_indices_addr = ALIGN(prim_attr_lds_addr + prim_attr_lds_size, 16);
+   unsigned prim_vtx_indices_size = max_primitives * vertices_per_prim;
+   /* LDS area for total number of output primitives. */
+   unsigned numprims_lds_addr = ALIGN(prim_vtx_indices_addr + prim_vtx_indices_size, 16);
+   unsigned numprims_lds_size = 4;
+
+   shader->info.shared_size = numprims_lds_addr + numprims_lds_size;
+
+   lower_ngg_ms_state state = {
+      .wave_size = wave_size,
+      .per_vertex_outputs = per_vertex_outputs,
+      .per_primitive_outputs = per_primitive_outputs,
+      .num_per_vertex_outputs = num_per_vertex_outputs,
+      .num_per_primitive_outputs = num_per_primitive_outputs,
+      .vertices_per_prim = vertices_per_prim,
+      .vertex_attr_lds_addr = vertex_attr_lds_addr,
+      .prim_attr_lds_addr = prim_attr_lds_addr,
+      .prim_vtx_indices_addr = prim_vtx_indices_addr,
+      .numprims_lds_addr = numprims_lds_addr,
+   };
+
+   /* Extract the full control flow. It is going to be wrapped in an if statement. */
+   nir_cf_list extracted;
+   nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
+
+   nir_builder builder;
+   nir_builder *b = &builder; /* This is to avoid the & */
+   nir_builder_init(b, impl);
+   b->cursor = nir_before_cf_list(&impl->body);
+
+   /* There may be a difference between MS workgroup size and the
+    * number of output vertices/primitives. So it is possible that the actual H
+    * workgroup is larger than what the user wants.
+    * So, only execute the API shader for invocations that the user needs.
+    */
+   unsigned num_ms_invocations = b->shader->info.workgroup_size[0] *
+                                 b->shader->info.workgroup_size[1] *
+                                 b->shader->info.workgroup_size[2];
+   nir_ssa_def *invocation_index = nir_build_load_local_invocation_index(b);
+   nir_ssa_def *has_ms_invocation = nir_ult(b, invocation_index, nir_imm_int(b, num_ms_invocations));
+   nir_if *if_has_ms_invocation = nir_push_if(b, has_ms_invocation);
+   nir_cf_reinsert(&extracted, b->cursor);
+   b->cursor = nir_after_cf_list(&if_has_ms_invocation->then_list);
+   nir_pop_if(b, if_has_ms_invocation);
+
+   lower_ms_intrinsics(shader, &state);
+   emit_ms_finale(shader, &state);
+
+   /* Cleanup */
+   nir_validate_shader(shader, "after emitting NGG MS");
+   nir_metadata_preserve(impl, nir_metadata_none);
+}



More information about the mesa-commit mailing list