[Mesa-dev] [PATCH 19/44] intel/cs: Push subgroup ID instead of base thread ID

Jason Ekstrand jason at jlekstrand.net
Tue Sep 5 15:13:11 UTC 2017


We're going to want subgroup ID for SPIR-V subgroups eventually anyway.
We really only want to push one and calculate the other from it.  It
makes a bit more sense to push the subgroup ID because it's simpler to
calculate and because it's a real API thing.  The only advantage to
pushing the base thread ID is to avoid a single SHL in the shader.
---
 src/compiler/nir/nir_intrinsics.h         |  4 +---
 src/intel/compiler/brw_compiler.h         |  2 +-
 src/intel/compiler/brw_fs.cpp             | 34 +++++++++++++++----------------
 src/intel/compiler/brw_fs.h               |  2 +-
 src/intel/compiler/brw_fs_nir.cpp         |  6 +++---
 src/intel/compiler/brw_fs_visitor.cpp     |  2 +-
 src/intel/compiler/brw_nir.h              |  2 +-
 src/intel/compiler/brw_nir_intrinsics.c   | 14 ++++++++-----
 src/intel/vulkan/anv_cmd_buffer.c         |  7 ++++---
 src/mesa/drivers/dri/i965/gen7_cs_state.c | 18 +++++++---------
 10 files changed, 45 insertions(+), 46 deletions(-)

diff --git a/src/compiler/nir/nir_intrinsics.h b/src/compiler/nir/nir_intrinsics.h
index 9389b74..54a51f8 100644
--- a/src/compiler/nir/nir_intrinsics.h
+++ b/src/compiler/nir/nir_intrinsics.h
@@ -355,6 +355,7 @@ SYSTEM_VALUE(subgroup_ge_mask, 1, 0, xx, xx, xx)
 SYSTEM_VALUE(subgroup_gt_mask, 1, 0, xx, xx, xx)
 SYSTEM_VALUE(subgroup_le_mask, 1, 0, xx, xx, xx)
 SYSTEM_VALUE(subgroup_lt_mask, 1, 0, xx, xx, xx)
+SYSTEM_VALUE(subgroup_id, 1, 0, xx, xx, xx)
 
 /* Blend constant color values.  Float values are clamped. */
 SYSTEM_VALUE(blend_const_color_r_float, 1, 0, xx, xx, xx)
@@ -364,9 +365,6 @@ SYSTEM_VALUE(blend_const_color_a_float, 1, 0, xx, xx, xx)
 SYSTEM_VALUE(blend_const_color_rgba8888_unorm, 1, 0, xx, xx, xx)
 SYSTEM_VALUE(blend_const_color_aaaa8888_unorm, 1, 0, xx, xx, xx)
 
-/* Intel specific system values */
-SYSTEM_VALUE(intel_thread_local_id, 1, 0, xx, xx, xx)
-
 /**
  * Barycentric coordinate intrinsics.
  *
diff --git a/src/intel/compiler/brw_compiler.h b/src/intel/compiler/brw_compiler.h
index 930c5b8..cf1b854 100644
--- a/src/intel/compiler/brw_compiler.h
+++ b/src/intel/compiler/brw_compiler.h
@@ -660,7 +660,7 @@ struct brw_cs_prog_data {
    unsigned threads;
    bool uses_barrier;
    bool uses_num_work_groups;
-   int thread_local_id_index;
+   int subgroup_id_index;
 
    struct {
       struct brw_push_const_block cross_thread;
diff --git a/src/intel/compiler/brw_fs.cpp b/src/intel/compiler/brw_fs.cpp
index 5057557..6d7373f 100644
--- a/src/intel/compiler/brw_fs.cpp
+++ b/src/intel/compiler/brw_fs.cpp
@@ -1402,7 +1402,7 @@ fs_visitor::assign_curb_setup()
 {
    unsigned num_push_constants = stage_prog_data->nr_params;
    if (stage == MESA_SHADER_COMPUTE &&
-       brw_cs_prog_data(stage_prog_data)->thread_local_id_index >= 0)
+       brw_cs_prog_data(stage_prog_data)->subgroup_id_index >= 0)
       num_push_constants++;
 
    unsigned uniform_push_length = DIV_ROUND_UP(num_push_constants, 8);
@@ -2021,7 +2021,7 @@ fs_visitor::assign_constant_locations()
     * brw_curbe.c.
     */
    unsigned int max_push_components = 16 * 8;
-   if (thread_local_id_index >= 0)
+   if (subgroup_id_index >= 0)
       max_push_components--; /* Save a slot for the thread ID */
 
    /* We push small arrays, but no bigger than 16 floats.  This is big enough
@@ -2065,8 +2065,8 @@ fs_visitor::assign_constant_locations()
       if (!is_live[u])
          continue;
 
-      /* Skip thread_local_id_index to put it in the last push register. */
-      if (thread_local_id_index == (int)u)
+      /* Skip subgroup_id_index to put it in the last push register. */
+      if (subgroup_id_index == (int)u)
          continue;
 
       set_push_pull_constant_loc(u, &chunk_start, &max_chunk_bitsize,
@@ -2082,8 +2082,8 @@ fs_visitor::assign_constant_locations()
     * We don't increment num_push_constants because this never actually ends
     * up in the params array.
     */
-   if (thread_local_id_index >= 0 && is_live[thread_local_id_index])
-      push_constant_loc[thread_local_id_index] = num_push_constants;
+   if (subgroup_id_index >= 0 && is_live[subgroup_id_index])
+      push_constant_loc[subgroup_id_index] = num_push_constants;
 
    /* As the uniforms are going to be reordered, take the data from a temporary
     * copy of the original param[].
@@ -2120,7 +2120,7 @@ fs_visitor::assign_constant_locations()
    for (unsigned int i = 0; i < uniforms; i++) {
       const gl_constant_value *value = param[i];
 
-      if (thread_local_id_index == (int)i)
+      if (subgroup_id_index == (int)i)
          continue;
 
       if (pull_constant_loc[i] != -1) {
@@ -2132,8 +2132,8 @@ fs_visitor::assign_constant_locations()
    ralloc_free(param);
 
    if (stage == MESA_SHADER_COMPUTE)
-      brw_cs_prog_data(stage_prog_data)->thread_local_id_index =
-         push_constant_loc[thread_local_id_index];
+      brw_cs_prog_data(stage_prog_data)->subgroup_id_index =
+         push_constant_loc[subgroup_id_index];
 }
 
 bool
@@ -6700,21 +6700,21 @@ cs_fill_push_const_info(const struct gen_device_info *devinfo,
 {
    const struct brw_stage_prog_data *prog_data = &cs_prog_data->base;
    bool cross_thread_supported = devinfo->gen > 7 || devinfo->is_haswell;
-   bool fill_thread_id = cs_prog_data->thread_local_id_index >= 0;
+   bool fill_subgroup_id = cs_prog_data->subgroup_id_index >= 0;
 
    /* The thread ID should be stored in the last param dword */
-   if (fill_thread_id)
-      assert(cs_prog_data->thread_local_id_index == (int)prog_data->nr_params);
+   if (fill_subgroup_id)
+      assert(cs_prog_data->subgroup_id_index == (int)prog_data->nr_params);
 
-   const unsigned dwords = prog_data->nr_params + fill_thread_id;
+   const unsigned dwords = prog_data->nr_params + fill_subgroup_id;
 
    unsigned cross_thread_dwords, per_thread_dwords;
    if (!cross_thread_supported) {
       cross_thread_dwords = 0u;
       per_thread_dwords = dwords;
-   } else if (fill_thread_id) {
+   } else if (fill_subgroup_id) {
       /* Fill all but the last register with cross-thread payload */
-      cross_thread_dwords = 8 * (cs_prog_data->thread_local_id_index / 8);
+      cross_thread_dwords = 8 * (cs_prog_data->subgroup_id_index / 8);
       per_thread_dwords = dwords - cross_thread_dwords;
       assert(per_thread_dwords > 0 && per_thread_dwords <= 8);
    } else {
@@ -6735,7 +6735,7 @@ cs_fill_push_const_info(const struct gen_device_info *devinfo,
           cs_prog_data->push.per_thread.size == 0);
    assert(cs_prog_data->push.cross_thread.dwords +
           cs_prog_data->push.per_thread.dwords ==
-             prog_data->nr_params + fill_thread_id);
+             prog_data->nr_params + fill_subgroup_id);
 }
 
 static void
@@ -6757,7 +6757,7 @@ compile_cs_to_nir(const struct brw_compiler *compiler,
 {
    nir_shader *shader = nir_shader_clone(mem_ctx, src_shader);
    shader = brw_nir_apply_sampler_key(shader, compiler, &key->tex, true);
-   brw_nir_lower_intrinsics(shader);
+   brw_nir_lower_intrinsics(shader, dispatch_width);
    return brw_postprocess_nir(shader, compiler, true);
 }
 
diff --git a/src/intel/compiler/brw_fs.h b/src/intel/compiler/brw_fs.h
index 29605be..96ffdef 100644
--- a/src/intel/compiler/brw_fs.h
+++ b/src/intel/compiler/brw_fs.h
@@ -320,7 +320,7 @@ public:
    /**
     * Uniform index of the compute shader thread id
     */
-   int thread_local_id_index;
+   int subgroup_id_index;
 
    fs_reg frag_depth;
    fs_reg frag_stencil;
diff --git a/src/intel/compiler/brw_fs_nir.cpp b/src/intel/compiler/brw_fs_nir.cpp
index ca82209..dd5570a 100644
--- a/src/intel/compiler/brw_fs_nir.cpp
+++ b/src/intel/compiler/brw_fs_nir.cpp
@@ -71,7 +71,7 @@ fs_visitor::nir_setup_uniforms()
    uniforms = nir->num_uniforms / 4;
 
    if (stage == MESA_SHADER_COMPUTE)
-      thread_local_id_index = uniforms++;
+      subgroup_id_index = uniforms++;
 }
 
 static bool
@@ -3396,8 +3396,8 @@ fs_visitor::nir_emit_cs_intrinsic(const fs_builder &bld,
       cs_prog_data->uses_barrier = true;
       break;
 
-   case nir_intrinsic_load_intel_thread_local_id: {
-      fs_reg uniform(UNIFORM, thread_local_id_index, BRW_REGISTER_TYPE_UD);
+   case nir_intrinsic_load_subgroup_id: {
+      fs_reg uniform(UNIFORM, subgroup_id_index, BRW_REGISTER_TYPE_UD);
       bld.MOV(retype(dest, BRW_REGISTER_TYPE_UD), uniform);
       break;
    }
diff --git a/src/intel/compiler/brw_fs_visitor.cpp b/src/intel/compiler/brw_fs_visitor.cpp
index 75ae463..fb575ef 100644
--- a/src/intel/compiler/brw_fs_visitor.cpp
+++ b/src/intel/compiler/brw_fs_visitor.cpp
@@ -887,7 +887,7 @@ fs_visitor::init()
    this->last_scratch = 0;
    this->pull_constant_loc = NULL;
    this->push_constant_loc = NULL;
-   this->thread_local_id_index = -1;
+   this->subgroup_id_index = -1;
 
    this->promoted_constants = 0,
 
diff --git a/src/intel/compiler/brw_nir.h b/src/intel/compiler/brw_nir.h
index df73303..1cd00f5 100644
--- a/src/intel/compiler/brw_nir.h
+++ b/src/intel/compiler/brw_nir.h
@@ -95,7 +95,7 @@ void brw_nir_analyze_boolean_resolves(nir_shader *nir);
 nir_shader *brw_preprocess_nir(const struct brw_compiler *compiler,
                                nir_shader *nir);
 
-bool brw_nir_lower_intrinsics(nir_shader *nir);
+bool brw_nir_lower_intrinsics(nir_shader *nir, unsigned dispatch_width);
 void brw_nir_lower_vs_inputs(nir_shader *nir,
                              bool use_legacy_snorm_formula,
                              const uint8_t *vs_attrib_wa_flags);
diff --git a/src/intel/compiler/brw_nir_intrinsics.c b/src/intel/compiler/brw_nir_intrinsics.c
index c4f6082..3b6403b 100644
--- a/src/intel/compiler/brw_nir_intrinsics.c
+++ b/src/intel/compiler/brw_nir_intrinsics.c
@@ -26,6 +26,7 @@
 
 struct lower_intrinsics_state {
    nir_shader *nir;
+   unsigned dispatch_width;
    nir_function_impl *impl;
    bool progress;
    nir_builder builder;
@@ -57,12 +58,14 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state,
           *    gl_LocalInvocationIndex =
           *       cs_thread_local_id + subgroup_invocation;
           */
-         nir_ssa_def *thread_local_id;
-         if (state->local_workgroup_size <= 8)
-            thread_local_id = nir_imm_int(b, 0);
+         nir_ssa_def *subgroup_id;
+         if (state->local_workgroup_size <= state->dispatch_width)
+            subgroup_id = nir_imm_int(b, 0);
          else
-            thread_local_id = nir_load_intel_thread_local_id(b);
+            subgroup_id = nir_load_subgroup_id(b);
 
+         nir_ssa_def *thread_local_id =
+            nir_imul(b, subgroup_id, nir_imm_int(b, state->dispatch_width));
          nir_ssa_def *channel = nir_load_subgroup_invocation(b);
          sysval = nir_iadd(b, channel, thread_local_id);
          break;
@@ -128,7 +131,7 @@ lower_cs_intrinsics_convert_impl(struct lower_intrinsics_state *state)
 }
 
 bool
-brw_nir_lower_intrinsics(nir_shader *nir)
+brw_nir_lower_intrinsics(nir_shader *nir, unsigned dispatch_width)
 {
    /* Currently we only lower intrinsics for compute shaders */
    if (nir->stage != MESA_SHADER_COMPUTE)
@@ -138,6 +141,7 @@ brw_nir_lower_intrinsics(nir_shader *nir)
    struct lower_intrinsics_state state;
    memset(&state, 0, sizeof(state));
    state.nir = nir;
+   state.dispatch_width = dispatch_width;
    state.local_workgroup_size = nir->info.cs.local_size[0] *
                                 nir->info.cs.local_size[1] *
                                 nir->info.cs.local_size[2];
diff --git a/src/intel/vulkan/anv_cmd_buffer.c b/src/intel/vulkan/anv_cmd_buffer.c
index c0d949c..c6d5dc7 100644
--- a/src/intel/vulkan/anv_cmd_buffer.c
+++ b/src/intel/vulkan/anv_cmd_buffer.c
@@ -688,8 +688,8 @@ anv_cmd_buffer_cs_push_constants(struct anv_cmd_buffer *cmd_buffer)
    uint32_t *u32_map = state.map;
 
    if (cs_prog_data->push.cross_thread.size > 0) {
-      assert(cs_prog_data->thread_local_id_index < 0 ||
-             cs_prog_data->thread_local_id_index >=
+      assert(cs_prog_data->subgroup_id_index < 0 ||
+             cs_prog_data->subgroup_id_index >=
                 cs_prog_data->push.cross_thread.dwords);
       for (unsigned i = 0;
            i < cs_prog_data->push.cross_thread.dwords;
@@ -709,7 +709,8 @@ anv_cmd_buffer_cs_push_constants(struct anv_cmd_buffer *cmd_buffer)
             uint32_t offset = (uintptr_t)prog_data->param[src];
             u32_map[dst] = *(uint32_t *)((uint8_t *)data + offset);
          }
-         if (cs_prog_data->thread_local_id_index >= 0)
+         /* Subgroup ID goes at the end */
+         if (cs_prog_data->subgroup_id_index >= 0)
             u32_map[dst] = t;
       }
    }
diff --git a/src/mesa/drivers/dri/i965/gen7_cs_state.c b/src/mesa/drivers/dri/i965/gen7_cs_state.c
index 26e4264..a3a1e13 100644
--- a/src/mesa/drivers/dri/i965/gen7_cs_state.c
+++ b/src/mesa/drivers/dri/i965/gen7_cs_state.c
@@ -74,8 +74,8 @@ brw_upload_cs_push_constants(struct brw_context *brw,
 
    if (cs_prog_data->push.cross_thread.size > 0) {
       gl_constant_value *param_copy = param;
-      assert(cs_prog_data->thread_local_id_index < 0 ||
-             cs_prog_data->thread_local_id_index >=
+      assert(cs_prog_data->subgroup_id_index < 0 ||
+             cs_prog_data->subgroup_id_index >=
                 cs_prog_data->push.cross_thread.dwords);
       for (unsigned i = 0;
            i < cs_prog_data->push.cross_thread.dwords;
@@ -84,21 +84,17 @@ brw_upload_cs_push_constants(struct brw_context *brw,
       }
    }
 
-   gl_constant_value thread_id;
    if (cs_prog_data->push.per_thread.size > 0) {
       for (unsigned t = 0; t < cs_prog_data->threads; t++) {
          unsigned dst =
             8 * (cs_prog_data->push.per_thread.regs * t +
                  cs_prog_data->push.cross_thread.regs);
          unsigned src = cs_prog_data->push.cross_thread.dwords;
-         for ( ; src < prog_data->nr_params; src++, dst++) {
-            if (src != cs_prog_data->thread_local_id_index)
-               param[dst] = *prog_data->param[src];
-            else {
-               thread_id.u = t * cs_prog_data->simd_size;
-               param[dst] = thread_id;
-            }
-         }
+         for ( ; src < prog_data->nr_params; src++, dst++)
+            param[dst] = *prog_data->param[src];
+         /* Subgroup ID goes at the end */
+         if (cs_prog_data->subgroup_id_index >= 0)
+            param[dst].u = t;
       }
    }
 
-- 
2.5.0.400.gff86faf



More information about the mesa-dev mailing list