Mesa (master): aco: Optimize workgroup exclusive scan to better avoid bank conflicts.

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Apr 14 14:15:10 UTC 2021


Module: Mesa
Branch: master
Commit: c1346e5c2202902a2359d928634e41dff4d2eb64
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=c1346e5c2202902a2359d928634e41dff4d2eb64

Author: Timur Kristóf <timur.kristof at gmail.com>
Date:   Sat Apr 10 14:52:55 2021 +0200

aco: Optimize workgroup exclusive scan to better avoid bank conflicts.

Previously, every wave had multiple active lanes read the LDS, and
the data was processed by VALU DPP instructions.

Now, only the first lane reads the LDS in order to avoid bank
conflicts, and the results are processed by SALU.

Signed-off-by: Timur Kristóf <timur.kristof at gmail.com>
Reviewed-by: Daniel Schürmann <daniel at schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/10155>

---

 src/amd/compiler/aco_instruction_selection.cpp | 70 +++++++++++++++-----------
 1 file changed, 42 insertions(+), 28 deletions(-)

diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp
index 3472aa3968f..75b9fff0159 100644
--- a/src/amd/compiler/aco_instruction_selection.cpp
+++ b/src/amd/compiler/aco_instruction_selection.cpp
@@ -11400,50 +11400,64 @@ std::pair<Temp, Temp> ngg_gs_workgroup_reduce_and_scan(isel_context *ctx, Temp s
    Temp wave_id_in_tg_lds_addr = bld.vop2_e64(aco_opcode::v_lshlrev_b32, bld.def(v1), Operand(2u), wave_id_in_tg);
    store_lds(ctx, 4u, as_vgpr(ctx, sg_reduction), 0x1u, wave_id_in_tg_lds_addr, ctx->ngg_gs_scratch_addr, 4u);
 
-   begin_divergent_if_else(ctx, &ic);
-   end_divergent_if(ctx, &ic);
-   bld.reset(ctx->block);
-
    /* Wait for all waves to write to LDS. */
    create_workgroup_barrier(bld);
 
-   /* Activate one lane per wave. */
-   Temp wave_count = wave_count_in_threadgroup(ctx);
-   Temp wave_count_mask = lanecount_to_mask(ctx, wave_count, false);
-   begin_divergent_if_then(ctx, &ic, wave_count_mask);
-   bld.reset(ctx->block);
-
-   /* Each lane loads the reduction result from the corresponding wave. */
-   Temp thread_id_in_wave = emit_mbcnt(ctx, bld.tmp(v1));
-   Temp loaded_wave_id_lds_addr = bld.v_mul24_imm(bld.def(v1), thread_id_in_wave, 4u);
-   Temp red_per_w = load_lds(ctx, 4u, bld.tmp(v1), loaded_wave_id_lds_addr, ctx->ngg_gs_scratch_addr, 4u);
+   /* Number of LDS dwords written by all waves (if there is only 1, that is already handled above) */
+   unsigned num_lds_dwords = DIV_ROUND_UP(MIN2(ctx->program->workgroup_size, 256), ctx->program->wave_size);
+   assert(num_lds_dwords >= 2 && num_lds_dwords <= 8);
 
-   /* Inclusive scan on the per-wave reduction results, only care about the first 8 lanes. */
-   Temp sgincl = bld.vop2_dpp(aco_opcode::v_add_u32, bld.def(v1), red_per_w, red_per_w, dpp_row_sr(1), 0b0001, 0b0111, true);
-   sgincl = bld.vop2_dpp(aco_opcode::v_add_u32, bld.def(v1), sgincl, sgincl, dpp_row_sr(2), 0x1, 0xf, true);
-   sgincl = bld.vop2_dpp(aco_opcode::v_add_u32, bld.def(v1), sgincl, sgincl, dpp_row_sr(4), 0x1, 0xf, true);
+   /* The first lane of each wave loads every wave's results from LDS, to avoid bank conflicts */
+   Temp reduction_per_wave_vector = load_lds(ctx, 4u * num_lds_dwords, bld.tmp(RegClass(RegType::vgpr, num_lds_dwords)),
+                                             bld.copy(bld.def(v1), Operand(0u)), ctx->ngg_gs_scratch_addr, 4u);
 
    begin_divergent_if_else(ctx, &ic);
    end_divergent_if(ctx, &ic);
+   bld.reset(ctx->block);
 
-   /* Create phi which gets us the above reduction results, or undef. */
+   /* Create phis which get us the above reduction results, or undef. */
    bld.reset(&ctx->block->instructions, ctx->block->instructions.begin());
-   sgincl = bld.pseudo(aco_opcode::p_phi, bld.def(sgincl.regClass()), sgincl, Operand(v1));
+   reduction_per_wave_vector = bld.pseudo(aco_opcode::p_phi, bld.def(reduction_per_wave_vector.regClass()), reduction_per_wave_vector, Operand(reduction_per_wave_vector.regClass()));
    bld.reset(ctx->block);
 
-   /* Make it an exclusive scan by shifting the results right by one lane. */
-   Temp per_wave_excl = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), sgincl, dpp_row_sr(1), 0x1, 0xf, true);
+   emit_split_vector(ctx, reduction_per_wave_vector, num_lds_dwords);
+   Temp reduction_per_wave[8];
+
+   for (unsigned i = 0; i < num_lds_dwords; ++i) {
+      Temp reduction_current_wave = emit_extract_vector(ctx, reduction_per_wave_vector, i, v1);
+      reduction_per_wave[i] = bld.readlane(bld.def(s1), reduction_current_wave, Operand(0u));
+   }
+
+   Temp wave_count = wave_count_in_threadgroup(ctx);
+   Temp reduction_result = reduction_per_wave[0];
+   Temp excl_base;
 
-   /* WG reduction result: the last lane of the above exclusive scan. */
-   Temp wg_reduction = bld.readlane(bld.def(s1), per_wave_excl, wave_count);
+   for (unsigned i = 0; i < num_lds_dwords; ++i) {
+      /* Workgroup reduction:
+       * Add the reduction results from all waves (up to and including wave_count).
+       */
+      if (i != 0) {
+         Temp should_add = bld.sopc(aco_opcode::s_cmp_ge_u32, bld.def(s1, scc), wave_count, Operand(i + 1u));
+         Temp addition = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), reduction_per_wave[i], Operand(0u), bld.scc(should_add));
+         reduction_result = bld.sop2(aco_opcode::s_add_u32, bld.def(s1), bld.def(s1, scc), reduction_result, addition);
+      }
+
+      /* Base of workgroup exclusive scan:
+       * Add the reduction results from waves up to and excluding wave_id_in_tg.
+       */
+      if (i != (num_lds_dwords - 1)) {
+         Temp should_add = bld.sopc(aco_opcode::s_cmp_ge_u32, bld.def(s1, scc), wave_id_in_tg, Operand(i + 1u));
+         Temp addition = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), reduction_per_wave[i], Operand(0u), bld.scc(should_add));
+         excl_base = !excl_base.id() ? addition : bld.sop2(aco_opcode::s_add_u32, bld.def(s1), bld.def(s1, scc), excl_base, addition);
+      }
+   }
 
-   /* Base of the exclusive WG scan: the above exclusive result corresponding to the current wave. */
-   Temp wg_excl_base = bld.readlane(bld.def(s1), per_wave_excl, wave_id_in_tg);
+   assert(excl_base.id());
 
    /* WG exclusive scan result: base + subgroup exclusive result. */
-   Temp wg_excl = bld.vadd32(bld.def(v1), Operand(wg_excl_base), Operand(sg_excl));
+   Temp wg_excl = bld.vadd32(bld.def(v1), Operand(excl_base), Operand(sg_excl));
 
-   return std::make_pair(wg_reduction, wg_excl);
+   return std::make_pair(reduction_result, wg_excl);
 }
 
 void ngg_gs_clear_primflags(isel_context *ctx, Temp vtx_cnt, unsigned stream)



More information about the mesa-commit mailing list