Mesa (main): nir/opt_vectorize: add callback for max vectorization width

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Jun 1 12:27:22 UTC 2022


Module: Mesa
Branch: main
Commit: bd151a256ebe5dba2f159d1cc72a8777b1a5b84c
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=bd151a256ebe5dba2f159d1cc72a8777b1a5b84c

Author: Daniel Schürmann <daniel at schuermann.dev>
Date:   Fri Dec 18 19:05:47 2020 +0100

nir/opt_vectorize: add callback for max vectorization width

The callback allows to request different vectorization factors
per instruction depending on e.g. bitsize or opcode.

This patch also removes using the vectorize_vec2_16bit option
from nir_opt_vectorize().

Reviewed-by: Alyssa Rosenzweig <alyssa at collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13080>

---

 src/amd/vulkan/radv_pipeline.c               | 18 ++++---
 src/compiler/nir/nir.h                       | 19 +++++--
 src/compiler/nir/nir_opt_vectorize.c         | 78 +++++++++++++---------------
 src/gallium/auxiliary/nir/nir_to_tgsi.c      | 12 ++---
 src/gallium/drivers/radeonsi/si_shader_nir.c | 14 ++++-
 src/mesa/state_tracker/st_glsl_to_nir.cpp    |  2 +-
 src/panfrost/bifrost/bifrost_compile.c       | 17 ++++--
 src/panfrost/midgard/midgard_compile.c       | 17 +++---
 8 files changed, 98 insertions(+), 79 deletions(-)

diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c
index d02e2e36c14..5cf91b0a6e7 100644
--- a/src/amd/vulkan/radv_pipeline.c
+++ b/src/amd/vulkan/radv_pipeline.c
@@ -4042,14 +4042,16 @@ lower_bit_size_callback(const nir_instr *instr, void *_)
    return 0;
 }
 
-static bool
-opt_vectorize_callback(const nir_instr *instr, void *_)
+static uint8_t
+opt_vectorize_callback(const nir_instr *instr, const void *_)
 {
-   assert(instr->type == nir_instr_type_alu);
-   nir_alu_instr *alu = nir_instr_as_alu(instr);
-   unsigned bit_size = alu->dest.dest.ssa.bit_size;
+   if (instr->type != nir_instr_type_alu)
+      return 0;
+
+   const nir_alu_instr *alu = nir_instr_as_alu(instr);
+   const unsigned bit_size = alu->dest.dest.ssa.bit_size;
    if (bit_size != 16)
-      return false;
+      return 1;
 
    switch (alu->op) {
    case nir_op_fadd:
@@ -4069,12 +4071,12 @@ opt_vectorize_callback(const nir_instr *instr, void *_)
    case nir_op_imax:
    case nir_op_umin:
    case nir_op_umax:
-      return true;
+      return 2;
    case nir_op_ishl: /* TODO: in NIR, these have 32bit shift operands */
    case nir_op_ishr: /* while Radeon needs 16bit operands when vectorized */
    case nir_op_ushr:
    default:
-      return false;
+      return 1;
    }
 }
 
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index a3e7485ff30..3ab0696a5a5 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -3228,6 +3228,15 @@ typedef enum {
  */
 typedef bool (*nir_instr_filter_cb)(const nir_instr *, const void *);
 
+/** A vectorization width callback
+ *
+ * Returns the maximum vectorization width per instruction.
+ * 0, if the instruction must not be modified.
+ *
+ * The vectorization width must be a power of 2.
+ */
+typedef uint8_t (*nir_vectorize_cb)(const nir_instr *, const void *);
+
 typedef struct nir_shader_compiler_options {
    bool lower_fdiv;
    bool lower_ffma16;
@@ -3455,7 +3464,11 @@ typedef struct nir_shader_compiler_options {
    nir_instr_filter_cb lower_to_scalar_filter;
 
    /**
-    * Whether nir_opt_vectorize should only create 16-bit 2D vectors.
+    * Disables potentially harmful algebraic transformations for architectures
+    * with SIMD-within-a-register semantics.
+    *
+    * Note, to actually vectorize 16bit instructions, use nir_opt_vectorize()
+    * with a suitable callback function.
     */
    bool vectorize_vec2_16bit;
 
@@ -5485,9 +5498,7 @@ bool nir_lower_undef_to_zero(nir_shader *shader);
 
 bool nir_opt_uniform_atomics(nir_shader *shader);
 
-typedef bool (*nir_opt_vectorize_cb)(const nir_instr *instr, void *data);
-
-bool nir_opt_vectorize(nir_shader *shader, nir_opt_vectorize_cb filter,
+bool nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter,
                        void *data);
 
 bool nir_opt_conditional_discard(nir_shader *shader);
diff --git a/src/compiler/nir/nir_opt_vectorize.c b/src/compiler/nir/nir_opt_vectorize.c
index 83c841ee63e..dc6e1d84b52 100644
--- a/src/compiler/nir/nir_opt_vectorize.c
+++ b/src/compiler/nir/nir_opt_vectorize.c
@@ -22,6 +22,16 @@
  *
  */
 
+/**
+ * nir_opt_vectorize() aims to vectorize ALU instructions.
+ *
+ * The default vectorization width is 4.
+ * If desired, a callback function which returns the max vectorization width
+ * per instruction can be provided.
+ *
+ * The max vectorization width must be a power of 2.
+ */
+
 #include "nir.h"
 #include "nir_vla.h"
 #include "nir_builder.h"
@@ -125,7 +135,7 @@ instrs_equal(const void *data1, const void *data2)
 }
 
 static bool
-instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
+instr_can_rewrite(nir_instr *instr)
 {
    switch (instr->type) {
    case nir_instr_type_alu: {
@@ -139,12 +149,7 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
          return false;
 
       /* no need to hash instructions which are already vectorized */
-      if (alu->dest.dest.ssa.num_components >= 4)
-         return false;
-
-      if (vectorize_16bit &&
-          (alu->dest.dest.ssa.num_components >= 2 ||
-           alu->dest.dest.ssa.bit_size != 16))
+      if (alu->dest.dest.ssa.num_components >= instr->pass_flags)
          return false;
 
       if (nir_op_infos[alu->op].output_size != 0)
@@ -156,8 +161,8 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
 
          /* don't hash instructions which are already swizzled
           * outside of max_components: these should better be scalarized */
-         uint32_t mask = vectorize_16bit ? ~1 : ~3;
-         for (unsigned j = 0; j < alu->dest.dest.ssa.num_components; j++) {
+         uint32_t mask = ~(instr->pass_flags - 1);
+         for (unsigned j = 1; j < alu->dest.dest.ssa.num_components; j++) {
             if ((alu->src[i].swizzle[0] & mask) != (alu->src[i].swizzle[j] & mask))
                return false;
          }
@@ -179,10 +184,8 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
  * the same instructions into one vectorized instruction. Note that instr1
  * should dominate instr2.
  */
-
 static nir_instr *
-instr_try_combine(struct nir_shader *nir, struct set *instr_set,
-                  nir_instr *instr1, nir_instr *instr2)
+instr_try_combine(struct set *instr_set, nir_instr *instr1, nir_instr *instr2)
 {
    assert(instr1->type == nir_instr_type_alu);
    assert(instr2->type == nir_instr_type_alu);
@@ -194,14 +197,10 @@ instr_try_combine(struct nir_shader *nir, struct set *instr_set,
    unsigned alu2_components = alu2->dest.dest.ssa.num_components;
    unsigned total_components = alu1_components + alu2_components;
 
-   if (total_components > 4)
+   assert(instr1->pass_flags == instr2->pass_flags);
+   if (total_components > instr1->pass_flags)
       return NULL;
 
-   if (nir->options->vectorize_vec2_16bit) {
-      assert(total_components == 2);
-      assert(alu1->dest.dest.ssa.bit_size == 16);
-   }
-
    nir_builder b;
    nir_builder_init(&b, nir_cf_node_get_function(&instr1->block->cf_node));
    b.cursor = nir_after_instr(instr1);
@@ -352,28 +351,23 @@ vec_instr_set_destroy(struct set *instr_set)
 }
 
 static bool
-vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set,
-                             nir_instr *instr,
-                             nir_opt_vectorize_cb filter, void *data)
+vec_instr_set_add_or_rewrite(struct set *instr_set, nir_instr *instr,
+                             nir_vectorize_cb filter, void *data)
 {
-   if (!instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit))
-      return false;
+   /* set max vector to instr pass flags: this is used to hash swizzles */
+   instr->pass_flags = filter ? filter(instr, data) : 4;
+   assert(util_is_power_of_two_or_zero(instr->pass_flags));
 
-   if (filter && !filter(instr, data))
+   if (!instr_can_rewrite(instr))
       return false;
 
-   /* set max vector to instr pass flags: this is used to hash swizzles */
-   instr->pass_flags = nir->options->vectorize_vec2_16bit ? 2 : 4;
-
    struct set_entry *entry = _mesa_set_search(instr_set, instr);
    if (entry) {
       nir_instr *old_instr = (nir_instr *) entry->key;
       _mesa_set_remove(instr_set, entry);
-      nir_instr *new_instr = instr_try_combine(nir, instr_set,
-                                               old_instr, instr);
+      nir_instr *new_instr = instr_try_combine(instr_set, old_instr, instr);
       if (new_instr) {
-         if (instr_can_rewrite(new_instr, nir->options->vectorize_vec2_16bit) &&
-             (!filter || filter(new_instr, data)))
+         if (instr_can_rewrite(new_instr))
             _mesa_set_add(instr_set, new_instr);
          return true;
       }
@@ -384,25 +378,23 @@ vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set,
 }
 
 static bool
-vectorize_block(struct nir_shader *nir, nir_block *block,
-                struct set *instr_set,
-                nir_opt_vectorize_cb filter, void *data)
+vectorize_block(nir_block *block, struct set *instr_set,
+                nir_vectorize_cb filter, void *data)
 {
    bool progress = false;
 
    nir_foreach_instr_safe(instr, block) {
-      if (vec_instr_set_add_or_rewrite(nir, instr_set, instr, filter, data))
+      if (vec_instr_set_add_or_rewrite(instr_set, instr, filter, data))
          progress = true;
    }
 
    for (unsigned i = 0; i < block->num_dom_children; i++) {
       nir_block *child = block->dom_children[i];
-      progress |= vectorize_block(nir, child, instr_set, filter, data);
+      progress |= vectorize_block(child, instr_set, filter, data);
    }
 
    nir_foreach_instr_reverse(instr, block) {
-      if (instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit) &&
-          (!filter || filter(instr, data)))
+      if (instr_can_rewrite(instr))
          _mesa_set_remove_key(instr_set, instr);
    }
 
@@ -410,14 +402,14 @@ vectorize_block(struct nir_shader *nir, nir_block *block,
 }
 
 static bool
-nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl,
-                       nir_opt_vectorize_cb filter, void *data)
+nir_opt_vectorize_impl(nir_function_impl *impl,
+                       nir_vectorize_cb filter, void *data)
 {
    struct set *instr_set = vec_instr_set_create();
 
    nir_metadata_require(impl, nir_metadata_dominance);
 
-   bool progress = vectorize_block(nir, nir_start_block(impl), instr_set,
+   bool progress = vectorize_block(nir_start_block(impl), instr_set,
                                    filter, data);
 
    if (progress) {
@@ -432,14 +424,14 @@ nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl,
 }
 
 bool
-nir_opt_vectorize(nir_shader *shader, nir_opt_vectorize_cb filter,
+nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter,
                   void *data)
 {
    bool progress = false;
 
    nir_foreach_function(function, shader) {
       if (function->impl)
-         progress |= nir_opt_vectorize_impl(shader, function->impl, filter, data);
+         progress |= nir_opt_vectorize_impl(function->impl, filter, data);
    }
 
    return progress;
diff --git a/src/gallium/auxiliary/nir/nir_to_tgsi.c b/src/gallium/auxiliary/nir/nir_to_tgsi.c
index 2993fc27e22..371b43c8a8e 100644
--- a/src/gallium/auxiliary/nir/nir_to_tgsi.c
+++ b/src/gallium/auxiliary/nir/nir_to_tgsi.c
@@ -3067,11 +3067,11 @@ type_size(const struct glsl_type *type, bool bindless)
 /* Allow vectorizing of ALU instructions, but avoid vectorizing past what we
  * can handle for 64-bit values in TGSI.
  */
-static bool
-ntt_should_vectorize_instr(const nir_instr *instr, void *data)
+static uint8_t
+ntt_should_vectorize_instr(const nir_instr *instr, const void *data)
 {
    if (instr->type != nir_instr_type_alu)
-      return false;
+      return 0;
 
    nir_alu_instr *alu = nir_instr_as_alu(instr);
 
@@ -3085,7 +3085,7 @@ ntt_should_vectorize_instr(const nir_instr *instr, void *data)
        *
        * https://gitlab.freedesktop.org/virgl/virglrenderer/-/issues/195
        */
-      return false;
+      return 1;
 
    default:
       break;
@@ -3102,10 +3102,10 @@ ntt_should_vectorize_instr(const nir_instr *instr, void *data)
        * 64-bit instrs in the first place, I don't see much reason to care about
        * this.
        */
-      return false;
+      return 1;
    }
 
-   return true;
+   return 4;
 }
 
 static bool
diff --git a/src/gallium/drivers/radeonsi/si_shader_nir.c b/src/gallium/drivers/radeonsi/si_shader_nir.c
index 8b267023504..ec1013f8d74 100644
--- a/src/gallium/drivers/radeonsi/si_shader_nir.c
+++ b/src/gallium/drivers/radeonsi/si_shader_nir.c
@@ -43,6 +43,18 @@ static bool si_alu_to_scalar_filter(const nir_instr *instr, const void *data)
    return true;
 }
 
+static uint8_t si_vectorize_callback(const nir_instr *instr, const void *data)
+{
+   if (instr->type != nir_instr_type_alu)
+      return 0;
+
+   nir_alu_instr *alu = nir_instr_as_alu(instr);
+   if (nir_dest_bit_size(alu->dest.dest) == 16)
+      return 2;
+
+   return 1;
+}
+
 void si_nir_opts(struct si_screen *sscreen, struct nir_shader *nir, bool first)
 {
    bool progress;
@@ -114,7 +126,7 @@ void si_nir_opts(struct si_screen *sscreen, struct nir_shader *nir, bool first)
          NIR_PASS_V(nir, nir_opt_move_discards_to_top);
 
       if (sscreen->options.fp16)
-         NIR_PASS(progress, nir, nir_opt_vectorize, NULL, NULL);
+         NIR_PASS(progress, nir, nir_opt_vectorize, si_vectorize_callback, NULL);
    } while (progress);
 
    NIR_PASS_V(nir, nir_lower_var_copies);
diff --git a/src/mesa/state_tracker/st_glsl_to_nir.cpp b/src/mesa/state_tracker/st_glsl_to_nir.cpp
index aca156de2f0..e6bf9b8fd91 100644
--- a/src/mesa/state_tracker/st_glsl_to_nir.cpp
+++ b/src/mesa/state_tracker/st_glsl_to_nir.cpp
@@ -517,7 +517,7 @@ st_glsl_to_nir_post_opts(struct st_context *st, struct gl_program *prog,
       if (nir->options->lower_int64_options)
          NIR_PASS(lowered_64bit_ops, nir, nir_lower_int64);
 
-      if (revectorize)
+      if (revectorize && !nir->options->vectorize_vec2_16bit)
          NIR_PASS_V(nir, nir_opt_vectorize, nullptr, nullptr);
 
       if (revectorize || lowered_64bit_ops)
diff --git a/src/panfrost/bifrost/bifrost_compile.c b/src/panfrost/bifrost/bifrost_compile.c
index b39e4b1eb5b..337e3e56454 100644
--- a/src/panfrost/bifrost/bifrost_compile.c
+++ b/src/panfrost/bifrost/bifrost_compile.c
@@ -4276,12 +4276,12 @@ bi_lower_bit_size(const nir_instr *instr, UNUSED void *data)
  * (8-bit in Bifrost, 32-bit in NIR TODO - workaround!). Some conversions need
  * to be scalarized due to type size. */
 
-static bool
-bi_vectorize_filter(const nir_instr *instr, void *data)
+static uint8_t
+bi_vectorize_filter(const nir_instr *instr, const void *data)
 {
         /* Defaults work for everything else */
         if (instr->type != nir_instr_type_alu)
-                return true;
+                return 0;
 
         const nir_alu_instr *alu = nir_instr_as_alu(instr);
 
@@ -4293,10 +4293,17 @@ bi_vectorize_filter(const nir_instr *instr, void *data)
         case nir_op_ushr:
         case nir_op_f2i16:
         case nir_op_f2u16:
-                return false;
+                return 1;
         default:
-                return true;
+                break;
         }
+
+        /* Vectorized instructions cannot write more than 32-bit */
+        int dst_bit_size = nir_dest_bit_size(alu->dest.dest);
+        if (dst_bit_size == 16)
+                return 2;
+        else
+                return 1;
 }
 
 static bool
diff --git a/src/panfrost/midgard/midgard_compile.c b/src/panfrost/midgard/midgard_compile.c
index 6cd48caf0e5..3c17819750c 100644
--- a/src/panfrost/midgard/midgard_compile.c
+++ b/src/panfrost/midgard/midgard_compile.c
@@ -303,25 +303,20 @@ mdg_should_scalarize(const nir_instr *instr, const void *_unused)
 }
 
 /* Only vectorize int64 up to vec2 */
-static bool
-midgard_vectorize_filter(const nir_instr *instr, void *data)
+static uint8_t
+midgard_vectorize_filter(const nir_instr *instr, const void *data)
 {
         if (instr->type != nir_instr_type_alu)
-                return true;
+                return 0;
 
         const nir_alu_instr *alu = nir_instr_as_alu(instr);
-
-        unsigned num_components = alu->dest.dest.ssa.num_components;
-
         int src_bit_size = nir_src_bit_size(alu->src[0].src);
         int dst_bit_size = nir_dest_bit_size(alu->dest.dest);
 
-        if (src_bit_size == 64 || dst_bit_size == 64) {
-                if (num_components > 1)
-                        return false;
-        }
+        if (src_bit_size == 64 || dst_bit_size == 64)
+                return 2;
 
-        return true;
+        return 4;
 }
 
 static void



More information about the mesa-commit mailing list