[Mesa-dev] [RFC] nir/algebraic: support for power-of-two optimizations

Rob Clark robdclark at gmail.com
Sat May 7 17:06:18 UTC 2016


From: Rob Clark <robclark at freedesktop.org>

It was kinda sad that we couldn't optimize imul/idiv by power-of-two.
So I bashed my head against python for a while and this is what I came
up with.  In the search expression, you can use "#a^2" to only match
constants which are a power of two.  The rest is taken care of w/ normal
replacement expression.  (Might be nice if we had an ilog2 to avoid the
float/int conversion stuff.)

Still a couple rough edges and things which should be split out.
---
 src/compiler/nir/nir_algebraic.py           |  9 ++++--
 src/compiler/nir/nir_opt_algebraic.py       |  5 ++++
 src/compiler/nir/nir_search.c               | 27 +++++++++++++++++
 src/compiler/nir/nir_search.h               |  9 +++++-
 src/gallium/drivers/freedreno/ir3/ir3_nir.c | 45 +++++++++++++++++++----------
 5 files changed, 77 insertions(+), 18 deletions(-)

diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
index 285f853..c2b47fd 100644
--- a/src/compiler/nir/nir_algebraic.py
+++ b/src/compiler/nir/nir_algebraic.py
@@ -83,6 +83,7 @@ static const ${val.c_type} ${val.name} = {
 % elif isinstance(val, Variable):
    ${val.index}, /* ${val.var_name} */
    ${'true' if val.is_constant else 'false'},
+   ${'true' if val.is_power_of_two else 'false'},
    ${val.type() or 'nir_type_invalid' },
 % elif isinstance(val, Expression):
    ${'true' if val.inexact else 'false'},
@@ -113,7 +114,7 @@ static const ${val.c_type} ${val.name} = {
                                     Variable=Variable,
                                     Expression=Expression)
 
-_constant_re = re.compile(r"(?P<value>[^@]+)(?:@(?P<bits>\d+))?")
+_constant_re = re.compile(r"(?P<value>[^@\^]+)(?P<PoT>\^2)?(?:@(?P<bits>\d+))?")
 
 class Constant(Value):
    def __init__(self, val, name):
@@ -123,6 +124,7 @@ class Constant(Value):
          m = _constant_re.match(val)
          self.value = ast.literal_eval(m.group('value'))
          self.bit_size = int(m.group('bits')) if m.group('bits') else 0
+         self.power_of_two = True if m.group('PoT') else False
       else:
          self.value = val
          self.bit_size = 0
@@ -149,7 +151,7 @@ class Constant(Value):
       elif isinstance(self.value, float):
          return "nir_type_float"
 
-_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
+_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)(?P<PoT>\^2)?"
                           r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?")
 
 class Variable(Value):
@@ -161,6 +163,9 @@ class Variable(Value):
 
       self.var_name = m.group('name')
       self.is_constant = m.group('const') is not None
+      self.is_power_of_two = m.group('PoT') is not None
+      if self.is_power_of_two:
+         assert self.is_constant
       self.required_type = m.group('type')
       self.bit_size = int(m.group('bits')) if m.group('bits') else 0
 
diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py
index 0a95725..e1381b2 100644
--- a/src/compiler/nir/nir_opt_algebraic.py
+++ b/src/compiler/nir/nir_opt_algebraic.py
@@ -62,6 +62,11 @@ d = 'd'
 # constructed value should have that bit-size.
 
 optimizations = [
+
+   # add 64b variants?
+   (('imul', a, '#b^2 at 32'), ('ishl', a, ('f2i', ('flog2', ('i2f', b))))),
+   (('idiv', a, '#b^2 at 32'), ('ishr', a, ('f2i', ('flog2', ('i2f', b))))),
+
    (('fneg', ('fneg', a)), a),
    (('ineg', ('ineg', a)), a),
    (('fabs', ('fabs', a)), ('fabs', a)),
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index 2c2fd92..92af521 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -70,6 +70,13 @@ alu_instr_is_bool(nir_alu_instr *instr)
    }
 }
 
+/* helper for this somewhere? */
+static bool
+is_power_of_two(unsigned int x)
+{
+   return ((x != 0) && !(x & (x - 1)));
+}
+
 static bool
 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
             unsigned num_components, const uint8_t *swizzle,
@@ -127,6 +134,26 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
              instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
             return false;
 
+         if (var->is_power_of_two) {
+            assert(var->is_constant);
+            nir_const_value *val = nir_src_as_const_value(instr->src[src].src);
+            for (unsigned i = 0; i < num_components; i++) {
+               switch (nir_op_infos[instr->op].input_types[src]) {
+               // TODO handle other types??
+               case nir_type_int:
+                  if (!is_power_of_two(val->i32[new_swizzle[i]]))
+                     return false;
+                  break;
+               case nir_type_uint:
+                  if (!is_power_of_two(val->u32[new_swizzle[i]]))
+                     return false;
+                  break;
+               default:
+                  return false;
+               }
+            }
+         }
+
          if (var->type != nir_type_invalid) {
             if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
                return false;
diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h
index c49eba7..32ed538 100644
--- a/src/compiler/nir/nir_search.h
+++ b/src/compiler/nir/nir_search.h
@@ -52,11 +52,18 @@ typedef struct {
 
    /** Indicates that the given variable must be a constant
     *
-    * This is only alloed in search expressions and indicates that the
+    * This is only allowed in search expressions and indicates that the
     * given variable is only allowed to match constant values.
     */
    bool is_constant;
 
+   /** Indicates that the given constant is a power of two
+    *
+    * This is only allowed in search expressions, and only for constant
+    * variables.
+    */
+   bool is_power_of_two;
+
    /** Indicates that the given variable must have a certain type
     *
     * This is only allowed in search expressions and indicates that the
diff --git a/src/gallium/drivers/freedreno/ir3/ir3_nir.c b/src/gallium/drivers/freedreno/ir3/ir3_nir.c
index 7e3ccc0..44c694a 100644
--- a/src/gallium/drivers/freedreno/ir3/ir3_nir.c
+++ b/src/gallium/drivers/freedreno/ir3/ir3_nir.c
@@ -77,6 +77,27 @@ ir3_key_lowers_nir(const struct ir3_shader_key *key)
 
 #define OPT_V(nir, pass, ...) NIR_PASS_V(nir, pass, ##__VA_ARGS__)
 
+static void
+ir3_optimize_loop(nir_shader *s)
+{
+	bool progress;
+	do {
+		progress = false;
+
+		OPT_V(s, nir_lower_vars_to_ssa);
+		OPT_V(s, nir_lower_alu_to_scalar);
+		OPT_V(s, nir_lower_phis_to_scalar);
+
+		progress |= OPT(s, nir_copy_prop);
+		progress |= OPT(s, nir_opt_dce);
+		progress |= OPT(s, nir_opt_cse);
+		progress |= OPT(s, ir3_nir_lower_if_else);
+		progress |= OPT(s, nir_opt_algebraic);
+		progress |= OPT(s, nir_opt_constant_folding);
+
+	} while (progress);
+}
+
 struct nir_shader *
 ir3_optimize_nir(struct ir3_shader *shader, nir_shader *s,
 		const struct ir3_shader_key *key)
@@ -84,7 +105,6 @@ ir3_optimize_nir(struct ir3_shader *shader, nir_shader *s,
 	struct nir_lower_tex_options tex_options = {
 			.lower_rect = 0,
 	};
-	bool progress;
 
 	if (key) {
 		switch (shader->type) {
@@ -140,24 +160,19 @@ ir3_optimize_nir(struct ir3_shader *shader, nir_shader *s,
 	}
 
 	OPT_V(s, nir_lower_tex, &tex_options);
-	OPT_V(s, nir_lower_idiv);
 	OPT_V(s, nir_lower_load_const_to_scalar);
 
-	do {
-		progress = false;
-
-		OPT_V(s, nir_lower_vars_to_ssa);
-		OPT_V(s, nir_lower_alu_to_scalar);
-		OPT_V(s, nir_lower_phis_to_scalar);
+	ir3_optimize_loop(s);
 
-		progress |= OPT(s, nir_copy_prop);
-		progress |= OPT(s, nir_opt_dce);
-		progress |= OPT(s, nir_opt_cse);
-		progress |= OPT(s, ir3_nir_lower_if_else);
-		progress |= OPT(s, nir_opt_algebraic);
-		progress |= OPT(s, nir_opt_constant_folding);
+	/* do idiv lowering after first opt loop to give a chance for
+	 * divide by immed power-of-two to be caught first:
+	 *
+	 * XXX TODO nir_lower_idiv should return progress so we could
+	 * skip second loop..
+	 */
+	OPT_V(s, nir_lower_idiv);
 
-	} while (progress);
+	ir3_optimize_loop(s);
 
 	OPT_V(s, nir_remove_dead_variables, nir_var_local);
 
-- 
2.5.5



More information about the mesa-dev mailing list