Mesa (main): spirv: Add support for SPV_KHR_integer_dot_product

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Tue Aug 24 20:19:21 UTC 2021


Module: Mesa
Branch: main
Commit: fe956d0182aeb54d03fdc69711ceae15cc29168a
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=fe956d0182aeb54d03fdc69711ceae15cc29168a

Author: Ian Romanick <ian.d.romanick at intel.com>
Date:   Mon Jun 14 14:12:36 2021 -0700

spirv: Add support for SPV_KHR_integer_dot_product

v2 (Ivan): Add missing capability enum handling.

v3 (idr): Properly handle cases where dest_size != 32.

v4 (idr): Rewrite most of the error checking to use vtn_fail_if.  Use
nir_ssa_def with vtn_push_nir_ssa instead of vtn_ssa_value with
vtn_push_ssa_value.  All suggested by Jason.  Massive rewrite of the
handling of packed 4x8 saturating opcodes.  Based on some observations
made by Jason.

v5 (idr): Remove some debugging cruft accidentally added in v4.  Noticed
by Jason.

v6: Emit packed versions of vectored instructions when possible.
Suggested by Jason.

Reviewed-by: Jason Ekstrand <jason at jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12142>

---

 src/compiler/spirv/spirv_to_nir.c |  13 ++
 src/compiler/spirv/vtn_alu.c      | 272 ++++++++++++++++++++++++++++++++++++++
 src/compiler/spirv/vtn_private.h  |   3 +
 3 files changed, 288 insertions(+)

diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index 983d8f9f06f..a64039f9469 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -4364,6 +4364,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
       case SpvCapabilityImageGatherExtended:
       case SpvCapabilityStorageImageExtendedFormats:
       case SpvCapabilityVector16:
+      case SpvCapabilityDotProductKHR:
+      case SpvCapabilityDotProductInputAllKHR:
+      case SpvCapabilityDotProductInput4x8BitKHR:
+      case SpvCapabilityDotProductInput4x8BitPackedKHR:
          break;
 
       case SpvCapabilityLinkage:
@@ -5650,6 +5654,15 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
       vtn_handle_alu(b, opcode, w, count);
       break;
 
+   case SpvOpSDotKHR:
+   case SpvOpUDotKHR:
+   case SpvOpSUDotKHR:
+   case SpvOpSDotAccSatKHR:
+   case SpvOpUDotAccSatKHR:
+   case SpvOpSUDotAccSatKHR:
+      vtn_handle_integer_dot(b, opcode, w, count);
+      break;
+
    case SpvOpBitcast:
       vtn_handle_bitcast(b, w, count);
       break;
diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c
index 48f41ac249a..ed731184d2d 100644
--- a/src/compiler/spirv/vtn_alu.c
+++ b/src/compiler/spirv/vtn_alu.c
@@ -765,6 +765,14 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       break;
    }
 
+   case SpvOpSDotKHR:
+   case SpvOpUDotKHR:
+   case SpvOpSUDotKHR:
+   case SpvOpSDotAccSatKHR:
+   case SpvOpUDotAccSatKHR:
+   case SpvOpSUDotAccSatKHR:
+      unreachable("Should have called vtn_handle_integer_dot instead.");
+
    default: {
       bool swap;
       bool exact;
@@ -823,6 +831,270 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
    b->nb.exact = b->exact;
 }
 
+void
+vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
+                       const uint32_t *w, unsigned count)
+{
+   struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
+   const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
+   const unsigned dest_size = glsl_get_bit_size(dest_type);
+
+   vtn_handle_no_contraction(b, dest_val);
+
+   /* Collect the various SSA sources.
+    *
+    * Due to the optional "Packed Vector Format" field, determine number of
+    * inputs from the opcode.  This differs from vtn_handle_alu.
+    */
+   const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
+                                opcode == SpvOpUDotAccSatKHR ||
+                                opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
+
+   vtn_assert(count >= num_inputs + 3);
+
+   struct vtn_ssa_value *vtn_src[3] = { NULL, };
+   nir_ssa_def *src[3] = { NULL, };
+
+   for (unsigned i = 0; i < num_inputs; i++) {
+      vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
+      src[i] = vtn_src[i]->def;
+
+      vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
+   }
+
+   /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
+    * the SPV_KHR_integer_dot_product spec says:
+    *
+    *    _Vector 1_ and _Vector 2_ must have the same type.
+    *
+    * The practical requirement is the same bit-size and the same number of
+    * components.
+    */
+   vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
+               glsl_get_bit_size(vtn_src[1]->type) ||
+               glsl_get_vector_elements(vtn_src[0]->type) !=
+               glsl_get_vector_elements(vtn_src[1]->type),
+               "Vector 1 and vector 2 source of opcode %s must have the same "
+               "type",
+               spirv_op_to_string(opcode));
+
+   if (num_inputs == 3) {
+      /* The SPV_KHR_integer_dot_product spec says:
+       *
+       *    The type of Accumulator must be the same as Result Type.
+       *
+       * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
+       * types (far below) assumes these types have the same size.
+       */
+      vtn_fail_if(dest_type != vtn_src[2]->type,
+                  "Accumulator type must be the same as Result Type for "
+                  "opcode %s",
+                  spirv_op_to_string(opcode));
+   }
+
+   if (glsl_type_is_vector(vtn_src[0]->type)) {
+      /* FINISHME: Is this actually as good or better for platforms that don't
+       * have the special instructions (i.e., one or both of has_dot_4x8 or
+       * has_sudot_4x8 is false)?
+       */
+      if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
+          glsl_get_bit_size(vtn_src[0]->type) == 8 &&
+          glsl_get_bit_size(dest_type) <= 32) {
+         src[0] = nir_pack_32_4x8(&b->nb, src[0]);
+         src[1] = nir_pack_32_4x8(&b->nb, src[1]);
+      }
+   } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
+              glsl_type_is_32bit(vtn_src[0]->type)) {
+      /* The SPV_KHR_integer_dot_product spec says:
+       *
+       *    When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
+       *    Vector Format_ must be specified to select how the integers are to
+       *    be interpreted as vectors.
+       *
+       * The "Packed Vector Format" value follows the last input.
+       */
+      vtn_assert(count == (num_inputs + 4));
+      const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
+      vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
+                  "Unsupported vector packing format %d for opcode %s",
+                  pack_format, spirv_op_to_string(opcode));
+   } else {
+      vtn_fail_with_opcode("Invalid source types.", opcode);
+   }
+
+   nir_ssa_def *dest = NULL;
+
+   if (src[0]->num_components > 1) {
+      const nir_op s_conversion_op =
+         nir_type_conversion_op(nir_type_int, nir_type_int | dest_size,
+                                nir_rounding_mode_undef);
+
+      const nir_op u_conversion_op =
+         nir_type_conversion_op(nir_type_uint, nir_type_uint | dest_size,
+                                nir_rounding_mode_undef);
+
+      nir_op src0_conversion_op;
+      nir_op src1_conversion_op;
+
+      switch (opcode) {
+      case SpvOpSDotKHR:
+      case SpvOpSDotAccSatKHR:
+         src0_conversion_op = s_conversion_op;
+         src1_conversion_op = s_conversion_op;
+         break;
+
+      case SpvOpUDotKHR:
+      case SpvOpUDotAccSatKHR:
+         src0_conversion_op = u_conversion_op;
+         src1_conversion_op = u_conversion_op;
+         break;
+
+      case SpvOpSUDotKHR:
+      case SpvOpSUDotAccSatKHR:
+         src0_conversion_op = s_conversion_op;
+         src1_conversion_op = u_conversion_op;
+         break;
+
+      default:
+         unreachable("Invalid opcode.");
+      }
+
+      /* The SPV_KHR_integer_dot_product spec says:
+       *
+       *    All components of the input vectors are sign-extended to the bit
+       *    width of the result's type. The sign-extended input vectors are
+       *    then multiplied component-wise and all components of the vector
+       *    resulting from the component-wise multiplication are added
+       *    together. The resulting value will equal the low-order N bits of
+       *    the correct result R, where N is the result width and R is
+       *    computed with enough precision to avoid overflow and underflow.
+       */
+      const unsigned vector_components =
+         glsl_get_vector_elements(vtn_src[0]->type);
+
+      for (unsigned i = 0; i < vector_components; i++) {
+         nir_ssa_def *const src0 =
+            nir_build_alu(&b->nb, src0_conversion_op,
+                          nir_channel(&b->nb, src[0], i), NULL, NULL, NULL);
+
+         nir_ssa_def *const src1 =
+            nir_build_alu(&b->nb, src1_conversion_op,
+                          nir_channel(&b->nb, src[1], i), NULL, NULL, NULL);
+
+         nir_ssa_def *const mul_result = nir_imul(&b->nb, src0, src1);
+
+         dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
+      }
+
+      if (num_inputs == 3) {
+         /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
+          *
+          *    Signed integer dot product of _Vector 1_ and _Vector 2_ and
+          *    signed saturating addition of the result with _Accumulator_.
+          *
+          * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
+          *
+          *    Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
+          *    unsigned saturating addition of the result with _Accumulator_.
+          *
+          * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
+          *
+          *    Mixed-signedness integer dot product of _Vector 1_ and _Vector
+          *    2_ and signed saturating addition of the result with
+          *    _Accumulator_.
+          */
+         dest = (opcode == SpvOpUDotAccSatKHR)
+            ? nir_uadd_sat(&b->nb, dest, src[2])
+            : nir_iadd_sat(&b->nb, dest, src[2]);
+      }
+   } else {
+      assert(src[0]->num_components == 1 && src[1]->num_components == 1);
+      assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
+
+      nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32);
+      bool is_signed;
+
+      switch (opcode) {
+      case SpvOpSDotKHR:
+         dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
+         is_signed = true;
+         break;
+
+      case SpvOpUDotKHR:
+         dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
+         is_signed = false;
+         break;
+
+      case SpvOpSUDotKHR:
+         dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
+         is_signed = true;
+         break;
+
+      case SpvOpSDotAccSatKHR:
+         if (dest_size == 32)
+            dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
+         else
+            dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
+
+         is_signed = true;
+         break;
+
+      case SpvOpUDotAccSatKHR:
+         if (dest_size == 32)
+            dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
+         else
+            dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
+
+         is_signed = false;
+         break;
+
+      case SpvOpSUDotAccSatKHR:
+         if (dest_size == 32)
+            dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
+         else
+            dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
+
+         is_signed = true;
+         break;
+
+      default:
+         unreachable("Invalid opcode.");
+      }
+
+      if (dest_size != 32) {
+         /* When the accumulator is 32-bits, a NIR dot-product with saturate
+          * is generated above.  In all other cases a regular dot-product is
+          * generated above, and separate addition with saturate is generated
+          * here.
+          *
+          * The SPV_KHR_integer_dot_product spec says:
+          *
+          *    If any of the multiplications or additions, with the exception
+          *    of the final accumulation, overflow or underflow, the result of
+          *    the instruction is undefined.
+          *
+          * Therefore it is safe to cast the dot-product result down to the
+          * size of the accumulator before doing the addition.  Since the
+          * result of the dot-product cannot overflow 32-bits, this is also
+          * safe to cast up.
+          */
+         if (num_inputs == 3) {
+            dest = is_signed
+               ? nir_iadd_sat(&b->nb, nir_i2i(&b->nb, dest, dest_size), src[2])
+               : nir_uadd_sat(&b->nb, nir_u2u(&b->nb, dest, dest_size), src[2]);
+         } else {
+            dest = is_signed
+               ? nir_i2i(&b->nb, dest, dest_size)
+               : nir_u2u(&b->nb, dest, dest_size);
+         }
+      }
+   }
+
+   vtn_push_nir_ssa(b, w[2], dest);
+
+   b->nb.exact = b->exact;
+}
+
 void
 vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
 {
diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h
index f2cfe144405..d95e3c72e81 100644
--- a/src/compiler/spirv/vtn_private.h
+++ b/src/compiler/spirv/vtn_private.h
@@ -919,6 +919,9 @@ nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
 void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
                     const uint32_t *w, unsigned count);
 
+void vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
+                            const uint32_t *w, unsigned count);
+
 void vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w,
                         unsigned count);
 



More information about the mesa-commit mailing list