Mesa (main): ac/nir/ngg: Use mesh shader scratch ring when outputs don't fit LDS.

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Jun 8 09:12:36 UTC 2022


Module: Mesa
Branch: main
Commit: b664279755bc963c098a327c278e7e2482c32ed2
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=b664279755bc963c098a327c278e7e2482c32ed2

Author: Timur Kristóf <timur.kristof at gmail.com>
Date:   Fri May 20 18:09:12 2022 +0200

ac/nir/ngg: Use mesh shader scratch ring when outputs don't fit LDS.

Signed-off-by: Timur Kristóf <timur.kristof at gmail.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset at gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02 at gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16737>

---

 src/amd/common/ac_nir.h           |  1 +
 src/amd/common/ac_nir_lower_ngg.c | 63 ++++++++++++++++++++++++++++++++++++++-
 src/amd/vulkan/radv_shader.c      |  3 +-
 3 files changed, 65 insertions(+), 2 deletions(-)

diff --git a/src/amd/common/ac_nir.h b/src/amd/common/ac_nir.h
index 8920e985517..f5c59acbcf9 100644
--- a/src/amd/common/ac_nir.h
+++ b/src/amd/common/ac_nir.h
@@ -138,6 +138,7 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
 
 void
 ac_nir_lower_ngg_ms(nir_shader *shader,
+                    bool *out_needs_scratch_ring,
                     unsigned wave_size,
                     bool multiview);
 
diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c
index d2339967c6a..5b99940b469 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -110,6 +110,7 @@ enum {
 /* Potential location for Mesh Shader outputs. */
 typedef enum {
    ms_out_mode_lds,
+   ms_out_mode_vram,
 } ms_out_mode;
 
 typedef struct
@@ -128,6 +129,11 @@ typedef struct
       uint32_t indices_addr;
       uint32_t total_size;
    } lds;
+   /* VRAM "mesh shader scratch ring" layout for outputs that don't fit into the LDS. */
+   struct {
+      ms_out_part vtx_attr;
+      ms_out_part prm_attr;
+   } vram;
 } ms_out_mem_layout;
 
 typedef struct
@@ -2219,11 +2225,17 @@ ms_get_out_layout_part(unsigned location,
       if (mask & s->layout.lds.prm_attr.mask) {
          *out_mode = ms_out_mode_lds;
          return &s->layout.lds.prm_attr;
+      } else if (mask & s->layout.vram.prm_attr.mask) {
+         *out_mode = ms_out_mode_vram;
+         return &s->layout.vram.prm_attr;
       }
    } else {
       if (mask & s->layout.lds.vtx_attr.mask) {
          *out_mode = ms_out_mode_lds;
          return &s->layout.lds.vtx_attr;
+      } else if (mask & s->layout.vram.vtx_attr.mask) {
+         *out_mode = ms_out_mode_vram;
+         return &s->layout.vram.vtx_attr;
       }
    }
 
@@ -2256,6 +2268,13 @@ ms_store_arrayed_output_intrin(nir_builder *b,
       nir_store_shared(b, store_val, addr, .base = const_off,
                      .write_mask = write_mask, .align_mul = 16,
                      .align_offset = const_off % 16);
+   } else if (out_mode == ms_out_mode_vram) {
+      nir_ssa_def *ring = nir_load_ring_mesh_scratch_amd(b);
+      nir_ssa_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
+      nir_store_buffer_amd(b, store_val, ring, base_addr, off,
+                           .base = const_off,
+                           .write_mask = write_mask,
+                           .memory_modes = nir_var_shader_out);
    } else {
       unreachable("Invalid MS output mode for store");
    }
@@ -2287,6 +2306,12 @@ ms_load_arrayed_output(nir_builder *b,
       return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16,
                              .align_offset = component_addr_off % 16,
                              .base = const_off);
+   } else if (out_mode == ms_out_mode_vram) {
+      nir_ssa_def *ring = nir_load_ring_mesh_scratch_amd(b);
+      nir_ssa_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
+      return nir_load_buffer_amd(b, num_components, load_bit_size, ring, addr, off,
+                                 .base = const_off,
+                                 .memory_modes = nir_var_shader_out);
    } else {
       unreachable("Invalid MS output mode for load");
    }
@@ -2748,6 +2773,15 @@ handle_smaller_ms_api_workgroup(nir_builder *b,
    }
 }
 
+static void
+ms_move_output(ms_out_part *from, ms_out_part *to)
+{
+   uint64_t loc = util_logbase2_64(from->mask);
+   uint64_t bit = BITFIELD64_BIT(loc);
+   from->mask ^= bit;
+   to->mask |= bit;
+}
+
 static void
 ms_calculate_arrayed_output_layout(ms_out_mem_layout *l,
                                    unsigned max_vertices,
@@ -2757,6 +2791,9 @@ ms_calculate_arrayed_output_layout(ms_out_mem_layout *l,
    uint32_t lds_prm_attr_size = util_bitcount64(l->lds.prm_attr.mask) * max_primitives * 16;
    l->lds.prm_attr.addr = ALIGN(l->lds.vtx_attr.addr + lds_vtx_attr_size, 16);
    l->lds.total_size = l->lds.prm_attr.addr + lds_prm_attr_size;
+
+   uint32_t vram_vtx_attr_size = util_bitcount64(l->vram.vtx_attr.mask) * max_vertices * 16;
+   l->vram.prm_attr.addr = ALIGN(l->vram.vtx_attr.addr + vram_vtx_attr_size, 16);
 }
 
 static ms_out_mem_layout
@@ -2777,12 +2814,34 @@ ms_calculate_output_layout(unsigned api_shared_size,
    l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16);
    l.lds.total_size = l.lds.workgroup_info_addr + 16;
 
-   /* Per-vertex and per-primitive output attributes. */
+   /* Per-vertex and per-primitive output attributes.
+    * First, try to put all outputs into LDS (shared memory).
+    * If they don't fit, try to move them to VRAM one by one.
+    */
    l.lds.vtx_attr.addr = ALIGN(l.lds.total_size, 16);
    l.lds.vtx_attr.mask = lds_per_vertex_output_mask;
    l.lds.prm_attr.mask = lds_per_primitive_output_mask;
    ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
 
+   /* NGG shaders can only address up to 32K LDS memory.
+    * The spec requires us to allow the application to use at least up to 28K
+    * shared memory. Additionally, we reserve 2K for driver internal use
+    * (eg. primitive indices and such, see below).
+    *
+    * Move the outputs that do not fit LDS, to VRAM.
+    * Start with per-primitive attributes, because those are grouped at the end.
+    */
+   while (l.lds.total_size >= 30 * 1024) {
+      if (l.lds.prm_attr.mask)
+         ms_move_output(&l.lds.prm_attr, &l.vram.prm_attr);
+      else if (l.lds.vtx_attr.mask)
+         ms_move_output(&l.lds.vtx_attr, &l.vram.vtx_attr);
+      else
+         unreachable("API shader uses too much shared memory.");
+
+      ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
+   }
+
    /* Indices: flat array of 8-bit vertex indices for each primitive. */
    l.lds.indices_addr = ALIGN(l.lds.total_size, 16);
    l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim;
@@ -2794,6 +2853,7 @@ ms_calculate_output_layout(unsigned api_shared_size,
 
 void
 ac_nir_lower_ngg_ms(nir_shader *shader,
+                    bool *out_needs_scratch_ring,
                     unsigned wave_size,
                     bool multiview)
 {
@@ -2815,6 +2875,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
                                  max_vertices, max_primitives, vertices_per_prim);
 
    shader->info.shared_size = layout.lds.total_size;
+   *out_needs_scratch_ring = layout.vram.vtx_attr.mask || layout.vram.prm_attr.mask;
 
    /* The workgroup size that is specified by the API shader may be different
     * from the size of the workgroup that actually runs on the HW, due to the
diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c
index 20dd59c3c83..3f4d1b3551a 100644
--- a/src/amd/vulkan/radv_shader.c
+++ b/src/amd/vulkan/radv_shader.c
@@ -1240,7 +1240,8 @@ void radv_lower_ngg(struct radv_device *device, struct radv_pipeline_stage *ngg_
                  info->ngg_info.esgs_ring_size, info->gs.gsvs_vertex_size,
                  info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last);
    } else if (nir->info.stage == MESA_SHADER_MESH) {
-      NIR_PASS_V(nir, ac_nir_lower_ngg_ms, info->wave_size, pl_key->has_multiview_view_index);
+      bool scratch_ring = false;
+      NIR_PASS_V(nir, ac_nir_lower_ngg_ms, &scratch_ring, info->wave_size, pl_key->has_multiview_view_index);
    } else {
       unreachable("invalid SW stage passed to radv_lower_ngg");
    }



More information about the mesa-commit mailing list