Mesa (main): ac/nir: Add remappability to tess and ESGS I/O lowering passes.

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Tue Jun 7 02:12:42 UTC 2022


Module: Mesa
Branch: main
Commit: f7f2770e7241bf47e2aea01474bd64187d118e2c
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=f7f2770e7241bf47e2aea01474bd64187d118e2c

Author: Timur Kristóf <timur.kristof at gmail.com>
Date:   Thu May 12 15:48:24 2022 +0200

ac/nir: Add remappability to tess and ESGS I/O lowering passes.

This will be used for radeonsi to map common I/O location to fixed
slots agreed by different shader stages.

Reviewed-by: Marek Olšák <marek.olsak at amd.com>
Acked-by: Pierre-Eric Pelloux-Prayer <pierre-eric.pelloux-prayer at amd.com>
Signed-off-by: Timur Kristóf <timur.kristof at gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16418>

---

 src/amd/common/ac_nir.c                      | 31 ++++++++++++++++++++++++++++
 src/amd/common/ac_nir.h                      | 16 ++++++++++++++
 src/amd/common/ac_nir_lower_esgs_io_to_mem.c | 11 ++++++++--
 src/amd/common/ac_nir_lower_tess_io_to_mem.c | 21 ++++++++++++++-----
 src/amd/vulkan/radv_shader.c                 | 16 +++++++-------
 src/compiler/nir/nir_builder.h               | 26 -----------------------
 6 files changed, 81 insertions(+), 40 deletions(-)

diff --git a/src/amd/common/ac_nir.c b/src/amd/common/ac_nir.c
index 6c8db43c516..49134cf33df 100644
--- a/src/amd/common/ac_nir.c
+++ b/src/amd/common/ac_nir.c
@@ -35,6 +35,37 @@ ac_nir_load_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct ac_
       return nir_load_vector_arg_amd(b, num_components, .base = arg.arg_index);
 }
 
+/**
+ * This function takes an I/O intrinsic like load/store_input,
+ * and emits a sequence that calculates the full offset of that instruction,
+ * including a stride to the base and component offsets.
+ */
+nir_ssa_def *
+ac_nir_calc_io_offset(nir_builder *b,
+                      nir_intrinsic_instr *intrin,
+                      nir_ssa_def *base_stride,
+                      unsigned component_stride,
+                      ac_nir_map_io_driver_location map_io)
+{
+   unsigned base = nir_intrinsic_base(intrin);
+   unsigned semantic = nir_intrinsic_io_semantics(intrin).location;
+   unsigned mapped_driver_location = map_io ? map_io(semantic) : base;
+
+   /* base is the driver_location, which is in slots (1 slot = 4x4 bytes) */
+   nir_ssa_def *base_op = nir_imul_imm(b, base_stride, mapped_driver_location);
+
+   /* offset should be interpreted in relation to the base,
+    * so the instruction effectively reads/writes another input/output
+    * when it has an offset
+    */
+   nir_ssa_def *offset_op = nir_imul(b, base_stride, nir_ssa_for_src(b, *nir_get_io_offset_src(intrin), 1));
+
+   /* component is in bytes */
+   unsigned const_op = nir_intrinsic_component(intrin) * component_stride;
+
+   return nir_iadd_imm_nuw(b, nir_iadd_nuw(b, base_op, offset_op), const_op);
+}
+
 bool
 ac_nir_lower_indirect_derefs(nir_shader *shader,
                              enum amd_gfx_level gfx_level)
diff --git a/src/amd/common/ac_nir.h b/src/amd/common/ac_nir.h
index 07613668f1a..8920e985517 100644
--- a/src/amd/common/ac_nir.h
+++ b/src/amd/common/ac_nir.h
@@ -48,6 +48,9 @@ enum
    AC_EXP_PARAM_UNDEFINED = 255, /* deprecated, use AC_EXP_PARAM_DEFAULT_VAL_0000 instead */
 };
 
+/* Maps I/O semantics to the actual location used by the lowering pass. */
+typedef unsigned (*ac_nir_map_io_driver_location)(unsigned semantic);
+
 /* Forward declaration of nir_builder so we don't have to include nir_builder.h here */
 struct nir_builder;
 typedef struct nir_builder nir_builder;
@@ -55,21 +58,31 @@ typedef struct nir_builder nir_builder;
 nir_ssa_def *
 ac_nir_load_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct ac_arg arg);
 
+nir_ssa_def *
+ac_nir_calc_io_offset(nir_builder *b,
+                      nir_intrinsic_instr *intrin,
+                      nir_ssa_def *base_stride,
+                      unsigned component_stride,
+                      ac_nir_map_io_driver_location map_io);
+
 bool ac_nir_optimize_outputs(nir_shader *nir, bool sprite_tex_disallowed,
                              int8_t slot_remap[NUM_TOTAL_VARYING_SLOTS],
                              uint8_t param_export_index[NUM_TOTAL_VARYING_SLOTS]);
 
 void
 ac_nir_lower_ls_outputs_to_mem(nir_shader *ls,
+                               ac_nir_map_io_driver_location map,
                                bool tcs_in_out_eq,
                                uint64_t tcs_temp_only_inputs);
 
 void
 ac_nir_lower_hs_inputs_to_mem(nir_shader *shader,
+                              ac_nir_map_io_driver_location map,
                               bool tcs_in_out_eq);
 
 void
 ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
+                               ac_nir_map_io_driver_location map,
                                enum amd_gfx_level gfx_level,
                                bool tes_reads_tessfactors,
                                uint64_t tes_inputs_read,
@@ -80,16 +93,19 @@ ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
 
 void
 ac_nir_lower_tes_inputs_to_mem(nir_shader *shader,
+                               ac_nir_map_io_driver_location map,
                                unsigned num_reserved_tcs_outputs,
                                unsigned num_reserved_tcs_patch_outputs);
 
 void
 ac_nir_lower_es_outputs_to_mem(nir_shader *shader,
+                               ac_nir_map_io_driver_location map,
                                enum amd_gfx_level gfx_level,
                                unsigned num_reserved_es_outputs);
 
 void
 ac_nir_lower_gs_inputs_to_mem(nir_shader *shader,
+                              ac_nir_map_io_driver_location map,
                               enum amd_gfx_level gfx_level,
                               unsigned num_reserved_es_outputs);
 
diff --git a/src/amd/common/ac_nir_lower_esgs_io_to_mem.c b/src/amd/common/ac_nir_lower_esgs_io_to_mem.c
index 1d5f9e9032b..9cd7f4d6f9b 100644
--- a/src/amd/common/ac_nir_lower_esgs_io_to_mem.c
+++ b/src/amd/common/ac_nir_lower_esgs_io_to_mem.c
@@ -44,6 +44,9 @@ typedef struct {
    /* Which hardware generation we're dealing with */
    enum amd_gfx_level gfx_level;
 
+   /* I/O semantic -> real location used by lowering. */
+   ac_nir_map_io_driver_location map_io;
+
    /* Number of ES outputs for which memory should be reserved.
     * When compacted, this should be the number of linked ES outputs.
     */
@@ -125,7 +128,7 @@ lower_es_output_store(nir_builder *b,
    unsigned write_mask = nir_intrinsic_write_mask(intrin);
 
    b->cursor = nir_before_instr(instr);
-   nir_ssa_def *io_off = nir_build_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u);
+   nir_ssa_def *io_off = ac_nir_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u, st->map_io);
 
    if (st->gfx_level <= GFX8) {
       /* GFX6-8: ES is a separate HW stage, data is passed from ES to GS in VRAM. */
@@ -198,7 +201,7 @@ gs_per_vertex_input_offset(nir_builder *b,
                                 : gs_per_vertex_input_vertex_offset_gfx6(b, vertex_src);
 
    unsigned base_stride = st->gfx_level >= GFX9 ? 1 : 64 /* Wave size on GFX6-8 */;
-   nir_ssa_def *io_off = nir_build_calc_io_offset(b, instr, nir_imm_int(b, base_stride * 4u), base_stride);
+   nir_ssa_def *io_off = ac_nir_calc_io_offset(b, instr, nir_imm_int(b, base_stride * 4u), base_stride, st->map_io);
    nir_ssa_def *off = nir_iadd(b, io_off, vertex_offset);
    return nir_imul_imm(b, off, 4u);
 }
@@ -230,12 +233,14 @@ filter_load_per_vertex_input(const nir_instr *instr, UNUSED const void *state)
 
 void
 ac_nir_lower_es_outputs_to_mem(nir_shader *shader,
+                               ac_nir_map_io_driver_location map,
                                enum amd_gfx_level gfx_level,
                                unsigned num_reserved_es_outputs)
 {
    lower_esgs_io_state state = {
       .gfx_level = gfx_level,
       .num_reserved_es_outputs = num_reserved_es_outputs,
+      .map_io = map,
    };
 
    nir_shader_instructions_pass(shader,
@@ -246,12 +251,14 @@ ac_nir_lower_es_outputs_to_mem(nir_shader *shader,
 
 void
 ac_nir_lower_gs_inputs_to_mem(nir_shader *shader,
+                              ac_nir_map_io_driver_location map,
                               enum amd_gfx_level gfx_level,
                               unsigned num_reserved_es_outputs)
 {
    lower_esgs_io_state state = {
       .gfx_level = gfx_level,
       .num_reserved_es_outputs = num_reserved_es_outputs,
+      .map_io = map,
    };
 
    nir_shader_lower_instructions(shader,
diff --git a/src/amd/common/ac_nir_lower_tess_io_to_mem.c b/src/amd/common/ac_nir_lower_tess_io_to_mem.c
index 8483f4609ee..bb886f4c38c 100644
--- a/src/amd/common/ac_nir_lower_tess_io_to_mem.c
+++ b/src/amd/common/ac_nir_lower_tess_io_to_mem.c
@@ -123,6 +123,9 @@ typedef struct {
    /* Which hardware generation we're dealing with */
    enum amd_gfx_level gfx_level;
 
+   /* I/O semantic -> real location used by lowering. */
+   ac_nir_map_io_driver_location map_io;
+
    /* True if merged VS+TCS (on GFX9+) has the same number
     * of input and output patch size.
     */
@@ -239,7 +242,7 @@ lower_ls_output_store(nir_builder *b,
    nir_ssa_def *vertex_idx = nir_load_local_invocation_index(b);
    nir_ssa_def *base_off_var = nir_imul(b, vertex_idx, nir_load_lshs_vertex_stride_amd(b));
 
-   nir_ssa_def *io_off = nir_build_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u);
+   nir_ssa_def *io_off = ac_nir_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u, st->map_io);
    unsigned write_mask = nir_intrinsic_write_mask(intrin);
 
    nir_ssa_def *off = nir_iadd_nuw(b, base_off_var, io_off);
@@ -299,7 +302,7 @@ hs_per_vertex_input_lds_offset(nir_builder *b,
 
    nir_ssa_def *tcs_in_current_patch_offset = nir_imul(b, rel_patch_id, tcs_in_patch_stride);
 
-   nir_ssa_def *io_offset = nir_build_calc_io_offset(b, instr, nir_imm_int(b, 16u), 4u);
+   nir_ssa_def *io_offset = ac_nir_calc_io_offset(b, instr, nir_imm_int(b, 16u), 4u, st->map_io);
 
    return nir_iadd_nuw(b, nir_iadd_nuw(b, tcs_in_current_patch_offset, vertex_index_off), io_offset);
 }
@@ -323,7 +326,7 @@ hs_output_lds_offset(nir_builder *b,
    nir_ssa_def *output_patch0_offset = nir_imul(b, input_patch_size, tcs_num_patches);
 
    nir_ssa_def *off = intrin
-                    ? nir_build_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u)
+                    ? ac_nir_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u, st->map_io)
                     : nir_imm_int(b, 0);
 
    nir_ssa_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
@@ -353,7 +356,7 @@ hs_per_vertex_output_vmem_offset(nir_builder *b,
 
    nir_ssa_def *tcs_num_patches = nir_load_tcs_num_patches_amd(b);
    nir_ssa_def *attr_stride = nir_imul(b, tcs_num_patches, nir_imul_imm(b, out_vertices_per_patch, 16u));
-   nir_ssa_def *io_offset = nir_build_calc_io_offset(b, intrin, attr_stride, 4u);
+   nir_ssa_def *io_offset = ac_nir_calc_io_offset(b, intrin, attr_stride, 4u, st->map_io);
 
    nir_ssa_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
    nir_ssa_def *patch_offset = nir_imul(b, rel_patch_id, nir_imul_imm(b, out_vertices_per_patch, 16u));
@@ -379,7 +382,7 @@ hs_per_patch_output_vmem_offset(nir_builder *b,
    nir_ssa_def *per_patch_data_offset = nir_imul(b, tcs_num_patches, per_vertex_output_patch_size);
 
    nir_ssa_def * off = intrin
-                    ? nir_build_calc_io_offset(b, intrin, nir_imul_imm(b, tcs_num_patches, 16u), 4u)
+                    ? ac_nir_calc_io_offset(b, intrin, nir_imul_imm(b, tcs_num_patches, 16u), 4u, st->map_io)
                     : nir_imm_int(b, 0);
 
    if (const_base_offset)
@@ -650,6 +653,7 @@ filter_any_input_access(const nir_instr *instr,
 
 void
 ac_nir_lower_ls_outputs_to_mem(nir_shader *shader,
+                               ac_nir_map_io_driver_location map,
                                bool tcs_in_out_eq,
                                uint64_t tcs_temp_only_inputs)
 {
@@ -658,6 +662,7 @@ ac_nir_lower_ls_outputs_to_mem(nir_shader *shader,
    lower_tess_io_state state = {
       .tcs_in_out_eq = tcs_in_out_eq,
       .tcs_temp_only_inputs = tcs_in_out_eq ? tcs_temp_only_inputs : 0,
+      .map_io = map,
    };
 
    nir_shader_instructions_pass(shader,
@@ -668,12 +673,14 @@ ac_nir_lower_ls_outputs_to_mem(nir_shader *shader,
 
 void
 ac_nir_lower_hs_inputs_to_mem(nir_shader *shader,
+                              ac_nir_map_io_driver_location map,
                               bool tcs_in_out_eq)
 {
    assert(shader->info.stage == MESA_SHADER_TESS_CTRL);
 
    lower_tess_io_state state = {
       .tcs_in_out_eq = tcs_in_out_eq,
+      .map_io = map,
    };
 
    nir_shader_lower_instructions(shader,
@@ -684,6 +691,7 @@ ac_nir_lower_hs_inputs_to_mem(nir_shader *shader,
 
 void
 ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
+                               ac_nir_map_io_driver_location map,
                                enum amd_gfx_level gfx_level,
                                bool tes_reads_tessfactors,
                                uint64_t tes_inputs_read,
@@ -702,6 +710,7 @@ ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
       .tcs_num_reserved_outputs = num_reserved_tcs_outputs,
       .tcs_num_reserved_patch_outputs = num_reserved_tcs_patch_outputs,
       .tcs_out_patch_fits_subgroup = 32 % shader->info.tess.tcs_vertices_out == 0,
+      .map_io = map,
    };
 
    nir_shader_lower_instructions(shader,
@@ -715,6 +724,7 @@ ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
 
 void
 ac_nir_lower_tes_inputs_to_mem(nir_shader *shader,
+                               ac_nir_map_io_driver_location map,
                                unsigned num_reserved_tcs_outputs,
                                unsigned num_reserved_tcs_patch_outputs)
 {
@@ -723,6 +733,7 @@ ac_nir_lower_tes_inputs_to_mem(nir_shader *shader,
    lower_tess_io_state state = {
       .tcs_num_reserved_outputs = num_reserved_tcs_outputs,
       .tcs_num_reserved_patch_outputs = num_reserved_tcs_patch_outputs,
+      .map_io = map,
    };
 
    nir_shader_lower_instructions(shader,
diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c
index 9743e6be191..ec5c00f0b1d 100644
--- a/src/amd/vulkan/radv_shader.c
+++ b/src/amd/vulkan/radv_shader.c
@@ -1059,34 +1059,36 @@ radv_lower_io_to_mem(struct radv_device *device, struct radv_pipeline_stage *sta
 
    if (nir->info.stage == MESA_SHADER_VERTEX) {
       if (info->vs.as_ls) {
-         NIR_PASS_V(nir, ac_nir_lower_ls_outputs_to_mem, info->vs.tcs_in_out_eq,
+         NIR_PASS_V(nir, ac_nir_lower_ls_outputs_to_mem, NULL, info->vs.tcs_in_out_eq,
                     info->vs.tcs_temp_only_input_mask);
          return true;
       } else if (info->vs.as_es) {
-         NIR_PASS_V(nir, ac_nir_lower_es_outputs_to_mem,
+         NIR_PASS_V(nir, ac_nir_lower_es_outputs_to_mem, NULL,
                     device->physical_device->rad_info.gfx_level, info->vs.num_linked_outputs);
          return true;
       }
    } else if (nir->info.stage == MESA_SHADER_TESS_CTRL) {
-      NIR_PASS_V(nir, ac_nir_lower_hs_inputs_to_mem, info->vs.tcs_in_out_eq);
-      NIR_PASS_V(nir, ac_nir_lower_hs_outputs_to_mem, device->physical_device->rad_info.gfx_level,
+      NIR_PASS_V(nir, ac_nir_lower_hs_inputs_to_mem, NULL, info->vs.tcs_in_out_eq);
+      NIR_PASS_V(nir, ac_nir_lower_hs_outputs_to_mem, NULL,
+                 device->physical_device->rad_info.gfx_level,
                  info->tcs.tes_reads_tess_factors, info->tcs.tes_inputs_read,
                  info->tcs.tes_patch_inputs_read, info->tcs.num_linked_outputs,
                  info->tcs.num_linked_patch_outputs, true);
 
       return true;
    } else if (nir->info.stage == MESA_SHADER_TESS_EVAL) {
-      NIR_PASS_V(nir, ac_nir_lower_tes_inputs_to_mem, info->tes.num_linked_inputs,
+      NIR_PASS_V(nir, ac_nir_lower_tes_inputs_to_mem, NULL, info->tes.num_linked_inputs,
                  info->tes.num_linked_patch_inputs);
 
       if (info->tes.as_es) {
-         NIR_PASS_V(nir, ac_nir_lower_es_outputs_to_mem,
+         NIR_PASS_V(nir, ac_nir_lower_es_outputs_to_mem, NULL,
                     device->physical_device->rad_info.gfx_level, info->tes.num_linked_outputs);
       }
 
       return true;
    } else if (nir->info.stage == MESA_SHADER_GEOMETRY) {
-      NIR_PASS_V(nir, ac_nir_lower_gs_inputs_to_mem, device->physical_device->rad_info.gfx_level,
+      NIR_PASS_V(nir, ac_nir_lower_gs_inputs_to_mem, NULL,
+                 device->physical_device->rad_info.gfx_level,
                  info->gs.num_linked_inputs);
       return true;
    } else if (nir->info.stage == MESA_SHADER_TASK) {
diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h
index af0660a7d88..bfe27164539 100644
--- a/src/compiler/nir/nir_builder.h
+++ b/src/compiler/nir/nir_builder.h
@@ -1543,32 +1543,6 @@ nir_load_param(nir_builder *build, uint32_t param_idx)
    return nir_build_load_param(build, param->num_components, param->bit_size, param_idx);
 }
 
-/**
- * This function takes an I/O intrinsic like load/store_input,
- * and emits a sequence that calculates the full offset of that instruction,
- * including a stride to the base and component offsets.
- */
-static inline nir_ssa_def *
-nir_build_calc_io_offset(nir_builder *b,
-                         nir_intrinsic_instr *intrin,
-                         nir_ssa_def *base_stride,
-                         unsigned component_stride)
-{
-   /* base is the driver_location, which is in slots (1 slot = 4x4 bytes) */
-   nir_ssa_def *base_op = nir_imul_imm(b, base_stride, nir_intrinsic_base(intrin));
-
-   /* offset should be interpreted in relation to the base,
-    * so the instruction effectively reads/writes another input/output
-    * when it has an offset
-    */
-   nir_ssa_def *offset_op = nir_imul(b, base_stride, nir_ssa_for_src(b, *nir_get_io_offset_src(intrin), 1));
-
-   /* component is in bytes */
-   unsigned const_op = nir_intrinsic_component(intrin) * component_stride;
-
-   return nir_iadd_imm_nuw(b, nir_iadd_nuw(b, base_op, offset_op), const_op);
-}
-
 /* calculate a `(1 << value) - 1` in ssa without overflows */
 static inline nir_ssa_def *
 nir_mask(nir_builder *b, nir_ssa_def *bits, unsigned dst_bit_size)



More information about the mesa-commit mailing list