<div dir="ltr"><br><br><div class="gmail_quote"><div dir="ltr">On Thu, Nov 29, 2018 at 12:32 PM Connor Abbott <<a href="mailto:cwabbott0@gmail.com">cwabbott0@gmail.com</a>> wrote:<br></div><blockquote class="gmail_quote" style="margin:0 0 0 .8ex;border-left:1px #ccc solid;padding-left:1ex">Before this commit, there were two copies of the algorithm: one in C,<br>
that we would use to figure out what bit-size to give the replacement<br>
expression, and one in Python, that emulated the C one and tried to<br>
prove that the C algorithm would never fail to correctly assign<br>
bit-sizes. That seemed pretty fragile, and likely to fall over if we<br>
make any changes. Furthermore, the C code was really just recomputing<br>
more-or-less the same thing as the Python code every time. Instead, we<br>
can just store the results of the Python algorithm in the C<br>
datastructure, and consult it to compute the bitsize of each value,<br>
moving the "brains" entirely into Python. Since the Python algorithm no<br>
longer has to match C, it's also a lot easier to change it to something<br>
more closely approximating an actual type-inference algorithm. The<br>
algorithm used is based on Hindley-Milner, although deliberately<br>
weakened a little. It's a few more lines than the old one, judging by<br>
the diffstat, but I think it's easier to verify that it's correct while<br>
being as general as possible.<br>
<br>
We could split this up into two changes, first making the C code use the<br>
results of the Python code and then rewriting the Python algorithm, but<br>
since the old algorithm never tracked which variable each equivalence<br>
class, it would mean we'd have to add some non-trivial code which would<br>
then get thrown away. I think it's better to see the final state all at<br>
once, although I could also try splitting it up.<br>
---<br>
 src/compiler/nir/nir_algebraic.py | 518 ++++++++++++++++--------------<br>
 src/compiler/nir/nir_search.c     | 146 +--------<br>
 src/compiler/nir/nir_search.h     |   2 +-<br>
 3 files changed, 295 insertions(+), 371 deletions(-)<br>
<br>
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py<br>
index 728196136ab..48390dbde38 100644<br>
--- a/src/compiler/nir/nir_algebraic.py<br>
+++ b/src/compiler/nir/nir_algebraic.py<br>
@@ -88,7 +88,7 @@ class Value(object):<br>
<br>
    __template = mako.template.Template("""<br>
 static const ${val.c_type} ${<a href="http://val.name" rel="noreferrer" target="_blank">val.name</a>} = {<br>
-   { ${val.type_enum}, ${val.bit_size} },<br>
+   { ${val.type_enum}, ${val.c_bit_size} },<br>
 % if isinstance(val, Constant):<br>
    ${val.type()}, { ${val.hex()} /* ${val.value} */ },<br>
 % elif isinstance(val, Variable):<br>
@@ -112,6 +112,40 @@ static const ${val.c_type} ${<a href="http://val.name" rel="noreferrer" target="_blank">val.name</a>} = {<br>
    def __str__(self):<br>
       return self.in_val<br>
<br>
+   def get_bit_size(self):<br>
+      """Get the physical bit-size that has been chosen for this value, or if<br>
+      there is none, the canonical value which currently represents this<br>
+      bit-size class. Variables will be preferred, i.e. if there are any<br>
+      variables in the equivalence class, the canonical value will be a<br>
+      variable. We do this since we'll need to know which variable each value<br>
+      is equivalent to when constructing the replacement expression. This is<br>
+      the "find" part of the union-find algorithm.<br>
+      """<br>
+      bit_size = self<br>
+<br>
+      while isinstance(bit_size, Value):<br>
+         if bit_size._bit_size == None:<br>
+            break<br>
+         bit_size = bit_size._bit_size<br>
+<br>
+      if bit_size != self:<br>
+         self._bit_size = bit_size<br>
+      return bit_size<br>
+<br>
+   def set_bit_size(self, other):<br>
+      """Make self.get_bit_size() return what other.get_bit_size() return<br>
+      before calling this, or just "other" if it's a concrete bit-size. This is<br>
+      the "union" part of the union-find algorithm.<br>
+      """<br>
+<br>
+      self_bit_size = self.get_bit_size()<br>
+      other_bit_size = other if isinstance(other, int) else other.get_bit_size()<br>
+<br>
+      if self_bit_size == other_bit_size:<br>
+         return<br>
+<br>
+      self_bit_size._bit_size = other_bit_size<br>
+<br>
    @property<br>
    def type_enum(self):<br>
       return "nir_search_value_" + self.type_str<br>
@@ -124,6 +158,21 @@ static const ${val.c_type} ${<a href="http://val.name" rel="noreferrer" target="_blank">val.name</a>} = {<br>
    def c_ptr(self):<br>
       return "&{0}.value".format(<a href="http://self.name" rel="noreferrer" target="_blank">self.name</a>)<br>
<br>
+   @property<br>
+   def c_bit_size(self):<br>
+      bit_size = self.get_bit_size()<br>
+      if isinstance(bit_size, int):<br>
+         return bit_size<br>
+      elif isinstance(bit_size, Variable):<br>
+         return -bit_size.index - 1<br>
+      else:<br>
+         # If the bit-size class is neither a variable, nor an actual bit-size, then<br>
+         # - If it's in the search expression, we don't need to check anything<br>
+         # - If it's in the replace expression, either it's ambiguous (in which<br>
+         # case we'd reject it), or it equals the bit-size of the search value<br>
+         # We represent these cases with a 0 bit-size.<br>
+         return 0<br>
+<br>
    def render(self):<br>
       return self.__template.render(val=self,<br>
                                     Constant=Constant,<br>
@@ -140,14 +189,14 @@ class Constant(Value):<br>
       if isinstance(val, (str)):<br>
          m = _constant_re.match(val)<br>
          self.value = ast.literal_eval(m.group('value'))<br>
-         self.bit_size = int(m.group('bits')) if m.group('bits') else 0<br>
+         self._bit_size = int(m.group('bits')) if m.group('bits') else None<br>
       else:<br>
          self.value = val<br>
-         self.bit_size = 0<br>
+         self._bit_size = None<br>
<br>
       if isinstance(self.value, bool):<br>
-         assert self.bit_size == 0 or self.bit_size == 32<br>
-         self.bit_size = 32<br>
+         assert self._bit_size == None or self._bit_size == 32<br>
+         self._bit_size = 32<br>
<br>
    def hex(self):<br>
       if isinstance(self.value, (bool)):<br>
@@ -191,11 +240,11 @@ class Variable(Value):<br>
       self.is_constant = m.group('const') is not None<br>
       self.cond = m.group('cond')<br>
       self.required_type = m.group('type')<br>
-      self.bit_size = int(m.group('bits')) if m.group('bits') else 0<br>
+      self._bit_size = int(m.group('bits')) if m.group('bits') else None<br>
<br>
       if self.required_type == 'bool':<br>
-         assert self.bit_size == 0 or self.bit_size == 32<br>
-         self.bit_size = 32<br>
+         assert self._bit_size == None or self._bit_size == 32<br>
+         self._bit_size = 32<br>
<br>
       if self.required_type is not None:<br>
          assert self.required_type in ('float', 'bool', 'int', 'uint')<br>
@@ -225,7 +274,7 @@ class Expression(Value):<br>
       assert m and m.group('opcode') is not None<br>
<br>
       self.opcode = m.group('opcode')<br>
-      self.bit_size = int(m.group('bits')) if m.group('bits') else 0<br>
+      self._bit_size = int(m.group('bits')) if m.group('bits') else None<br>
       self.inexact = m.group('inexact') is not None<br>
       self.cond = m.group('cond')<br>
       self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)<br>
@@ -235,40 +284,6 @@ class Expression(Value):<br>
       srcs = "\n".join(src.render() for src in self.sources)<br>
       return srcs + super(Expression, self).render()<br>
<br>
-class IntEquivalenceRelation(object):<br>
-   """A class representing an equivalence relation on integers.<br>
-<br>
-   Each integer has a canonical form which is the maximum integer to which it<br>
-   is equivalent.  Two integers are equivalent precisely when they have the<br>
-   same canonical form.<br>
-<br>
-   The convention of maximum is explicitly chosen to make using it in<br>
-   BitSizeValidator easier because it means that an actual bit_size (if any)<br>
-   will always be the canonical form.<br>
-   """<br>
-   def __init__(self):<br>
-      self._remap = {}<br>
-<br>
-   def get_canonical(self, x):<br>
-      """Get the canonical integer corresponding to x."""<br>
-      if x in self._remap:<br>
-         return self.get_canonical(self._remap[x])<br>
-      else:<br>
-         return x<br>
-<br>
-   def add_equiv(self, a, b):<br>
-      """Add an equivalence and return the canonical form."""<br>
-      c = max(self.get_canonical(a), self.get_canonical(b))<br>
-      if a != c:<br>
-         assert a < c<br>
-         self._remap[a] = c<br>
-<br>
-      if b != c:<br>
-         assert b < c<br>
-         self._remap[b] = c<br>
-<br>
-      return c<br>
-<br>
 class BitSizeValidator(object):<br>
    """A class for validating bit sizes of expressions.<br>
<br>
@@ -296,7 +311,7 @@ class BitSizeValidator(object):<br>
    inference can be ambiguous or contradictory.  Consider, for instance, the<br>
    following transformation:<br>
<br>
-   (('usub_borrow', a, b), ('b2i', ('ult', a, b)))<br>
+   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))<br>
<br>
    This transformation can potentially cause a problem because usub_borrow is<br>
    well-defined for any bit-size of integer.  However, b2i always generates a<br>
@@ -315,217 +330,238 @@ class BitSizeValidator(object):<br>
    generate any code.  This ensures that bugs are caught at compile time<br>
    rather than at run time.<br>
<br>
-   The basic operation of the validator is very similar to the bitsize_tree in<br>
-   nir_search only a little more subtle.  Instead of simply tracking bit<br>
-   sizes, it tracks "bit classes" where each class is represented by an<br>
-   integer.  A value of 0 means we don't know anything yet, positive values<br>
-   are actual bit-sizes, and negative values are used to track equivalence<br>
-   classes of sizes that must be the same but have yet to receive an actual<br>
-   size.  The first stage uses the bitsize_tree algorithm to assign bit<br>
-   classes to each variable.  If it ever comes across an inconsistency, it<br>
-   assert-fails.  Then the second stage uses that information to prove that<br>
-   the resulting expression can always validly be constructed.<br>
-   """<br>
+   Each value maintains a "bit-size class", which is either an actual bit size<br>
+   or an equivalence class with other values that must have the same bit size.<br>
+   The validator works by combining bit-size classes with each other according<br>
+   to the NIR rules outlined above, checking that there are no inconsistencies.<br>
+   When doing this for the replacement expression, we make sure to never change<br>
+   the equivalence class of any of the search values. We could make the example<br>
+   transforms above work by doing some extra run-time checking of the search<br>
+   expression, but we make the user specify those constraints themselves, to<br>
+   avoid any surprises. Since the replacement bitsizes can only be connected to<br>
+   the source bitsize via variables (variables must have the same bitsize in<br>
+   the source and replacment expressions) or the roots of the expression (the<br>
+   replacement expression must produce the same bit size as the search<br>
+   expression), we prevent merging a variable with anything when processing the<br>
+   replacement expression, or specializing the search bitsize<br>
+   with anything. The former prevents<br>
<br>
-   def __init__(self, varset):<br>
-      self._num_classes = 0<br>
-      self._var_classes = [0] * len(varset.names)<br>
-      self._class_relation = IntEquivalenceRelation()<br>
+   (('bcsel', a, b, 0), ('iand', a, b))<br>
<br>
-   def validate(self, search, replace):<br>
-      search_dst_class = self._propagate_bit_size_up(search)<br>
-      if search_dst_class == 0:<br>
-         search_dst_class = self._new_class()<br>
-      self._propagate_bit_class_down(search, search_dst_class)<br>
-<br>
-      replace_dst_class = self._validate_bit_class_up(replace)<br>
-      if replace_dst_class != 0:<br>
-         assert search_dst_class != 0, \<br>
-                'Search expression matches any bit size but replace ' \<br>
-                'expression can only generate {0}-bit values' \<br>
-                .format(replace_dst_class)<br>
-<br>
-         assert search_dst_class == replace_dst_class, \<br>
-                'Search expression matches any {0}-bit values but replace ' \<br>
-                'expression can only generates {1}-bit values' \<br>
-                .format(search_dst_class, replace_dst_class)<br>
-<br>
-      self._validate_bit_class_down(replace, search_dst_class)<br>
-<br>
-   def _new_class(self):<br>
-      self._num_classes += 1<br>
-      return -self._num_classes<br>
-<br>
-   def _set_var_bit_class(self, var, bit_class):<br>
-      assert bit_class != 0<br>
-      var_class = self._var_classes[var.index]<br>
-      if var_class == 0:<br>
-         self._var_classes[var.index] = bit_class<br>
-      else:<br>
-         canon_var_class = self._class_relation.get_canonical(var_class)<br>
-         canon_bit_class = self._class_relation.get_canonical(bit_class)<br>
-         assert canon_var_class < 0 or canon_bit_class < 0 or \<br>
-                canon_var_class == canon_bit_class, \<br>
-                'Variable {0} cannot be both {1}-bit and {2}-bit' \<br>
-                .format(str(var), bit_class, var_class)<br>
-         var_class = self._class_relation.add_equiv(var_class, bit_class)<br>
-         self._var_classes[var.index] = var_class<br>
-<br>
-   def _get_var_bit_class(self, var):<br>
-      return self._class_relation.get_canonical(self._var_classes[var.index])<br>
-<br>
-   def _propagate_bit_size_up(self, val):<br>
-      if isinstance(val, (Constant, Variable)):<br>
-         return val.bit_size<br>
+   from being allowed, since we'd have to merge the bitsizes for a and b due to<br>
+   the 'iand', while the latter prevents<br>
<br>
-      elif isinstance(val, Expression):<br>
-         nir_op = opcodes[val.opcode]<br>
-         val.common_size = 0<br>
-         for i in range(nir_op.num_inputs):<br>
-            src_bits = self._propagate_bit_size_up(val.sources[i])<br>
-            if src_bits == 0:<br>
-               continue<br>
+   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))<br>
<br>
-            src_type_bits = type_bits(nir_op.input_types[i])<br>
-            if src_type_bits != 0:<br>
-               assert src_bits == src_type_bits, \<br>
-                      'Source {0} of nir_op_{1} must be a {2}-bit value but ' \<br>
-                      'the only possible matched values are {3}-bit: {4}' \<br>
-                      .format(i, val.opcode, src_type_bits, src_bits, str(val))<br>
-            else:<br>
-               assert val.common_size == 0 or src_bits == val.common_size, \<br>
-                      'Expression cannot have both {0}-bit and {1}-bit ' \<br>
-                      'variable-width sources: {2}' \<br>
-                      .format(src_bits, val.common_size, str(val))<br>
-               val.common_size = src_bits<br>
-<br>
-         dst_type_bits = type_bits(nir_op.output_type)<br>
-         if dst_type_bits != 0:<br>
-            assert val.bit_size == 0 or val.bit_size == dst_type_bits, \<br>
-                   'nir_op_{0} produces a {1}-bit result but a {2}-bit ' \<br>
-                   'result was requested' \<br>
-                   .format(val.opcode, dst_type_bits, val.bit_size)<br>
-            return dst_type_bits<br>
-         else:<br>
-            if val.common_size != 0:<br>
-               assert val.bit_size == 0 or val.bit_size == val.common_size, \<br>
-                      'Variable width expression musr be {0}-bit based on ' \<br>
-                      'the sources but a {1}-bit result was requested: {2}' \<br>
-                      .format(val.common_size, val.bit_size, str(val))<br>
-            else:<br>
-               val.common_size = val.bit_size<br>
-            return val.common_size<br>
-<br>
-   def _propagate_bit_class_down(self, val, bit_class):<br>
-      if isinstance(val, Constant):<br>
-         assert val.bit_size == 0 or val.bit_size == bit_class, \<br>
-                'Constant is {0}-bit but a {1}-bit value is required: {2}' \<br>
-                .format(val.bit_size, bit_class, str(val))<br>
-<br>
-      elif isinstance(val, Variable):<br>
-         assert val.bit_size == 0 or val.bit_size == bit_class, \<br>
-                'Variable is {0}-bit but a {1}-bit value is required: {2}' \<br>
-                .format(val.bit_size, bit_class, str(val))<br>
-         self._set_var_bit_class(val, bit_class)<br>
+   from being allowed, since the search expression has the bit size of a and b,<br>
+   which can't be specialized to 32 which is the bitsize of the replace<br>
+   expression. It also prevents something like:<br>
<br>
-      elif isinstance(val, Expression):<br>
-         nir_op = opcodes[val.opcode]<br>
-         dst_type_bits = type_bits(nir_op.output_type)<br>
-         if dst_type_bits != 0:<br>
-            assert bit_class == 0 or bit_class == dst_type_bits, \<br>
-                   'nir_op_{0} produces a {1}-bit result but the parent ' \<br>
-                   'expression wants a {2}-bit value' \<br>
-                   .format(val.opcode, dst_type_bits, bit_class)<br>
-         else:<br>
-            assert val.common_size == 0 or val.common_size == bit_class, \<br>
-                   'Variable-width expression produces a {0}-bit result ' \<br>
-                   'based on the source widths but the parent expression ' \<br>
-                   'wants a {1}-bit value: {2}' \<br>
-                   .format(val.common_size, bit_class, str(val))<br>
-            val.common_size = bit_class<br>
-<br>
-         if val.common_size:<br>
-            common_class = val.common_size<br>
-         elif nir_op.num_inputs:<br>
-            # If we got here then we have no idea what the actual size is.<br>
-            # Instead, we use a generic class<br>
-            common_class = self._new_class()<br>
-<br>
-         for i in range(nir_op.num_inputs):<br>
-            src_type_bits = type_bits(nir_op.input_types[i])<br>
-            if src_type_bits != 0:<br>
-               self._propagate_bit_class_down(val.sources[i], src_type_bits)<br>
-            else:<br>
-               self._propagate_bit_class_down(val.sources[i], common_class)<br>
-<br>
-   def _validate_bit_class_up(self, val):<br>
-      if isinstance(val, Constant):<br>
-         return val.bit_size<br>
-<br>
-      elif isinstance(val, Variable):<br>
-         var_class = self._get_var_bit_class(val)<br>
-         # By the time we get to validation, every variable should have a class<br>
-         assert var_class != 0<br>
-<br>
-         # If we have an explicit size provided by the user, the variable<br>
-         # *must* exactly match the search.  It cannot be implicitly sized<br>
-         # because otherwise we could end up with a conflict at runtime.<br>
-         assert val.bit_size == 0 or val.bit_size == var_class<br>
-<br>
-         return var_class<br>
+   (('b2i', ('i2b', a)), ('ineq', a, 0))<br>
+<br>
+   since the bitsize of 'b2i', which can be anything, can't be specialized to<br>
+   the bitsize of a.<br>
<br>
+   After doing all this, we check that every subexpression of the replacement<br>
+   was assigned a constant bitsize, the bitsize of a variable, or the bitsize<br>
+   of the search expresssion, since those are the things that are known when<br>
+   constructing the replacement expresssion. Finally, we record the bitsize<br>
+   needed in nir_search_value so that we know what to do when building the<br>
+   replacement expression.<br>
+   """<br>
+<br>
+   def __init__(self, varset):<br>
+      self._var_classes = [None] * len(varset.names)<br>
+<br>
+   def compare_bitsizes(self, a, b):<br>
+      """Determines which bitsize class is a specialization of the other, or<br>
+      whether neither is. When we merge two different bitsizes, the<br>
+      less-specialized bitsize always points to the more-specialized one, so<br>
+      that calling get_bit_size() always gets you the most specialized bitsize.<br>
+      The specialization partial order is given by:<br>
+      - Physical bitsizes are always the most specialized, and a different<br>
+        bitsize can never specialize another.<br>
+      - In the search expression, variables can always be specialized to each<br>
+        other and to physical bitsizes. In the replace expression, we disallow<br>
+        this to avoid adding extra constraints to the search expression that<br>
+        the user didn't specify.<br>
+      - Expressions and constants without a bitsize can always be specialized to<br>
+        each other and variables, but not the other way around.<br>
+<br>
+        We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,<br>
+        and None if they are not comparable (neither a <= b nor b <= a).<br>
+      """<br>
+      if isinstance(a, int):<br>
+         if isinstance(b, int):<br>
+            return 0 if a == b else None<br>
+         elif isinstance(b, Variable):<br>
+            return -1 if self.is_search else None<br>
+         else:<br>
+            return -1<br>
+      elif isinstance(a, Variable):<br>
+         if isinstance(b, int):<br>
+            return 1 if self.is_search else None<br>
+         elif isinstance(b, Variable):<br>
+            return 0 if self.is_search or a.index == b.index else None<br>
+         else:<br>
+            return -1<br>
+      else:<br>
+         if isinstance(b, int):<br>
+            return 1<br>
+         elif isinstance(b, Variable):<br>
+            return 1<br>
+         else:<br>
+            return 0<br>
+<br>
+   def unify_bit_size(self, a, b, error_msg):<br>
+      """Record that a must have the same bit-size as b. If both<br>
+      have been assigned conflicting physical bit-sizes, call "error_msg" with<br>
+      the bit-sizes of self and other to get a message and raise an error.<br>
+      In the replace expression, disallow merging variables with other<br>
+      variables and physical bit-sizes as well.<br>
+      """<br>
+      a_bit_size = a.get_bit_size()<br>
+      b_bit_size = b if isinstance(b, int) else b.get_bit_size()<br>
+<br>
+      cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)<br>
+<br>
+      assert cmp_result != None, \<br>
+         error_msg(a_bit_size, b_bit_size)<br>
+<br>
+      if cmp_result < 0:<br>
+         b_bit_size.set_bit_size(a)<br>
+      elif not isinstance(a_bit_size, int):<br>
+         a_bit_size.set_bit_size(b)<br>
+<br>
+   def merge_variables(self, val):<br>
+      """Perform the first part of type inference by merging all the different<br>
+      uses of the same variable. We always do this as if we're in the search<br>
+      expression, even if we're actually not, since otherwise we'd get errors<br>
+      if the search expression specified some constraint but the replace<br>
+      expression didn't, because we'd be merging a variable and a constant.<br>
+      """<br>
+      if isinstance(val, Variable):<br>
+         if self._var_classes[val.index] == None:<br>
+            self._var_classes[val.index] = val<br>
+         else:<br>
+            other = self._var_classes[val.index]<br>
+            self.unify_bit_size(other, val,<br>
+                  lambda other_bit_size, bit_size:<br>
+                     'Variable {} has conflicting bit size requirements: ' \<br>
+                     'it must have bit size {} and {}'.format(<br>
+                        val.var_name, other_bit_size, bit_size))<br>
       elif isinstance(val, Expression):<br>
-         nir_op = opcodes[val.opcode]<br>
-         val.common_class = 0<br>
-         for i in range(nir_op.num_inputs):<br>
-            src_class = self._validate_bit_class_up(val.sources[i])<br>
-            if src_class == 0:<br>
+         for src in val.sources:<br>
+            self.merge_variables(src)<br>
+      <br>
+   def validate_value(self, val):<br>
+      """Validate the an expression by performing classic Hindley-Milner<br>
+      type inference on bitsizes. This will detect if there are any conflicting<br>
+      requirements, and unify variables so that we know which variables must<br>
+      have the same bitsize. If we're operating on the replace expression, we<br>
+      will refuse to merge different variables together or merge a variable<br>
+      with a constant, in order to prevent surprises due to rules unexpectedly<br>
+      not matching at runtime.<br>
+      """<br>
+      if not isinstance(val, Expression):<br>
+         return<br>
+<br>
+      nir_op = opcodes[val.opcode]<br>
+      assert len(val.sources) == nir_op.num_inputs, \<br>
+         "Expression {} has {} sources, expected {}".format(<br>
+            val, len(val.sources), nir_op.num_inputs)<br>
+<br>
+      for src in val.sources:<br>
+         self.validate_value(src)<br>
+<br>
+      dst_type_bits = type_bits(nir_op.output_type)<br>
+<br>
+      # First, unify all the sources. That way, an error coming up because two<br>
+      # sources have an incompatible bit-size won't produce an error message<br>
+      # involving the destination.<br>
+      first_src = None<br></blockquote><div><br></div><div>Maybe call this first_unsized_src?<br></div><div> </div><blockquote class="gmail_quote" style="margin:0 0 0 .8ex;border-left:1px #ccc solid;padding-left:1ex">
+      for src_type, src in zip(nir_op.input_types, val.sources):<br>
+         src_type_bits = type_bits(src_type)<br>
+         if src_type_bits == 0:<br>
+            if first_src == None:<br>
+               first_src = src<br>
                continue<br>
<br>
-            src_type_bits = type_bits(nir_op.input_types[i])<br>
-            if src_type_bits != 0:<br>
-               assert src_class == src_type_bits<br>
-            else:<br>
-               assert val.common_class == 0 or src_class == val.common_class<br>
-               val.common_class = src_class<br>
-<br>
-         dst_type_bits = type_bits(nir_op.output_type)<br>
-         if dst_type_bits != 0:<br>
-            assert val.bit_size == 0 or val.bit_size == dst_type_bits<br>
-            return dst_type_bits<br>
+            self.unify_bit_size(first_src, src,<br>
+               lambda first_src_bit_size, src_bit_size:<br>
+                  'Source {} of {} must have bit size {}, while source {} must '\<br>
+                  'have incompatible bit size {}'.format(<br>
+                     first_src, val, first_src_bit_size, src, src_bit_size)<br>
+                  if self.is_search else<br>
+                  'Sources {} (bit size of {}) and {} (bit size of {}) ' \<br>
+                  'of {} may not have the same bit size when building the ' \<br>
+                  'replacement expression.'.format(<br>
+                     first_src, first_src_bit_size, src, src_bit_size, val))<br>
          else:<br>
-            if val.common_class != 0:<br>
-               assert val.bit_size == 0 or val.bit_size == val.common_class<br>
-            else:<br>
-               val.common_class = val.bit_size<br>
-            return val.common_class<br>
+            self.unify_bit_size(src, src_type_bits,<br>
+               lambda src_bit_size, unused:<br>
+                  '{} must have {} bits, but as a source of nir_op_{} '\<br>
+                  'it must have {} bits'.format(src, src_bit_size, <a href="http://nir_op.name" rel="noreferrer" target="_blank">nir_op.name</a>, src_type_bits)<br>
+                  if self.is_search else<br>
+                  '{} has the bit size of {}, but as a source of nir_op_{} '\<br>
+                  'it must have {} bits, which may not be the same'.format(<br>
+                     src, src_bit_size, <a href="http://nir_op.name" rel="noreferrer" target="_blank">nir_op.name</a>, src_type_bits))<br>
+<br>
+      if dst_type_bits == 0:<br>
+         for src_type, src in zip(nir_op.input_types, val.sources):<br></blockquote><div><br></div><div>We don't need to walk the list of sources.  It should be sufficient to do</div><div><br></div><div>if dst_type_bit_size == 0:</div><div>   if first_unsized_src is not None:</div><div>      self.unify_bit_size(val, first_unsized_src, ...)</div><div>else:</div><div>   self.unify_bit_size(val, dst_type_bit_size, ...)<br></div><div> </div><blockquote class="gmail_quote" style="margin:0 0 0 .8ex;border-left:1px #ccc solid;padding-left:1ex">
+            src_type_bits = type_bits(src_type)<br>
+            if src_type_bits == 0:<br>
+               self.unify_bit_size(val, src,<br>
+                  lambda val_bit_size, src_bit_size:<br>
+                     '{} must have the bit size of {}, while its source {} must '<br>
+                     'have incompatible bit size {}'.format(<br>
+                        val, val_bit_size, src, src_bit_size)<br>
+                     if self.is_search else<br>
+                     '{} must have {} bits, but its source {} ' \<br>
+                     '(bit size of {}) may not have that bit size ' \<br>
+                     'when building the replacement.'.format(<br>
+                        val, val_bit_size, src, src_bit_size))<br>
+      else:<br>
+         self.unify_bit_size(val, dst_type_bits,<br>
+            lambda dst_bit_size, unused:<br>
+               '{} must have {} bits, but as a destination of nir_op_{} '\<br>
+               'it must have {} bits'.format(val, dst_bit_size, <a href="http://nir_op.name" rel="noreferrer" target="_blank">nir_op.name</a>, dst_type_bits))<br>
+<br>
+   def validate_replace(self, val, search):<br>
+      bit_size = val.get_bit_size()<br>
+      assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \<br>
+            bit_size == search.get_bit_size(), \<br>
+            'Ambiguous bit size for replacement value {}: ' \<br>
+            'it cannot be deduced from a variable, a fixed bit size ' \<br>
+            'somewhere, or the search expression.'.format(val)<br>
+<br>
+      if isinstance(val, Expression):<br>
+         for src in val.sources:<br>
+            self.validate_replace(src, search)<br>
<br>
-   def _validate_bit_class_down(self, val, bit_class):<br>
-      # At this point, everything *must* have a bit class.  Otherwise, we have<br>
-      # a value we don't know how to define.<br>
-      assert bit_class != 0<br>
+   def validate(self, search, replace):<br>
+      self.is_search = True<br>
+      self.merge_variables(search)<br>
+      self.merge_variables(replace)<br>
+      self.validate_value(search)<br>
<br>
-      if isinstance(val, Constant):<br>
-         assert val.bit_size == 0 or val.bit_size == bit_class<br>
+      self.is_search = False<br>
+      self.validate_value(replace)<br>
<br>
-      elif isinstance(val, Variable):<br>
-         assert val.bit_size == 0 or val.bit_size == bit_class<br>
+      # Check that search is always more specialized than replace. Note that<br>
+      # we're doing this in replace mode, disallowing merging variables.<br></blockquote><div><br></div><div><br></div><div> </div><blockquote class="gmail_quote" style="margin:0 0 0 .8ex;border-left:1px #ccc solid;padding-left:1ex">
+      search_bit_size = search.get_bit_size()<br>
+      replace_bit_size = replace.get_bit_size()<br>
+      cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)<br>
<br>
-      elif isinstance(val, Expression):<br>
-         nir_op = opcodes[val.opcode]<br>
-         dst_type_bits = type_bits(nir_op.output_type)<br>
-         if dst_type_bits != 0:<br>
-            assert bit_class == dst_type_bits<br>
-         else:<br>
-            assert val.common_class == 0 or val.common_class == bit_class<br>
-            val.common_class = bit_class<br>
-<br>
-         for i in range(nir_op.num_inputs):<br>
-            src_type_bits = type_bits(nir_op.input_types[i])<br>
-            if src_type_bits != 0:<br>
-               self._validate_bit_class_down(val.sources[i], src_type_bits)<br>
-            else:<br>
-               self._validate_bit_class_down(val.sources[i], val.common_class)<br>
+      assert cmp_result != None and cmp_result <= 0, \<br></blockquote><div><br></div><div>Should be "is not None" rather than "!= None"<br></div><div> </div><blockquote class="gmail_quote" style="margin:0 0 0 .8ex;border-left:1px #ccc solid;padding-left:1ex">
+         'The search expression bit size {} and replace expression ' \<br>
+         'bit size {} may not be the same'.format(<br>
+               search_bit_size, replace_bit_size)<br>
+<br>
+      replace.set_bit_size(search)<br>
+<br>
+      self.validate_replace(replace, search)<br></blockquote><div><br></div><div>Ok... I think I've convinced myself that this is correct.  I think it'll also be very easy to modify for either approach in my nasty patch series.<br></div><div> </div><blockquote class="gmail_quote" style="margin:0 0 0 .8ex;border-left:1px #ccc solid;padding-left:1ex">
<br>
 _optimization_ids = itertools.count()<br>
<br>
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c<br>
index 0270302fd3d..a41fca876d5 100644<br>
--- a/src/compiler/nir/nir_search.c<br>
+++ b/src/compiler/nir/nir_search.c<br>
@@ -118,7 +118,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,<br>
       new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];<br>
<br>
    /* If the value has a specific bit size and it doesn't match, bail */<br>
-   if (value->bit_size &&<br>
+   if (value->bit_size > 0 &&<br>
        nir_src_bit_size(instr->src[src].src) != value->bit_size)<br>
       return false;<br>
<br>
@@ -228,7 +228,7 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,<br>
<br>
    assert(instr->dest.dest.is_ssa);<br>
<br>
-   if (expr->value.bit_size &&<br>
+   if (expr->value.bit_size > 0 &&<br>
        instr->dest.dest.ssa.bit_size != expr->value.bit_size)<br>
       return false;<br>
<br>
@@ -290,128 +290,21 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,<br>
    }<br>
 }<br>
<br>
-typedef struct bitsize_tree {<br>
-   unsigned num_srcs;<br>
-   struct bitsize_tree *srcs[4];<br>
-<br>
-   unsigned common_size;<br>
-   bool is_src_sized[4];<br>
-   bool is_dest_sized;<br>
-<br>
-   unsigned dest_size;<br>
-   unsigned src_size[4];<br>
-} bitsize_tree;<br>
-<br>
-static bitsize_tree *<br>
-build_bitsize_tree(void *mem_ctx, struct match_state *state,<br>
-                   const nir_search_value *value)<br>
-{<br>
-   bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree);<br>
-<br>
-   switch (value->type) {<br>
-   case nir_search_value_expression: {<br>
-      nir_search_expression *expr = nir_search_value_as_expression(value);<br>
-      nir_op_info info = nir_op_infos[expr->opcode];<br>
-      tree->num_srcs = info.num_inputs;<br>
-      tree->common_size = 0;<br>
-      for (unsigned i = 0; i < info.num_inputs; i++) {<br>
-         tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]);<br>
-         if (tree->is_src_sized[i])<br>
-            tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]);<br>
-         tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]);<br>
-      }<br>
-      tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);<br>
-      if (tree->is_dest_sized)<br>
-         tree->dest_size = nir_alu_type_get_type_size(info.output_type);<br>
-      break;<br>
-   }<br>
-<br>
-   case nir_search_value_variable: {<br>
-      nir_search_variable *var = nir_search_value_as_variable(value);<br>
-      tree->num_srcs = 0;<br>
-      tree->is_dest_sized = true;<br>
-      tree->dest_size = nir_src_bit_size(state->variables[var->variable].src);<br>
-      break;<br>
-   }<br>
-<br>
-   case nir_search_value_constant: {<br>
-      tree->num_srcs = 0;<br>
-      tree->is_dest_sized = false;<br>
-      tree->common_size = 0;<br>
-      break;<br>
-   }<br>
-   }<br>
-<br>
-   if (value->bit_size) {<br>
-      assert(!tree->is_dest_sized || tree->dest_size == value->bit_size);<br>
-      tree->common_size = value->bit_size;<br>
-   }<br>
-<br>
-   return tree;<br>
-}<br>
-<br>
 static unsigned<br>
-bitsize_tree_filter_up(bitsize_tree *tree)<br>
+replace_bitsize(const nir_search_value *value, unsigned search_bitsize,<br>
+                struct match_state *state)<br>
 {<br>
-   for (unsigned i = 0; i < tree->num_srcs; i++) {<br>
-      unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]);<br>
-      if (src_size == 0)<br>
-         continue;<br>
-<br>
-      if (tree->is_src_sized[i]) {<br>
-         assert(src_size == tree->src_size[i]);<br>
-      } else if (tree->common_size != 0) {<br>
-         assert(src_size == tree->common_size);<br>
-         tree->src_size[i] = src_size;<br>
-      } else {<br>
-         tree->common_size = src_size;<br>
-         tree->src_size[i] = src_size;<br>
-      }<br>
-   }<br>
-<br>
-   if (tree->num_srcs && tree->common_size) {<br>
-      if (tree->dest_size == 0)<br>
-         tree->dest_size = tree->common_size;<br>
-      else if (!tree->is_dest_sized)<br>
-         assert(tree->dest_size == tree->common_size);<br>
-<br>
-      for (unsigned i = 0; i < tree->num_srcs; i++) {<br>
-         if (!tree->src_size[i])<br>
-            tree->src_size[i] = tree->common_size;<br>
-      }<br>
-   }<br>
-<br>
-   return tree->dest_size;<br>
-}<br>
-<br>
-static void<br>
-bitsize_tree_filter_down(bitsize_tree *tree, unsigned size)<br>
-{<br>
-   if (tree->dest_size)<br>
-      assert(tree->dest_size == size);<br>
-   else<br>
-      tree->dest_size = size;<br>
-<br>
-   if (!tree->is_dest_sized) {<br>
-      if (tree->common_size)<br>
-         assert(tree->common_size == size);<br>
-      else<br>
-         tree->common_size = size;<br>
-   }<br>
-<br>
-   for (unsigned i = 0; i < tree->num_srcs; i++) {<br>
-      if (!tree->src_size[i]) {<br>
-         assert(tree->common_size);<br>
-         tree->src_size[i] = tree->common_size;<br>
-      }<br>
-      bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]);<br>
-   }<br>
+   if (value->bit_size > 0)<br>
+      return value->bit_size;<br>
+   if (value->bit_size < 0)<br>
+      return nir_src_bit_size(state->variables[-value->bit_size - 1].src);<br>
+   return search_bitsize;<br>
 }<br>
<br>
 static nir_alu_src<br>
 construct_value(nir_builder *build,<br>
                 const nir_search_value *value,<br>
-                unsigned num_components, bitsize_tree *bitsize,<br>
+                unsigned num_components, unsigned search_bitsize,<br>
                 struct match_state *state,<br>
                 nir_instr *instr)<br>
 {<br>
@@ -424,7 +317,7 @@ construct_value(nir_builder *build,<br>
<br>
       nir_alu_instr *alu = nir_alu_instr_create(build->shader, expr->opcode);<br>
       nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,<br>
-                        bitsize->dest_size, NULL);<br>
+                        replace_bitsize(value, search_bitsize, state), NULL);<br>
       alu->dest.write_mask = (1 << num_components) - 1;<br>
       alu->dest.saturate = false;<br>
<br>
@@ -443,7 +336,7 @@ construct_value(nir_builder *build,<br>
             num_components = nir_op_infos[alu->op].input_sizes[i];<br>
<br>
          alu->src[i] = construct_value(build, expr->srcs[i],<br>
-                                       num_components, bitsize->srcs[i],<br>
+                                       num_components, search_bitsize,<br>
                                        state, instr);<br>
       }<br>
<br>
@@ -472,16 +365,17 @@ construct_value(nir_builder *build,<br>
<br>
    case nir_search_value_constant: {<br>
       const nir_search_constant *c = nir_search_value_as_constant(value);<br>
+      unsigned bit_size = replace_bitsize(value, search_bitsize, state);<br>
<br>
       nir_ssa_def *cval;<br>
       switch (c->type) {<br>
       case nir_type_float:<br>
-         cval = nir_imm_floatN_t(build, c->data.d, bitsize->dest_size);<br>
+         cval = nir_imm_floatN_t(build, c->data.d, bit_size);<br>
          break;<br>
<br>
       case nir_type_int:<br>
       case nir_type_uint:<br>
-         cval = nir_imm_intN_t(build, c->data.i, bitsize->dest_size);<br>
+         cval = nir_imm_intN_t(build, c->data.i, bit_size);<br>
          break;<br>
<br>
       case nir_type_bool:<br>
@@ -526,16 +420,12 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,<br>
                          swizzle, &state))<br>
       return NULL;<br>
<br>
-   void *bitsize_ctx = ralloc_context(NULL);<br>
-   bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace);<br>
-   bitsize_tree_filter_up(tree);<br>
-   bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size);<br>
-<br>
    build->cursor = nir_before_instr(&instr->instr);<br>
<br>
    nir_alu_src val = construct_value(build, replace,<br>
                                      instr->dest.dest.ssa.num_components,<br>
-                                     tree, &state, &instr->instr);<br>
+                                     instr->dest.dest.ssa.bit_size,<br>
+                                     &state, &instr->instr);<br>
<br>
    /* Inserting a mov may be unnecessary.  However, it's much easier to<br>
     * simply let copy propagation clean this up than to try to go through<br>
@@ -551,7 +441,5 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,<br>
     */<br>
    nir_instr_remove(&instr->instr);<br>
<br>
-   ralloc_free(bitsize_ctx);<br>
-<br>
    return ssa_val;<br>
 }<br>
diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h<br>
index df4189ede74..82c3b2e1bc8 100644<br>
--- a/src/compiler/nir/nir_search.h<br>
+++ b/src/compiler/nir/nir_search.h<br>
@@ -43,7 +43,7 @@ typedef enum {<br>
 typedef struct {<br>
    nir_search_value_type type;<br>
<br></blockquote><div><br></div><div>/** Bit size of the value.</div><div> *</div><div> * For a search expression:</div><div> *  - If bit_size > 0, the value only matches a SSA value with the given bit size</div><div> *  - If bit_size <= 0, the value matches any size SSA value<br></div><div> *</div><div> * For a replace expression:<br></div><div> *  - If bit_size > 0, the value is constructed with the given bit size</div><div> *  - If bit_size < 0, the value is constructed with the same bit size as variable (-bit_size - 1)<br></div><div> *  - If bit_size == 0, the value is constructed with the same bit size as the search value<br></div><div> */</div><div> </div><blockquote class="gmail_quote" style="margin:0 0 0 .8ex;border-left:1px #ccc solid;padding-left:1ex">
-   unsigned bit_size;<br>
+   int bit_size;<br>
 } nir_search_value;<br>
<br>
 typedef struct {<br>
-- <br>
2.17.2<br>
<br>
_______________________________________________<br>
mesa-dev mailing list<br>
<a href="mailto:mesa-dev@lists.freedesktop.org" target="_blank">mesa-dev@lists.freedesktop.org</a><br>
<a href="https://lists.freedesktop.org/mailman/listinfo/mesa-dev" rel="noreferrer" target="_blank">https://lists.freedesktop.org/mailman/listinfo/mesa-dev</a><br>
</blockquote></div></div>