[Mesa-dev] [PATCH 6/6] nir/algebraic: Add a bit-size validator
Samuel Iglesias Gonsálvez
siglesias at igalia.com
Wed Apr 27 06:20:12 UTC 2016
On 26/04/16 06:39, Jason Ekstrand wrote:
> This commit adds a validator that ensures that all expressions passed
> through nir_algebraic are 100% non-ambiguous as far as bit-sizes are
> concerned. This way it's a compile-time error rather than a hard-to-trace
> C exception some time later.
> ---
> src/compiler/nir/nir_algebraic.py | 270 ++++++++++++++++++++++++++++++++++++++
> 1 file changed, 270 insertions(+)
>
> diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
> index e9b5832..503371b 100644
> --- a/src/compiler/nir/nir_algebraic.py
> +++ b/src/compiler/nir/nir_algebraic.py
> @@ -33,6 +33,19 @@ import mako.template
> import re
> import traceback
>
> +from nir_opcodes import opcodes
> +
> +_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
> +
> +def type_bits(type_str):
> + m = _type_re.match(type_str)
> + assert m.group('type')
> +
> + if m.group('bits') is None:
> + return 0
> + else:
> + return int(m.group('bits'))
> +
> # Represents a set of variables, each with a unique id
> class VarSet(object):
> def __init__(self):
> @@ -188,6 +201,261 @@ class Expression(Value):
> srcs = "\n".join(src.render() for src in self.sources)
> return srcs + super(Expression, self).render()
>
> +class IntEquivalenceRelation(object):
> + """A class representing an equivalence relation on integers.
> +
> + Each integer has a cannonical form which is the maximum integer to which it
> + is equivalent. Two integers are equivalent precicely when they have the
precisely
> + same cannonical form.
> +
canonical. This typo is repeated in the rest of the patch.
> + The convention of maximum is explicitly chosen to make using it in
> + BitSizeValidator easier because it means that an actual bit_size (if any)
> + will always be the cannonical form.
> + """
> + def __init__(self):
> + self._remap = {}
> +
> + def get_cannonical(self, x):
> + """Get the cannonical integer corresponding to x."""
> + if x in self._remap:
> + return self.get_cannonical(self._remap[x])
> + else:
> + return x
> +
> + def add_equiv(self, a, b):
> + """Add an equivalence and return the cannonical form."""
> + c = max(self.get_cannonical(a), self.get_cannonical(b))
> + if a != c:
> + assert a < c
> + self._remap[a] = c
> +
> + if b != c:
> + assert b < c
> + self._remap[b] = c
> +
> + return c
> +
> +class BitSizeValidator(object):
> + """A class for validating bit sizes of expressions.
> +
> + NIR supports multiple bit-sizes on expressions in order to handle things
> + such as fp64. The source and destination of every ALU operation is
> + assigned a type and that type may or may not specify a bit size. Sources
> + and destinations whose type does not specify a bit size are considered
> + "unsized" and automatically take on the bit size of the corresponding
> + register or SSA value. NIR has two simple rules for bit sizes that are
> + validated by nir_validator:
> +
> + 1) A given SSA def or register has a single bit size that is respected by
> + everything that reads from it or writes to it.
> +
> + 2) The bit sizes of all unsized inputs/outputs on any given ALU
> + instruction must match. They need not match the sized inputs or
> + outputs but they must match each other.
> +
> + In order to keep nir_algebraic relatively simple and easy-to-use,
> + nir_search supports a type of bit-size inference based on the two rules
> + above. This is similar to type inference in many common programming
> + languages. If, for instance, you are constructing an add operation and you
> + know the second source is 16-bit, then you know that the other source and
> + the destination must also be 16-bit. There are, however, cases where this
> + inference can be ambiguous or contradictory. Consider, for instance, the
> + following transformation:
> +
> + (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
> +
> + This transformation can potentiall cause a problem because usub_borrow is
potentially
> + well-defined for any bit-size of integer. However, b2i always generates a
> + 32-bit result so it could end up replacing a 64-bit expression with one
> + that takes two 64-bit values and produces a 32-bit value. As another
> + example, consider this expression:
> +
> + (('bcsel', a, b, 0), ('iand', a, b))
> +
> + In this case, in the search expression a must be 32-bit but b can
> + potentially have any bit size. If we had a 64-bit b value, we would end up
> + trying to and a 32-bit value with a 64-bit value which would be invalid
> +
> + This class solves that problem by providing a validation layer that proves
> + that a given search-and-replace operation is 100% well-defined before we
> + generate any code. This ensures that bugs are caught at compile time
> + rather than at run time.
> +
> + The basic operation of the validator is very similar to the bitsize_tree in
> + nir_search only a little more subtle. Instead of simply tracking bit
> + sizes, it tracks "bit classes" where each class is represented by an
> + integer. A value of 0 means we don't know anything yet, positive values
> + are actual bit-sizes, and negative values are used to track equivalence
> + classes of sizes that must be the same but have yet to recieve an actual
receive
> + size. The first stage uses the bitsize_tree algorithm to assign bit
> + classes to each variable. If it ever comes across an inconsistency, it
> + assert-fails. Then the second stage uses that information to prove that
> + the resulting expression can always validly be constructed.
> + """
> +
> + def __init__(self, varset):
> + self._num_classes = 0
> + self._var_classes = [0] * len(varset.names)
> + self._class_relation = IntEquivalenceRelation()
> +
> + def validate(self, search, replace):
> + dst_class = self._propagate_bit_size_up(search)
> + if dst_class == 0:
> + dst_class = self._new_class()
> + self._propagate_bit_class_down(search, dst_class)
> +
> + validate_dst_class = self._validate_bit_class_up(replace)
> + assert validate_dst_class == 0 or validate_dst_class == dst_class
> + self._validate_bit_class_down(replace, dst_class)
> +
> + def _new_class(self):
> + self._num_classes += 1
> + return -self._num_classes
> +
> + def _set_var_bit_class(self, var_id, bit_class):
> + assert bit_class != 0
> + var_class = self._var_classes[var_id]
> + if var_class == 0:
> + self._var_classes[var_id] = bit_class
> + else:
> + cannon_class = self._class_relation.get_cannonical(var_class)
canon_class??
> + assert cannon_class < 0 or cannon_class == bit_class
> + var_class = self._class_relation.add_equiv(var_class, bit_class)
> + self._var_classes[var_id] = var_class
> +
> + def _get_var_bit_class(self, var_id):
> + return self._class_relation.get_cannonical(self._var_classes[var_id])
> +
> + def _propagate_bit_size_up(self, val):
> + if isinstance(val, (Constant, Variable)):
> + return val.bit_size
> +
> + elif isinstance(val, Expression):
> + nir_op = opcodes[val.opcode]
> + val.common_size = 0
> + for i in range(nir_op.num_inputs):
> + src_bits = self._propagate_bit_size_up(val.sources[i])
> + if src_bits == 0:
> + continue
> +
> + src_type_bits = type_bits(nir_op.input_types[i])
> + if src_type_bits != 0:
> + assert src_bits == src_type_bits
> + else:
> + assert val.common_size == 0 or src_bits == val.common_size
> + val.common_size = src_bits
> +
> + dst_type_bits = type_bits(nir_op.output_type)
> + if dst_type_bits != 0:
> + assert val.bit_size == 0 or val.bit_size == dst_type_bits
> + return dst_type_bits
> + else:
> + if val.common_size != 0:
> + assert val.bit_size == 0 or val.bit_size == val.common_size
> + else:
> + val.common_size = val.bit_size
> + return val.common_size
> +
> + def _propagate_bit_class_down(self, val, bit_class):
> + if isinstance(val, Constant):
> + assert val.bit_size == 0 or val.bit_size == bit_class
> +
> + elif isinstance(val, Variable):
> + assert val.bit_size == 0 or val.bit_size == bit_class
> + self._set_var_bit_class(val.index, bit_class)
> +
> + elif isinstance(val, Expression):
> + nir_op = opcodes[val.opcode]
> + dst_type_bits = type_bits(nir_op.output_type)
> + if dst_type_bits != 0:
> + assert bit_class == 0 or bit_class == dst_type_bits
> + else:
> + assert val.common_size == 0 or val.common_size == bit_class
> + val.common_size = bit_class
> +
> + if val.common_size:
> + common_class = val.common_size
> + elif nir_op.num_inputs:
> + # If we got here then we have no idea what the actual size is.
> + # Instead, we use a generic class
> + common_class = self._new_class()
> +
> + for i in range(nir_op.num_inputs):
> + src_type_bits = type_bits(nir_op.input_types[i])
> + if src_type_bits != 0:
> + self._propagate_bit_class_down(val.sources[i], src_type_bits)
> + else:
> + self._propagate_bit_class_down(val.sources[i], common_class)
> +
> + def _validate_bit_class_up(self, val):
> + if isinstance(val, Constant):
> + return val.bit_size
> +
> + elif isinstance(val, Variable):
> + var_class = self._get_var_bit_class(val.index)
> + # By the time we get to validation, every variable should have a class
> + assert var_class != 0
> +
> + # If we have an explicit size provided by the user, the variable
> + # *must* exactly match the search. It cannot be implicitly sized
> + # because otherwise we could end up with a conflict at runtime.
> + assert val.bit_size == 0 or val.bit_size == var_class
> +
> + return var_class
> +
> + elif isinstance(val, Expression):
> + nir_op = opcodes[val.opcode]
> + val.common_class = 0
> + for i in range(nir_op.num_inputs):
> + src_class = self._validate_bit_class_up(val.sources[i])
> + if src_class == 0:
> + continue
> +
> + src_type_bits = type_bits(nir_op.input_types[i])
> + if src_type_bits != 0:
> + assert src_class == src_type_bits
> + else:
> + assert val.common_class == 0 or src_class == val.common_class
> + val.common_class = src_class
> +
> + dst_type_bits = type_bits(nir_op.output_type)
> + if dst_type_bits != 0:
> + assert val.bit_size == 0 or val.bit_size == dst_type_bits
> + return dst_type_bits
> + else:
> + if val.common_class != 0:
> + assert val.bit_size == 0 or val.bit_size == val.common_class
> + else:
> + val.common_class = val.bit_size
> + return val.common_class
> +
> + def _validate_bit_class_down(self, val, bit_class):
> + # At this point, everthing *must* have a bit class. Otherwise, we have
everything
Other than that,
Reviewed-by: Samuel Iglesias Gonsálvez <siglesias at igalia.com>
Sam
> + # a value we don't know how to define.
> + assert bit_class != 0
> +
> + if isinstance(val, Constant):
> + assert val.bit_size == 0 or val.bit_size == bit_class
> +
> + elif isinstance(val, Variable):
> + assert val.bit_size == 0 or val.bit_size == bit_class
> +
> + elif isinstance(val, Expression):
> + nir_op = opcodes[val.opcode]
> + dst_type_bits = type_bits(nir_op.output_type)
> + if dst_type_bits != 0:
> + assert bit_class == dst_type_bits
> + else:
> + assert val.common_class == 0 or val.common_class == bit_class
> + val.common_class = bit_class
> +
> + for i in range(nir_op.num_inputs):
> + src_type_bits = type_bits(nir_op.input_types[i])
> + if src_type_bits != 0:
> + self._validate_bit_class_down(val.sources[i], src_type_bits)
> + else:
> + self._validate_bit_class_down(val.sources[i], val.common_class)
> +
> _optimization_ids = itertools.count()
>
> condition_list = ['true']
> @@ -220,6 +488,8 @@ class SearchAndReplace(object):
> else:
> self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
>
> + BitSizeValidator(varset).validate(self.search, self.replace)
> +
> _algebraic_pass_template = mako.template.Template("""
> #include "nir.h"
> #include "nir_search.h"
>
More information about the mesa-dev
mailing list