Mesa (master): nir/algebraic: Add a bit-size validator

Jason Ekstrand jekstrand at kemper.freedesktop.org
Wed Apr 27 18:21:11 UTC 2016


Module: Mesa
Branch: master
Commit: e0806930ad2406d611a0d2fa1d3420a74122921c
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=e0806930ad2406d611a0d2fa1d3420a74122921c

Author: Jason Ekstrand <jason.ekstrand at intel.com>
Date:   Mon Apr 25 20:58:47 2016 -0700

nir/algebraic: Add a bit-size validator

This commit adds a validator that ensures that all expressions passed
through nir_algebraic are 100% non-ambiguous as far as bit-sizes are
concerned.  This way it's a compile-time error rather than a hard-to-trace
C exception some time later.

Reviewed-by: Samuel Iglesias Gonsálvez <siglesias at igalia.com>

---

 src/compiler/nir/nir_algebraic.py | 270 ++++++++++++++++++++++++++++++++++++++
 1 file changed, 270 insertions(+)

diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
index 35f6597..1fc0289 100644
--- a/src/compiler/nir/nir_algebraic.py
+++ b/src/compiler/nir/nir_algebraic.py
@@ -33,6 +33,19 @@ import mako.template
 import re
 import traceback
 
+from nir_opcodes import opcodes
+
+_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
+
+def type_bits(type_str):
+   m = _type_re.match(type_str)
+   assert m.group('type')
+
+   if m.group('bits') is None:
+      return 0
+   else:
+      return int(m.group('bits'))
+
 # Represents a set of variables, each with a unique id
 class VarSet(object):
    def __init__(self):
@@ -188,6 +201,261 @@ class Expression(Value):
       srcs = "\n".join(src.render() for src in self.sources)
       return srcs + super(Expression, self).render()
 
+class IntEquivalenceRelation(object):
+   """A class representing an equivalence relation on integers.
+
+   Each integer has a canonical form which is the maximum integer to which it
+   is equivalent.  Two integers are equivalent precisely when they have the
+   same canonical form.
+
+   The convention of maximum is explicitly chosen to make using it in
+   BitSizeValidator easier because it means that an actual bit_size (if any)
+   will always be the canonical form.
+   """
+   def __init__(self):
+      self._remap = {}
+
+   def get_canonical(self, x):
+      """Get the canonical integer corresponding to x."""
+      if x in self._remap:
+         return self.get_canonical(self._remap[x])
+      else:
+         return x
+
+   def add_equiv(self, a, b):
+      """Add an equivalence and return the canonical form."""
+      c = max(self.get_canonical(a), self.get_canonical(b))
+      if a != c:
+         assert a < c
+         self._remap[a] = c
+
+      if b != c:
+         assert b < c
+         self._remap[b] = c
+
+      return c
+
+class BitSizeValidator(object):
+   """A class for validating bit sizes of expressions.
+
+   NIR supports multiple bit-sizes on expressions in order to handle things
+   such as fp64.  The source and destination of every ALU operation is
+   assigned a type and that type may or may not specify a bit size.  Sources
+   and destinations whose type does not specify a bit size are considered
+   "unsized" and automatically take on the bit size of the corresponding
+   register or SSA value.  NIR has two simple rules for bit sizes that are
+   validated by nir_validator:
+
+    1) A given SSA def or register has a single bit size that is respected by
+       everything that reads from it or writes to it.
+
+    2) The bit sizes of all unsized inputs/outputs on any given ALU
+       instruction must match.  They need not match the sized inputs or
+       outputs but they must match each other.
+
+   In order to keep nir_algebraic relatively simple and easy-to-use,
+   nir_search supports a type of bit-size inference based on the two rules
+   above.  This is similar to type inference in many common programming
+   languages.  If, for instance, you are constructing an add operation and you
+   know the second source is 16-bit, then you know that the other source and
+   the destination must also be 16-bit.  There are, however, cases where this
+   inference can be ambiguous or contradictory.  Consider, for instance, the
+   following transformation:
+
+   (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
+
+   This transformation can potentially cause a problem because usub_borrow is
+   well-defined for any bit-size of integer.  However, b2i always generates a
+   32-bit result so it could end up replacing a 64-bit expression with one
+   that takes two 64-bit values and produces a 32-bit value.  As another
+   example, consider this expression:
+
+   (('bcsel', a, b, 0), ('iand', a, b))
+
+   In this case, in the search expression a must be 32-bit but b can
+   potentially have any bit size.  If we had a 64-bit b value, we would end up
+   trying to and a 32-bit value with a 64-bit value which would be invalid
+
+   This class solves that problem by providing a validation layer that proves
+   that a given search-and-replace operation is 100% well-defined before we
+   generate any code.  This ensures that bugs are caught at compile time
+   rather than at run time.
+
+   The basic operation of the validator is very similar to the bitsize_tree in
+   nir_search only a little more subtle.  Instead of simply tracking bit
+   sizes, it tracks "bit classes" where each class is represented by an
+   integer.  A value of 0 means we don't know anything yet, positive values
+   are actual bit-sizes, and negative values are used to track equivalence
+   classes of sizes that must be the same but have yet to receive an actual
+   size.  The first stage uses the bitsize_tree algorithm to assign bit
+   classes to each variable.  If it ever comes across an inconsistency, it
+   assert-fails.  Then the second stage uses that information to prove that
+   the resulting expression can always validly be constructed.
+   """
+
+   def __init__(self, varset):
+      self._num_classes = 0
+      self._var_classes = [0] * len(varset.names)
+      self._class_relation = IntEquivalenceRelation()
+
+   def validate(self, search, replace):
+      dst_class = self._propagate_bit_size_up(search)
+      if dst_class == 0:
+         dst_class = self._new_class()
+      self._propagate_bit_class_down(search, dst_class)
+
+      validate_dst_class = self._validate_bit_class_up(replace)
+      assert validate_dst_class == 0 or validate_dst_class == dst_class
+      self._validate_bit_class_down(replace, dst_class)
+
+   def _new_class(self):
+      self._num_classes += 1
+      return -self._num_classes
+
+   def _set_var_bit_class(self, var_id, bit_class):
+      assert bit_class != 0
+      var_class = self._var_classes[var_id]
+      if var_class == 0:
+         self._var_classes[var_id] = bit_class
+      else:
+         canon_class = self._class_relation.get_canonical(var_class)
+         assert canon_class < 0 or canon_class == bit_class
+         var_class = self._class_relation.add_equiv(var_class, bit_class)
+         self._var_classes[var_id] = var_class
+
+   def _get_var_bit_class(self, var_id):
+      return self._class_relation.get_canonical(self._var_classes[var_id])
+
+   def _propagate_bit_size_up(self, val):
+      if isinstance(val, (Constant, Variable)):
+         return val.bit_size
+
+      elif isinstance(val, Expression):
+         nir_op = opcodes[val.opcode]
+         val.common_size = 0
+         for i in range(nir_op.num_inputs):
+            src_bits = self._propagate_bit_size_up(val.sources[i])
+            if src_bits == 0:
+               continue
+
+            src_type_bits = type_bits(nir_op.input_types[i])
+            if src_type_bits != 0:
+               assert src_bits == src_type_bits
+            else:
+               assert val.common_size == 0 or src_bits == val.common_size
+               val.common_size = src_bits
+
+         dst_type_bits = type_bits(nir_op.output_type)
+         if dst_type_bits != 0:
+            assert val.bit_size == 0 or val.bit_size == dst_type_bits
+            return dst_type_bits
+         else:
+            if val.common_size != 0:
+               assert val.bit_size == 0 or val.bit_size == val.common_size
+            else:
+               val.common_size = val.bit_size
+            return val.common_size
+
+   def _propagate_bit_class_down(self, val, bit_class):
+      if isinstance(val, Constant):
+         assert val.bit_size == 0 or val.bit_size == bit_class
+
+      elif isinstance(val, Variable):
+         assert val.bit_size == 0 or val.bit_size == bit_class
+         self._set_var_bit_class(val.index, bit_class)
+
+      elif isinstance(val, Expression):
+         nir_op = opcodes[val.opcode]
+         dst_type_bits = type_bits(nir_op.output_type)
+         if dst_type_bits != 0:
+            assert bit_class == 0 or bit_class == dst_type_bits
+         else:
+            assert val.common_size == 0 or val.common_size == bit_class
+            val.common_size = bit_class
+
+         if val.common_size:
+            common_class = val.common_size
+         elif nir_op.num_inputs:
+            # If we got here then we have no idea what the actual size is.
+            # Instead, we use a generic class
+            common_class = self._new_class()
+
+         for i in range(nir_op.num_inputs):
+            src_type_bits = type_bits(nir_op.input_types[i])
+            if src_type_bits != 0:
+               self._propagate_bit_class_down(val.sources[i], src_type_bits)
+            else:
+               self._propagate_bit_class_down(val.sources[i], common_class)
+
+   def _validate_bit_class_up(self, val):
+      if isinstance(val, Constant):
+         return val.bit_size
+
+      elif isinstance(val, Variable):
+         var_class = self._get_var_bit_class(val.index)
+         # By the time we get to validation, every variable should have a class
+         assert var_class != 0
+
+         # If we have an explicit size provided by the user, the variable
+         # *must* exactly match the search.  It cannot be implicitly sized
+         # because otherwise we could end up with a conflict at runtime.
+         assert val.bit_size == 0 or val.bit_size == var_class
+
+         return var_class
+
+      elif isinstance(val, Expression):
+         nir_op = opcodes[val.opcode]
+         val.common_class = 0
+         for i in range(nir_op.num_inputs):
+            src_class = self._validate_bit_class_up(val.sources[i])
+            if src_class == 0:
+               continue
+
+            src_type_bits = type_bits(nir_op.input_types[i])
+            if src_type_bits != 0:
+               assert src_class == src_type_bits
+            else:
+               assert val.common_class == 0 or src_class == val.common_class
+               val.common_class = src_class
+
+         dst_type_bits = type_bits(nir_op.output_type)
+         if dst_type_bits != 0:
+            assert val.bit_size == 0 or val.bit_size == dst_type_bits
+            return dst_type_bits
+         else:
+            if val.common_class != 0:
+               assert val.bit_size == 0 or val.bit_size == val.common_class
+            else:
+               val.common_class = val.bit_size
+            return val.common_class
+
+   def _validate_bit_class_down(self, val, bit_class):
+      # At this point, everything *must* have a bit class.  Otherwise, we have
+      # a value we don't know how to define.
+      assert bit_class != 0
+
+      if isinstance(val, Constant):
+         assert val.bit_size == 0 or val.bit_size == bit_class
+
+      elif isinstance(val, Variable):
+         assert val.bit_size == 0 or val.bit_size == bit_class
+
+      elif isinstance(val, Expression):
+         nir_op = opcodes[val.opcode]
+         dst_type_bits = type_bits(nir_op.output_type)
+         if dst_type_bits != 0:
+            assert bit_class == dst_type_bits
+         else:
+            assert val.common_class == 0 or val.common_class == bit_class
+            val.common_class = bit_class
+
+         for i in range(nir_op.num_inputs):
+            src_type_bits = type_bits(nir_op.input_types[i])
+            if src_type_bits != 0:
+               self._validate_bit_class_down(val.sources[i], src_type_bits)
+            else:
+               self._validate_bit_class_down(val.sources[i], val.common_class)
+
 _optimization_ids = itertools.count()
 
 condition_list = ['true']
@@ -220,6 +488,8 @@ class SearchAndReplace(object):
       else:
          self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
 
+      BitSizeValidator(varset).validate(self.search, self.replace)
+
 _algebraic_pass_template = mako.template.Template("""
 #include "nir.h"
 #include "nir_search.h"




More information about the mesa-commit mailing list