Mesa (main): nir/opt_vectorize: add callback for max vectorization width
GitLab Mirror
gitlab-mirror at kemper.freedesktop.org
Wed Jun 1 12:27:22 UTC 2022
Module: Mesa
Branch: main
Commit: bd151a256ebe5dba2f159d1cc72a8777b1a5b84c
URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=bd151a256ebe5dba2f159d1cc72a8777b1a5b84c
Author: Daniel Schürmann <daniel at schuermann.dev>
Date: Fri Dec 18 19:05:47 2020 +0100
nir/opt_vectorize: add callback for max vectorization width
The callback allows to request different vectorization factors
per instruction depending on e.g. bitsize or opcode.
This patch also removes using the vectorize_vec2_16bit option
from nir_opt_vectorize().
Reviewed-by: Alyssa Rosenzweig <alyssa at collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13080>
---
src/amd/vulkan/radv_pipeline.c | 18 ++++---
src/compiler/nir/nir.h | 19 +++++--
src/compiler/nir/nir_opt_vectorize.c | 78 +++++++++++++---------------
src/gallium/auxiliary/nir/nir_to_tgsi.c | 12 ++---
src/gallium/drivers/radeonsi/si_shader_nir.c | 14 ++++-
src/mesa/state_tracker/st_glsl_to_nir.cpp | 2 +-
src/panfrost/bifrost/bifrost_compile.c | 17 ++++--
src/panfrost/midgard/midgard_compile.c | 17 +++---
8 files changed, 98 insertions(+), 79 deletions(-)
diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c
index d02e2e36c14..5cf91b0a6e7 100644
--- a/src/amd/vulkan/radv_pipeline.c
+++ b/src/amd/vulkan/radv_pipeline.c
@@ -4042,14 +4042,16 @@ lower_bit_size_callback(const nir_instr *instr, void *_)
return 0;
}
-static bool
-opt_vectorize_callback(const nir_instr *instr, void *_)
+static uint8_t
+opt_vectorize_callback(const nir_instr *instr, const void *_)
{
- assert(instr->type == nir_instr_type_alu);
- nir_alu_instr *alu = nir_instr_as_alu(instr);
- unsigned bit_size = alu->dest.dest.ssa.bit_size;
+ if (instr->type != nir_instr_type_alu)
+ return 0;
+
+ const nir_alu_instr *alu = nir_instr_as_alu(instr);
+ const unsigned bit_size = alu->dest.dest.ssa.bit_size;
if (bit_size != 16)
- return false;
+ return 1;
switch (alu->op) {
case nir_op_fadd:
@@ -4069,12 +4071,12 @@ opt_vectorize_callback(const nir_instr *instr, void *_)
case nir_op_imax:
case nir_op_umin:
case nir_op_umax:
- return true;
+ return 2;
case nir_op_ishl: /* TODO: in NIR, these have 32bit shift operands */
case nir_op_ishr: /* while Radeon needs 16bit operands when vectorized */
case nir_op_ushr:
default:
- return false;
+ return 1;
}
}
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index a3e7485ff30..3ab0696a5a5 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -3228,6 +3228,15 @@ typedef enum {
*/
typedef bool (*nir_instr_filter_cb)(const nir_instr *, const void *);
+/** A vectorization width callback
+ *
+ * Returns the maximum vectorization width per instruction.
+ * 0, if the instruction must not be modified.
+ *
+ * The vectorization width must be a power of 2.
+ */
+typedef uint8_t (*nir_vectorize_cb)(const nir_instr *, const void *);
+
typedef struct nir_shader_compiler_options {
bool lower_fdiv;
bool lower_ffma16;
@@ -3455,7 +3464,11 @@ typedef struct nir_shader_compiler_options {
nir_instr_filter_cb lower_to_scalar_filter;
/**
- * Whether nir_opt_vectorize should only create 16-bit 2D vectors.
+ * Disables potentially harmful algebraic transformations for architectures
+ * with SIMD-within-a-register semantics.
+ *
+ * Note, to actually vectorize 16bit instructions, use nir_opt_vectorize()
+ * with a suitable callback function.
*/
bool vectorize_vec2_16bit;
@@ -5485,9 +5498,7 @@ bool nir_lower_undef_to_zero(nir_shader *shader);
bool nir_opt_uniform_atomics(nir_shader *shader);
-typedef bool (*nir_opt_vectorize_cb)(const nir_instr *instr, void *data);
-
-bool nir_opt_vectorize(nir_shader *shader, nir_opt_vectorize_cb filter,
+bool nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter,
void *data);
bool nir_opt_conditional_discard(nir_shader *shader);
diff --git a/src/compiler/nir/nir_opt_vectorize.c b/src/compiler/nir/nir_opt_vectorize.c
index 83c841ee63e..dc6e1d84b52 100644
--- a/src/compiler/nir/nir_opt_vectorize.c
+++ b/src/compiler/nir/nir_opt_vectorize.c
@@ -22,6 +22,16 @@
*
*/
+/**
+ * nir_opt_vectorize() aims to vectorize ALU instructions.
+ *
+ * The default vectorization width is 4.
+ * If desired, a callback function which returns the max vectorization width
+ * per instruction can be provided.
+ *
+ * The max vectorization width must be a power of 2.
+ */
+
#include "nir.h"
#include "nir_vla.h"
#include "nir_builder.h"
@@ -125,7 +135,7 @@ instrs_equal(const void *data1, const void *data2)
}
static bool
-instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
+instr_can_rewrite(nir_instr *instr)
{
switch (instr->type) {
case nir_instr_type_alu: {
@@ -139,12 +149,7 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
return false;
/* no need to hash instructions which are already vectorized */
- if (alu->dest.dest.ssa.num_components >= 4)
- return false;
-
- if (vectorize_16bit &&
- (alu->dest.dest.ssa.num_components >= 2 ||
- alu->dest.dest.ssa.bit_size != 16))
+ if (alu->dest.dest.ssa.num_components >= instr->pass_flags)
return false;
if (nir_op_infos[alu->op].output_size != 0)
@@ -156,8 +161,8 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
/* don't hash instructions which are already swizzled
* outside of max_components: these should better be scalarized */
- uint32_t mask = vectorize_16bit ? ~1 : ~3;
- for (unsigned j = 0; j < alu->dest.dest.ssa.num_components; j++) {
+ uint32_t mask = ~(instr->pass_flags - 1);
+ for (unsigned j = 1; j < alu->dest.dest.ssa.num_components; j++) {
if ((alu->src[i].swizzle[0] & mask) != (alu->src[i].swizzle[j] & mask))
return false;
}
@@ -179,10 +184,8 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
* the same instructions into one vectorized instruction. Note that instr1
* should dominate instr2.
*/
-
static nir_instr *
-instr_try_combine(struct nir_shader *nir, struct set *instr_set,
- nir_instr *instr1, nir_instr *instr2)
+instr_try_combine(struct set *instr_set, nir_instr *instr1, nir_instr *instr2)
{
assert(instr1->type == nir_instr_type_alu);
assert(instr2->type == nir_instr_type_alu);
@@ -194,14 +197,10 @@ instr_try_combine(struct nir_shader *nir, struct set *instr_set,
unsigned alu2_components = alu2->dest.dest.ssa.num_components;
unsigned total_components = alu1_components + alu2_components;
- if (total_components > 4)
+ assert(instr1->pass_flags == instr2->pass_flags);
+ if (total_components > instr1->pass_flags)
return NULL;
- if (nir->options->vectorize_vec2_16bit) {
- assert(total_components == 2);
- assert(alu1->dest.dest.ssa.bit_size == 16);
- }
-
nir_builder b;
nir_builder_init(&b, nir_cf_node_get_function(&instr1->block->cf_node));
b.cursor = nir_after_instr(instr1);
@@ -352,28 +351,23 @@ vec_instr_set_destroy(struct set *instr_set)
}
static bool
-vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set,
- nir_instr *instr,
- nir_opt_vectorize_cb filter, void *data)
+vec_instr_set_add_or_rewrite(struct set *instr_set, nir_instr *instr,
+ nir_vectorize_cb filter, void *data)
{
- if (!instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit))
- return false;
+ /* set max vector to instr pass flags: this is used to hash swizzles */
+ instr->pass_flags = filter ? filter(instr, data) : 4;
+ assert(util_is_power_of_two_or_zero(instr->pass_flags));
- if (filter && !filter(instr, data))
+ if (!instr_can_rewrite(instr))
return false;
- /* set max vector to instr pass flags: this is used to hash swizzles */
- instr->pass_flags = nir->options->vectorize_vec2_16bit ? 2 : 4;
-
struct set_entry *entry = _mesa_set_search(instr_set, instr);
if (entry) {
nir_instr *old_instr = (nir_instr *) entry->key;
_mesa_set_remove(instr_set, entry);
- nir_instr *new_instr = instr_try_combine(nir, instr_set,
- old_instr, instr);
+ nir_instr *new_instr = instr_try_combine(instr_set, old_instr, instr);
if (new_instr) {
- if (instr_can_rewrite(new_instr, nir->options->vectorize_vec2_16bit) &&
- (!filter || filter(new_instr, data)))
+ if (instr_can_rewrite(new_instr))
_mesa_set_add(instr_set, new_instr);
return true;
}
@@ -384,25 +378,23 @@ vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set,
}
static bool
-vectorize_block(struct nir_shader *nir, nir_block *block,
- struct set *instr_set,
- nir_opt_vectorize_cb filter, void *data)
+vectorize_block(nir_block *block, struct set *instr_set,
+ nir_vectorize_cb filter, void *data)
{
bool progress = false;
nir_foreach_instr_safe(instr, block) {
- if (vec_instr_set_add_or_rewrite(nir, instr_set, instr, filter, data))
+ if (vec_instr_set_add_or_rewrite(instr_set, instr, filter, data))
progress = true;
}
for (unsigned i = 0; i < block->num_dom_children; i++) {
nir_block *child = block->dom_children[i];
- progress |= vectorize_block(nir, child, instr_set, filter, data);
+ progress |= vectorize_block(child, instr_set, filter, data);
}
nir_foreach_instr_reverse(instr, block) {
- if (instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit) &&
- (!filter || filter(instr, data)))
+ if (instr_can_rewrite(instr))
_mesa_set_remove_key(instr_set, instr);
}
@@ -410,14 +402,14 @@ vectorize_block(struct nir_shader *nir, nir_block *block,
}
static bool
-nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl,
- nir_opt_vectorize_cb filter, void *data)
+nir_opt_vectorize_impl(nir_function_impl *impl,
+ nir_vectorize_cb filter, void *data)
{
struct set *instr_set = vec_instr_set_create();
nir_metadata_require(impl, nir_metadata_dominance);
- bool progress = vectorize_block(nir, nir_start_block(impl), instr_set,
+ bool progress = vectorize_block(nir_start_block(impl), instr_set,
filter, data);
if (progress) {
@@ -432,14 +424,14 @@ nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl,
}
bool
-nir_opt_vectorize(nir_shader *shader, nir_opt_vectorize_cb filter,
+nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter,
void *data)
{
bool progress = false;
nir_foreach_function(function, shader) {
if (function->impl)
- progress |= nir_opt_vectorize_impl(shader, function->impl, filter, data);
+ progress |= nir_opt_vectorize_impl(function->impl, filter, data);
}
return progress;
diff --git a/src/gallium/auxiliary/nir/nir_to_tgsi.c b/src/gallium/auxiliary/nir/nir_to_tgsi.c
index 2993fc27e22..371b43c8a8e 100644
--- a/src/gallium/auxiliary/nir/nir_to_tgsi.c
+++ b/src/gallium/auxiliary/nir/nir_to_tgsi.c
@@ -3067,11 +3067,11 @@ type_size(const struct glsl_type *type, bool bindless)
/* Allow vectorizing of ALU instructions, but avoid vectorizing past what we
* can handle for 64-bit values in TGSI.
*/
-static bool
-ntt_should_vectorize_instr(const nir_instr *instr, void *data)
+static uint8_t
+ntt_should_vectorize_instr(const nir_instr *instr, const void *data)
{
if (instr->type != nir_instr_type_alu)
- return false;
+ return 0;
nir_alu_instr *alu = nir_instr_as_alu(instr);
@@ -3085,7 +3085,7 @@ ntt_should_vectorize_instr(const nir_instr *instr, void *data)
*
* https://gitlab.freedesktop.org/virgl/virglrenderer/-/issues/195
*/
- return false;
+ return 1;
default:
break;
@@ -3102,10 +3102,10 @@ ntt_should_vectorize_instr(const nir_instr *instr, void *data)
* 64-bit instrs in the first place, I don't see much reason to care about
* this.
*/
- return false;
+ return 1;
}
- return true;
+ return 4;
}
static bool
diff --git a/src/gallium/drivers/radeonsi/si_shader_nir.c b/src/gallium/drivers/radeonsi/si_shader_nir.c
index 8b267023504..ec1013f8d74 100644
--- a/src/gallium/drivers/radeonsi/si_shader_nir.c
+++ b/src/gallium/drivers/radeonsi/si_shader_nir.c
@@ -43,6 +43,18 @@ static bool si_alu_to_scalar_filter(const nir_instr *instr, const void *data)
return true;
}
+static uint8_t si_vectorize_callback(const nir_instr *instr, const void *data)
+{
+ if (instr->type != nir_instr_type_alu)
+ return 0;
+
+ nir_alu_instr *alu = nir_instr_as_alu(instr);
+ if (nir_dest_bit_size(alu->dest.dest) == 16)
+ return 2;
+
+ return 1;
+}
+
void si_nir_opts(struct si_screen *sscreen, struct nir_shader *nir, bool first)
{
bool progress;
@@ -114,7 +126,7 @@ void si_nir_opts(struct si_screen *sscreen, struct nir_shader *nir, bool first)
NIR_PASS_V(nir, nir_opt_move_discards_to_top);
if (sscreen->options.fp16)
- NIR_PASS(progress, nir, nir_opt_vectorize, NULL, NULL);
+ NIR_PASS(progress, nir, nir_opt_vectorize, si_vectorize_callback, NULL);
} while (progress);
NIR_PASS_V(nir, nir_lower_var_copies);
diff --git a/src/mesa/state_tracker/st_glsl_to_nir.cpp b/src/mesa/state_tracker/st_glsl_to_nir.cpp
index aca156de2f0..e6bf9b8fd91 100644
--- a/src/mesa/state_tracker/st_glsl_to_nir.cpp
+++ b/src/mesa/state_tracker/st_glsl_to_nir.cpp
@@ -517,7 +517,7 @@ st_glsl_to_nir_post_opts(struct st_context *st, struct gl_program *prog,
if (nir->options->lower_int64_options)
NIR_PASS(lowered_64bit_ops, nir, nir_lower_int64);
- if (revectorize)
+ if (revectorize && !nir->options->vectorize_vec2_16bit)
NIR_PASS_V(nir, nir_opt_vectorize, nullptr, nullptr);
if (revectorize || lowered_64bit_ops)
diff --git a/src/panfrost/bifrost/bifrost_compile.c b/src/panfrost/bifrost/bifrost_compile.c
index b39e4b1eb5b..337e3e56454 100644
--- a/src/panfrost/bifrost/bifrost_compile.c
+++ b/src/panfrost/bifrost/bifrost_compile.c
@@ -4276,12 +4276,12 @@ bi_lower_bit_size(const nir_instr *instr, UNUSED void *data)
* (8-bit in Bifrost, 32-bit in NIR TODO - workaround!). Some conversions need
* to be scalarized due to type size. */
-static bool
-bi_vectorize_filter(const nir_instr *instr, void *data)
+static uint8_t
+bi_vectorize_filter(const nir_instr *instr, const void *data)
{
/* Defaults work for everything else */
if (instr->type != nir_instr_type_alu)
- return true;
+ return 0;
const nir_alu_instr *alu = nir_instr_as_alu(instr);
@@ -4293,10 +4293,17 @@ bi_vectorize_filter(const nir_instr *instr, void *data)
case nir_op_ushr:
case nir_op_f2i16:
case nir_op_f2u16:
- return false;
+ return 1;
default:
- return true;
+ break;
}
+
+ /* Vectorized instructions cannot write more than 32-bit */
+ int dst_bit_size = nir_dest_bit_size(alu->dest.dest);
+ if (dst_bit_size == 16)
+ return 2;
+ else
+ return 1;
}
static bool
diff --git a/src/panfrost/midgard/midgard_compile.c b/src/panfrost/midgard/midgard_compile.c
index 6cd48caf0e5..3c17819750c 100644
--- a/src/panfrost/midgard/midgard_compile.c
+++ b/src/panfrost/midgard/midgard_compile.c
@@ -303,25 +303,20 @@ mdg_should_scalarize(const nir_instr *instr, const void *_unused)
}
/* Only vectorize int64 up to vec2 */
-static bool
-midgard_vectorize_filter(const nir_instr *instr, void *data)
+static uint8_t
+midgard_vectorize_filter(const nir_instr *instr, const void *data)
{
if (instr->type != nir_instr_type_alu)
- return true;
+ return 0;
const nir_alu_instr *alu = nir_instr_as_alu(instr);
-
- unsigned num_components = alu->dest.dest.ssa.num_components;
-
int src_bit_size = nir_src_bit_size(alu->src[0].src);
int dst_bit_size = nir_dest_bit_size(alu->dest.dest);
- if (src_bit_size == 64 || dst_bit_size == 64) {
- if (num_components > 1)
- return false;
- }
+ if (src_bit_size == 64 || dst_bit_size == 64)
+ return 2;
- return true;
+ return 4;
}
static void
More information about the mesa-commit
mailing list