[Mesa-dev] [PATCH A 10/15] nir/algebraic: Add support for unsized conversion opcodes

Jason Ekstrand jason at jlekstrand.net
Fri Nov 9 03:45:11 UTC 2018


Unsized conversion opcodes require special handling in opt_algebraic
because they fallow different bit size rules from regular opcodes.  In
particular, we now have a new case where we have an opcode with multiple
variable-size inputs and outputs but no common size.
---
 src/compiler/nir/nir_algebraic.py | 68 ++++++++++++++++++++++++-------
 src/compiler/nir/nir_search.c     | 19 +++++----
 2 files changed, 65 insertions(+), 22 deletions(-)

diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
index f9ee637830c..837ef114349 100644
--- a/src/compiler/nir/nir_algebraic.py
+++ b/src/compiler/nir/nir_algebraic.py
@@ -406,6 +406,9 @@ class BitSizeValidator(object):
                       'Source {0} of nir_op_{1} must be a {2}-bit value but ' \
                       'the only possible matched values are {3}-bit: {4}' \
                       .format(i, val.opcode, src_type_bits, src_bits, str(val))
+            elif nir_op.is_unsized_conversion:
+                # Nothing to do here
+                pass
             else:
                assert val.common_size == 0 or src_bits == val.common_size, \
                       'Expression cannot have both {0}-bit and {1}-bit ' \
@@ -420,6 +423,10 @@ class BitSizeValidator(object):
                    'result was requested' \
                    .format(val.opcode, dst_type_bits, val.bit_size)
             return dst_type_bits
+         elif nir_op.is_unsized_conversion:
+            # No validation to do here.  If we have a non-zero bit size,
+            # that's the bit size of the result of the expression; return it.
+            return val.bit_size
          else:
             if val.common_size != 0:
                assert val.bit_size == 0 or val.bit_size == val.common_size, \
@@ -432,12 +439,16 @@ class BitSizeValidator(object):
 
    def _propagate_bit_class_down(self, val, bit_class):
       if isinstance(val, Constant):
-         assert val.bit_size == 0 or val.bit_size == bit_class, \
+         assert bit_class == 0 or val.bit_size == 0 or \
+                val.bit_size == bit_class, \
                 'Constant is {0}-bit but a {1}-bit value is required: {2}' \
                 .format(val.bit_size, bit_class, str(val))
 
       elif isinstance(val, Variable):
-         self._set_var_bit_class(val, bit_class)
+         if bit_class == 0:
+            self._set_var_bit_class(val, self._new_class())
+         else:
+            self._set_var_bit_class(val, bit_class)
 
       elif isinstance(val, Expression):
          nir_op = opcodes[val.opcode]
@@ -448,14 +459,19 @@ class BitSizeValidator(object):
                    'expression wants a {2} value' \
                    .format(val.opcode, dst_type_bits,
                            self._bit_class_to_str(bit_class))
+         elif nir_op.is_unsized_conversion:
+            # Nothing to do here
+            pass
          else:
-            assert val.common_size == 0 or val.common_size == bit_class, \
+            assert bit_class == 0 or val.common_size == 0 or \
+                   val.common_size == bit_class, \
                    'Variable-width expression produces a {0}-bit result ' \
                    'based on the source widths but the parent expression ' \
                    'wants a {1} value: {2}' \
                    .format(val.common_size,
                            self._bit_class_to_str(bit_class), str(val))
-            val.common_size = bit_class
+            if val.common_size == 0:
+               val.common_size = bit_class
 
          if val.common_size:
             common_class = val.common_size
@@ -468,6 +484,8 @@ class BitSizeValidator(object):
             src_type_bits = type_bits(nir_op.input_types[i])
             if src_type_bits != 0:
                self._propagate_bit_class_down(val.sources[i], src_type_bits)
+            elif nir_op.is_unsized_conversion:
+               self._propagate_bit_class_down(val.sources[i], 0)
             else:
                self._propagate_bit_class_down(val.sources[i], common_class)
 
@@ -507,6 +525,9 @@ class BitSizeValidator(object):
                       'the constructed value would be {}: {}' \
                       .format(i, val.opcode, src_type_bits,
                               self._bit_class_to_str(src_class), str(val))
+            elif nir_op.is_unsized_conversion:
+               # Nothing to do here
+               pass
             else:
                assert val.common_class == 0 or src_class == val.common_class, \
                       'Source {} of nir_op_{} must be a {} value based ' \
@@ -524,6 +545,8 @@ class BitSizeValidator(object):
                    'expression explicitly requests a {}-bit value' \
                    .format(val.opcode, dst_type_bits, val.bit_size)
             return dst_type_bits
+         elif nir_op.is_unsized_conversion:
+            return 0
          else:
             if val.common_class != 0:
                assert val.bit_size == 0 or val.bit_size == val.common_class, \
@@ -538,21 +561,21 @@ class BitSizeValidator(object):
             return val.common_class
 
    def _validate_bit_class_down(self, val, bit_class):
-      # At this point, everything *must* have a bit class.  Otherwise, we have
-      # a value we don't know how to define.
-      assert bit_class != 0, \
-             'Value cannot be constructed because no bit-size is implied '\
-             '{}'.format(str(val))
-
       if isinstance(val, Constant):
-         assert val.bit_size == 0 or val.bit_size == bit_class, \
+         assert bit_class != 0 or val.bit_size != 0, \
+                'Constant value {} cannot be constructed because ' \
+                'nothing provides or implies a bit size'.format(val)
+
+         assert val.bit_size == 0 or bit_class == 0 or \
+                val.bit_size == bit_class, \
                 'Constant value {} explicitly requests being {}-bit but ' \
                 'must be {} thanks to its consumer' \
                 .format(str(val), val.bit_size,
                         self._bit_class_to_str(bit_class))
 
       elif isinstance(val, Variable):
-         assert val.bit_size == 0 or val.bit_size == bit_class, \
+         assert val.bit_size == 0 or bit_class == 0 or \
+                val.bit_size == bit_class, \
                 'Variable {} explicitly only matches {}-bit values but ' \
                 'must be {} thanks to its consumer' \
                 .format(str(val), val.bit_size,
@@ -562,24 +585,39 @@ class BitSizeValidator(object):
          nir_op = opcodes[val.opcode]
          dst_type_bits = type_bits(nir_op.output_type)
          if dst_type_bits != 0:
-            assert bit_class == dst_type_bits, \
+            assert bit_class == 0 or bit_class == dst_type_bits, \
                    'Result of nir_op_{} must be a {}-bit value but the ' \
                    'consumer requires a {} value: {}' \
                    .format(val.opcode, dst_type_bits,
                            self._bit_class_to_str(bit_class), str(val))
+         elif nir_op.is_unsized_conversion:
+            # Nothing to do here
+            pass
          else:
-            assert val.common_class == 0 or val.common_class == bit_class, \
+            assert bit_class != 0 or val.common_class != 0, \
+                   'Expression cannot be constructed because nothing ' \
+                   'provides or implies a result bit size: {}' \
+                   .format(str(val))
+
+            assert bit_class == 0 or val.common_class == 0 or \
+                   val.common_class == bit_class, \
                    'Result of nir_op_{} must be a {} value but based on ' \
                    'the sources but the consumer requires a {} value: {}' \
                    .format(val.opcode,
                            self._bit_class_to_str(val.common_class),
                            self._bit_class_to_str(bit_class), str(val))
-            val.common_class = bit_class
+            if val.common_class == 0:
+               val.common_class = bit_class
 
          for i in range(nir_op.num_inputs):
             src_type_bits = type_bits(nir_op.input_types[i])
             if src_type_bits != 0:
                self._validate_bit_class_down(val.sources[i], src_type_bits)
+            elif nir_op.is_unsized_conversion:
+               # We can't imply a bit size and nothing is coming up the chain
+               # so we just pass it it's own bit size.  If it's zero, it will
+               # trigger the assert at the top of this function.
+               self._validate_bit_class_down(val.sources[i], 0)
             else:
                self._validate_bit_class_down(val.sources[i], val.common_class)
 
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index 0270302fd3d..838031e700d 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -297,6 +297,7 @@ typedef struct bitsize_tree {
    unsigned common_size;
    bool is_src_sized[4];
    bool is_dest_sized;
+   bool is_unsized_conversion;
 
    unsigned dest_size;
    unsigned src_size[4];
@@ -323,6 +324,7 @@ build_bitsize_tree(void *mem_ctx, struct match_state *state,
       tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);
       if (tree->is_dest_sized)
          tree->dest_size = nir_alu_type_get_type_size(info.output_type);
+      tree->is_unsized_conversion = info.is_unsized_conversion;
       break;
    }
 
@@ -360,6 +362,9 @@ bitsize_tree_filter_up(bitsize_tree *tree)
 
       if (tree->is_src_sized[i]) {
          assert(src_size == tree->src_size[i]);
+      } else if (tree->is_unsized_conversion) {
+         assert(src_size);
+         tree->src_size[i] = src_size;
       } else if (tree->common_size != 0) {
          assert(src_size == tree->common_size);
          tree->src_size[i] = src_size;
@@ -369,11 +374,11 @@ bitsize_tree_filter_up(bitsize_tree *tree)
       }
    }
 
-   if (tree->num_srcs && tree->common_size) {
-      if (tree->dest_size == 0)
+   if (tree->common_size) {
+      if (!tree->is_dest_sized && !tree->is_unsized_conversion) {
+         assert(tree->dest_size == 0 || tree->dest_size == tree->common_size);
          tree->dest_size = tree->common_size;
-      else if (!tree->is_dest_sized)
-         assert(tree->dest_size == tree->common_size);
+      }
 
       for (unsigned i = 0; i < tree->num_srcs; i++) {
          if (!tree->src_size[i])
@@ -388,13 +393,13 @@ static void
 bitsize_tree_filter_down(bitsize_tree *tree, unsigned size)
 {
    if (tree->dest_size)
-      assert(tree->dest_size == size);
+      assert(size == 0 || tree->dest_size == size);
    else
       tree->dest_size = size;
 
-   if (!tree->is_dest_sized) {
+   if (!tree->is_dest_sized && !tree->is_unsized_conversion) {
       if (tree->common_size)
-         assert(tree->common_size == size);
+         assert(size == 0 || tree->common_size == size);
       else
          tree->common_size = size;
    }
-- 
2.19.1



More information about the mesa-dev mailing list