Mesa (main): nir/algebraic: Add lowering for dot_4x8 instructions

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Tue Aug 24 20:19:21 UTC 2021


Module: Mesa
Branch: main
Commit: 839495efc67cce5b44e76abc35b3b355c75a88c7
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=839495efc67cce5b44e76abc35b3b355c75a88c7

Author: Ian Romanick <ian.d.romanick at intel.com>
Date:   Wed Jun  9 14:53:49 2021 -0700

nir/algebraic: Add lowering for dot_4x8 instructions

v2: Fix copy-and-paste bugs in lowering patterns.

v3: Add has_sudot_4x8 flag.  Requested by Rhys.

v4: Since the names of the opcodes changed from dp4 to dot_4x8, also
change the names of the lowering helpers.  Suggested by Jason.

Reviewed-by: Jason Ekstrand <jason at jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12142>

---

 src/compiler/nir/nir_opt_algebraic.py | 36 +++++++++++++++++++++++++++++++++++
 src/compiler/nir/nir_search_helpers.h | 21 ++++++++++++++++++++
 2 files changed, 57 insertions(+)

diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py
index 2f0a044e2a8..ee876b8b7ab 100644
--- a/src/compiler/nir/nir_opt_algebraic.py
+++ b/src/compiler/nir/nir_opt_algebraic.py
@@ -223,6 +223,42 @@ optimizations = [
    (('sudot_4x8_iadd_sat', '#a', '#b', 'c(is_not_const)'), ('iadd_sat', ('sudot_4x8_iadd', a, b, 0), c), '!options->lower_add_sat'),
 ]
 
+# Shorthand for the expansion of just the dot product part of the [iu]dp4a
+# instructions.
+sdot_4x8_a_b = ('iadd', ('iadd', ('imul', ('extract_i8', a, 0), ('extract_i8', b, 0)),
+                                 ('imul', ('extract_i8', a, 1), ('extract_i8', b, 1))),
+                        ('iadd', ('imul', ('extract_i8', a, 2), ('extract_i8', b, 2)),
+                                 ('imul', ('extract_i8', a, 3), ('extract_i8', b, 3))))
+udot_4x8_a_b = ('iadd', ('iadd', ('imul', ('extract_u8', a, 0), ('extract_u8', b, 0)),
+                                 ('imul', ('extract_u8', a, 1), ('extract_u8', b, 1))),
+                        ('iadd', ('imul', ('extract_u8', a, 2), ('extract_u8', b, 2)),
+                                 ('imul', ('extract_u8', a, 3), ('extract_u8', b, 3))))
+sudot_4x8_a_b = ('iadd', ('iadd', ('imul', ('extract_i8', a, 0), ('extract_u8', b, 0)),
+                                  ('imul', ('extract_i8', a, 1), ('extract_u8', b, 1))),
+                         ('iadd', ('imul', ('extract_i8', a, 2), ('extract_u8', b, 2)),
+                                  ('imul', ('extract_i8', a, 3), ('extract_u8', b, 3))))
+
+optimizations.extend([
+   (('sdot_4x8_iadd', a, b, c), ('iadd', sdot_4x8_a_b, c), '!options->has_dot_4x8'),
+   (('udot_4x8_uadd', a, b, c), ('iadd', udot_4x8_a_b, c), '!options->has_dot_4x8'),
+   (('sudot_4x8_iadd', a, b, c), ('iadd', sudot_4x8_a_b, c), '!options->has_sudot_4x8'),
+
+   # For the unsigned dot-product, the largest possible value 4*(255*255) =
+   # 0x3f804, so we don't have to worry about that intermediate result
+   # overflowing.  0x100000000 - 0x3f804 = 0xfffc07fc.  If c is a constant
+   # that is less than 0xfffc07fc, then the result cannot overflow ever.
+   (('udot_4x8_uadd_sat', a, b, '#c(is_ult_0xfffc07fc)'), ('udot_4x8_uadd', a, b, c)),
+   (('udot_4x8_uadd_sat', a, b, c), ('uadd_sat', udot_4x8_a_b, c), '!options->has_dot_4x8'),
+
+   # For the signed dot-product, the largest positive value is 4*(-128*-128) =
+   # 0x10000, and the largest negative value is 4*(-128*127) = -0xfe00.  We
+   # don't have to worry about that intermediate result overflowing or
+   # underflowing.
+   (('sdot_4x8_iadd_sat', a, b, c), ('iadd_sat', sdot_4x8_a_b, c), '!options->has_dot_4x8'),
+
+   (('sudot_4x8_iadd_sat', a, b, c), ('iadd_sat', sudot_4x8_a_b, c), '!options->has_sudot_4x8'),
+])
+
 # Float sizes
 for s in [16, 32, 64]:
     optimizations.extend([
diff --git a/src/compiler/nir/nir_search_helpers.h b/src/compiler/nir/nir_search_helpers.h
index 24938484377..1188b50ed2d 100644
--- a/src/compiler/nir/nir_search_helpers.h
+++ b/src/compiler/nir/nir_search_helpers.h
@@ -205,6 +205,27 @@ is_not_const_zero(UNUSED struct hash_table *ht, const nir_alu_instr *instr,
    return true;
 }
 
+/** Is value unsigned less than 0xfffc07fc? */
+static inline bool
+is_ult_0xfffc07fc(UNUSED struct hash_table *ht, const nir_alu_instr *instr,
+                  unsigned src, unsigned num_components,
+                  const uint8_t *swizzle)
+{
+   /* only constant srcs: */
+   if (!nir_src_is_const(instr->src[src].src))
+      return false;
+
+   for (unsigned i = 0; i < num_components; i++) {
+      const unsigned val =
+         nir_src_comp_as_uint(instr->src[src].src, swizzle[i]);
+
+      if (val >= 0xfffc07fcU)
+         return false;
+   }
+
+   return true;
+}
+
 static inline bool
 is_not_const(UNUSED struct hash_table *ht, const nir_alu_instr *instr,
              unsigned src, UNUSED unsigned num_components,



More information about the mesa-commit mailing list