[Mesa-dev] [PATCH v3] nir/algebraic: Rewrite bit-size inference

Jason Ekstrand jason at jlekstrand.net
Wed Dec 5 13:49:57 UTC 2018


Rb me.  Now you can review my comparison patches. 😁

On December 5, 2018 06:20:49 Connor Abbott <cwabbott0 at gmail.com> wrote:

> Before this commit, there were two copies of the algorithm: one in C,
> that we would use to figure out what bit-size to give the replacement
> expression, and one in Python, that emulated the C one and tried to
> prove that the C algorithm would never fail to correctly assign
> bit-sizes. That seemed pretty fragile, and likely to fall over if we
> make any changes. Furthermore, the C code was really just recomputing
> more-or-less the same thing as the Python code every time. Instead, we
> can just store the results of the Python algorithm in the C
> datastructure, and consult it to compute the bitsize of each value,
> moving the "brains" entirely into Python. Since the Python algorithm no
> longer has to match C, it's also a lot easier to change it to something
> more closely approximating an actual type-inference algorithm. The
> algorithm used is based on Hindley-Milner, although deliberately
> weakened a little. It's a few more lines than the old one, judging by
> the diffstat, but I think it's easier to verify that it's correct while
> being as general as possible.
>
> We could split this up into two changes, first making the C code use the
> results of the Python code and then rewriting the Python algorithm, but
> since the old algorithm never tracked which variable each equivalence
> class, it would mean we'd have to add some non-trivial code which would
> then get thrown away. I think it's better to see the final state all at
> once, although I could also try splitting it up.
>
> v2:
> - Replace instances of "== None" and "!= None" with "is None" and
> "is not None".
> - Rename first_src to first_unsized_src
> - Only merge the destination with the first unsized source, since the
> sources have already been merged.
> - Add a comment explaining what nir_search_value::bit_size now means.
> v3:
> - Fix one last instance to use "is not" instead of !=
> - Don't try to be so clever when choosing which error message to print
> based on whether we're in the search or replace expression.
> - Fix trailing whitespace.
> ---
> src/compiler/nir/nir_algebraic.py | 520 ++++++++++++++++--------------
> src/compiler/nir/nir_search.c     | 146 +--------
> src/compiler/nir/nir_search.h     |  17 +-
> 3 files changed, 317 insertions(+), 366 deletions(-)
>
> diff --git a/src/compiler/nir/nir_algebraic.py 
> b/src/compiler/nir/nir_algebraic.py
> index 728196136ab..efd6e52cdb9 100644
> --- a/src/compiler/nir/nir_algebraic.py
> +++ b/src/compiler/nir/nir_algebraic.py
> @@ -88,7 +88,7 @@ class Value(object):
>
>    __template = mako.template.Template("""
> static const ${val.c_type} ${val.name} = {
> -   { ${val.type_enum}, ${val.bit_size} },
> +   { ${val.type_enum}, ${val.c_bit_size} },
> % if isinstance(val, Constant):
>    ${val.type()}, { ${val.hex()} /* ${val.value} */ },
> % elif isinstance(val, Variable):
> @@ -112,6 +112,40 @@ static const ${val.c_type} ${val.name} = {
>    def __str__(self):
>       return self.in_val
>
> +   def get_bit_size(self):
> +      """Get the physical bit-size that has been chosen for this value, or if
> +      there is none, the canonical value which currently represents this
> +      bit-size class. Variables will be preferred, i.e. if there are any
> +      variables in the equivalence class, the canonical value will be a
> +      variable. We do this since we'll need to know which variable each value
> +      is equivalent to when constructing the replacement expression. This is
> +      the "find" part of the union-find algorithm.
> +      """
> +      bit_size = self
> +
> +      while isinstance(bit_size, Value):
> +         if bit_size._bit_size is None:
> +            break
> +         bit_size = bit_size._bit_size
> +
> +      if bit_size is not self:
> +         self._bit_size = bit_size
> +      return bit_size
> +
> +   def set_bit_size(self, other):
> +      """Make self.get_bit_size() return what other.get_bit_size() return
> +      before calling this, or just "other" if it's a concrete bit-size. 
> This is
> +      the "union" part of the union-find algorithm.
> +      """
> +
> +      self_bit_size = self.get_bit_size()
> +      other_bit_size = other if isinstance(other, int) else 
> other.get_bit_size()
> +
> +      if self_bit_size == other_bit_size:
> +         return
> +
> +      self_bit_size._bit_size = other_bit_size
> +
>    @property
>    def type_enum(self):
>       return "nir_search_value_" + self.type_str
> @@ -124,6 +158,21 @@ static const ${val.c_type} ${val.name} = {
>    def c_ptr(self):
>       return "&{0}.value".format(self.name)
>
> +   @property
> +   def c_bit_size(self):
> +      bit_size = self.get_bit_size()
> +      if isinstance(bit_size, int):
> +         return bit_size
> +      elif isinstance(bit_size, Variable):
> +         return -bit_size.index - 1
> +      else:
> +         # If the bit-size class is neither a variable, nor an actual 
> bit-size, then
> +         # - If it's in the search expression, we don't need to check anything
> +         # - If it's in the replace expression, either it's ambiguous (in 
> which
> +         # case we'd reject it), or it equals the bit-size of the search value
> +         # We represent these cases with a 0 bit-size.
> +         return 0
> +
>    def render(self):
>       return self.__template.render(val=self,
>                                     Constant=Constant,
> @@ -140,14 +189,14 @@ class Constant(Value):
>       if isinstance(val, (str)):
>          m = _constant_re.match(val)
>          self.value = ast.literal_eval(m.group('value'))
> -         self.bit_size = int(m.group('bits')) if m.group('bits') else 0
> +         self._bit_size = int(m.group('bits')) if m.group('bits') else None
>       else:
>          self.value = val
> -         self.bit_size = 0
> +         self._bit_size = None
>
>       if isinstance(self.value, bool):
> -         assert self.bit_size == 0 or self.bit_size == 32
> -         self.bit_size = 32
> +         assert self._bit_size is None or self._bit_size == 32
> +         self._bit_size = 32
>
>    def hex(self):
>       if isinstance(self.value, (bool)):
> @@ -191,11 +240,11 @@ class Variable(Value):
>       self.is_constant = m.group('const') is not None
>       self.cond = m.group('cond')
>       self.required_type = m.group('type')
> -      self.bit_size = int(m.group('bits')) if m.group('bits') else 0
> +      self._bit_size = int(m.group('bits')) if m.group('bits') else None
>
>       if self.required_type == 'bool':
> -         assert self.bit_size == 0 or self.bit_size == 32
> -         self.bit_size = 32
> +         assert self._bit_size is None or self._bit_size == 32
> +         self._bit_size = 32
>
>       if self.required_type is not None:
>          assert self.required_type in ('float', 'bool', 'int', 'uint')
> @@ -225,7 +274,7 @@ class Expression(Value):
>       assert m and m.group('opcode') is not None
>
>       self.opcode = m.group('opcode')
> -      self.bit_size = int(m.group('bits')) if m.group('bits') else 0
> +      self._bit_size = int(m.group('bits')) if m.group('bits') else None
>       self.inexact = m.group('inexact') is not None
>       self.cond = m.group('cond')
>       self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
> @@ -235,40 +284,6 @@ class Expression(Value):
>       srcs = "\n".join(src.render() for src in self.sources)
>       return srcs + super(Expression, self).render()
>
> -class IntEquivalenceRelation(object):
> -   """A class representing an equivalence relation on integers.
> -
> -   Each integer has a canonical form which is the maximum integer to which it
> -   is equivalent.  Two integers are equivalent precisely when they have the
> -   same canonical form.
> -
> -   The convention of maximum is explicitly chosen to make using it in
> -   BitSizeValidator easier because it means that an actual bit_size (if any)
> -   will always be the canonical form.
> -   """
> -   def __init__(self):
> -      self._remap = {}
> -
> -   def get_canonical(self, x):
> -      """Get the canonical integer corresponding to x."""
> -      if x in self._remap:
> -         return self.get_canonical(self._remap[x])
> -      else:
> -         return x
> -
> -   def add_equiv(self, a, b):
> -      """Add an equivalence and return the canonical form."""
> -      c = max(self.get_canonical(a), self.get_canonical(b))
> -      if a != c:
> -         assert a < c
> -         self._remap[a] = c
> -
> -      if b != c:
> -         assert b < c
> -         self._remap[b] = c
> -
> -      return c
> -
> class BitSizeValidator(object):
>    """A class for validating bit sizes of expressions.
>
> @@ -296,7 +311,7 @@ class BitSizeValidator(object):
>    inference can be ambiguous or contradictory.  Consider, for instance, the
>    following transformation:
>
> -   (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
> +   (('usub_borrow', a, b), ('b2i at 32', ('ult', a, b)))
>
>    This transformation can potentially cause a problem because usub_borrow is
>    well-defined for any bit-size of integer.  However, b2i always generates a
> @@ -315,217 +330,250 @@ class BitSizeValidator(object):
>    generate any code.  This ensures that bugs are caught at compile time
>    rather than at run time.
>
> -   The basic operation of the validator is very similar to the bitsize_tree in
> -   nir_search only a little more subtle.  Instead of simply tracking bit
> -   sizes, it tracks "bit classes" where each class is represented by an
> -   integer.  A value of 0 means we don't know anything yet, positive values
> -   are actual bit-sizes, and negative values are used to track equivalence
> -   classes of sizes that must be the same but have yet to receive an actual
> -   size.  The first stage uses the bitsize_tree algorithm to assign bit
> -   classes to each variable.  If it ever comes across an inconsistency, it
> -   assert-fails.  Then the second stage uses that information to prove that
> -   the resulting expression can always validly be constructed.
> -   """
> -
> -   def __init__(self, varset):
> -      self._num_classes = 0
> -      self._var_classes = [0] * len(varset.names)
> -      self._class_relation = IntEquivalenceRelation()
> -
> -   def validate(self, search, replace):
> -      search_dst_class = self._propagate_bit_size_up(search)
> -      if search_dst_class == 0:
> -         search_dst_class = self._new_class()
> -      self._propagate_bit_class_down(search, search_dst_class)
> -
> -      replace_dst_class = self._validate_bit_class_up(replace)
> -      if replace_dst_class != 0:
> -         assert search_dst_class != 0, \
> -                'Search expression matches any bit size but replace ' \
> -                'expression can only generate {0}-bit values' \
> -                .format(replace_dst_class)
> -
> -         assert search_dst_class == replace_dst_class, \
> -                'Search expression matches any {0}-bit values but replace ' \
> -                'expression can only generates {1}-bit values' \
> -                .format(search_dst_class, replace_dst_class)
> -
> -      self._validate_bit_class_down(replace, search_dst_class)
> -
> -   def _new_class(self):
> -      self._num_classes += 1
> -      return -self._num_classes
> -
> -   def _set_var_bit_class(self, var, bit_class):
> -      assert bit_class != 0
> -      var_class = self._var_classes[var.index]
> -      if var_class == 0:
> -         self._var_classes[var.index] = bit_class
> -      else:
> -         canon_var_class = self._class_relation.get_canonical(var_class)
> -         canon_bit_class = self._class_relation.get_canonical(bit_class)
> -         assert canon_var_class < 0 or canon_bit_class < 0 or \
> -                canon_var_class == canon_bit_class, \
> -                'Variable {0} cannot be both {1}-bit and {2}-bit' \
> -                .format(str(var), bit_class, var_class)
> -         var_class = self._class_relation.add_equiv(var_class, bit_class)
> -         self._var_classes[var.index] = var_class
> -
> -   def _get_var_bit_class(self, var):
> -      return self._class_relation.get_canonical(self._var_classes[var.index])
> -
> -   def _propagate_bit_size_up(self, val):
> -      if isinstance(val, (Constant, Variable)):
> -         return val.bit_size
> -
> -      elif isinstance(val, Expression):
> -         nir_op = opcodes[val.opcode]
> -         val.common_size = 0
> -         for i in range(nir_op.num_inputs):
> -            src_bits = self._propagate_bit_size_up(val.sources[i])
> -            if src_bits == 0:
> -               continue
> -
> -            src_type_bits = type_bits(nir_op.input_types[i])
> -            if src_type_bits != 0:
> -               assert src_bits == src_type_bits, \
> -                      'Source {0} of nir_op_{1} must be a {2}-bit value 
> but ' \
> -                      'the only possible matched values are {3}-bit: {4}' \
> -                      .format(i, val.opcode, src_type_bits, src_bits, 
> str(val))
> -            else:
> -               assert val.common_size == 0 or src_bits == val.common_size, \
> -                      'Expression cannot have both {0}-bit and {1}-bit ' \
> -                      'variable-width sources: {2}' \
> -                      .format(src_bits, val.common_size, str(val))
> -               val.common_size = src_bits
> -
> -         dst_type_bits = type_bits(nir_op.output_type)
> -         if dst_type_bits != 0:
> -            assert val.bit_size == 0 or val.bit_size == dst_type_bits, \
> -                   'nir_op_{0} produces a {1}-bit result but a {2}-bit ' \
> -                   'result was requested' \
> -                   .format(val.opcode, dst_type_bits, val.bit_size)
> -            return dst_type_bits
> -         else:
> -            if val.common_size != 0:
> -               assert val.bit_size == 0 or val.bit_size == val.common_size, \
> -                      'Variable width expression musr be {0}-bit based on ' \
> -                      'the sources but a {1}-bit result was requested: {2}' \
> -                      .format(val.common_size, val.bit_size, str(val))
> -            else:
> -               val.common_size = val.bit_size
> -            return val.common_size
> +   Each value maintains a "bit-size class", which is either an actual bit size
> +   or an equivalence class with other values that must have the same bit size.
> +   The validator works by combining bit-size classes with each other according
> +   to the NIR rules outlined above, checking that there are no 
> inconsistencies.
> +   When doing this for the replacement expression, we make sure to never 
> change
> +   the equivalence class of any of the search values. We could make the 
> example
> +   transforms above work by doing some extra run-time checking of the search
> +   expression, but we make the user specify those constraints themselves, to
> +   avoid any surprises. Since the replacement bitsizes can only be 
> connected to
> +   the source bitsize via variables (variables must have the same bitsize in
> +   the source and replacment expressions) or the roots of the expression (the
> +   replacement expression must produce the same bit size as the search
> +   expression), we prevent merging a variable with anything when 
> processing the
> +   replacement expression, or specializing the search bitsize
> +   with anything. The former prevents
>
> -   def _propagate_bit_class_down(self, val, bit_class):
> -      if isinstance(val, Constant):
> -         assert val.bit_size == 0 or val.bit_size == bit_class, \
> -                'Constant is {0}-bit but a {1}-bit value is required: {2}' \
> -                .format(val.bit_size, bit_class, str(val))
> +   (('bcsel', a, b, 0), ('iand', a, b))
>
> -      elif isinstance(val, Variable):
> -         assert val.bit_size == 0 or val.bit_size == bit_class, \
> -                'Variable is {0}-bit but a {1}-bit value is required: {2}' \
> -                .format(val.bit_size, bit_class, str(val))
> -         self._set_var_bit_class(val, bit_class)
> +   from being allowed, since we'd have to merge the bitsizes for a and b 
> due to
> +   the 'iand', while the latter prevents
>
> -      elif isinstance(val, Expression):
> -         nir_op = opcodes[val.opcode]
> -         dst_type_bits = type_bits(nir_op.output_type)
> -         if dst_type_bits != 0:
> -            assert bit_class == 0 or bit_class == dst_type_bits, \
> -                   'nir_op_{0} produces a {1}-bit result but the parent ' \
> -                   'expression wants a {2}-bit value' \
> -                   .format(val.opcode, dst_type_bits, bit_class)
> -         else:
> -            assert val.common_size == 0 or val.common_size == bit_class, \
> -                   'Variable-width expression produces a {0}-bit result ' \
> -                   'based on the source widths but the parent expression ' \
> -                   'wants a {1}-bit value: {2}' \
> -                   .format(val.common_size, bit_class, str(val))
> -            val.common_size = bit_class
> -
> -         if val.common_size:
> -            common_class = val.common_size
> -         elif nir_op.num_inputs:
> -            # If we got here then we have no idea what the actual size is.
> -            # Instead, we use a generic class
> -            common_class = self._new_class()
> -
> -         for i in range(nir_op.num_inputs):
> -            src_type_bits = type_bits(nir_op.input_types[i])
> -            if src_type_bits != 0:
> -               self._propagate_bit_class_down(val.sources[i], src_type_bits)
> -            else:
> -               self._propagate_bit_class_down(val.sources[i], common_class)
> +   (('usub_borrow', a, b), ('b2i at 32', ('ult', a, b)))
>
> -   def _validate_bit_class_up(self, val):
> -      if isinstance(val, Constant):
> -         return val.bit_size
> +   from being allowed, since the search expression has the bit size of a 
> and b,
> +   which can't be specialized to 32 which is the bitsize of the replace
> +   expression. It also prevents something like:
>
> -      elif isinstance(val, Variable):
> -         var_class = self._get_var_bit_class(val)
> -         # By the time we get to validation, every variable should have a 
> class
> -         assert var_class != 0
> +   (('b2i', ('i2b', a)), ('ineq', a, 0))
>
> -         # If we have an explicit size provided by the user, the variable
> -         # *must* exactly match the search.  It cannot be implicitly sized
> -         # because otherwise we could end up with a conflict at runtime.
> -         assert val.bit_size == 0 or val.bit_size == var_class
> +   since the bitsize of 'b2i', which can be anything, can't be specialized to
> +   the bitsize of a.
>
> -         return var_class
> +   After doing all this, we check that every subexpression of the replacement
> +   was assigned a constant bitsize, the bitsize of a variable, or the bitsize
> +   of the search expresssion, since those are the things that are known when
> +   constructing the replacement expresssion. Finally, we record the bitsize
> +   needed in nir_search_value so that we know what to do when building the
> +   replacement expression.
> +   """
>
> +   def __init__(self, varset):
> +      self._var_classes = [None] * len(varset.names)
> +
> +   def compare_bitsizes(self, a, b):
> +      """Determines which bitsize class is a specialization of the other, or
> +      whether neither is. When we merge two different bitsizes, the
> +      less-specialized bitsize always points to the more-specialized one, so
> +      that calling get_bit_size() always gets you the most specialized 
> bitsize.
> +      The specialization partial order is given by:
> +      - Physical bitsizes are always the most specialized, and a different
> +        bitsize can never specialize another.
> +      - In the search expression, variables can always be specialized to each
> +        other and to physical bitsizes. In the replace expression, we disallow
> +        this to avoid adding extra constraints to the search expression that
> +        the user didn't specify.
> +      - Expressions and constants without a bitsize can always be 
> specialized to
> +        each other and variables, but not the other way around.
> +
> +        We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 
> if a >= b,
> +        and None if they are not comparable (neither a <= b nor b <= a).
> +      """
> +      if isinstance(a, int):
> +         if isinstance(b, int):
> +            return 0 if a == b else None
> +         elif isinstance(b, Variable):
> +            return -1 if self.is_search else None
> +         else:
> +            return -1
> +      elif isinstance(a, Variable):
> +         if isinstance(b, int):
> +            return 1 if self.is_search else None
> +         elif isinstance(b, Variable):
> +            return 0 if self.is_search or a.index == b.index else None
> +         else:
> +            return -1
> +      else:
> +         if isinstance(b, int):
> +            return 1
> +         elif isinstance(b, Variable):
> +            return 1
> +         else:
> +            return 0
> +
> +   def unify_bit_size(self, a, b, error_msg):
> +      """Record that a must have the same bit-size as b. If both
> +      have been assigned conflicting physical bit-sizes, call "error_msg" with
> +      the bit-sizes of self and other to get a message and raise an error.
> +      In the replace expression, disallow merging variables with other
> +      variables and physical bit-sizes as well.
> +      """
> +      a_bit_size = a.get_bit_size()
> +      b_bit_size = b if isinstance(b, int) else b.get_bit_size()
> +
> +      cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
> +
> +      assert cmp_result is not None, \
> +         error_msg(a_bit_size, b_bit_size)
> +
> +      if cmp_result < 0:
> +         b_bit_size.set_bit_size(a)
> +      elif not isinstance(a_bit_size, int):
> +         a_bit_size.set_bit_size(b)
> +
> +   def merge_variables(self, val):
> +      """Perform the first part of type inference by merging all the different
> +      uses of the same variable. We always do this as if we're in the search
> +      expression, even if we're actually not, since otherwise we'd get errors
> +      if the search expression specified some constraint but the replace
> +      expression didn't, because we'd be merging a variable and a constant.
> +      """
> +      if isinstance(val, Variable):
> +         if self._var_classes[val.index] is None:
> +            self._var_classes[val.index] = val
> +         else:
> +            other = self._var_classes[val.index]
> +            self.unify_bit_size(other, val,
> +                  lambda other_bit_size, bit_size:
> +                     'Variable {} has conflicting bit size requirements: ' \
> +                     'it must have bit size {} and {}'.format(
> +                        val.var_name, other_bit_size, bit_size))
>       elif isinstance(val, Expression):
> -         nir_op = opcodes[val.opcode]
> -         val.common_class = 0
> -         for i in range(nir_op.num_inputs):
> -            src_class = self._validate_bit_class_up(val.sources[i])
> -            if src_class == 0:
> +         for src in val.sources:
> +            self.merge_variables(src)
> +
> +   def validate_value(self, val):
> +      """Validate the an expression by performing classic Hindley-Milner
> +      type inference on bitsizes. This will detect if there are any 
> conflicting
> +      requirements, and unify variables so that we know which variables must
> +      have the same bitsize. If we're operating on the replace expression, we
> +      will refuse to merge different variables together or merge a variable
> +      with a constant, in order to prevent surprises due to rules unexpectedly
> +      not matching at runtime.
> +      """
> +      if not isinstance(val, Expression):
> +         return
> +
> +      nir_op = opcodes[val.opcode]
> +      assert len(val.sources) == nir_op.num_inputs, \
> +         "Expression {} has {} sources, expected {}".format(
> +            val, len(val.sources), nir_op.num_inputs)
> +
> +      for src in val.sources:
> +         self.validate_value(src)
> +
> +      dst_type_bits = type_bits(nir_op.output_type)
> +
> +      # First, unify all the sources. That way, an error coming up because two
> +      # sources have an incompatible bit-size won't produce an error message
> +      # involving the destination.
> +      first_unsized_src = None
> +      for src_type, src in zip(nir_op.input_types, val.sources):
> +         src_type_bits = type_bits(src_type)
> +         if src_type_bits == 0:
> +            if first_unsized_src is None:
> +               first_unsized_src = src
>                continue
>
> -            src_type_bits = type_bits(nir_op.input_types[i])
> -            if src_type_bits != 0:
> -               assert src_class == src_type_bits
> +            if self.is_search:
> +               self.unify_bit_size(first_unsized_src, src,
> +                  lambda first_unsized_src_bit_size, src_bit_size:
> +                     'Source {} of {} must have bit size {}, while source 
> {} ' \
> +                     'must have incompatible bit size {}'.format(
> +                        first_unsized_src, val, first_unsized_src_bit_size,
> +                        src, src_bit_size))
>             else:
> -               assert val.common_class == 0 or src_class == val.common_class
> -               val.common_class = src_class
> -
> -         dst_type_bits = type_bits(nir_op.output_type)
> -         if dst_type_bits != 0:
> -            assert val.bit_size == 0 or val.bit_size == dst_type_bits
> -            return dst_type_bits
> +               self.unify_bit_size(first_unsized_src, src,
> +                  lambda first_unsized_src_bit_size, src_bit_size:
> +                     'Sources {} (bit size of {}) and {} (bit size of {}) ' \
> +                     'of {} may not have the same bit size when building 
> the ' \
> +                     'replacement expression.'.format(
> +                        first_unsized_src, first_unsized_src_bit_size, src,
> +                        src_bit_size, val))
>          else:
> -            if val.common_class != 0:
> -               assert val.bit_size == 0 or val.bit_size == val.common_class
> +            if self.is_search:
> +               self.unify_bit_size(src, src_type_bits,
> +                  lambda src_bit_size, unused:
> +                     '{} must have {} bits, but as a source of nir_op_{} '\
> +                     'it must have {} bits'.format(
> +                        src, src_bit_size, nir_op.name, src_type_bits))
> +            else:
> +               self.unify_bit_size(src, src_type_bits,
> +                  lambda src_bit_size, unused:
> +                     '{} has the bit size of {}, but as a source of ' \
> +                     'nir_op_{} it must have {} bits, which may not be the ' \
> +                     'same'.format(
> +                        src, src_bit_size, nir_op.name, src_type_bits))
> +
> +      if dst_type_bits == 0:
> +         if first_unsized_src is not None:
> +            if self.is_search:
> +               self.unify_bit_size(val, first_unsized_src,
> +                  lambda val_bit_size, src_bit_size:
> +                     '{} must have the bit size of {}, while its source {} ' \
> +                     'must have incompatible bit size {}'.format(
> +                        val, val_bit_size, first_unsized_src, src_bit_size))
>             else:
> -               val.common_class = val.bit_size
> -            return val.common_class
> +               self.unify_bit_size(val, first_unsized_src,
> +                  lambda val_bit_size, src_bit_size:
> +                     '{} must have {} bits, but its source {} ' \
> +                     '(bit size of {}) may not have that bit size ' \
> +                     'when building the replacement.'.format(
> +                        val, val_bit_size, first_unsized_src, src_bit_size))
> +      else:
> +         self.unify_bit_size(val, dst_type_bits,
> +            lambda dst_bit_size, unused:
> +               '{} must have {} bits, but as a destination of nir_op_{} ' \
> +               'it must have {} bits'.format(
> +                  val, dst_bit_size, nir_op.name, dst_type_bits))
> +
> +   def validate_replace(self, val, search):
> +      bit_size = val.get_bit_size()
> +      assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
> +            bit_size == search.get_bit_size(), \
> +            'Ambiguous bit size for replacement value {}: ' \
> +            'it cannot be deduced from a variable, a fixed bit size ' \
> +            'somewhere, or the search expression.'.format(val)
> +
> +      if isinstance(val, Expression):
> +         for src in val.sources:
> +            self.validate_replace(src, search)
>
> -   def _validate_bit_class_down(self, val, bit_class):
> -      # At this point, everything *must* have a bit class.  Otherwise, we have
> -      # a value we don't know how to define.
> -      assert bit_class != 0
> +   def validate(self, search, replace):
> +      self.is_search = True
> +      self.merge_variables(search)
> +      self.merge_variables(replace)
> +      self.validate_value(search)
>
> -      if isinstance(val, Constant):
> -         assert val.bit_size == 0 or val.bit_size == bit_class
> +      self.is_search = False
> +      self.validate_value(replace)
>
> -      elif isinstance(val, Variable):
> -         assert val.bit_size == 0 or val.bit_size == bit_class
> +      # Check that search is always more specialized than replace. Note that
> +      # we're doing this in replace mode, disallowing merging variables.
> +      search_bit_size = search.get_bit_size()
> +      replace_bit_size = replace.get_bit_size()
> +      cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
>
> -      elif isinstance(val, Expression):
> -         nir_op = opcodes[val.opcode]
> -         dst_type_bits = type_bits(nir_op.output_type)
> -         if dst_type_bits != 0:
> -            assert bit_class == dst_type_bits
> -         else:
> -            assert val.common_class == 0 or val.common_class == bit_class
> -            val.common_class = bit_class
> +      assert cmp_result is not None and cmp_result <= 0, \
> +         'The search expression bit size {} and replace expression ' \
> +         'bit size {} may not be the same'.format(
> +               search_bit_size, replace_bit_size)
>
> -         for i in range(nir_op.num_inputs):
> -            src_type_bits = type_bits(nir_op.input_types[i])
> -            if src_type_bits != 0:
> -               self._validate_bit_class_down(val.sources[i], src_type_bits)
> -            else:
> -               self._validate_bit_class_down(val.sources[i], val.common_class)
> +      replace.set_bit_size(search)
> +
> +      self.validate_replace(replace, search)
>
> _optimization_ids = itertools.count()
>
> diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
> index 0270302fd3d..a41fca876d5 100644
> --- a/src/compiler/nir/nir_search.c
> +++ b/src/compiler/nir/nir_search.c
> @@ -118,7 +118,7 @@ match_value(const nir_search_value *value, 
> nir_alu_instr *instr, unsigned src,
>       new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
>
>    /* If the value has a specific bit size and it doesn't match, bail */
> -   if (value->bit_size &&
> +   if (value->bit_size > 0 &&
>        nir_src_bit_size(instr->src[src].src) != value->bit_size)
>       return false;
>
> @@ -228,7 +228,7 @@ match_expression(const nir_search_expression *expr, 
> nir_alu_instr *instr,
>
>    assert(instr->dest.dest.is_ssa);
>
> -   if (expr->value.bit_size &&
> +   if (expr->value.bit_size > 0 &&
>        instr->dest.dest.ssa.bit_size != expr->value.bit_size)
>       return false;
>
> @@ -290,128 +290,21 @@ match_expression(const nir_search_expression *expr, 
> nir_alu_instr *instr,
>    }
> }
>
> -typedef struct bitsize_tree {
> -   unsigned num_srcs;
> -   struct bitsize_tree *srcs[4];
> -
> -   unsigned common_size;
> -   bool is_src_sized[4];
> -   bool is_dest_sized;
> -
> -   unsigned dest_size;
> -   unsigned src_size[4];
> -} bitsize_tree;
> -
> -static bitsize_tree *
> -build_bitsize_tree(void *mem_ctx, struct match_state *state,
> -                   const nir_search_value *value)
> -{
> -   bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree);
> -
> -   switch (value->type) {
> -   case nir_search_value_expression: {
> -      nir_search_expression *expr = nir_search_value_as_expression(value);
> -      nir_op_info info = nir_op_infos[expr->opcode];
> -      tree->num_srcs = info.num_inputs;
> -      tree->common_size = 0;
> -      for (unsigned i = 0; i < info.num_inputs; i++) {
> -         tree->is_src_sized[i] = 
> !!nir_alu_type_get_type_size(info.input_types[i]);
> -         if (tree->is_src_sized[i])
> -            tree->src_size[i] = 
> nir_alu_type_get_type_size(info.input_types[i]);
> -         tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]);
> -      }
> -      tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);
> -      if (tree->is_dest_sized)
> -         tree->dest_size = nir_alu_type_get_type_size(info.output_type);
> -      break;
> -   }
> -
> -   case nir_search_value_variable: {
> -      nir_search_variable *var = nir_search_value_as_variable(value);
> -      tree->num_srcs = 0;
> -      tree->is_dest_sized = true;
> -      tree->dest_size = nir_src_bit_size(state->variables[var->variable].src);
> -      break;
> -   }
> -
> -   case nir_search_value_constant: {
> -      tree->num_srcs = 0;
> -      tree->is_dest_sized = false;
> -      tree->common_size = 0;
> -      break;
> -   }
> -   }
> -
> -   if (value->bit_size) {
> -      assert(!tree->is_dest_sized || tree->dest_size == value->bit_size);
> -      tree->common_size = value->bit_size;
> -   }
> -
> -   return tree;
> -}
> -
> static unsigned
> -bitsize_tree_filter_up(bitsize_tree *tree)
> +replace_bitsize(const nir_search_value *value, unsigned search_bitsize,
> +                struct match_state *state)
> {
> -   for (unsigned i = 0; i < tree->num_srcs; i++) {
> -      unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]);
> -      if (src_size == 0)
> -         continue;
> -
> -      if (tree->is_src_sized[i]) {
> -         assert(src_size == tree->src_size[i]);
> -      } else if (tree->common_size != 0) {
> -         assert(src_size == tree->common_size);
> -         tree->src_size[i] = src_size;
> -      } else {
> -         tree->common_size = src_size;
> -         tree->src_size[i] = src_size;
> -      }
> -   }
> -
> -   if (tree->num_srcs && tree->common_size) {
> -      if (tree->dest_size == 0)
> -         tree->dest_size = tree->common_size;
> -      else if (!tree->is_dest_sized)
> -         assert(tree->dest_size == tree->common_size);
> -
> -      for (unsigned i = 0; i < tree->num_srcs; i++) {
> -         if (!tree->src_size[i])
> -            tree->src_size[i] = tree->common_size;
> -      }
> -   }
> -
> -   return tree->dest_size;
> -}
> -
> -static void
> -bitsize_tree_filter_down(bitsize_tree *tree, unsigned size)
> -{
> -   if (tree->dest_size)
> -      assert(tree->dest_size == size);
> -   else
> -      tree->dest_size = size;
> -
> -   if (!tree->is_dest_sized) {
> -      if (tree->common_size)
> -         assert(tree->common_size == size);
> -      else
> -         tree->common_size = size;
> -   }
> -
> -   for (unsigned i = 0; i < tree->num_srcs; i++) {
> -      if (!tree->src_size[i]) {
> -         assert(tree->common_size);
> -         tree->src_size[i] = tree->common_size;
> -      }
> -      bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]);
> -   }
> +   if (value->bit_size > 0)
> +      return value->bit_size;
> +   if (value->bit_size < 0)
> +      return nir_src_bit_size(state->variables[-value->bit_size - 1].src);
> +   return search_bitsize;
> }
>
> static nir_alu_src
> construct_value(nir_builder *build,
>                 const nir_search_value *value,
> -                unsigned num_components, bitsize_tree *bitsize,
> +                unsigned num_components, unsigned search_bitsize,
>                 struct match_state *state,
>                 nir_instr *instr)
> {
> @@ -424,7 +317,7 @@ construct_value(nir_builder *build,
>
>       nir_alu_instr *alu = nir_alu_instr_create(build->shader, expr->opcode);
>       nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
> -                        bitsize->dest_size, NULL);
> +                        replace_bitsize(value, search_bitsize, state), NULL);
>       alu->dest.write_mask = (1 << num_components) - 1;
>       alu->dest.saturate = false;
>
> @@ -443,7 +336,7 @@ construct_value(nir_builder *build,
>             num_components = nir_op_infos[alu->op].input_sizes[i];
>
>          alu->src[i] = construct_value(build, expr->srcs[i],
> -                                       num_components, bitsize->srcs[i],
> +                                       num_components, search_bitsize,
>                                        state, instr);
>       }
>
> @@ -472,16 +365,17 @@ construct_value(nir_builder *build,
>
>    case nir_search_value_constant: {
>       const nir_search_constant *c = nir_search_value_as_constant(value);
> +      unsigned bit_size = replace_bitsize(value, search_bitsize, state);
>
>       nir_ssa_def *cval;
>       switch (c->type) {
>       case nir_type_float:
> -         cval = nir_imm_floatN_t(build, c->data.d, bitsize->dest_size);
> +         cval = nir_imm_floatN_t(build, c->data.d, bit_size);
>          break;
>
>       case nir_type_int:
>       case nir_type_uint:
> -         cval = nir_imm_intN_t(build, c->data.i, bitsize->dest_size);
> +         cval = nir_imm_intN_t(build, c->data.i, bit_size);
>          break;
>
>       case nir_type_bool:
> @@ -526,16 +420,12 @@ nir_replace_instr(nir_builder *build, nir_alu_instr 
> *instr,
>                          swizzle, &state))
>       return NULL;
>
> -   void *bitsize_ctx = ralloc_context(NULL);
> -   bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace);
> -   bitsize_tree_filter_up(tree);
> -   bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size);
> -
>    build->cursor = nir_before_instr(&instr->instr);
>
>    nir_alu_src val = construct_value(build, replace,
>                                      instr->dest.dest.ssa.num_components,
> -                                     tree, &state, &instr->instr);
> +                                     instr->dest.dest.ssa.bit_size,
> +                                     &state, &instr->instr);
>
>    /* Inserting a mov may be unnecessary.  However, it's much easier to
>     * simply let copy propagation clean this up than to try to go through
> @@ -551,7 +441,5 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
>     */
>    nir_instr_remove(&instr->instr);
>
> -   ralloc_free(bitsize_ctx);
> -
>    return ssa_val;
> }
> diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h
> index df4189ede74..a76f39e0f40 100644
> --- a/src/compiler/nir/nir_search.h
> +++ b/src/compiler/nir/nir_search.h
> @@ -43,7 +43,22 @@ typedef enum {
> typedef struct {
>    nir_search_value_type type;
>
> -   unsigned bit_size;
> +   /**
> +    * Bit size of the value. It is interpreted as follows:
> +    *
> +    * For a search expression:
> +    * - If bit_size > 0, then the value only matches an SSA value with the
> +    *   given bit size.
> +    * - If bit_size <= 0, then the value matches any size SSA value.
> +    *
> +    * For a replace expression:
> +    * - If bit_size > 0, then the value is constructed with the given bit 
> size.
> +    * - If bit_size == 0, then the value is constructed with the same bit size
> +    *   as the search value.
> +    * - If bit_size < 0, then the value is constructed with the same bit size
> +    *   as variable (-bit_size - 1).
> +    */
> +   int bit_size;
> } nir_search_value;
>
> typedef struct {
> --
> 2.17.2





More information about the mesa-dev mailing list