Mesa (main): ac/nir: Use a ballot that matches the wave size during NGG lowering.
GitLab Mirror
gitlab-mirror at kemper.freedesktop.org
Wed Jul 14 00:44:52 UTC 2021
Module: Mesa
Branch: main
Commit: 556a690bac3a48f7f1c0627f3fc4caf8f21d0f89
URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=556a690bac3a48f7f1c0627f3fc4caf8f21d0f89
Author: Timur Kristóf <timur.kristof at gmail.com>
Date: Mon Jul 5 11:19:14 2021 +0200
ac/nir: Use a ballot that matches the wave size during NGG lowering.
This generates slightly more efficient code in Wave32 mode.
Signed-off-by: Timur Kristóf <timur.kristof at gmail.com>
Reviewed-by: Daniel Schürmann <daniel at schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/10525>
---
src/amd/common/ac_nir_lower_ngg.c | 15 ++++++++++-----
1 file changed, 10 insertions(+), 5 deletions(-)
diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c
index 7af8708966b..9cc7687a352 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -34,6 +34,7 @@ typedef struct
bool passthrough;
bool export_prim_id;
bool early_prim_export;
+ unsigned wave_size;
unsigned max_num_waves;
unsigned num_vertices_per_primitives;
unsigned provoking_vtx_idx;
@@ -55,6 +56,7 @@ typedef struct
nir_variable *current_clear_primflag_idx_var;
int const_out_vtxcnt[4];
int const_out_prmcnt[4];
+ unsigned wave_size;
unsigned max_num_waves;
unsigned num_vertices_per_primitive;
unsigned lds_addr_gs_out_vtx;
@@ -80,17 +82,18 @@ typedef struct {
*/
static wg_repack_result
repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool,
- unsigned lds_addr_base, unsigned max_num_waves)
+ unsigned lds_addr_base, unsigned max_num_waves,
+ unsigned wave_size)
{
/* Input boolean: 1 if the current invocation should survive the repack. */
assert(input_bool->bit_size == 1);
/* STEP 1. Count surviving invocations in the current wave.
*
- * Implemented by a scalar instruction that simply counts the number of bits set in a 64-bit mask.
+ * Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.
*/
- nir_ssa_def *input_mask = nir_build_ballot(b, 1, 64, input_bool);
+ nir_ssa_def *input_mask = nir_build_ballot(b, 1, wave_size, input_bool);
nir_ssa_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask);
/* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
@@ -342,6 +345,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
.prim_exp_arg_var = prim_exp_arg_var,
.max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
.max_es_num_vertices = max_num_es_vertices,
+ .wave_size = wave_size,
};
/* We need LDS space when VS needs to export the primitive ID. */
@@ -488,7 +492,7 @@ ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_st
unsigned gs_vtx_cnt = nir_src_as_uint(intrin->src[0]);
unsigned gs_prm_cnt = nir_src_as_uint(intrin->src[1]);
unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u);
- nir_ssa_def *num_threads = nir_bit_count(b, nir_build_ballot(b, 1, 64, nir_imm_bool(b, true)));
+ nir_ssa_def *num_threads = nir_bit_count(b, nir_build_ballot(b, 1, s->wave_size, nir_imm_bool(b, true)));
num_prims_in_wave = nir_imul_imm(b, num_threads, total_prm_cnt);
} else {
nir_ssa_def *gs_vtx_cnt = intrin->src[0].ssa;
@@ -819,7 +823,7 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
* To ensure this, we need to repack invocations that have a live vertex.
*/
nir_ssa_def *vertex_live = nir_ine(b, out_vtx_primflag_0, nir_imm_zero(b, 1, out_vtx_primflag_0->bit_size));
- wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves);
+ wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves, s->wave_size);
nir_ssa_def *workgroup_num_vertices = rep.num_repacked_invocations;
nir_ssa_def *exporter_tid_in_tg = rep.repacked_invocation_index;
@@ -858,6 +862,7 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
lower_ngg_gs_state state = {
.max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
+ .wave_size = wave_size,
.lds_addr_gs_out_vtx = esgs_ring_lds_bytes,
.lds_addr_gs_scratch = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 8u /* for the repacking code */),
.lds_offs_primflags = gs_out_vtx_bytes,
More information about the mesa-commit
mailing list