[Mesa-dev] [PATCH 08/14] nir/algebraic: Add support for unsized conversion opcodes

Jason Ekstrand jason at jlekstrand.net
Fri Nov 9 03:46:13 UTC 2018


All conversion opcodes require a destination size but this makes
constructing certain algebraic expressions rather cumbersome.  This
commit adds support to nir_search and nir_algebraic for writing
conversion opcodes without a size.  These meta-opcodes match any
conversion of that type regardless of destination size and the size gets
inferred from the sizes of the things being matched or from other
opcodes in the expression.
---
 src/compiler/nir/nir_algebraic.py | 110 ++++++++++++++++++++----
 src/compiler/nir/nir_search.c     | 133 ++++++++++++++++++++++++------
 src/compiler/nir/nir_search.h     |  13 ++-
 3 files changed, 214 insertions(+), 42 deletions(-)

diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
index abe1d64e042..4506cb1c649 100644
--- a/src/compiler/nir/nir_algebraic.py
+++ b/src/compiler/nir/nir_algebraic.py
@@ -33,7 +33,19 @@ import mako.template
 import re
 import traceback
 
-from nir_opcodes import opcodes
+from nir_opcodes import opcodes, type_sizes
+
+# These opcodes are only employed by nir_search.  This provides a mapping from
+# opcode to destination type.
+search_opcode_types = {
+    'i2f' : 'float',
+    'u2f' : 'float',
+    'f2f' : 'float',
+    'f2u' : 'uint',
+    'f2i' : 'int',
+    'u2u' : 'uint',
+    'i2i' : 'int',
+}
 
 if sys.version_info < (3, 0):
     integer_types = (int, long)
@@ -98,7 +110,7 @@ static const ${val.c_type} ${val.name} = {
    ${val.cond if val.cond else 'NULL'},
 % elif isinstance(val, Expression):
    ${'true' if val.inexact else 'false'},
-   nir_op_${val.opcode},
+   ${val.c_opcode()},
    { ${', '.join(src.c_ptr for src in val.sources)} },
    ${val.cond if val.cond else 'NULL'},
 % endif
@@ -227,6 +239,18 @@ class Expression(Value):
       self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
                        for (i, src) in enumerate(expr[1:]) ]
 
+      if self.opcode in search_opcode_types:
+         assert self.bit_size == 0, \
+                'Expression cannot use an unsized conversion opcode with ' \
+                'an explicit size; that\'s silly.'
+
+
+   def c_opcode(self):
+      if self.opcode in search_opcode_types:
+         return 'nir_search_op_' + self.opcode
+      else:
+         return 'nir_op_' + self.opcode
+
    def render(self):
       srcs = "\n".join(src.render() for src in self.sources)
       return srcs + super(Expression, self).render()
@@ -393,6 +417,13 @@ class BitSizeValidator(object):
          return val.bit_size
 
       elif isinstance(val, Expression):
+         # These fake opcodes don't require a source size, don't have a
+         # destination size, and don't require the two to match
+         if val.opcode in search_opcode_types:
+            self._propagate_bit_size_up(val.sources[0])
+            val.common_size = 0
+            return 0
+
          nir_op = opcodes[val.opcode]
          val.common_size = 0
          for i in range(nir_op.num_inputs):
@@ -432,14 +463,30 @@ class BitSizeValidator(object):
 
    def _propagate_bit_class_down(self, val, bit_class):
       if isinstance(val, Constant):
-         assert val.bit_size == 0 or val.bit_size == bit_class, \
+         assert bit_class == 0 or val.bit_size == 0 or \
+                val.bit_size == bit_class, \
                 'Constant is {0}-bit but a {1}-bit value is required: {2}' \
                 .format(val.bit_size, bit_class, str(val))
 
       elif isinstance(val, Variable):
-         self._set_var_bit_class(val, bit_class)
+         if bit_class == 0:
+            self._set_var_bit_class(val, self._new_class())
+         else:
+            self._set_var_bit_class(val, bit_class)
 
       elif isinstance(val, Expression):
+         # These fake opcodes don't require a source size, don't have a
+         # destination size, and don't require the two to match
+         if val.opcode in search_opcode_types:
+            assert val.common_size == 0 or val.common_size == bit_class, \
+                   'Variable-width expression produces a {0}-bit result ' \
+                   'based on the source widths but the parent expression ' \
+                   'wants a {1}-bit value: {2}' \
+                   .format(val.common_size, bit_class, str(val))
+            val.common_size = bit_class
+            self._propagate_bit_class_down(val.sources[0], 0)
+            return
+
          nir_op = opcodes[val.opcode]
          dst_type_bits = type_bits(nir_op.output_type)
          if dst_type_bits != 0:
@@ -449,13 +496,15 @@ class BitSizeValidator(object):
                    .format(val.opcode, dst_type_bits,
                            self._bit_class_to_str(bit_class))
          else:
-            assert val.common_size == 0 or val.common_size == bit_class, \
+            assert bit_class == 0 or val.common_size == 0 or \
+                   val.common_size == bit_class, \
                    'Variable-width expression produces a {0}-bit result ' \
                    'based on the source widths but the parent expression ' \
                    'wants a {1} value: {2}' \
                    .format(val.common_size,
                            self._bit_class_to_str(bit_class), str(val))
-            val.common_size = bit_class
+            if val.common_size == 0:
+               val.common_size = bit_class
 
          if val.common_size:
             common_class = val.common_size
@@ -493,6 +542,12 @@ class BitSizeValidator(object):
          return var_class
 
       elif isinstance(val, Expression):
+         # These fake opcodes don't require a source size, don't have a
+         # destination size, and don't require the two to match
+         if val.opcode in search_opcode_types:
+            self._validate_bit_class_up(val.sources[0])
+            return 0
+
          nir_op = opcodes[val.opcode]
          val.common_class = 0
          for i in range(nir_op.num_inputs):
@@ -538,43 +593,56 @@ class BitSizeValidator(object):
             return val.common_class
 
    def _validate_bit_class_down(self, val, bit_class):
-      # At this point, everything *must* have a bit class.  Otherwise, we have
-      # a value we don't know how to define.
-      assert bit_class != 0, \
-             'Value cannot be constructed because no bit-size is implied '\
-             '{}'.format(str(val))
-
       if isinstance(val, Constant):
-         assert val.bit_size == 0 or val.bit_size == bit_class, \
+         assert bit_class != 0 or val.bit_size != 0, \
+                'Constant value {} cannot be constructed because ' \
+                'nothing provides or implies a bit size'.format(val)
+
+         assert bit_class == 0 or val.bit_size == 0 or \
+                val.bit_size == bit_class, \
                 'Constant value {} explicitly requests being {}-bit but ' \
                 'must be {} thanks to its consumer' \
                 .format(str(val), val.bit_size,
                         self._bit_class_to_str(bit_class))
 
       elif isinstance(val, Variable):
-         assert val.bit_size == 0 or val.bit_size == bit_class, \
+         assert bit_class == 0 or val.bit_size == 0 or \
+                val.bit_size == bit_class, \
                 'Variable {} explicitly only matches {}-bit values but ' \
                 'must be {} thanks to its consumer' \
                 .format(str(val), val.bit_size,
                         self._bit_class_to_str(bit_class))
 
       elif isinstance(val, Expression):
+         # These fake opcodes don't require a source size, don't have a
+         # destination size, and don't require the two to match
+         if val.opcode in search_opcode_types:
+            self._validate_bit_class_down(val.sources[0], 0)
+            return
+
          nir_op = opcodes[val.opcode]
          dst_type_bits = type_bits(nir_op.output_type)
          if dst_type_bits != 0:
-            assert bit_class == dst_type_bits, \
+            assert bit_class == 0 or bit_class == dst_type_bits, \
                    'Result of nir_op_{} must be a {}-bit value but the ' \
                    'consumer requires a {} value: {}' \
                    .format(val.opcode, dst_type_bits,
                            self._bit_class_to_str(bit_class), str(val))
          else:
-            assert val.common_class == 0 or val.common_class == bit_class, \
+            assert bit_class != 0 or val.common_class != 0, \
+                   'Expression cannot be constructed because nothing ' \
+                   'provides or implies a result bit size: {}' \
+                   .format(str(val))
+
+            assert bit_class == 0 or val.common_class == 0 or \
+                   val.common_class == bit_class, \
                    'Result of nir_op_{} must be a {} value but based on ' \
                    'the sources but the consumer requires a {} value: {}' \
                    .format(val.opcode,
                            self._bit_class_to_str(val.common_class),
                            self._bit_class_to_str(bit_class), str(val))
-            val.common_class = bit_class
+            if val.common_class == 0:
+               val.common_class = bit_class
 
          for i in range(nir_op.num_inputs):
             src_type_bits = type_bits(nir_op.input_types[i])
@@ -744,7 +812,13 @@ class AlgebraicPass(object):
                continue
 
          self.xforms.append(xform)
-         self.opcode_xforms[xform.search.opcode].append(xform)
+         if xform.search.opcode in search_opcode_types:
+            dst_type = search_opcode_types[xform.search.opcode]
+            for size in type_sizes(dst_type):
+               sized_opcode = xform.search.opcode + str(size)
+               self.opcode_xforms[sized_opcode].append(xform)
+         else:
+            self.opcode_xforms[xform.search.opcode].append(xform)
 
       if error:
          sys.exit(1)
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index 0270302fd3d..067a277d791 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -89,6 +89,82 @@ src_is_type(nir_src src, nir_alu_type type)
    return false;
 }
 
+static bool
+nir_op_matches_search_op(nir_op nop, uint16_t sop)
+{
+   if (sop <= nir_last_opcode)
+      return nop == sop;
+
+#define MATCH_FCONV_CASE(op) \
+   case nir_search_op_##op: \
+      return nop == nir_op_##op##16 || \
+             nop == nir_op_##op##32 || \
+             nop == nir_op_##op##64;
+
+#define MATCH_ICONV_CASE(op) \
+   case nir_search_op_##op: \
+      return nop == nir_op_##op##8 || \
+             nop == nir_op_##op##16 || \
+             nop == nir_op_##op##32 || \
+             nop == nir_op_##op##64;
+
+   switch (sop) {
+   MATCH_FCONV_CASE(i2f)
+   MATCH_FCONV_CASE(u2f)
+   MATCH_FCONV_CASE(f2f)
+   MATCH_ICONV_CASE(f2u)
+   MATCH_ICONV_CASE(f2i)
+   MATCH_ICONV_CASE(u2u)
+   MATCH_ICONV_CASE(i2i)
+   default:
+      unreachable("Invalid nir_search_op");
+   }
+
+#undef MATCH_FCONV_CASE
+#undef MATCH_ICONV_CASE
+}
+
+static nir_op
+nir_op_for_search_op(uint16_t sop, unsigned bit_size)
+{
+   if (sop <= nir_last_opcode)
+      return sop;
+
+#define RET_FCONV_CASE(op) \
+   case nir_search_op_##op: \
+      switch (bit_size) { \
+      case 16: return nir_op_##op##16; \
+      case 32: return nir_op_##op##32; \
+      case 64: return nir_op_##op##64; \
+      default: unreachable("Invalid bit size"); \
+      }
+
+#define RET_ICONV_CASE(op) \
+   case nir_search_op_##op: \
+      switch (bit_size) { \
+      case 8:  return nir_op_##op##8; \
+      case 16: return nir_op_##op##16; \
+      case 32: return nir_op_##op##32; \
+      case 64: return nir_op_##op##64; \
+      default: unreachable("Invalid bit size"); \
+      }
+
+   switch (sop) {
+   RET_FCONV_CASE(i2f)
+   RET_FCONV_CASE(u2f)
+   RET_FCONV_CASE(f2f)
+   RET_ICONV_CASE(f2u)
+   RET_ICONV_CASE(f2i)
+   RET_ICONV_CASE(u2u)
+   RET_ICONV_CASE(i2i)
+   default:
+      unreachable("Invalid nir_search_op");
+   }
+
+#undef RET_FCONV_CASE
+#undef RET_ICONV_CASE
+}
+
 static bool
 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
             unsigned num_components, const uint8_t *swizzle,
@@ -223,7 +299,7 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
    if (expr->cond && !expr->cond(instr))
       return false;
 
-   if (instr->op != expr->opcode)
+   if (!nir_op_matches_search_op(instr->op, expr->opcode))
       return false;
 
    assert(instr->dest.dest.is_ssa);
@@ -297,6 +373,7 @@ typedef struct bitsize_tree {
    unsigned common_size;
    bool is_src_sized[4];
    bool is_dest_sized;
+   bool is_unsized_conversion;
 
    unsigned dest_size;
    unsigned src_size[4];
@@ -311,18 +388,24 @@ build_bitsize_tree(void *mem_ctx, struct match_state *state,
    switch (value->type) {
    case nir_search_value_expression: {
       nir_search_expression *expr = nir_search_value_as_expression(value);
-      nir_op_info info = nir_op_infos[expr->opcode];
-      tree->num_srcs = info.num_inputs;
-      tree->common_size = 0;
-      for (unsigned i = 0; i < info.num_inputs; i++) {
-         tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]);
-         if (tree->is_src_sized[i])
-            tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]);
-         tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]);
+      if (expr->opcode <= nir_last_opcode) {
+         nir_op_info info = nir_op_infos[expr->opcode];
+         tree->num_srcs = info.num_inputs;
+         tree->common_size = 0;
+         for (unsigned i = 0; i < info.num_inputs; i++) {
+            tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]);
+            if (tree->is_src_sized[i])
+               tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]);
+            tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]);
+         }
+         tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);
+         if (tree->is_dest_sized)
+            tree->dest_size = nir_alu_type_get_type_size(info.output_type);
+      } else {
+         tree->num_srcs = 1;
+         tree->srcs[0] = build_bitsize_tree(mem_ctx, state, expr->srcs[0]);
+         tree->is_unsized_conversion = true;
       }
-      tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);
-      if (tree->is_dest_sized)
-         tree->dest_size = nir_alu_type_get_type_size(info.output_type);
       break;
    }
 
@@ -360,6 +443,9 @@ bitsize_tree_filter_up(bitsize_tree *tree)
 
       if (tree->is_src_sized[i]) {
          assert(src_size == tree->src_size[i]);
+      } else if (tree->is_unsized_conversion) {
+         assert(src_size);
+         tree->src_size[i] = src_size;
       } else if (tree->common_size != 0) {
          assert(src_size == tree->common_size);
          tree->src_size[i] = src_size;
@@ -369,11 +455,11 @@ bitsize_tree_filter_up(bitsize_tree *tree)
       }
    }
 
-   if (tree->num_srcs && tree->common_size) {
-      if (tree->dest_size == 0)
+   if (tree->common_size) {
+      if (!tree->is_dest_sized && !tree->is_unsized_conversion) {
+         assert(tree->dest_size == 0 || tree->dest_size == tree->common_size);
          tree->dest_size = tree->common_size;
-      else if (!tree->is_dest_sized)
-         assert(tree->dest_size == tree->common_size);
+      }
 
       for (unsigned i = 0; i < tree->num_srcs; i++) {
          if (!tree->src_size[i])
@@ -388,13 +474,13 @@ static void
 bitsize_tree_filter_down(bitsize_tree *tree, unsigned size)
 {
    if (tree->dest_size)
-      assert(tree->dest_size == size);
+      assert(size == 0 || tree->dest_size == size);
    else
       tree->dest_size = size;
 
-   if (!tree->is_dest_sized) {
+   if (!tree->is_dest_sized && !tree->is_unsized_conversion) {
       if (tree->common_size)
-         assert(tree->common_size == size);
+         assert(size == 0 || tree->common_size == size);
       else
          tree->common_size = size;
    }
@@ -418,11 +504,12 @@ construct_value(nir_builder *build,
    switch (value->type) {
    case nir_search_value_expression: {
       const nir_search_expression *expr = nir_search_value_as_expression(value);
+      nir_op op = nir_op_for_search_op(expr->opcode, bitsize->dest_size);
 
-      if (nir_op_infos[expr->opcode].output_size != 0)
-         num_components = nir_op_infos[expr->opcode].output_size;
+      if (nir_op_infos[op].output_size != 0)
+         num_components = nir_op_infos[op].output_size;
 
-      nir_alu_instr *alu = nir_alu_instr_create(build->shader, expr->opcode);
+      nir_alu_instr *alu = nir_alu_instr_create(build->shader, op);
       nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
                         bitsize->dest_size, NULL);
       alu->dest.write_mask = (1 << num_components) - 1;
@@ -435,7 +522,7 @@ construct_value(nir_builder *build,
        */
       alu->exact = state->has_exact_alu;
 
-      for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
+      for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
          /* If the source is an explicitly sized source, then we need to reset
           * the number of components to match.
           */
diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h
index df4189ede74..3850cfedcde 100644
--- a/src/compiler/nir/nir_search.h
+++ b/src/compiler/nir/nir_search.h
@@ -94,6 +94,16 @@ typedef struct {
    } data;
 } nir_search_constant;
 
+enum nir_search_op {
+   nir_search_op_i2f = nir_last_opcode + 1,
+   nir_search_op_u2f,
+   nir_search_op_f2f,
+   nir_search_op_f2u,
+   nir_search_op_f2i,
+   nir_search_op_u2u,
+   nir_search_op_i2i,
+};
+
 typedef struct {
    nir_search_value value;
 
@@ -103,7 +113,8 @@ typedef struct {
     */
    bool inexact;
 
-   nir_op opcode;
+   /* One of nir_op or nir_search_op */
+   uint16_t opcode;
    const nir_search_value *srcs[4];
 
    /** Optional condition fxn ptr
-- 
2.19.1



More information about the mesa-dev mailing list