[Mesa-dev] [PATCH 6/6] nir/algebraic: Add a bit-size validator

Samuel Iglesias Gonsálvez siglesias at igalia.com
Wed Apr 27 06:20:12 UTC 2016



On 26/04/16 06:39, Jason Ekstrand wrote:
> This commit adds a validator that ensures that all expressions passed
> through nir_algebraic are 100% non-ambiguous as far as bit-sizes are
> concerned.  This way it's a compile-time error rather than a hard-to-trace
> C exception some time later.
> ---
>  src/compiler/nir/nir_algebraic.py | 270 ++++++++++++++++++++++++++++++++++++++
>  1 file changed, 270 insertions(+)
> 
> diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
> index e9b5832..503371b 100644
> --- a/src/compiler/nir/nir_algebraic.py
> +++ b/src/compiler/nir/nir_algebraic.py
> @@ -33,6 +33,19 @@ import mako.template
>  import re
>  import traceback
>  
> +from nir_opcodes import opcodes
> +
> +_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
> +
> +def type_bits(type_str):
> +   m = _type_re.match(type_str)
> +   assert m.group('type')
> +
> +   if m.group('bits') is None:
> +      return 0
> +   else:
> +      return int(m.group('bits'))
> +
>  # Represents a set of variables, each with a unique id
>  class VarSet(object):
>     def __init__(self):
> @@ -188,6 +201,261 @@ class Expression(Value):
>        srcs = "\n".join(src.render() for src in self.sources)
>        return srcs + super(Expression, self).render()
>  
> +class IntEquivalenceRelation(object):
> +   """A class representing an equivalence relation on integers.
> +
> +   Each integer has a cannonical form which is the maximum integer to which it
> +   is equivalent.  Two integers are equivalent precicely when they have the

precisely

> +   same cannonical form.
> +

canonical. This typo is repeated in the rest of the patch.

> +   The convention of maximum is explicitly chosen to make using it in
> +   BitSizeValidator easier because it means that an actual bit_size (if any)
> +   will always be the cannonical form.
> +   """
> +   def __init__(self):
> +      self._remap = {}
> +
> +   def get_cannonical(self, x):
> +      """Get the cannonical integer corresponding to x."""
> +      if x in self._remap:
> +         return self.get_cannonical(self._remap[x])
> +      else:
> +         return x
> +
> +   def add_equiv(self, a, b):
> +      """Add an equivalence and return the cannonical form."""
> +      c = max(self.get_cannonical(a), self.get_cannonical(b))
> +      if a != c:
> +         assert a < c
> +         self._remap[a] = c
> +
> +      if b != c:
> +         assert b < c
> +         self._remap[b] = c
> +
> +      return c
> +
> +class BitSizeValidator(object):
> +   """A class for validating bit sizes of expressions.
> +
> +   NIR supports multiple bit-sizes on expressions in order to handle things
> +   such as fp64.  The source and destination of every ALU operation is
> +   assigned a type and that type may or may not specify a bit size.  Sources
> +   and destinations whose type does not specify a bit size are considered
> +   "unsized" and automatically take on the bit size of the corresponding
> +   register or SSA value.  NIR has two simple rules for bit sizes that are
> +   validated by nir_validator:
> +
> +    1) A given SSA def or register has a single bit size that is respected by
> +       everything that reads from it or writes to it.
> +
> +    2) The bit sizes of all unsized inputs/outputs on any given ALU
> +       instruction must match.  They need not match the sized inputs or
> +       outputs but they must match each other.
> +
> +   In order to keep nir_algebraic relatively simple and easy-to-use,
> +   nir_search supports a type of bit-size inference based on the two rules
> +   above.  This is similar to type inference in many common programming
> +   languages.  If, for instance, you are constructing an add operation and you
> +   know the second source is 16-bit, then you know that the other source and
> +   the destination must also be 16-bit.  There are, however, cases where this
> +   inference can be ambiguous or contradictory.  Consider, for instance, the
> +   following transformation:
> +
> +   (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
> +
> +   This transformation can potentiall cause a problem because usub_borrow is

potentially

> +   well-defined for any bit-size of integer.  However, b2i always generates a
> +   32-bit result so it could end up replacing a 64-bit expression with one
> +   that takes two 64-bit values and produces a 32-bit value.  As another
> +   example, consider this expression:
> +
> +   (('bcsel', a, b, 0), ('iand', a, b))
> +
> +   In this case, in the search expression a must be 32-bit but b can
> +   potentially have any bit size.  If we had a 64-bit b value, we would end up
> +   trying to and a 32-bit value with a 64-bit value which would be invalid
> +
> +   This class solves that problem by providing a validation layer that proves
> +   that a given search-and-replace operation is 100% well-defined before we
> +   generate any code.  This ensures that bugs are caught at compile time
> +   rather than at run time.
> +
> +   The basic operation of the validator is very similar to the bitsize_tree in
> +   nir_search only a little more subtle.  Instead of simply tracking bit
> +   sizes, it tracks "bit classes" where each class is represented by an
> +   integer.  A value of 0 means we don't know anything yet, positive values
> +   are actual bit-sizes, and negative values are used to track equivalence
> +   classes of sizes that must be the same but have yet to recieve an actual

receive

> +   size.  The first stage uses the bitsize_tree algorithm to assign bit
> +   classes to each variable.  If it ever comes across an inconsistency, it
> +   assert-fails.  Then the second stage uses that information to prove that
> +   the resulting expression can always validly be constructed.
> +   """
> +
> +   def __init__(self, varset):
> +      self._num_classes = 0
> +      self._var_classes = [0] * len(varset.names)
> +      self._class_relation = IntEquivalenceRelation()
> +
> +   def validate(self, search, replace):
> +      dst_class = self._propagate_bit_size_up(search)
> +      if dst_class == 0:
> +         dst_class = self._new_class()
> +      self._propagate_bit_class_down(search, dst_class)
> +
> +      validate_dst_class = self._validate_bit_class_up(replace)
> +      assert validate_dst_class == 0 or validate_dst_class == dst_class
> +      self._validate_bit_class_down(replace, dst_class)
> +
> +   def _new_class(self):
> +      self._num_classes += 1
> +      return -self._num_classes
> +
> +   def _set_var_bit_class(self, var_id, bit_class):
> +      assert bit_class != 0
> +      var_class = self._var_classes[var_id]
> +      if var_class == 0:
> +         self._var_classes[var_id] = bit_class
> +      else:
> +         cannon_class = self._class_relation.get_cannonical(var_class)

canon_class??

> +         assert cannon_class < 0 or cannon_class == bit_class
> +         var_class = self._class_relation.add_equiv(var_class, bit_class)
> +         self._var_classes[var_id] = var_class
> +
> +   def _get_var_bit_class(self, var_id):
> +      return self._class_relation.get_cannonical(self._var_classes[var_id])
> +
> +   def _propagate_bit_size_up(self, val):
> +      if isinstance(val, (Constant, Variable)):
> +         return val.bit_size
> +
> +      elif isinstance(val, Expression):
> +         nir_op = opcodes[val.opcode]
> +         val.common_size = 0
> +         for i in range(nir_op.num_inputs):
> +            src_bits = self._propagate_bit_size_up(val.sources[i])
> +            if src_bits == 0:
> +               continue
> +
> +            src_type_bits = type_bits(nir_op.input_types[i])
> +            if src_type_bits != 0:
> +               assert src_bits == src_type_bits
> +            else:
> +               assert val.common_size == 0 or src_bits == val.common_size
> +               val.common_size = src_bits
> +
> +         dst_type_bits = type_bits(nir_op.output_type)
> +         if dst_type_bits != 0:
> +            assert val.bit_size == 0 or val.bit_size == dst_type_bits
> +            return dst_type_bits
> +         else:
> +            if val.common_size != 0:
> +               assert val.bit_size == 0 or val.bit_size == val.common_size
> +            else:
> +               val.common_size = val.bit_size
> +            return val.common_size
> +
> +   def _propagate_bit_class_down(self, val, bit_class):
> +      if isinstance(val, Constant):
> +         assert val.bit_size == 0 or val.bit_size == bit_class
> +
> +      elif isinstance(val, Variable):
> +         assert val.bit_size == 0 or val.bit_size == bit_class
> +         self._set_var_bit_class(val.index, bit_class)
> +
> +      elif isinstance(val, Expression):
> +         nir_op = opcodes[val.opcode]
> +         dst_type_bits = type_bits(nir_op.output_type)
> +         if dst_type_bits != 0:
> +            assert bit_class == 0 or bit_class == dst_type_bits
> +         else:
> +            assert val.common_size == 0 or val.common_size == bit_class
> +            val.common_size = bit_class
> +
> +         if val.common_size:
> +            common_class = val.common_size
> +         elif nir_op.num_inputs:
> +            # If we got here then we have no idea what the actual size is.
> +            # Instead, we use a generic class
> +            common_class = self._new_class()
> +
> +         for i in range(nir_op.num_inputs):
> +            src_type_bits = type_bits(nir_op.input_types[i])
> +            if src_type_bits != 0:
> +               self._propagate_bit_class_down(val.sources[i], src_type_bits)
> +            else:
> +               self._propagate_bit_class_down(val.sources[i], common_class)
> +
> +   def _validate_bit_class_up(self, val):
> +      if isinstance(val, Constant):
> +         return val.bit_size
> +
> +      elif isinstance(val, Variable):
> +         var_class = self._get_var_bit_class(val.index)
> +         # By the time we get to validation, every variable should have a class
> +         assert var_class != 0
> +
> +         # If we have an explicit size provided by the user, the variable
> +         # *must* exactly match the search.  It cannot be implicitly sized
> +         # because otherwise we could end up with a conflict at runtime.
> +         assert val.bit_size == 0 or val.bit_size == var_class
> +
> +         return var_class
> +
> +      elif isinstance(val, Expression):
> +         nir_op = opcodes[val.opcode]
> +         val.common_class = 0
> +         for i in range(nir_op.num_inputs):
> +            src_class = self._validate_bit_class_up(val.sources[i])
> +            if src_class == 0:
> +               continue
> +
> +            src_type_bits = type_bits(nir_op.input_types[i])
> +            if src_type_bits != 0:
> +               assert src_class == src_type_bits
> +            else:
> +               assert val.common_class == 0 or src_class == val.common_class
> +               val.common_class = src_class
> +
> +         dst_type_bits = type_bits(nir_op.output_type)
> +         if dst_type_bits != 0:
> +            assert val.bit_size == 0 or val.bit_size == dst_type_bits
> +            return dst_type_bits
> +         else:
> +            if val.common_class != 0:
> +               assert val.bit_size == 0 or val.bit_size == val.common_class
> +            else:
> +               val.common_class = val.bit_size
> +            return val.common_class
> +
> +   def _validate_bit_class_down(self, val, bit_class):
> +      # At this point, everthing *must* have a bit class.  Otherwise, we have

everything

Other than that,

Reviewed-by: Samuel Iglesias Gonsálvez <siglesias at igalia.com>

Sam

> +      # a value we don't know how to define.
> +      assert bit_class != 0
> +
> +      if isinstance(val, Constant):
> +         assert val.bit_size == 0 or val.bit_size == bit_class
> +
> +      elif isinstance(val, Variable):
> +         assert val.bit_size == 0 or val.bit_size == bit_class
> +
> +      elif isinstance(val, Expression):
> +         nir_op = opcodes[val.opcode]
> +         dst_type_bits = type_bits(nir_op.output_type)
> +         if dst_type_bits != 0:
> +            assert bit_class == dst_type_bits
> +         else:
> +            assert val.common_class == 0 or val.common_class == bit_class
> +            val.common_class = bit_class
> +
> +         for i in range(nir_op.num_inputs):
> +            src_type_bits = type_bits(nir_op.input_types[i])
> +            if src_type_bits != 0:
> +               self._validate_bit_class_down(val.sources[i], src_type_bits)
> +            else:
> +               self._validate_bit_class_down(val.sources[i], val.common_class)
> +
>  _optimization_ids = itertools.count()
>  
>  condition_list = ['true']
> @@ -220,6 +488,8 @@ class SearchAndReplace(object):
>        else:
>           self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
>  
> +      BitSizeValidator(varset).validate(self.search, self.replace)
> +
>  _algebraic_pass_template = mako.template.Template("""
>  #include "nir.h"
>  #include "nir_search.h"
> 


More information about the mesa-dev mailing list