Mesa (main): nir/opcodes: Add integer dot-product opcodes

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Tue Aug 24 20:19:20 UTC 2021


Module: Mesa
Branch: main
Commit: 6c18a3b497a68ad5143d9de8d539237e3ab1c8e2
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=6c18a3b497a68ad5143d9de8d539237e3ab1c8e2

Author: Ian Romanick <ian.d.romanick at intel.com>
Date:   Tue Feb 23 17:33:04 2021 -0800

nir/opcodes: Add integer dot-product opcodes

Six opcodes are added: sdot_4x8_iadd, udot_4x8_uadd, sudot_4x8_iadd,
sdot_4x8_iadd_sat, udot_4x8_uadd_sate, and sudot_4x8_iadd_sat.  These
represent the combinations of integer dot-product and add that operate
on packed source vectors.  That is, the four 8-bit values for each
vector is stored in a single 32-bit integer.

Some hardware may prefer to operate on unpacked byte vectors.  When such
hardware comes to Mesa, we'll have to figure out how to name things.

v2: Add nir_op_iudp4a and nir_op_iudp4a_sat instructions.  These opcodes
are not 2-source commutative.

v3: Rename all opcodes to be more like some existing 4x8 opcodes.
Suggested by Timur.  Change type of packed vector sources to uint32,
change types of constant folding variables to have explicit size, and
delete some extra casts.  All suggested by Jason.

v4: Fix typo previously noticed by Alyssa but missed in v2.

v5: Add has_sudot_4x8 flag.  Requested by Rhys.

Reviewed-by: Jason Ekstrand <jason at jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12142>

---

 src/compiler/nir/nir.h          |   6 +++
 src/compiler/nir/nir_opcodes.py | 107 ++++++++++++++++++++++++++++++++++++++++
 2 files changed, 113 insertions(+)

diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 0ee29855fd5..04ca38fb014 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -3688,6 +3688,12 @@ typedef struct nir_shader_compiler_options {
     * for rect texture lowering. */
    bool has_txs;
 
+   /** Backend supports sdot_4x8 and udot_4x8 opcodes. */
+   bool has_dot_4x8;
+
+   /** Backend supports sudot_4x8 opcodes. */
+   bool has_sudot_4x8;
+
    /* Whether to generate only scoped_barrier intrinsics instead of the set of
     * memory and control barrier intrinsics based on GLSL.
     */
diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py
index 6b2fc24300a..527cf6bb56a 100644
--- a/src/compiler/nir/nir_opcodes.py
+++ b/src/compiler/nir/nir_opcodes.py
@@ -1314,3 +1314,110 @@ unop_horiz("pack_double_2x32_dxil", 1, tuint64, 2, tuint32,
            "dst.x = src0.x | ((uint64_t)src0.y << 32);")
 unop_horiz("unpack_double_2x32_dxil", 2, tuint32, 1, tuint64,
            "dst.x = src0.x; dst.y = src0.x >> 32;")
+
+# src0 and src1 are i8vec4 packed in an int32, and src2 is an int32.  The int8
+# components are sign-extended to 32-bits, and a dot-product is performed on
+# the resulting vectors.  src2 is added to the result of the dot-product.
+opcode("sdot_4x8_iadd", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
+       False, _2src_commutative, """
+   const int32_t v0x = (int8_t)(src0      );
+   const int32_t v0y = (int8_t)(src0 >>  8);
+   const int32_t v0z = (int8_t)(src0 >> 16);
+   const int32_t v0w = (int8_t)(src0 >> 24);
+   const int32_t v1x = (int8_t)(src1      );
+   const int32_t v1y = (int8_t)(src1 >>  8);
+   const int32_t v1z = (int8_t)(src1 >> 16);
+   const int32_t v1w = (int8_t)(src1 >> 24);
+
+   dst = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
+""")
+
+# Like sdot_4x8_iadd, but unsigned.
+opcode("udot_4x8_uadd", 0, tuint32, [0, 0, 0], [tuint32, tuint32, tuint32],
+       False, _2src_commutative, """
+   const uint32_t v0x = (uint8_t)(src0      );
+   const uint32_t v0y = (uint8_t)(src0 >>  8);
+   const uint32_t v0z = (uint8_t)(src0 >> 16);
+   const uint32_t v0w = (uint8_t)(src0 >> 24);
+   const uint32_t v1x = (uint8_t)(src1      );
+   const uint32_t v1y = (uint8_t)(src1 >>  8);
+   const uint32_t v1z = (uint8_t)(src1 >> 16);
+   const uint32_t v1w = (uint8_t)(src1 >> 24);
+
+   dst = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
+""")
+
+# src0 is i8vec4 packed in an int32, src1 is u8vec4 packed in an int32, and
+# src2 is an int32.  The 8-bit components are extended to 32-bits, and a
+# dot-product is performed on the resulting vectors.  src2 is added to the
+# result of the dot-product.
+#
+# NOTE: Unlike many of the other dp4a opcodes, this mixed signs of source 0
+# and source 1 mean that this opcode is not 2-source commutative
+opcode("sudot_4x8_iadd", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
+       False, "", """
+   const int32_t v0x = (int8_t)(src0      );
+   const int32_t v0y = (int8_t)(src0 >>  8);
+   const int32_t v0z = (int8_t)(src0 >> 16);
+   const int32_t v0w = (int8_t)(src0 >> 24);
+   const uint32_t v1x = (uint8_t)(src1      );
+   const uint32_t v1y = (uint8_t)(src1 >>  8);
+   const uint32_t v1z = (uint8_t)(src1 >> 16);
+   const uint32_t v1w = (uint8_t)(src1 >> 24);
+
+   dst = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
+""")
+
+# Like sdot_4x8_iadd, but the result is clampled to the range [-0x80000000, 0x7ffffffff].
+opcode("sdot_4x8_iadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
+       False, _2src_commutative, """
+   const int64_t v0x = (int8_t)(src0      );
+   const int64_t v0y = (int8_t)(src0 >>  8);
+   const int64_t v0z = (int8_t)(src0 >> 16);
+   const int64_t v0w = (int8_t)(src0 >> 24);
+   const int64_t v1x = (int8_t)(src1      );
+   const int64_t v1y = (int8_t)(src1 >>  8);
+   const int64_t v1z = (int8_t)(src1 >> 16);
+   const int64_t v1w = (int8_t)(src1 >> 24);
+
+   const int64_t tmp = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
+
+   dst = tmp >= INT32_MAX ? INT32_MAX : (tmp <= INT32_MIN ? INT32_MIN : tmp);
+""")
+
+# Like udot_4x8_uadd, but the result is clampled to the range [0, 0xfffffffff].
+opcode("udot_4x8_uadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
+       False, _2src_commutative, """
+   const uint64_t v0x = (uint8_t)(src0      );
+   const uint64_t v0y = (uint8_t)(src0 >>  8);
+   const uint64_t v0z = (uint8_t)(src0 >> 16);
+   const uint64_t v0w = (uint8_t)(src0 >> 24);
+   const uint64_t v1x = (uint8_t)(src1      );
+   const uint64_t v1y = (uint8_t)(src1 >>  8);
+   const uint64_t v1z = (uint8_t)(src1 >> 16);
+   const uint64_t v1w = (uint8_t)(src1 >> 24);
+
+   const uint64_t tmp = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
+
+   dst = tmp >= UINT32_MAX ? UINT32_MAX : tmp;
+""")
+
+# Like sudot_4x8_iadd, but the result is clampled to the range [-0x80000000, 0x7ffffffff].
+#
+# NOTE: Unlike many of the other dp4a opcodes, this mixed signs of source 0
+# and source 1 mean that this opcode is not 2-source commutative
+opcode("sudot_4x8_iadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
+       False, "", """
+   const int64_t v0x = (int8_t)(src0      );
+   const int64_t v0y = (int8_t)(src0 >>  8);
+   const int64_t v0z = (int8_t)(src0 >> 16);
+   const int64_t v0w = (int8_t)(src0 >> 24);
+   const uint64_t v1x = (uint8_t)(src1      );
+   const uint64_t v1y = (uint8_t)(src1 >>  8);
+   const uint64_t v1z = (uint8_t)(src1 >> 16);
+   const uint64_t v1w = (uint8_t)(src1 >> 24);
+
+   const int64_t tmp = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
+
+   dst = tmp >= INT32_MAX ? INT32_MAX : (tmp <= INT32_MIN ? INT32_MIN : tmp);
+""")



More information about the mesa-commit mailing list