[Mesa-dev] [PATCH 48/56] spirv: Add support for subgroup arithmetic

Jason Ekstrand jason at jlekstrand.net
Wed Mar 7 14:35:36 UTC 2018


Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin at intel.com>
Reviewed-by: Iago Toral Quiroga <itoral at igalia.com>
---
 src/compiler/shader_info.h        |  1 +
 src/compiler/spirv/spirv_to_nir.c |  4 ++
 src/compiler/spirv/vtn_subgroup.c | 97 +++++++++++++++++++++++++++++++++++----
 3 files changed, 94 insertions(+), 8 deletions(-)

diff --git a/src/compiler/shader_info.h b/src/compiler/shader_info.h
index 140f661..e23690a 100644
--- a/src/compiler/shader_info.h
+++ b/src/compiler/shader_info.h
@@ -44,6 +44,7 @@ struct spirv_supported_capabilities {
    bool multiview;
    bool variable_pointers;
    bool storage_16bit;
+   bool subgroup_arithmetic;
    bool subgroup_ballot;
    bool subgroup_basic;
    bool subgroup_quad;
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index 7019ab8..ee7a900 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -3313,6 +3313,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
       case SpvCapabilityGroupNonUniformQuad:
          spv_check_supported(subgroup_quad, cap);
 
+      case SpvCapabilityGroupNonUniformArithmetic:
+      case SpvCapabilityGroupNonUniformClustered:
+         spv_check_supported(subgroup_arithmetic, cap);
+
       case SpvCapabilityVariablePointersStorageBuffer:
       case SpvCapabilityVariablePointers:
          spv_check_supported(variable_pointers, cap);
diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c
index 1204c59..bd31439 100644
--- a/src/compiler/spirv/vtn_subgroup.c
+++ b/src/compiler/spirv/vtn_subgroup.c
@@ -28,7 +28,9 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
                          nir_intrinsic_op nir_op,
                          struct vtn_ssa_value *dst,
                          struct vtn_ssa_value *src0,
-                         nir_ssa_def *index)
+                         nir_ssa_def *index,
+                         unsigned const_idx0,
+                         unsigned const_idx1)
 {
    /* Some of the subgroup operations take an index.  SPIR-V allows this to be
     * any integer type.  To make things simpler for drivers, we only support
@@ -41,7 +43,8 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
    if (!glsl_type_is_vector_or_scalar(dst->type)) {
       for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
          vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
-                                  src0->elems[i], index);
+                                  src0->elems[i], index,
+                                  const_idx0, const_idx1);
       }
       return;
    }
@@ -56,6 +59,9 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
    if (index)
       intrin->src[1] = nir_src_for_ssa(index);
 
+   intrin->const_index[0] = const_idx0;
+   intrin->const_index[1] = const_idx1;
+
    nir_builder_instr_insert(&b->nb, &intrin->instr);
 
    dst->def = &intrin->dest.ssa;
@@ -169,13 +175,13 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
 
    case SpvOpGroupNonUniformBroadcastFirst:
       vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
-                               val->ssa, vtn_ssa_value(b, w[4]), NULL);
+                               val->ssa, vtn_ssa_value(b, w[4]), NULL, 0, 0);
       break;
 
    case SpvOpGroupNonUniformBroadcast:
       vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
                                val->ssa, vtn_ssa_value(b, w[4]),
-                               vtn_ssa_value(b, w[5])->def);
+                               vtn_ssa_value(b, w[5])->def, 0, 0);
       break;
 
    case SpvOpGroupNonUniformAll:
@@ -248,14 +254,14 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
          unreachable("Invalid opcode");
       }
       vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
-                               vtn_ssa_value(b, w[5])->def);
+                               vtn_ssa_value(b, w[5])->def, 0, 0);
       break;
    }
 
    case SpvOpGroupNonUniformQuadBroadcast:
       vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
                                val->ssa, vtn_ssa_value(b, w[4]),
-                               vtn_ssa_value(b, w[5])->def);
+                               vtn_ssa_value(b, w[5])->def, 0, 0);
       break;
 
    case SpvOpGroupNonUniformQuadSwap: {
@@ -272,7 +278,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
          op = nir_intrinsic_quad_swap_diagonal;
          break;
       }
-      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]), NULL);
+      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
+                               NULL, 0, 0);
       break;
    }
 
@@ -291,7 +298,81 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
    case SpvOpGroupNonUniformBitwiseXor:
    case SpvOpGroupNonUniformLogicalAnd:
    case SpvOpGroupNonUniformLogicalOr:
-   case SpvOpGroupNonUniformLogicalXor:
+   case SpvOpGroupNonUniformLogicalXor: {
+      nir_op reduction_op;
+      switch (opcode) {
+      case SpvOpGroupNonUniformIAdd:
+         reduction_op = nir_op_iadd;
+         break;
+      case SpvOpGroupNonUniformFAdd:
+         reduction_op = nir_op_fadd;
+         break;
+      case SpvOpGroupNonUniformIMul:
+         reduction_op = nir_op_imul;
+         break;
+      case SpvOpGroupNonUniformFMul:
+         reduction_op = nir_op_fmul;
+         break;
+      case SpvOpGroupNonUniformSMin:
+         reduction_op = nir_op_imin;
+         break;
+      case SpvOpGroupNonUniformUMin:
+         reduction_op = nir_op_umin;
+         break;
+      case SpvOpGroupNonUniformFMin:
+         reduction_op = nir_op_fmin;
+         break;
+      case SpvOpGroupNonUniformSMax:
+         reduction_op = nir_op_imax;
+         break;
+      case SpvOpGroupNonUniformUMax:
+         reduction_op = nir_op_umax;
+         break;
+      case SpvOpGroupNonUniformFMax:
+         reduction_op = nir_op_fmax;
+         break;
+      case SpvOpGroupNonUniformBitwiseAnd:
+      case SpvOpGroupNonUniformLogicalAnd:
+         reduction_op = nir_op_iand;
+         break;
+      case SpvOpGroupNonUniformBitwiseOr:
+      case SpvOpGroupNonUniformLogicalOr:
+         reduction_op = nir_op_ior;
+         break;
+      case SpvOpGroupNonUniformBitwiseXor:
+      case SpvOpGroupNonUniformLogicalXor:
+         reduction_op = nir_op_ixor;
+         break;
+      default:
+         unreachable("Invalid reduction operation");
+      }
+
+      nir_intrinsic_op op;
+      unsigned cluster_size = 0;
+      switch ((SpvGroupOperation)w[4]) {
+      case SpvGroupOperationReduce:
+         op = nir_intrinsic_reduce;
+         break;
+      case SpvGroupOperationInclusiveScan:
+         op = nir_intrinsic_inclusive_scan;
+         break;
+      case SpvGroupOperationExclusiveScan:
+         op = nir_intrinsic_exclusive_scan;
+         break;
+      case SpvGroupOperationClusteredReduce:
+         op = nir_intrinsic_reduce;
+         assert(count == 7);
+         cluster_size = vtn_constant_value(b, w[6])->values[0].u32[0];
+         break;
+      default:
+         unreachable("Invalid group operation");
+      }
+
+      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
+                               NULL, reduction_op, cluster_size);
+      break;
+   }
+
    default:
       unreachable("Invalid SPIR-V opcode");
    }
-- 
2.5.0.400.gff86faf



More information about the mesa-dev mailing list