Mesa (main): ac/nir: Properly handle when mesh API workgroup size is smaller than HW.

GitLab Mirror gitlab-mirror at
Fri Feb 25 06:51:17 UTC 2022

Module: Mesa
Branch: main
Commit: 0746b98f4a7fd0c4b332a084ba4461724a31fa5f

Author: Timur Kristóf <timur.kristof at>
Date:   Thu Feb 10 23:55:51 2022 +0100

ac/nir: Properly handle when mesh API workgroup size is smaller than HW.

The problem is that the real workgroup launched on NGG HW
can be larger than the size specified by the API, and the
extra waves need to keep up with barriers in the API waves.

There are 2 different cases:

1. The whole API workgroup fits in a single wave.
   We can shrink the barriers to subgroup scope and
   don't need to insert any extra ones.

2. The API workgroup occupies multiple waves, but not
   all. In this case, we emit code that consumes every
   barrier on the extra waves.

Signed-off-by: Timur Kristóf <timur.kristof at>
Reviewed-by: Rhys Perry <pendingchaos02 at>
Part-of: <>


 src/amd/common/ac_nir_lower_ngg.c | 191 +++++++++++++++++++++++++++++++++-----
 1 file changed, 169 insertions(+), 22 deletions(-)

diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c
index ca7498ea1c7..583a1c1b222 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -2417,13 +2417,153 @@ emit_ms_finale(nir_shader *shader, lower_ngg_ms_state *s)
    nir_pop_if(b, if_has_output_primitive);
+static void
+handle_smaller_ms_api_workgroup(nir_function_impl *impl,
+                                nir_builder *b,
+                                unsigned api_workgroup_size,
+                                unsigned hw_workgroup_size,
+                                lower_ngg_ms_state *s)
+   /* Handle barriers manually when the API workgroup
+    * size is less than the HW workgroup size.
+    *
+    * The problem is that the real workgroup launched on NGG HW
+    * will be larger than the size specified by the API, and the
+    * extra waves need to keep up with barriers in the API waves.
+    *
+    * There are 2 different cases:
+    * 1. The whole API workgroup fits in a single wave.
+    *    We can shrink the barriers to subgroup scope and
+    *    don't need to insert any extra ones.
+    * 2. The API workgroup occupies multiple waves, but not
+    *    all. In this case, we emit code that consumes every
+    *    barrier on the extra waves.
+    */
+   assert(hw_workgroup_size % s->wave_size == 0);
+   bool scan_barriers = ALIGN(api_workgroup_size, s->wave_size) < hw_workgroup_size;
+   bool can_shrink_barriers = api_workgroup_size <= s->wave_size;
+   bool need_additional_barriers = scan_barriers && !can_shrink_barriers;
+   unsigned api_waves_in_flight_addr = s->numprims_lds_addr + 12;
+   unsigned num_api_waves = DIV_ROUND_UP(api_workgroup_size, s->wave_size);
+   /* Scan the shader for workgroup barriers. */
+   if (scan_barriers) {
+      bool has_any_workgroup_barriers = false;
+      nir_foreach_block(block, impl) {
+         nir_foreach_instr_safe(instr, block) {
+            if (instr->type != nir_instr_type_intrinsic)
+               continue;
+            nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+            bool is_workgroup_barrier =
+               intrin->intrinsic == nir_intrinsic_scoped_barrier &&
+               nir_intrinsic_execution_scope(intrin) == NIR_SCOPE_WORKGROUP;
+            if (!is_workgroup_barrier)
+               continue;
+            if (can_shrink_barriers) {
+               /* Every API invocation runs in the first wave.
+                * In this case, we can change the barriers to subgroup scope
+                * and avoid adding additional barriers.
+                */
+               nir_intrinsic_set_memory_scope(intrin, NIR_SCOPE_SUBGROUP);
+               nir_intrinsic_set_execution_scope(intrin, NIR_SCOPE_SUBGROUP);
+            } else {
+               has_any_workgroup_barriers = true;
+            }
+         }
+      }
+      need_additional_barriers &= has_any_workgroup_barriers;
+   }
+   /* Extract the full control flow of the shader. */
+   nir_cf_list extracted;
+   nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
+   b->cursor = nir_before_cf_list(&impl->body);
+   /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */
+   nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
+   nir_ssa_def *zero = nir_imm_int(b, 0);
+   if (need_additional_barriers) {
+      /* First invocation stores 0 to number of API waves in flight. */
+      nir_if *if_first_in_workgroup = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
+      {
+         nir_store_shared(b, nir_imm_int(b, num_api_waves), zero, .base = api_waves_in_flight_addr);
+      }
+      nir_pop_if(b, if_first_in_workgroup);
+      nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
+                            .memory_scope = NIR_SCOPE_WORKGROUP,
+                            .memory_semantics = NIR_MEMORY_ACQ_REL,
+                            .memory_modes = nir_var_shader_out | nir_var_mem_shared);
+   }
+   nir_ssa_def *has_api_ms_invocation = nir_ult(b, invocation_index, nir_imm_int(b, api_workgroup_size));
+   nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation);
+   {
+      nir_cf_reinsert(&extracted, b->cursor);
+      b->cursor = nir_after_cf_list(&if_has_api_ms_invocation->then_list);
+      if (need_additional_barriers) {
+         /* One invocation in each API wave decrements the number of API waves in flight. */
+         nir_if *if_elected_again = nir_push_if(b, nir_elect(b, 1));
+         {
+            nir_shared_atomic_add(b, 32, zero, nir_imm_int(b, -1u), .base = api_waves_in_flight_addr);
+         }
+         nir_pop_if(b, if_elected_again);
+         nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
+                               .memory_scope = NIR_SCOPE_WORKGROUP,
+                               .memory_semantics = NIR_MEMORY_ACQ_REL,
+                               .memory_modes = nir_var_shader_out | nir_var_mem_shared);
+      }
+   }
+   nir_pop_if(b, if_has_api_ms_invocation);
+   if (need_additional_barriers) {
+      /* Make sure that waves that don't run any API invocations execute
+       * the same amount of barriers as those that do.
+       *
+       * We do this by executing a barrier until the number of API waves
+       * in flight becomes zero.
+       */
+      nir_ssa_def *has_api_ms_ballot = nir_ballot(b, 1, s->wave_size, has_api_ms_invocation);
+      nir_ssa_def *wave_has_no_api_ms = nir_ieq_imm(b, has_api_ms_ballot, 0);
+      nir_if *if_wave_has_no_api_ms = nir_push_if(b, wave_has_no_api_ms);
+      {
+         nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
+         {
+            nir_loop *loop = nir_push_loop(b);
+            {
+               nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
+                                     .memory_scope = NIR_SCOPE_WORKGROUP,
+                                     .memory_semantics = NIR_MEMORY_ACQ_REL,
+                                     .memory_modes = nir_var_shader_out | nir_var_mem_shared);
+               nir_ssa_def *loaded = nir_load_shared(b, 1, 32, zero, .base = api_waves_in_flight_addr);
+               nir_if *if_break = nir_push_if(b, nir_ieq_imm(b, loaded, 0));
+               {
+                  nir_jump(b, nir_jump_break);
+               }
+               nir_pop_if(b, if_break);
+            }
+            nir_pop_loop(b, loop);
+         }
+         nir_pop_if(b, if_elected);
+      }
+      nir_pop_if(b, if_wave_has_no_api_ms);
+   }
 ac_nir_lower_ngg_ms(nir_shader *shader,
                     unsigned wave_size)
-   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
-   assert(impl);
    unsigned vertices_per_prim =
@@ -2436,9 +2576,14 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
    unsigned max_vertices = shader->info.mesh.max_vertices_out;
    unsigned max_primitives = shader->info.mesh.max_primitives_out;
-   /* LDS area for total number of output primitives. */
+   /* LDS area for total number of output primitives and other info.
+    * DW0: number of primitives
+    * DW1: reserved for later use
+    * DW2: reserved for later use
+    * DW3: number of API workgroups in flight
+    */
    unsigned numprims_lds_addr = ALIGN(shader->info.shared_size, 16);
-   unsigned numprims_lds_size = 4;
+   unsigned numprims_lds_size = 16;
    /* LDS area for vertex attributes */
    unsigned vertex_attr_lds_addr = ALIGN(numprims_lds_addr + numprims_lds_size, 16);
    unsigned vertex_attr_lds_size = max_vertices * num_per_vertex_outputs * 16;
@@ -2464,29 +2609,31 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
       .numprims_lds_addr = numprims_lds_addr,
-   /* Extract the full control flow. It is going to be wrapped in an if statement. */
-   nir_cf_list extracted;
-   nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
+   /* The workgroup size that is specified by the API shader may be different
+    * from the size of the workgroup that actually runs on the HW, due to the
+    * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed.
+    *
+    * Therefore, we must make sure that when the API workgroup size is smaller,
+    * we don't run the API shader on more HW invocations than is necessary.
+    */
+   unsigned api_workgroup_size = shader->info.workgroup_size[0] *
+                                 shader->info.workgroup_size[1] *
+                                 shader->info.workgroup_size[2];
+   unsigned hw_workgroup_size =
+      ALIGN(MAX3(api_workgroup_size, max_primitives, max_vertices), wave_size);
+   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
+   assert(impl);
    nir_builder builder;
    nir_builder *b = &builder; /* This is to avoid the & */
    nir_builder_init(b, impl);
    b->cursor = nir_before_cf_list(&impl->body);
-   /* There may be a difference between MS workgroup size and the
-    * number of output vertices/primitives. So it is possible that the actual H
-    * workgroup is larger than what the user wants.
-    * So, only execute the API shader for invocations that the user needs.
-    */
-   unsigned num_ms_invocations = b->shader->info.workgroup_size[0] *
-                                 b->shader->info.workgroup_size[1] *
-                                 b->shader->info.workgroup_size[2];
-   nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
-   nir_ssa_def *has_ms_invocation = nir_ult(b, invocation_index, nir_imm_int(b, num_ms_invocations));
-   nir_if *if_has_ms_invocation = nir_push_if(b, has_ms_invocation);
-   nir_cf_reinsert(&extracted, b->cursor);
-   b->cursor = nir_after_cf_list(&if_has_ms_invocation->then_list);
-   nir_pop_if(b, if_has_ms_invocation);
+   if (api_workgroup_size < hw_workgroup_size) {
+      handle_smaller_ms_api_workgroup(impl, b, api_workgroup_size, hw_workgroup_size, &state);
+   }
    lower_ms_intrinsics(shader, &state);
    emit_ms_finale(shader, &state);

More information about the mesa-commit mailing list