Mesa (main): ac/nir: Use a ballot that matches the wave size during NGG lowering.

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Jul 14 00:44:52 UTC 2021


Module: Mesa
Branch: main
Commit: 556a690bac3a48f7f1c0627f3fc4caf8f21d0f89
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=556a690bac3a48f7f1c0627f3fc4caf8f21d0f89

Author: Timur Kristóf <timur.kristof at gmail.com>
Date:   Mon Jul  5 11:19:14 2021 +0200

ac/nir: Use a ballot that matches the wave size during NGG lowering.

This generates slightly more efficient code in Wave32 mode.

Signed-off-by: Timur Kristóf <timur.kristof at gmail.com>
Reviewed-by: Daniel Schürmann <daniel at schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/10525>

---

 src/amd/common/ac_nir_lower_ngg.c | 15 ++++++++++-----
 1 file changed, 10 insertions(+), 5 deletions(-)

diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c
index 7af8708966b..9cc7687a352 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -34,6 +34,7 @@ typedef struct
    bool passthrough;
    bool export_prim_id;
    bool early_prim_export;
+   unsigned wave_size;
    unsigned max_num_waves;
    unsigned num_vertices_per_primitives;
    unsigned provoking_vtx_idx;
@@ -55,6 +56,7 @@ typedef struct
    nir_variable *current_clear_primflag_idx_var;
    int const_out_vtxcnt[4];
    int const_out_prmcnt[4];
+   unsigned wave_size;
    unsigned max_num_waves;
    unsigned num_vertices_per_primitive;
    unsigned lds_addr_gs_out_vtx;
@@ -80,17 +82,18 @@ typedef struct {
  */
 static wg_repack_result
 repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool,
-                                unsigned lds_addr_base, unsigned max_num_waves)
+                                unsigned lds_addr_base, unsigned max_num_waves,
+                                unsigned wave_size)
 {
    /* Input boolean: 1 if the current invocation should survive the repack. */
    assert(input_bool->bit_size == 1);
 
    /* STEP 1. Count surviving invocations in the current wave.
     *
-    * Implemented by a scalar instruction that simply counts the number of bits set in a 64-bit mask.
+    * Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.
     */
 
-   nir_ssa_def *input_mask = nir_build_ballot(b, 1, 64, input_bool);
+   nir_ssa_def *input_mask = nir_build_ballot(b, 1, wave_size, input_bool);
    nir_ssa_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask);
 
    /* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
@@ -342,6 +345,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
       .prim_exp_arg_var = prim_exp_arg_var,
       .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
       .max_es_num_vertices = max_num_es_vertices,
+      .wave_size = wave_size,
    };
 
    /* We need LDS space when VS needs to export the primitive ID. */
@@ -488,7 +492,7 @@ ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_st
       unsigned gs_vtx_cnt = nir_src_as_uint(intrin->src[0]);
       unsigned gs_prm_cnt = nir_src_as_uint(intrin->src[1]);
       unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u);
-      nir_ssa_def *num_threads = nir_bit_count(b, nir_build_ballot(b, 1, 64, nir_imm_bool(b, true)));
+      nir_ssa_def *num_threads = nir_bit_count(b, nir_build_ballot(b, 1, s->wave_size, nir_imm_bool(b, true)));
       num_prims_in_wave = nir_imul_imm(b, num_threads, total_prm_cnt);
    } else {
       nir_ssa_def *gs_vtx_cnt = intrin->src[0].ssa;
@@ -819,7 +823,7 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
     * To ensure this, we need to repack invocations that have a live vertex.
     */
    nir_ssa_def *vertex_live = nir_ine(b, out_vtx_primflag_0, nir_imm_zero(b, 1, out_vtx_primflag_0->bit_size));
-   wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves);
+   wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves, s->wave_size);
 
    nir_ssa_def *workgroup_num_vertices = rep.num_repacked_invocations;
    nir_ssa_def *exporter_tid_in_tg = rep.repacked_invocation_index;
@@ -858,6 +862,7 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
 
    lower_ngg_gs_state state = {
       .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
+      .wave_size = wave_size,
       .lds_addr_gs_out_vtx = esgs_ring_lds_bytes,
       .lds_addr_gs_scratch = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 8u /* for the repacking code */),
       .lds_offs_primflags = gs_out_vtx_bytes,



More information about the mesa-commit mailing list