Mesa (main): spirv: Add support for SPV_KHR_integer_dot_product
GitLab Mirror
gitlab-mirror at kemper.freedesktop.org
Tue Aug 24 20:19:21 UTC 2021
Module: Mesa
Branch: main
Commit: fe956d0182aeb54d03fdc69711ceae15cc29168a
URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=fe956d0182aeb54d03fdc69711ceae15cc29168a
Author: Ian Romanick <ian.d.romanick at intel.com>
Date: Mon Jun 14 14:12:36 2021 -0700
spirv: Add support for SPV_KHR_integer_dot_product
v2 (Ivan): Add missing capability enum handling.
v3 (idr): Properly handle cases where dest_size != 32.
v4 (idr): Rewrite most of the error checking to use vtn_fail_if. Use
nir_ssa_def with vtn_push_nir_ssa instead of vtn_ssa_value with
vtn_push_ssa_value. All suggested by Jason. Massive rewrite of the
handling of packed 4x8 saturating opcodes. Based on some observations
made by Jason.
v5 (idr): Remove some debugging cruft accidentally added in v4. Noticed
by Jason.
v6: Emit packed versions of vectored instructions when possible.
Suggested by Jason.
Reviewed-by: Jason Ekstrand <jason at jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12142>
---
src/compiler/spirv/spirv_to_nir.c | 13 ++
src/compiler/spirv/vtn_alu.c | 272 ++++++++++++++++++++++++++++++++++++++
src/compiler/spirv/vtn_private.h | 3 +
3 files changed, 288 insertions(+)
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index 983d8f9f06f..a64039f9469 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -4364,6 +4364,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvCapabilityImageGatherExtended:
case SpvCapabilityStorageImageExtendedFormats:
case SpvCapabilityVector16:
+ case SpvCapabilityDotProductKHR:
+ case SpvCapabilityDotProductInputAllKHR:
+ case SpvCapabilityDotProductInput4x8BitKHR:
+ case SpvCapabilityDotProductInput4x8BitPackedKHR:
break;
case SpvCapabilityLinkage:
@@ -5650,6 +5654,15 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
vtn_handle_alu(b, opcode, w, count);
break;
+ case SpvOpSDotKHR:
+ case SpvOpUDotKHR:
+ case SpvOpSUDotKHR:
+ case SpvOpSDotAccSatKHR:
+ case SpvOpUDotAccSatKHR:
+ case SpvOpSUDotAccSatKHR:
+ vtn_handle_integer_dot(b, opcode, w, count);
+ break;
+
case SpvOpBitcast:
vtn_handle_bitcast(b, w, count);
break;
diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c
index 48f41ac249a..ed731184d2d 100644
--- a/src/compiler/spirv/vtn_alu.c
+++ b/src/compiler/spirv/vtn_alu.c
@@ -765,6 +765,14 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
break;
}
+ case SpvOpSDotKHR:
+ case SpvOpUDotKHR:
+ case SpvOpSUDotKHR:
+ case SpvOpSDotAccSatKHR:
+ case SpvOpUDotAccSatKHR:
+ case SpvOpSUDotAccSatKHR:
+ unreachable("Should have called vtn_handle_integer_dot instead.");
+
default: {
bool swap;
bool exact;
@@ -823,6 +831,270 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
b->nb.exact = b->exact;
}
+void
+vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
+ const uint32_t *w, unsigned count)
+{
+ struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
+ const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
+ const unsigned dest_size = glsl_get_bit_size(dest_type);
+
+ vtn_handle_no_contraction(b, dest_val);
+
+ /* Collect the various SSA sources.
+ *
+ * Due to the optional "Packed Vector Format" field, determine number of
+ * inputs from the opcode. This differs from vtn_handle_alu.
+ */
+ const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
+ opcode == SpvOpUDotAccSatKHR ||
+ opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
+
+ vtn_assert(count >= num_inputs + 3);
+
+ struct vtn_ssa_value *vtn_src[3] = { NULL, };
+ nir_ssa_def *src[3] = { NULL, };
+
+ for (unsigned i = 0; i < num_inputs; i++) {
+ vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
+ src[i] = vtn_src[i]->def;
+
+ vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
+ }
+
+ /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
+ * the SPV_KHR_integer_dot_product spec says:
+ *
+ * _Vector 1_ and _Vector 2_ must have the same type.
+ *
+ * The practical requirement is the same bit-size and the same number of
+ * components.
+ */
+ vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
+ glsl_get_bit_size(vtn_src[1]->type) ||
+ glsl_get_vector_elements(vtn_src[0]->type) !=
+ glsl_get_vector_elements(vtn_src[1]->type),
+ "Vector 1 and vector 2 source of opcode %s must have the same "
+ "type",
+ spirv_op_to_string(opcode));
+
+ if (num_inputs == 3) {
+ /* The SPV_KHR_integer_dot_product spec says:
+ *
+ * The type of Accumulator must be the same as Result Type.
+ *
+ * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
+ * types (far below) assumes these types have the same size.
+ */
+ vtn_fail_if(dest_type != vtn_src[2]->type,
+ "Accumulator type must be the same as Result Type for "
+ "opcode %s",
+ spirv_op_to_string(opcode));
+ }
+
+ if (glsl_type_is_vector(vtn_src[0]->type)) {
+ /* FINISHME: Is this actually as good or better for platforms that don't
+ * have the special instructions (i.e., one or both of has_dot_4x8 or
+ * has_sudot_4x8 is false)?
+ */
+ if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
+ glsl_get_bit_size(vtn_src[0]->type) == 8 &&
+ glsl_get_bit_size(dest_type) <= 32) {
+ src[0] = nir_pack_32_4x8(&b->nb, src[0]);
+ src[1] = nir_pack_32_4x8(&b->nb, src[1]);
+ }
+ } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
+ glsl_type_is_32bit(vtn_src[0]->type)) {
+ /* The SPV_KHR_integer_dot_product spec says:
+ *
+ * When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
+ * Vector Format_ must be specified to select how the integers are to
+ * be interpreted as vectors.
+ *
+ * The "Packed Vector Format" value follows the last input.
+ */
+ vtn_assert(count == (num_inputs + 4));
+ const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
+ vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
+ "Unsupported vector packing format %d for opcode %s",
+ pack_format, spirv_op_to_string(opcode));
+ } else {
+ vtn_fail_with_opcode("Invalid source types.", opcode);
+ }
+
+ nir_ssa_def *dest = NULL;
+
+ if (src[0]->num_components > 1) {
+ const nir_op s_conversion_op =
+ nir_type_conversion_op(nir_type_int, nir_type_int | dest_size,
+ nir_rounding_mode_undef);
+
+ const nir_op u_conversion_op =
+ nir_type_conversion_op(nir_type_uint, nir_type_uint | dest_size,
+ nir_rounding_mode_undef);
+
+ nir_op src0_conversion_op;
+ nir_op src1_conversion_op;
+
+ switch (opcode) {
+ case SpvOpSDotKHR:
+ case SpvOpSDotAccSatKHR:
+ src0_conversion_op = s_conversion_op;
+ src1_conversion_op = s_conversion_op;
+ break;
+
+ case SpvOpUDotKHR:
+ case SpvOpUDotAccSatKHR:
+ src0_conversion_op = u_conversion_op;
+ src1_conversion_op = u_conversion_op;
+ break;
+
+ case SpvOpSUDotKHR:
+ case SpvOpSUDotAccSatKHR:
+ src0_conversion_op = s_conversion_op;
+ src1_conversion_op = u_conversion_op;
+ break;
+
+ default:
+ unreachable("Invalid opcode.");
+ }
+
+ /* The SPV_KHR_integer_dot_product spec says:
+ *
+ * All components of the input vectors are sign-extended to the bit
+ * width of the result's type. The sign-extended input vectors are
+ * then multiplied component-wise and all components of the vector
+ * resulting from the component-wise multiplication are added
+ * together. The resulting value will equal the low-order N bits of
+ * the correct result R, where N is the result width and R is
+ * computed with enough precision to avoid overflow and underflow.
+ */
+ const unsigned vector_components =
+ glsl_get_vector_elements(vtn_src[0]->type);
+
+ for (unsigned i = 0; i < vector_components; i++) {
+ nir_ssa_def *const src0 =
+ nir_build_alu(&b->nb, src0_conversion_op,
+ nir_channel(&b->nb, src[0], i), NULL, NULL, NULL);
+
+ nir_ssa_def *const src1 =
+ nir_build_alu(&b->nb, src1_conversion_op,
+ nir_channel(&b->nb, src[1], i), NULL, NULL, NULL);
+
+ nir_ssa_def *const mul_result = nir_imul(&b->nb, src0, src1);
+
+ dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
+ }
+
+ if (num_inputs == 3) {
+ /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
+ *
+ * Signed integer dot product of _Vector 1_ and _Vector 2_ and
+ * signed saturating addition of the result with _Accumulator_.
+ *
+ * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
+ *
+ * Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
+ * unsigned saturating addition of the result with _Accumulator_.
+ *
+ * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
+ *
+ * Mixed-signedness integer dot product of _Vector 1_ and _Vector
+ * 2_ and signed saturating addition of the result with
+ * _Accumulator_.
+ */
+ dest = (opcode == SpvOpUDotAccSatKHR)
+ ? nir_uadd_sat(&b->nb, dest, src[2])
+ : nir_iadd_sat(&b->nb, dest, src[2]);
+ }
+ } else {
+ assert(src[0]->num_components == 1 && src[1]->num_components == 1);
+ assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
+
+ nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32);
+ bool is_signed;
+
+ switch (opcode) {
+ case SpvOpSDotKHR:
+ dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
+ is_signed = true;
+ break;
+
+ case SpvOpUDotKHR:
+ dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
+ is_signed = false;
+ break;
+
+ case SpvOpSUDotKHR:
+ dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
+ is_signed = true;
+ break;
+
+ case SpvOpSDotAccSatKHR:
+ if (dest_size == 32)
+ dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
+ else
+ dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
+
+ is_signed = true;
+ break;
+
+ case SpvOpUDotAccSatKHR:
+ if (dest_size == 32)
+ dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
+ else
+ dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
+
+ is_signed = false;
+ break;
+
+ case SpvOpSUDotAccSatKHR:
+ if (dest_size == 32)
+ dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
+ else
+ dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
+
+ is_signed = true;
+ break;
+
+ default:
+ unreachable("Invalid opcode.");
+ }
+
+ if (dest_size != 32) {
+ /* When the accumulator is 32-bits, a NIR dot-product with saturate
+ * is generated above. In all other cases a regular dot-product is
+ * generated above, and separate addition with saturate is generated
+ * here.
+ *
+ * The SPV_KHR_integer_dot_product spec says:
+ *
+ * If any of the multiplications or additions, with the exception
+ * of the final accumulation, overflow or underflow, the result of
+ * the instruction is undefined.
+ *
+ * Therefore it is safe to cast the dot-product result down to the
+ * size of the accumulator before doing the addition. Since the
+ * result of the dot-product cannot overflow 32-bits, this is also
+ * safe to cast up.
+ */
+ if (num_inputs == 3) {
+ dest = is_signed
+ ? nir_iadd_sat(&b->nb, nir_i2i(&b->nb, dest, dest_size), src[2])
+ : nir_uadd_sat(&b->nb, nir_u2u(&b->nb, dest, dest_size), src[2]);
+ } else {
+ dest = is_signed
+ ? nir_i2i(&b->nb, dest, dest_size)
+ : nir_u2u(&b->nb, dest, dest_size);
+ }
+ }
+ }
+
+ vtn_push_nir_ssa(b, w[2], dest);
+
+ b->nb.exact = b->exact;
+}
+
void
vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
{
diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h
index f2cfe144405..d95e3c72e81 100644
--- a/src/compiler/spirv/vtn_private.h
+++ b/src/compiler/spirv/vtn_private.h
@@ -919,6 +919,9 @@ nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count);
+void vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
+ const uint32_t *w, unsigned count);
+
void vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w,
unsigned count);
More information about the mesa-commit
mailing list