[Mesa-dev] [PATCH 2/2] nir/search: Allow conditions on expressions as well as variables

Jason Ekstrand jason at jlekstrand.net
Wed Jan 11 19:13:57 UTC 2017


---
 src/compiler/nir/nir_algebraic.py | 10 +++++++---
 src/compiler/nir/nir_search.c     | 28 +++++++++++++++++-----------
 src/compiler/nir/nir_search.h     | 18 +++++++++---------
 3 files changed, 33 insertions(+), 23 deletions(-)

diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
index 19ac6ee..e70c511 100644
--- a/src/compiler/nir/nir_algebraic.py
+++ b/src/compiler/nir/nir_algebraic.py
@@ -78,14 +78,13 @@ class Value(object):
    __template = mako.template.Template("""
 #include "compiler/nir/nir_search_helpers.h"
 static const ${val.c_type} ${val.name} = {
-   { ${val.type_enum}, ${val.bit_size} },
+   { ${val.type_enum}, ${val.bit_size}, ${val.cond if val.cond else 'NULL'} },
 % if isinstance(val, Constant):
    ${val.type()}, { ${hex(val)} /* ${val.value} */ },
 % elif isinstance(val, Variable):
    ${val.index}, /* ${val.var_name} */
    ${'true' if val.is_constant else 'false'},
    ${val.type() or 'nir_type_invalid' },
-   ${val.cond if val.cond else 'NULL'},
 % elif isinstance(val, Expression):
    ${'true' if val.inexact else 'false'},
    nir_op_${val.opcode},
@@ -121,6 +120,9 @@ class Constant(Value):
    def __init__(self, val, name):
       Value.__init__(self, name, "constant")
 
+      # Constants can't have conditions.  They either match or they don't.
+      self.cond = None
+
       if isinstance(val, (str)):
          m = _constant_re.match(val)
          self.value = ast.literal_eval(m.group('value'))
@@ -185,7 +187,8 @@ class Variable(Value):
       elif self.required_type == 'float':
          return "nir_type_float"
 
-_opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?")
+_opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
+                        r"(?P<cond>\([^\)]+\))?")
 
 class Expression(Value):
    def __init__(self, expr, name_base, varset):
@@ -198,6 +201,7 @@ class Expression(Value):
       self.opcode = m.group('opcode')
       self.bit_size = int(m.group('bits')) if m.group('bits') else 0
       self.inexact = m.group('inexact') is not None
+      self.cond = m.group('cond')
       self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
                        for (i, src) in enumerate(expr[1:]) ]
 
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index 2f57821..0148b2f 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -119,6 +119,17 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
    for (unsigned i = 0; i < num_components; ++i)
       new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
 
+   if (value->cond) {
+      uint8_t read_mask = 0;
+      for (unsigned i = 0; i < num_components; i++)
+         read_mask |= 1 << new_swizzle[i];
+
+      if (!value->cond(instr->src[src].src.ssa,
+                       nir_op_infos[instr->op].input_types[src],
+                       read_mask))
+         return false;
+   }
+
    /* If the value has a specific bit size and it doesn't match, bail */
    if (value->bit_size &&
        nir_src_bit_size(instr->src[src].src) != value->bit_size)
@@ -154,17 +165,6 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
              instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
             return false;
 
-         if (var->cond) {
-            uint8_t read_mask = 0;
-            for (unsigned i = 0; i < num_components; i++)
-               read_mask |= 1 << new_swizzle[i];
-
-            if (!var->cond(instr->src[src].src.ssa,
-                           nir_op_infos[instr->op].input_types[src],
-                           read_mask))
-               return false;
-         }
-
          if (var->type != nir_type_invalid &&
              !src_is_type(instr->src[src].src, var->type))
             return false;
@@ -604,6 +604,12 @@ nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
    state.has_exact_alu = false;
    state.variables_seen = 0;
 
+   if (search->value.cond) {
+      if (!search->value.cond(&instr->dest.dest.ssa, nir_type_invalid,
+                              instr->dest.write_mask))
+         return false;
+   }
+
    if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
                          swizzle, &state))
       return NULL;
diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h
index 9d25018..6399033 100644
--- a/src/compiler/nir/nir_search.h
+++ b/src/compiler/nir/nir_search.h
@@ -42,6 +42,15 @@ typedef struct {
    nir_search_value_type type;
 
    unsigned bit_size;
+
+   /** Optional condition fxn ptr
+    *
+    * This is only allowed in search expressions, and allows additional
+    * constraints to be placed on the match.  Typically used for 'is_constant'
+    * variables to require, for example, power-of-two in order for the search
+    * to match.
+    */
+   bool (*cond)(nir_ssa_def *def, nir_alu_type type, uint8_t read_mask);
 } nir_search_value;
 
 typedef struct {
@@ -68,15 +77,6 @@ typedef struct {
     * never match anything.
     */
    nir_alu_type type;
-
-   /** Optional condition fxn ptr
-    *
-    * This is only allowed in search expressions, and allows additional
-    * constraints to be placed on the match.  Typically used for 'is_constant'
-    * variables to require, for example, power-of-two in order for the search
-    * to match.
-    */
-   bool (*cond)(nir_ssa_def *def, nir_alu_type type, uint8_t read_mask);
 } nir_search_variable;
 
 typedef struct {
-- 
2.5.0.400.gff86faf



More information about the mesa-dev mailing list