Mesa (main): ac/nir: Refactor and optimize the repacking sequence.

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Jun 9 17:25:49 UTC 2021


Module: Mesa
Branch: main
Commit: f6b2db298f79aa40ceffd36f757105620cb2d66b
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=f6b2db298f79aa40ceffd36f757105620cb2d66b

Author: Timur Kristóf <timur.kristof at gmail.com>
Date:   Fri May 28 21:59:21 2021 +0200

ac/nir: Refactor and optimize the repacking sequence.

According to feedback, the terminology with "exclusive scan"
and "reduction" is difficult. Change it to use "repack" instead,
which better fits what this sequence is actually used for.

The new sequence stores only 1 byte / wave to LDS, and uses packed
instructions to produce the results. This has lower latency and
fewer instructions than what we previously had.

Signed-off-by: Timur Kristóf <timur.kristof at gmail.com>
Reviewed-by: Tony Wasserka <tony.wasserka at gmx.de>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/11072>

---

 src/amd/common/ac_nir_lower_ngg.c | 178 +++++++++++++++++++++-----------------
 1 file changed, 101 insertions(+), 77 deletions(-)

diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c
index 0181d1a701d..fffbdac49ef 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -68,103 +68,128 @@ typedef struct
 } lower_ngg_gs_state;
 
 typedef struct {
-   nir_ssa_def *reduction_result;
-   nir_ssa_def *excl_scan_result;
-} wg_scan_result;
+   nir_ssa_def *num_repacked_invocations;
+   nir_ssa_def *repacked_invocation_index;
+} wg_repack_result;
 
-static wg_scan_result
-workgroup_reduce_and_exclusive_scan(nir_builder *b, nir_ssa_def *input_bool,
-                                    unsigned lds_addr_base, unsigned max_num_waves)
+/**
+ * Repacks invocations in the current workgroup to eliminate gaps between them.
+ *
+ * Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave).
+ * Assumes that all invocations in the workgroup are active (exec = -1).
+ */
+static wg_repack_result
+repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool,
+                                unsigned lds_addr_base, unsigned max_num_waves)
 {
-   /* This performs a reduction along with an exclusive scan addition accross the workgroup.
-    * Assumes that all lanes are enabled (exec = -1) where this is emitted.
+   /* Input boolean: 1 if the current invocation should survive the repack. */
+   assert(input_bool->bit_size == 1);
+
+   /* STEP 1. Count surviving invocations in the current wave.
     *
-    * Input:  (1) divergent bool
-    *             -- 1 if the lane has a live/valid vertex, 0 otherwise
-    * Output: (1) result of a reduction over the entire workgroup,
-    *             -- the total number of vertices emitted by the workgroup
-    *         (2) result of an exclusive scan over the entire workgroup
-    *             -- used for vertex compaction, in order to determine
-    *                which lane should export the current lane's vertex
+    * Implemented by a scalar instruction that simply counts the number of bits set in a 64-bit mask.
     */
 
-   assert(input_bool->bit_size == 1);
-
-   /* Reduce the boolean -- result is the number of live vertices in the current wave */
    nir_ssa_def *input_mask = nir_build_ballot(b, 1, 64, input_bool);
-   nir_ssa_def *wave_reduction = nir_bit_count(b, input_mask);
+   nir_ssa_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask);
 
-   /* Take care of when we know in compile time that the maximum workgroup size is small */
+   /* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
    if (max_num_waves == 1) {
-      wg_scan_result r = {
-         .reduction_result = wave_reduction,
-         .excl_scan_result = nir_build_mbcnt_amd(b, input_mask),
+      wg_repack_result r = {
+         .num_repacked_invocations = surviving_invocations_in_current_wave,
+         .repacked_invocation_index = nir_build_mbcnt_amd(b, input_mask),
       };
       return r;
    }
 
-   /* Number of LDS dwords written by all waves (if there is only 1, that is already handled above) */
-   unsigned num_lds_dwords = max_num_waves;
-   assert(num_lds_dwords >= 2 && num_lds_dwords <= 8);
+   /* STEP 2. Waves tell each other their number of surviving invocations.
+    *
+    * Each wave activates only its first lane (exec = 1), which stores the number of surviving
+    * invocations in that wave into the LDS, then reads the numbers from every wave.
+    *
+    * The workgroup size of NGG shaders is at most 256, which means
+    * the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode.
+    * Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary.
+    */
 
-   /* NIR doesn't have vec6 and vec7 so just use 8 for these cases. */
-   if (num_lds_dwords == 6 || num_lds_dwords == 7)
-      num_lds_dwords = 8;
+   const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4);
+   assert(num_lds_dwords <= 2);
 
    nir_ssa_def *wave_id = nir_build_load_subgroup_id(b);
-   nir_ssa_def *dont_care = nir_ssa_undef(b, num_lds_dwords, 32);
+   nir_ssa_def *dont_care = nir_ssa_undef(b, 1, num_lds_dwords * 32);
    nir_if *if_first_lane = nir_push_if(b, nir_build_elect(b, 1));
 
-   /* The first lane of each wave stores the result of its subgroup reduction to LDS (NGG scratch). */
-   nir_ssa_def *wave_id_lds_addr = nir_imul_imm(b, wave_id, 4u);
-   nir_build_store_shared(b, wave_reduction, wave_id_lds_addr, .base = lds_addr_base, .align_mul = 4u, .write_mask = 0x1u);
+   nir_build_store_shared(b, nir_u2u8(b, surviving_invocations_in_current_wave), wave_id, .base = lds_addr_base, .align_mul = 1u, .write_mask = 0x1u);
 
-   /* Workgroup barrier: wait for all waves to finish storing their result */
    nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
 
-   /* Only the first lane of each wave loads every wave's results from LDS, to avoid bank conflicts */
-   nir_ssa_def *reduction_vector = nir_build_load_shared(b, num_lds_dwords, 32, nir_imm_zero(b, 1, 32), .base = lds_addr_base, .align_mul = 16u);
+   nir_ssa_def *packed_counts = nir_build_load_shared(b, 1, num_lds_dwords * 32, nir_imm_int(b, 0), .base = lds_addr_base, .align_mul = 8u);
+
    nir_pop_if(b, if_first_lane);
 
-   reduction_vector = nir_if_phi(b, reduction_vector, dont_care);
+   packed_counts = nir_if_phi(b, packed_counts, dont_care);
 
-   nir_ssa_def *reduction_per_wave[8] = {0};
-   for (unsigned i = 0; i < num_lds_dwords; ++i) {
-      nir_ssa_def *reduction_wave_i = nir_channel(b, reduction_vector, i);
-      reduction_per_wave[i] = nir_build_read_first_invocation(b, reduction_wave_i);
-   }
+   /* STEP 3. Compute the repacked invocation index and the total number of surviving invocations.
+    *
+    * By now, every wave knows the number of surviving invocations in all waves.
+    * Each number is 1 byte, and they are packed into up to 2 dwords.
+    *
+    * Each lane N will sum the number of surviving invocations from waves 0 to N-1.
+    * If the workgroup has M waves, then each wave will use only its first M+1 lanes for this.
+    * (Other lanes are not deactivated but their calculation is not used.)
+    *
+    * - We read the sum from the lane whose id is the current wave's id.
+    *   Add the masked bitcount to this, and we get the repacked invocation index.
+    * - We read the sum from the lane whose id is the number of waves in the workgroup.
+    *   This is the total number of surviving invocations in the workgroup.
+    */
 
    nir_ssa_def *num_waves = nir_build_load_num_subgroups(b);
-   nir_ssa_def *wg_reduction = reduction_per_wave[0];
-   nir_ssa_def *wg_excl_scan_base = NULL;
-
-   for (unsigned i = 0; i < num_lds_dwords; ++i) {
-      /* Workgroup reduction:
-       * Add the reduction results from all waves up to and including wave_count.
-       */
-      if (i != 0) {
-         nir_ssa_def *should_add = nir_ige(b, num_waves, nir_imm_int(b, i + 1u));
-         nir_ssa_def *addition = nir_bcsel(b, should_add, reduction_per_wave[i], nir_imm_zero(b, 1, 32));
-         wg_reduction = nir_iadd_nuw(b, wg_reduction, addition);
-      }
 
-      /* Base of workgroup exclusive scan:
-       * Add the reduction results from waves up to and excluding wave_id_in_tg.
-       */
-      if (i != (num_lds_dwords - 1u)) {
-         nir_ssa_def *should_add = nir_ige(b, wave_id, nir_imm_int(b, i + 1u));
-         nir_ssa_def *addition = nir_bcsel(b, should_add, reduction_per_wave[i], nir_imm_zero(b, 1, 32));
-         wg_excl_scan_base = !wg_excl_scan_base ? addition : nir_iadd_nuw(b, wg_excl_scan_base, addition);
-      }
+   /* sel = 0x01010101 * lane_id + 0x03020100 */
+   nir_ssa_def *lane_id = nir_load_subgroup_invocation(b);
+   nir_ssa_def *packed_id = nir_build_byte_permute_amd(b, nir_imm_int(b, 0), lane_id, nir_imm_int(b, 0));
+   nir_ssa_def *sel = nir_iadd_imm_nuw(b, packed_id, 0x03020100);
+   nir_ssa_def *sum = NULL;
+
+   if (num_lds_dwords == 1) {
+      /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
+      nir_ssa_def *packed_dw = nir_build_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
+
+      /* Use byte-permute to filter out the bytes not needed by the current lane. */
+      nir_ssa_def *filtered_packed = nir_build_byte_permute_amd(b, packed_dw, nir_imm_int(b, 0), sel);
+
+      /* Horizontally add the packed bytes. */
+      sum = nir_sad_u8x4(b, filtered_packed, nir_imm_int(b, 0), nir_imm_int(b, 0));
+   } else if (num_lds_dwords == 2) {
+      /* Create selectors for the byte-permutes below. */
+      nir_ssa_def *dw0_selector = nir_build_lane_permute_16_amd(b, sel, nir_imm_int(b, 0x44443210), nir_imm_int(b, 0x4));
+      nir_ssa_def *dw1_selector = nir_build_lane_permute_16_amd(b, sel, nir_imm_int(b, 0x32100000), nir_imm_int(b, 0x4));
+
+      /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
+      nir_ssa_def *packed_dw0 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
+      nir_ssa_def *packed_dw1 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
+
+      /* Use byte-permute to filter out the bytes not needed by the current lane. */
+      nir_ssa_def *filtered_packed_dw0 = nir_build_byte_permute_amd(b, packed_dw0, nir_imm_int(b, 0), dw0_selector);
+      nir_ssa_def *filtered_packed_dw1 = nir_build_byte_permute_amd(b, packed_dw1, nir_imm_int(b, 0), dw1_selector);
+
+      /* Horizontally add the packed bytes. */
+      sum = nir_sad_u8x4(b, filtered_packed_dw0, nir_imm_int(b, 0), nir_imm_int(b, 0));
+      sum = nir_sad_u8x4(b, filtered_packed_dw1, nir_imm_int(b, 0), sum);
+   } else {
+      unreachable("Unimplemented NGG wave count");
    }
 
-   nir_ssa_def *sg_excl_scan = nir_build_mbcnt_amd(b, input_mask);
-   nir_ssa_def *wg_excl_scan = nir_iadd_nuw(b, wg_excl_scan_base, sg_excl_scan);
+   nir_ssa_def *wave_repacked_index = nir_build_mbcnt_amd(b, input_mask);
+   nir_ssa_def *wg_repacked_index_base = nir_build_read_invocation(b, sum, wave_id);
+   nir_ssa_def *wg_num_repacked_invocations = nir_build_read_invocation(b, sum, num_waves);
+   nir_ssa_def *wg_repacked_index = nir_iadd_nuw(b, wg_repacked_index_base, wave_repacked_index);
 
-   wg_scan_result r = {
-      .reduction_result = wg_reduction,
-      .excl_scan_result = wg_excl_scan,
+   wg_repack_result r = {
+      .num_repacked_invocations = wg_num_repacked_invocations,
+      .repacked_invocation_index = wg_repacked_index,
    };
 
    return r;
@@ -789,17 +814,16 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
       return;
    }
 
-   /* When the output is not known in compile time: there are gaps between the output vertices data in LDS.
-    * However, we need to make sure that the vertex exports are packed, meaning that there shouldn't be any gaps
-    * between the threads that perform the exports. We solve this using a perform a workgroup reduction + scan.
+   /* When the output vertex count is not known at compile time:
+    * There may be gaps between invocations that have live vertices, but NGG hardware
+    * requires that the invocations that export vertices are packed (ie. compact).
+    * To ensure this, we need to repack invocations that have a live vertex.
     */
    nir_ssa_def *vertex_live = nir_ine(b, out_vtx_primflag_0, nir_imm_zero(b, 1, out_vtx_primflag_0->bit_size));
-   wg_scan_result wg_scan = workgroup_reduce_and_exclusive_scan(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves);
+   wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves);
 
-   /* Reduction result = total number of vertices emitted in the workgroup. */
-   nir_ssa_def *workgroup_num_vertices = wg_scan.reduction_result;
-   /* Exclusive scan result = the index of the thread in the workgroup that will export the current thread's vertex. */
-   nir_ssa_def *exporter_tid_in_tg = wg_scan.excl_scan_result;
+   nir_ssa_def *workgroup_num_vertices = rep.num_repacked_invocations;
+   nir_ssa_def *exporter_tid_in_tg = rep.repacked_invocation_index;
 
    /* When the workgroup emits 0 total vertices, we also must export 0 primitives (otherwise the HW can hang). */
    nir_ssa_def *any_output = nir_ine(b, workgroup_num_vertices, nir_imm_int(b, 0));
@@ -836,13 +860,13 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
    lower_ngg_gs_state state = {
       .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
       .lds_addr_gs_out_vtx = esgs_ring_lds_bytes,
-      .lds_addr_gs_scratch = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 16u),
+      .lds_addr_gs_scratch = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 8u /* for the repacking code */),
       .lds_offs_primflags = gs_out_vtx_bytes,
       .lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,
       .provoking_vertex_last = provoking_vertex_last,
    };
 
-   unsigned lds_scratch_bytes = state.max_num_waves * 4u;
+   unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u;
    unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes;
    shader->info.shared_size = total_lds_bytes;
 



More information about the mesa-commit mailing list