[Mesa-dev] [PATCH 36/56] spirv: Add subgroup ballot support

Jason Ekstrand jason at jlekstrand.net
Wed Mar 7 14:35:24 UTC 2018


Reviewed-by: Iago Toral Quiroga <itoral at igalia.com>
---
 src/compiler/shader_info.h         |   1 +
 src/compiler/spirv/spirv_to_nir.c  |   5 ++
 src/compiler/spirv/vtn_subgroup.c  | 143 ++++++++++++++++++++++++++++++++++---
 src/compiler/spirv/vtn_variables.c |  20 ++++++
 4 files changed, 161 insertions(+), 8 deletions(-)

diff --git a/src/compiler/shader_info.h b/src/compiler/shader_info.h
index 6183432..3a6e545 100644
--- a/src/compiler/shader_info.h
+++ b/src/compiler/shader_info.h
@@ -44,6 +44,7 @@ struct spirv_supported_capabilities {
    bool multiview;
    bool variable_pointers;
    bool storage_16bit;
+   bool subgroup_ballot;
    bool subgroup_basic;
 };
 
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index 0a06c39..451d44f 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -3296,6 +3296,11 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
          spv_check_supported(subgroup_basic, cap);
          break;
 
+      case SpvCapabilitySubgroupBallotKHR:
+      case SpvCapabilityGroupNonUniformBallot:
+         spv_check_supported(subgroup_ballot, cap);
+         break;
+
       case SpvCapabilityVariablePointersStorageBuffer:
       case SpvCapabilityVariablePointers:
          spv_check_supported(variable_pointers, cap);
diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c
index 033c43e..a86f0cb 100644
--- a/src/compiler/spirv/vtn_subgroup.c
+++ b/src/compiler/spirv/vtn_subgroup.c
@@ -23,6 +23,44 @@
 
 #include "vtn_private.h"
 
+static void
+vtn_build_subgroup_instr(struct vtn_builder *b,
+                         nir_intrinsic_op nir_op,
+                         struct vtn_ssa_value *dst,
+                         struct vtn_ssa_value *src0,
+                         nir_ssa_def *index)
+{
+   /* Some of the subgroup operations take an index.  SPIR-V allows this to be
+    * any integer type.  To make things simpler for drivers, we only support
+    * 32-bit indices.
+    */
+   if (index && index->bit_size != 32)
+      index = nir_u2u32(&b->nb, index);
+
+   vtn_assert(dst->type == src0->type);
+   if (!glsl_type_is_vector_or_scalar(dst->type)) {
+      for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
+         vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
+                                  src0->elems[i], index);
+      }
+      return;
+   }
+
+   nir_intrinsic_instr *intrin =
+      nir_intrinsic_instr_create(b->nb.shader, nir_op);
+   nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+                              dst->type, NULL);
+   intrin->num_components = intrin->dest.ssa.num_components;
+
+   intrin->src[0] = nir_src_for_ssa(src0->def);
+   if (index)
+      intrin->src[1] = nir_src_for_ssa(index);
+
+   nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+   dst->def = &intrin->dest.ssa;
+}
+
 void
 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
                     const uint32_t *w, unsigned count)
@@ -43,17 +81,106 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
       break;
    }
 
-   case SpvOpGroupNonUniformAll:
-   case SpvOpGroupNonUniformAny:
-   case SpvOpGroupNonUniformAllEqual:
-   case SpvOpGroupNonUniformBroadcast:
-   case SpvOpGroupNonUniformBroadcastFirst:
-   case SpvOpGroupNonUniformBallot:
-   case SpvOpGroupNonUniformInverseBallot:
+   case SpvOpGroupNonUniformBallot: {
+      vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
+                  "OpGroupNonUniformBallot must return a uvec4");
+      nir_intrinsic_instr *ballot =
+         nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
+      ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
+      nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
+      ballot->num_components = 4;
+      nir_builder_instr_insert(&b->nb, &ballot->instr);
+      val->ssa->def = &ballot->dest.ssa;
+      break;
+   }
+
+   case SpvOpGroupNonUniformInverseBallot: {
+      /* This one is just a BallotBitfieldExtract with subgroup invocation.
+       * We could add a NIR intrinsic but it's easier to just lower it on the
+       * spot.
+       */
+      nir_intrinsic_instr *intrin =
+         nir_intrinsic_instr_create(b->nb.shader,
+                                    nir_intrinsic_ballot_bitfield_extract);
+
+      intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
+      intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
+
+      nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
+      nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+      val->ssa->def = &intrin->dest.ssa;
+      break;
+   }
+
    case SpvOpGroupNonUniformBallotBitExtract:
    case SpvOpGroupNonUniformBallotBitCount:
    case SpvOpGroupNonUniformBallotFindLSB:
-   case SpvOpGroupNonUniformBallotFindMSB:
+   case SpvOpGroupNonUniformBallotFindMSB: {
+      nir_ssa_def *src0, *src1 = NULL;
+      nir_intrinsic_op op;
+      switch (opcode) {
+      case SpvOpGroupNonUniformBallotBitExtract:
+         op = nir_intrinsic_ballot_bitfield_extract;
+         src0 = vtn_ssa_value(b, w[4])->def;
+         src1 = vtn_ssa_value(b, w[5])->def;
+         break;
+      case SpvOpGroupNonUniformBallotBitCount:
+         switch ((SpvGroupOperation)w[4]) {
+         case SpvGroupOperationReduce:
+            op = nir_intrinsic_ballot_bit_count_reduce;
+            break;
+         case SpvGroupOperationInclusiveScan:
+            op = nir_intrinsic_ballot_bit_count_inclusive;
+            break;
+         case SpvGroupOperationExclusiveScan:
+            op = nir_intrinsic_ballot_bit_count_exclusive;
+            break;
+         default:
+            unreachable("Invalid group operation");
+         }
+         src0 = vtn_ssa_value(b, w[5])->def;
+         break;
+      case SpvOpGroupNonUniformBallotFindLSB:
+         op = nir_intrinsic_ballot_find_lsb;
+         src0 = vtn_ssa_value(b, w[4])->def;
+         break;
+      case SpvOpGroupNonUniformBallotFindMSB:
+         op = nir_intrinsic_ballot_find_msb;
+         src0 = vtn_ssa_value(b, w[4])->def;
+         break;
+      default:
+         unreachable("Unhandled opcode");
+      }
+
+      nir_intrinsic_instr *intrin =
+         nir_intrinsic_instr_create(b->nb.shader, op);
+
+      intrin->src[0] = nir_src_for_ssa(src0);
+      if (src1)
+         intrin->src[1] = nir_src_for_ssa(src1);
+
+      nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
+      nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+      val->ssa->def = &intrin->dest.ssa;
+      break;
+   }
+
+   case SpvOpGroupNonUniformBroadcastFirst:
+      vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
+                               val->ssa, vtn_ssa_value(b, w[4]), NULL);
+      break;
+
+   case SpvOpGroupNonUniformBroadcast:
+      vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
+                               val->ssa, vtn_ssa_value(b, w[4]),
+                               vtn_ssa_value(b, w[5])->def);
+      break;
+
+   case SpvOpGroupNonUniformAll:
+   case SpvOpGroupNonUniformAny:
+   case SpvOpGroupNonUniformAllEqual:
    case SpvOpGroupNonUniformShuffle:
    case SpvOpGroupNonUniformShuffleXor:
    case SpvOpGroupNonUniformShuffleUp:
diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c
index 2b8c5f6..53952e8 100644
--- a/src/compiler/spirv/vtn_variables.c
+++ b/src/compiler/spirv/vtn_variables.c
@@ -1309,6 +1309,26 @@ vtn_get_builtin_location(struct vtn_builder *b,
       *location = SYSTEM_VALUE_VIEW_INDEX;
       set_mode_system_value(b, mode);
       break;
+   case SpvBuiltInSubgroupEqMask:
+      *location = SYSTEM_VALUE_SUBGROUP_EQ_MASK,
+      set_mode_system_value(b, mode);
+      break;
+   case SpvBuiltInSubgroupGeMask:
+      *location = SYSTEM_VALUE_SUBGROUP_GE_MASK,
+      set_mode_system_value(b, mode);
+      break;
+   case SpvBuiltInSubgroupGtMask:
+      *location = SYSTEM_VALUE_SUBGROUP_GT_MASK,
+      set_mode_system_value(b, mode);
+      break;
+   case SpvBuiltInSubgroupLeMask:
+      *location = SYSTEM_VALUE_SUBGROUP_LE_MASK,
+      set_mode_system_value(b, mode);
+      break;
+   case SpvBuiltInSubgroupLtMask:
+      *location = SYSTEM_VALUE_SUBGROUP_LT_MASK,
+      set_mode_system_value(b, mode);
+      break;
    default:
       vtn_fail("unsupported builtin");
    }
-- 
2.5.0.400.gff86faf



More information about the mesa-dev mailing list