Mesa (main): ac/nir: Implement NGG deferred attribute culling in NIR.

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Jul 14 00:44:52 UTC 2021


Module: Mesa
Branch: main
Commit: e97f0463a8f55d5d407178f74b0cdb916a42aef8
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=e97f0463a8f55d5d407178f74b0cdb916a42aef8

Author: Timur Kristóf <timur.kristof at gmail.com>
Date:   Mon Apr 26 16:56:11 2021 +0200

ac/nir: Implement NGG deferred attribute culling in NIR.

Culling is traditionally done by the rasterizer, but that
can be a bottleneck when an app creates a large number
of primitives. Eg. a lot of tiny triangles reduce the
rasterziation efficiency.

NGG makes it possible for the shader to check primitives
and delete those that it can prove are not needed.

After this is done, we have to repack the surviving invocations
so they remain compact. This also saves bandwidth, because
some memory loads are only executed by those vertices that
survived the culling.

Signed-off-by: Timur Kristóf <timur.kristof at gmail.com>
Reviewed-by: Daniel Schürmann <daniel at schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/10525>

---

 src/amd/common/ac_nir.h           |   1 +
 src/amd/common/ac_nir_lower_ngg.c | 568 +++++++++++++++++++++++++++++++++++++-
 2 files changed, 557 insertions(+), 12 deletions(-)

diff --git a/src/amd/common/ac_nir.h b/src/amd/common/ac_nir.h
index 1ce615b859b..6a32e64e159 100644
--- a/src/amd/common/ac_nir.h
+++ b/src/amd/common/ac_nir.h
@@ -93,6 +93,7 @@ ac_nir_lower_indirect_derefs(nir_shader *shader,
 
 typedef struct
 {
+   unsigned lds_bytes_if_culling_off;
    bool can_cull;
    bool passthrough;
 } ac_nir_ngg_config;
diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c
index 9cc7687a352..90dba07bb03 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -30,6 +30,8 @@ typedef struct
 {
    nir_variable *position_value_var;
    nir_variable *prim_exp_arg_var;
+   nir_variable *es_accepted_var;
+   nir_variable *gs_accepted_var;
 
    bool passthrough;
    bool export_prim_id;
@@ -69,6 +71,46 @@ typedef struct
    gs_output_component_info output_component_info[VARYING_SLOT_MAX][4];
 } lower_ngg_gs_state;
 
+typedef struct {
+   nir_variable *pre_cull_position_value_var;
+} remove_culling_shader_outputs_state;
+
+typedef struct {
+   nir_variable *pos_value_replacement;
+} remove_extra_position_output_state;
+
+typedef struct {
+   nir_ssa_def *reduction_result;
+   nir_ssa_def *excl_scan_result;
+} wg_scan_result;
+
+/* Per-vertex LDS layout of culling shaders */
+enum {
+   /* Position of the ES vertex (at the beginning for alignment reasons) */
+   lds_es_pos_x = 0,
+   lds_es_pos_y = 4,
+   lds_es_pos_z = 8,
+   lds_es_pos_w = 12,
+
+   /* 1 when the vertex is accepted, 0 if it should be culled */
+   lds_es_vertex_accepted = 16,
+   /* ID of the thread which will export the current thread's vertex */
+   lds_es_exporter_tid = 17,
+
+   /* Repacked arguments - also listed separately for VS and TES */
+   lds_es_arg_0 = 20,
+
+   /* VS arguments which need to be repacked */
+   lds_es_vs_vertex_id = 20,
+   lds_es_vs_instance_id = 24,
+
+   /* TES arguments which need to be repacked */
+   lds_es_tes_u = 20,
+   lds_es_tes_v = 24,
+   lds_es_tes_rel_patch_id = 28,
+   lds_es_tes_patch_id = 32,
+};
+
 typedef struct {
    nir_ssa_def *num_repacked_invocations;
    nir_ssa_def *repacked_invocation_index;
@@ -313,6 +355,467 @@ emit_store_ngg_nogs_es_primitive_id(nir_builder *b)
                           .write_mask = 1u, .src_type = nir_type_uint32, .io_semantics = io_sem);
 }
 
+static bool
+remove_culling_shader_output(nir_builder *b, nir_instr *instr, void *state)
+{
+   remove_culling_shader_outputs_state *s = (remove_culling_shader_outputs_state *) state;
+
+   if (instr->type != nir_instr_type_intrinsic)
+      return false;
+
+   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+
+   /* These are not allowed in VS / TES */
+   assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
+          intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
+
+   /* We are only interested in output stores now */
+   if (intrin->intrinsic != nir_intrinsic_store_output)
+      return false;
+
+   b->cursor = nir_before_instr(instr);
+
+   /* Position output - store the value to a variable, remove output store */
+   nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
+   if (io_sem.location == VARYING_SLOT_POS) {
+      /* TODO: check if it's indirect, etc? */
+      unsigned writemask = nir_intrinsic_write_mask(intrin);
+      nir_ssa_def *store_val = intrin->src[0].ssa;
+      nir_store_var(b, s->pre_cull_position_value_var, store_val, writemask);
+   }
+
+   /* Remove all output stores */
+   nir_instr_remove(instr);
+   return true;
+}
+
+static void
+remove_culling_shader_outputs(nir_shader *culling_shader, lower_ngg_nogs_state *nogs_state, nir_variable *pre_cull_position_value_var)
+{
+   remove_culling_shader_outputs_state s = {
+      .pre_cull_position_value_var = pre_cull_position_value_var,
+   };
+
+   nir_shader_instructions_pass(culling_shader, remove_culling_shader_output,
+                                nir_metadata_block_index | nir_metadata_dominance, &s);
+
+   /* Remove dead code resulting from the deleted outputs. */
+   bool progress;
+   do {
+      progress = false;
+      NIR_PASS(progress, culling_shader, nir_opt_dead_write_vars);
+      NIR_PASS(progress, culling_shader, nir_opt_dce);
+      NIR_PASS(progress, culling_shader, nir_opt_dead_cf);
+   } while (progress);
+}
+
+static bool
+remove_extra_pos_output(nir_builder *b, nir_instr *instr, void *state)
+{
+   remove_extra_position_output_state *s = (remove_extra_position_output_state *) state;
+
+   if (instr->type != nir_instr_type_intrinsic)
+      return false;
+
+   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+
+   /* These are not allowed in VS / TES */
+   assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
+          intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
+
+   /* We are only interested in output stores now */
+   if (intrin->intrinsic != nir_intrinsic_store_output)
+      return false;
+
+   nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
+   if (io_sem.location != VARYING_SLOT_POS)
+      return false;
+
+   b->cursor = nir_before_instr(instr);
+
+   /* TODO: in case other outputs use what we calculated for pos, rewrite the usages of the store components here */
+
+   nir_instr_remove(instr);
+   return true;
+}
+
+static void
+remove_extra_pos_outputs(nir_shader *shader, lower_ngg_nogs_state *nogs_state)
+{
+   remove_extra_position_output_state s = {
+      .pos_value_replacement = nogs_state->position_value_var,
+   };
+
+   nir_shader_instructions_pass(shader, remove_extra_pos_output,
+                                nir_metadata_block_index | nir_metadata_dominance, &s);
+}
+
+/**
+ * Perform vertex compaction after culling.
+ *
+ * 1. Repack surviving ES invocations (this determines which lane will export which vertex)
+ * 2. Surviving ES vertex invocations store their data to LDS
+ * 3. Emit GS_ALLOC_REQ
+ * 4. Repacked invocations load the vertex data from LDS
+ * 5. GS threads update their vertex indices
+ */
+static void
+compact_vertices_after_culling(nir_builder *b,
+                               lower_ngg_nogs_state *nogs_state,
+                               nir_variable *vertices_in_wave_var,
+                               nir_variable *primitives_in_wave_var,
+                               nir_variable **repacked_arg_vars,
+                               nir_variable **gs_vtxaddr_vars,
+                               nir_ssa_def *invocation_index,
+                               nir_ssa_def *es_vertex_lds_addr,
+                               unsigned ngg_scratch_lds_base_addr,
+                               unsigned pervertex_lds_bytes,
+                               unsigned max_exported_args)
+{
+   nir_variable *es_accepted_var = nogs_state->es_accepted_var;
+   nir_variable *gs_accepted_var = nogs_state->gs_accepted_var;
+   nir_variable *position_value_var = nogs_state->position_value_var;
+   nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var;
+
+   nir_ssa_def *es_accepted = nir_load_var(b, es_accepted_var);
+
+   /* Repack the vertices that survived the culling. */
+   wg_repack_result rep = repack_invocations_in_workgroup(b, es_accepted, ngg_scratch_lds_base_addr,
+                                                          nogs_state->max_num_waves, nogs_state->wave_size);
+   nir_ssa_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations;
+   nir_ssa_def *es_exporter_tid = rep.repacked_invocation_index;
+
+   nir_if *if_es_accepted = nir_push_if(b, es_accepted);
+   {
+      nir_ssa_def *exporter_addr = pervertex_lds_addr(b, es_exporter_tid, pervertex_lds_bytes);
+
+      /* Store the exporter thread's index to the LDS space of the current thread so GS threads can load it */
+      nir_build_store_shared(b, nir_u2u8(b, es_exporter_tid), es_vertex_lds_addr, .base = lds_es_exporter_tid, .align_mul = 1u, .write_mask = 0x1u);
+
+      /* Store the current thread's position output to the exporter thread's LDS space */
+      nir_ssa_def *pos = nir_load_var(b, position_value_var);
+      nir_build_store_shared(b, pos, exporter_addr, .base = lds_es_pos_x, .align_mul = 4u, .write_mask = 0xfu);
+
+      /* Store the current thread's repackable arguments to the exporter thread's LDS space */
+      for (unsigned i = 0; i < max_exported_args; ++i) {
+         nir_ssa_def *arg_val = nir_load_var(b, repacked_arg_vars[i]);
+         nir_build_store_shared(b, arg_val, exporter_addr, .base = lds_es_arg_0 + 4u * i, .align_mul = 4u, .write_mask = 0x1u);
+      }
+   }
+   nir_pop_if(b, if_es_accepted);
+
+   /* If all vertices are culled, set primitive count to 0 as well. */
+   nir_ssa_def *num_exported_prims = nir_build_load_workgroup_num_input_primitives_amd(b);
+   num_exported_prims = nir_bcsel(b, nir_ieq_imm(b, num_live_vertices_in_workgroup, 0u), nir_imm_int(b, 0u), num_exported_prims);
+
+   nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_int(b, 0)));
+   {
+      /* Tell the final vertex and primitive count to the HW.
+       * We do this here to mask some of the latency of the LDS.
+       */
+      nir_build_alloc_vertices_and_primitives_amd(b, num_live_vertices_in_workgroup, num_exported_prims);
+   }
+   nir_pop_if(b, if_wave_0);
+
+   /* Calculate the number of vertices and primitives left in the current wave */
+   nir_ssa_def *has_vtx_after_culling = nir_ilt(b, invocation_index, num_live_vertices_in_workgroup);
+   nir_ssa_def *has_prm_after_culling = nir_ilt(b, invocation_index, num_exported_prims);
+   nir_ssa_def *vtx_cnt = nir_bit_count(b, nir_build_ballot(b, 1, nogs_state->wave_size, has_vtx_after_culling));
+   nir_ssa_def *prm_cnt = nir_bit_count(b, nir_build_ballot(b, 1, nogs_state->wave_size, has_prm_after_culling));
+   nir_store_var(b, vertices_in_wave_var, vtx_cnt, 0x1u);
+   nir_store_var(b, primitives_in_wave_var, prm_cnt, 0x1u);
+
+   /* TODO: Consider adding a shortcut exit.
+    * Waves that have no vertices and primitives left can s_endpgm right here.
+    */
+
+   nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
+                         .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
+
+   nir_if *if_packed_es_thread = nir_push_if(b, nir_ilt(b, invocation_index, num_live_vertices_in_workgroup));
+   {
+      /* Read position from the current ES thread's LDS space (written by the exported vertex's ES thread) */
+      nir_ssa_def *exported_pos = nir_build_load_shared(b, 4, 32, es_vertex_lds_addr, .base = lds_es_pos_x, .align_mul = 4u);
+      nir_store_var(b, position_value_var, exported_pos, 0xfu);
+
+      /* Read the repacked arguments */
+      for (unsigned i = 0; i < max_exported_args; ++i) {
+         nir_ssa_def *arg_val = nir_build_load_shared(b, 1, 32, es_vertex_lds_addr, .base = lds_es_arg_0 + 4u * i, .align_mul = 4u);
+         nir_store_var(b, repacked_arg_vars[i], arg_val, 0x1u);
+      }
+   }
+   nir_pop_if(b, if_packed_es_thread);
+
+   nir_if *if_gs_accepted = nir_push_if(b, nir_load_var(b, gs_accepted_var));
+   {
+      nir_ssa_def *exporter_vtx_indices[3] = {0};
+
+      /* Load the index of the ES threads that will export the current GS thread's vertices */
+      for (unsigned v = 0; v < 3; ++v) {
+         nir_ssa_def *vtx_addr = nir_load_var(b, gs_vtxaddr_vars[v]);
+         nir_ssa_def *exporter_vtx_idx = nir_build_load_shared(b, 1, 8, vtx_addr, .base = lds_es_exporter_tid, .align_mul = 1u);
+         exporter_vtx_indices[v] = nir_u2u32(b, exporter_vtx_idx);
+      }
+
+      nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, 3, exporter_vtx_indices, NULL);
+      nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u);
+   }
+   nir_pop_if(b, if_gs_accepted);
+}
+
+static void
+add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_cf, lower_ngg_nogs_state *nogs_state)
+{
+   assert(b->shader->info.outputs_written & (1 << VARYING_SLOT_POS));
+
+   bool uses_instance_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_INSTANCE_ID);
+   bool uses_tess_primitive_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_PRIMITIVE_ID);
+
+   unsigned max_exported_args = b->shader->info.stage == MESA_SHADER_VERTEX ? 2 : 4;
+   if (b->shader->info.stage == MESA_SHADER_VERTEX && !uses_instance_id)
+      max_exported_args--;
+   else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL && !uses_tess_primitive_id)
+      max_exported_args--;
+
+   unsigned pervertex_lds_bytes = lds_es_arg_0 + max_exported_args * 4u;
+   unsigned total_es_lds_bytes = pervertex_lds_bytes * nogs_state->max_es_num_vertices;
+   unsigned max_num_waves = nogs_state->max_num_waves;
+   unsigned ngg_scratch_lds_base_addr = ALIGN(total_es_lds_bytes, 8u);
+   unsigned ngg_scratch_lds_bytes = DIV_ROUND_UP(max_num_waves, 4u);
+   nogs_state->total_lds_bytes = ngg_scratch_lds_base_addr + ngg_scratch_lds_bytes;
+
+   nir_function_impl *impl = nir_shader_get_entrypoint(b->shader);
+
+   /* Create some helper variables. */
+   nir_variable *position_value_var = nogs_state->position_value_var;
+   nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var;
+   nir_variable *gs_accepted_var = nogs_state->gs_accepted_var;
+   nir_variable *es_accepted_var = nogs_state->es_accepted_var;
+   nir_variable *vertices_in_wave_var = nir_local_variable_create(impl, glsl_uint_type(), "vertices_in_wave");
+   nir_variable *primitives_in_wave_var = nir_local_variable_create(impl, glsl_uint_type(), "primitives_in_wave");
+   nir_variable *gs_vtxaddr_vars[3] = {
+      nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx0_addr"),
+      nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx1_addr"),
+      nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx2_addr"),
+   };
+   nir_variable *repacked_arg_vars[4] = {
+      nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_0"),
+      nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_1"),
+      nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_2"),
+      nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_3"),
+   };
+
+   /* Top part of the culling shader (aka. position shader part)
+    *
+    * We clone the full ES shader and emit it here, but we only really care
+    * about its position output, so we delete every other output from this part.
+    * The position output is stored into a temporary variable, and reloaded later.
+    */
+
+   b->cursor = nir_before_cf_list(&impl->body);
+
+   nir_if *if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));
+   {
+      /* Initialize the position output variable to zeroes, in case not all VS/TES invocations store the output.
+       * The spec doesn't require it, but we use (0, 0, 0, 1) because some games rely on that.
+       */
+      nir_store_var(b, position_value_var, nir_imm_vec4(b, 0.0f, 0.0f, 0.0f, 1.0f), 0xfu);
+
+      /* Now reinsert a clone of the shader code */
+      struct hash_table *remap_table = _mesa_pointer_hash_table_create(NULL);
+      nir_cf_list_clone_and_reinsert(original_extracted_cf, &if_es_thread->cf_node, b->cursor, remap_table);
+      _mesa_hash_table_destroy(remap_table, NULL);
+      b->cursor = nir_after_cf_list(&if_es_thread->then_list);
+
+      /* Remember the current thread's shader arguments */
+      if (b->shader->info.stage == MESA_SHADER_VERTEX) {
+         nir_store_var(b, repacked_arg_vars[0], nir_build_load_vertex_id_zero_base(b), 0x1u);
+         if (uses_instance_id)
+            nir_store_var(b, repacked_arg_vars[1], nir_build_load_instance_id(b), 0x1u);
+      } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
+         nir_ssa_def *tess_coord = nir_build_load_tess_coord(b);
+         nir_store_var(b, repacked_arg_vars[0], nir_channel(b, tess_coord, 0), 0x1u);
+         nir_store_var(b, repacked_arg_vars[1], nir_channel(b, tess_coord, 1), 0x1u);
+         nir_store_var(b, repacked_arg_vars[2], nir_build_load_tess_rel_patch_id_amd(b), 0x1u);
+         if (uses_tess_primitive_id)
+            nir_store_var(b, repacked_arg_vars[3], nir_build_load_primitive_id(b), 0x1u);
+      } else {
+         unreachable("Should be VS or TES.");
+      }
+   }
+   nir_pop_if(b, if_es_thread);
+
+   /* Remove all non-position outputs, and put the position output into the variable. */
+   nir_metadata_preserve(impl, nir_metadata_none);
+   remove_culling_shader_outputs(b->shader, nogs_state, position_value_var);
+   b->cursor = nir_after_cf_list(&impl->body);
+
+   /* Run culling algorithms if culling is enabled.
+    *
+    * NGG culling can be enabled or disabled in runtime.
+    * This is determined by a SGPR shader argument which is acccessed
+    * by the following NIR intrinsic.
+    */
+
+   nir_if *if_cull_en = nir_push_if(b, nir_build_load_cull_any_enabled_amd(b));
+   {
+      nir_ssa_def *invocation_index = nir_build_load_local_invocation_index(b);
+      nir_ssa_def *es_vertex_lds_addr = pervertex_lds_addr(b, invocation_index, pervertex_lds_bytes);
+
+      /* ES invocations store their vertex data to LDS for GS threads to read. */
+      if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));
+      {
+         /* Store position components that are relevant to culling in LDS */
+         nir_ssa_def *pre_cull_pos = nir_load_var(b, position_value_var);
+         nir_ssa_def *pre_cull_w = nir_channel(b, pre_cull_pos, 3);
+         nir_build_store_shared(b, pre_cull_w, es_vertex_lds_addr, .write_mask = 0x1u, .align_mul = 4, .base = lds_es_pos_w);
+         nir_ssa_def *pre_cull_x_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 0), pre_cull_w);
+         nir_ssa_def *pre_cull_y_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 1), pre_cull_w);
+         nir_build_store_shared(b, nir_vec2(b, pre_cull_x_div_w, pre_cull_y_div_w), es_vertex_lds_addr, .write_mask = 0x3u, .align_mul = 4, .base = lds_es_pos_x);
+
+         /* Clear out the ES accepted flag in LDS */
+         nir_build_store_shared(b, nir_imm_zero(b, 1, 8), es_vertex_lds_addr, .write_mask = 0x1u, .align_mul = 4, .base = lds_es_vertex_accepted);
+      }
+      nir_pop_if(b, if_es_thread);
+
+      nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
+                            .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
+
+      nir_store_var(b, gs_accepted_var, nir_imm_bool(b, false), 0x1u);
+      nir_store_var(b, prim_exp_arg_var, nir_imm_int(b, 1 << 31), 0x1u);
+
+      /* GS invocations load the vertex data and perform the culling. */
+      nir_if *if_gs_thread = nir_push_if(b, nir_build_has_input_primitive_amd(b));
+      {
+         /* Load vertex indices from input VGPRs */
+         nir_ssa_def *vtx_idx[3] = {0};
+         for (unsigned vertex = 0; vertex < 3; ++vertex)
+            vtx_idx[vertex] = ngg_input_primitive_vertex_index(b, vertex);
+
+         nir_ssa_def *vtx_addr[3] = {0};
+         nir_ssa_def *pos[3][4] = {0};
+
+         /* Load W positions of vertices first because the culling code will use these first */
+         for (unsigned vtx = 0; vtx < 3; ++vtx) {
+            vtx_addr[vtx] = pervertex_lds_addr(b, vtx_idx[vtx], pervertex_lds_bytes);
+            pos[vtx][3] = nir_build_load_shared(b, 1, 32, vtx_addr[vtx], .align_mul = 4u, .base = lds_es_pos_w);
+            nir_store_var(b, gs_vtxaddr_vars[vtx], vtx_addr[vtx], 0x1u);
+         }
+
+         /* Load the X/W, Y/W positions of vertices */
+         for (unsigned vtx = 0; vtx < 3; ++vtx) {
+            nir_ssa_def *xy = nir_build_load_shared(b, 2, 32, vtx_addr[vtx], .align_mul = 4u, .base = lds_es_pos_x);
+            pos[vtx][0] = nir_channel(b, xy, 0);
+            pos[vtx][1] = nir_channel(b, xy, 1);
+         }
+
+         /* See if the current primitive is accepted */
+         nir_ssa_def *accepted = ac_nir_cull_triangle(b, nir_imm_bool(b, true), pos);
+         nir_store_var(b, gs_accepted_var, accepted, 0x1u);
+
+         nir_if *if_gs_accepted = nir_push_if(b, accepted);
+         {
+            /* Store the accepted state to LDS for ES threads */
+            for (unsigned vtx = 0; vtx < 3; ++vtx)
+               nir_build_store_shared(b, nir_imm_intN_t(b, 0xff, 8), vtx_addr[vtx], .base = lds_es_vertex_accepted, .align_mul = 4u, .write_mask = 0x1u);
+         }
+         nir_pop_if(b, if_gs_accepted);
+      }
+      nir_pop_if(b, if_gs_thread);
+
+      nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
+                            .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
+
+      nir_store_var(b, es_accepted_var, nir_imm_bool(b, false), 0x1u);
+
+      /* ES invocations load their accepted flag from LDS. */
+      if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));
+      {
+         nir_ssa_def *accepted = nir_build_load_shared(b, 1, 8u, es_vertex_lds_addr, .base = lds_es_vertex_accepted, .align_mul = 4u);
+         nir_ssa_def *accepted_bool = nir_ine(b, accepted, nir_imm_intN_t(b, 0, 8));
+         nir_store_var(b, es_accepted_var, accepted_bool, 0x1u);
+      }
+      nir_pop_if(b, if_es_thread);
+
+      /* Vertex compaction. */
+      compact_vertices_after_culling(b, nogs_state,
+                                     vertices_in_wave_var, primitives_in_wave_var,
+                                     repacked_arg_vars, gs_vtxaddr_vars,
+                                     invocation_index, es_vertex_lds_addr,
+                                     ngg_scratch_lds_base_addr, pervertex_lds_bytes, max_exported_args);
+   }
+   nir_push_else(b, if_cull_en);
+   {
+      /* When culling is disabled, we do the same as we would without culling. */
+      nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_int(b, 0)));
+      {
+         nir_ssa_def *vtx_cnt = nir_build_load_workgroup_num_input_vertices_amd(b);
+         nir_ssa_def *prim_cnt = nir_build_load_workgroup_num_input_primitives_amd(b);
+         nir_build_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt);
+      }
+      nir_pop_if(b, if_wave_0);
+      nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, nogs_state), 0x1u);
+
+      nir_ssa_def *vtx_cnt = nir_bit_count(b, nir_build_ballot(b, 1, nogs_state->wave_size, nir_build_has_input_vertex_amd(b)));
+      nir_ssa_def *prm_cnt = nir_bit_count(b, nir_build_ballot(b, 1, nogs_state->wave_size, nir_build_has_input_primitive_amd(b)));
+      nir_store_var(b, vertices_in_wave_var, vtx_cnt, 0x1u);
+      nir_store_var(b, primitives_in_wave_var, prm_cnt, 0x1u);
+   }
+   nir_pop_if(b, if_cull_en);
+
+   /* Update shader arguments.
+    *
+    * The registers which hold information about the subgroup's
+    * vertices and primitives are updated here, so the rest of the shader
+    * doesn't need to worry about the culling.
+    *
+    * These "overwrite" intrinsics must be at top level control flow,
+    * otherwise they can mess up the backend (eg. ACO's SSA).
+    *
+    * TODO:
+    * A cleaner solution would be to simply replace all usages of these args
+    * with the load of the variables.
+    * However, this wouldn't work right now because the backend uses the arguments
+    * for purposes not expressed in NIR, eg. VS input loads, etc.
+    * This can change if VS input loads and other stuff are lowered to eg. load_buffer_amd.
+    */
+
+   if (b->shader->info.stage == MESA_SHADER_VERTEX)
+      nir_build_overwrite_vs_arguments_amd(b,
+         nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]));
+   else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL)
+      nir_build_overwrite_tes_arguments_amd(b,
+         nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]),
+         nir_load_var(b, repacked_arg_vars[2]), nir_load_var(b, repacked_arg_vars[3]));
+   else
+      unreachable("Should be VS or TES.");
+
+   nir_ssa_def *vertices_in_wave = nir_load_var(b, vertices_in_wave_var);
+   nir_ssa_def *primitives_in_wave = nir_load_var(b, primitives_in_wave_var);
+   nir_build_overwrite_subgroup_num_vertices_and_primitives_amd(b, vertices_in_wave, primitives_in_wave);
+}
+
+static bool
+can_use_deferred_attribute_culling(nir_shader *shader)
+{
+   /* When the shader writes memory, it is difficult to guarantee correctness.
+    * Future work:
+    * - if only write-only SSBOs are used
+    * - if we can prove that non-position outputs don't rely on memory stores
+    * then may be okay to keep the memory stores in the 1st shader part, and delete them from the 2nd.
+    */
+   if (shader->info.writes_memory)
+      return false;
+
+   /* When the shader relies on the subgroup invocation ID, we'd break it, because the ID changes after the culling.
+    * Future work: try to save this to LDS and reload, but it can still be broken in subtle ways.
+    */
+   if (BITSET_TEST(shader->info.system_values_read, SYSTEM_VALUE_SUBGROUP_INVOCATION))
+      return false;
+
+   return true;
+}
+
 ac_nir_ngg_config
 ac_nir_lower_ngg_nogs(nir_shader *shader,
                       unsigned max_num_es_vertices,
@@ -328,12 +831,15 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
    assert(impl);
    assert(max_num_es_vertices && max_workgroup_size && wave_size);
 
-   bool can_cull = false; /* TODO */
+   bool can_cull = consider_culling && (num_vertices_per_primitives == 3) &&
+                   can_use_deferred_attribute_culling(shader);
    bool passthrough = consider_passthrough && !can_cull &&
                       !(shader->info.stage == MESA_SHADER_VERTEX && export_prim_id);
 
    nir_variable *position_value_var = nir_local_variable_create(impl, glsl_vec4_type(), "position_value");
    nir_variable *prim_exp_arg_var = nir_local_variable_create(impl, glsl_uint_type(), "prim_exp_arg");
+   nir_variable *es_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL;
+   nir_variable *gs_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL;
 
    lower_ngg_nogs_state state = {
       .passthrough = passthrough,
@@ -343,6 +849,8 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
       .provoking_vtx_idx = provoking_vtx_last ? (num_vertices_per_primitives - 1) : 0,
       .position_value_var = position_value_var,
       .prim_exp_arg_var = prim_exp_arg_var,
+      .es_accepted_var = es_accepted_var,
+      .gs_accepted_var = gs_accepted_var,
       .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
       .max_es_num_vertices = max_num_es_vertices,
       .wave_size = wave_size,
@@ -352,12 +860,15 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
    if (shader->info.stage == MESA_SHADER_VERTEX && export_prim_id)
       state.total_lds_bytes = max_num_es_vertices * 4u;
 
-   nir_cf_list extracted;
-   nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
+   /* The shader only needs this much LDS when culling is turned off. */
+   unsigned lds_bytes_if_culling_off = state.total_lds_bytes;
 
    nir_builder builder;
    nir_builder *b = &builder; /* This is to avoid the & */
    nir_builder_init(b, impl);
+
+   nir_cf_list extracted;
+   nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
    b->cursor = nir_before_cf_list(&impl->body);
 
    if (!can_cull) {
@@ -376,23 +887,23 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
       else
          nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, &state), 0x1u);
    } else {
-      abort(); /* TODO */
+      add_deferred_attribute_culling(b, &extracted, &state);
+      b->cursor = nir_after_cf_list(&impl->body);
+
+      if (state.early_prim_export)
+         emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, state.prim_exp_arg_var));
    }
 
+   nir_intrinsic_instr *export_vertex_instr;
+
    nir_if *if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));
    {
-      if (can_cull) {
-         nir_ssa_def *pos_val = nir_load_var(b, state.position_value_var);
-         nir_io_semantics io_sem = { .location = VARYING_SLOT_POS, .num_slots = 1 };
-         nir_build_store_output(b, pos_val, nir_imm_int(b, 0), .base = VARYING_SLOT_POS, .component = 0, .io_semantics = io_sem, .write_mask = 0xfu);
-      }
-
       /* Run the actual shader */
       nir_cf_reinsert(&extracted, b->cursor);
       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
 
       /* Export all vertex attributes (except primitive ID) */
-      nir_build_export_vertex_amd(b);
+      export_vertex_instr = nir_build_export_vertex_amd(b);
 
       /* Export primitive ID (in case of early primitive export or TES) */
       if (state.export_prim_id && (state.early_prim_export || shader->info.stage != MESA_SHADER_VERTEX))
@@ -410,17 +921,50 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
       }
    }
 
+   if (can_cull) {
+      /* Remove the redundant position output. */
+      remove_extra_pos_outputs(shader, &state);
+
+      /* After looking at the performance in apps eg. Doom Eternal, and The Witcher 3,
+       * it seems that it's best to put the position export always at the end, and
+       * then let ACO schedule it up (slightly) only when early prim export is used.
+       */
+      b->cursor = nir_before_instr(&export_vertex_instr->instr);
+
+      nir_ssa_def *pos_val = nir_load_var(b, state.position_value_var);
+      nir_io_semantics io_sem = { .location = VARYING_SLOT_POS, .num_slots = 1 };
+      nir_build_store_output(b, pos_val, nir_imm_int(b, 0), .base = VARYING_SLOT_POS, .component = 0, .io_semantics = io_sem, .write_mask = 0xfu);
+   }
+
    nir_metadata_preserve(impl, nir_metadata_none);
    nir_validate_shader(shader, "after emitting NGG VS/TES");
 
    /* Cleanup */
+   nir_opt_dead_write_vars(shader);
    nir_lower_vars_to_ssa(shader);
    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
-   nir_opt_undef(shader);
+   nir_lower_alu_to_scalar(shader, NULL, NULL);
+   nir_lower_phis_to_scalar(shader, true);
+
+   if (can_cull) {
+      /* It's beneficial to redo these opts after splitting the shader. */
+      nir_opt_sink(shader, nir_move_load_input | nir_move_const_undef | nir_move_copies);
+      nir_opt_move(shader, nir_move_load_input | nir_move_copies | nir_move_const_undef);
+   }
+
+   bool progress;
+   do {
+      progress = false;
+      NIR_PASS(progress, shader, nir_opt_undef);
+      NIR_PASS(progress, shader, nir_opt_cse);
+      NIR_PASS(progress, shader, nir_opt_dce);
+      NIR_PASS(progress, shader, nir_opt_dead_cf);
+   } while (progress);
 
    shader->info.shared_size = state.total_lds_bytes;
 
    ac_nir_ngg_config ret = {
+      .lds_bytes_if_culling_off = lds_bytes_if_culling_off,
       .can_cull = can_cull,
       .passthrough = passthrough,
    };



More information about the mesa-commit mailing list