[Mesa-dev] [PATCH 7/7] nir/algebraic: Use compact expressions and don't emit duplicates

Jason Ekstrand jason at jlekstrand.net
Fri May 27 01:30:37 UTC 2016


The deduplication helps a lot for variables and constants where we have a
*lot* of duplicates.  The compact expressions are also a *lot* smaller.
This cuts about 56K from nir_search_algebraic.o (mostly from .data):

   text     data      bss      dec      hex  filename
  17951    64584        0    82535    14267  nir_opt_algebraic.o
  10780    15536        0    26316     66cc  nir_opt_algebraic.o
---
 src/compiler/nir/nir_algebraic.py | 135 +++++++++++++++++++++++++-------------
 src/compiler/nir/nir_search.h     |   2 +-
 2 files changed, 92 insertions(+), 45 deletions(-)

diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
index 0d1ed3a..cccc34d 100644
--- a/src/compiler/nir/nir_algebraic.py
+++ b/src/compiler/nir/nir_algebraic.py
@@ -63,6 +63,28 @@ class VarSet(object):
    def lock(self):
       self.immutable = True
 
+class ValueCache(object):
+   def __init__(self):
+      self.cache = {}
+      self.list = []
+
+   def __iter__(self):
+      return self.list.__iter__()
+
+   def add_value(self, val):
+      if val in self.cache:
+         return;
+
+      if isinstance(val, Expression):
+         for src in val.sources:
+            self.add_value(src)
+
+      self.cache[val] = len(self.list)
+      self.list.append(val)
+
+   def get_index(self, value):
+      return self.cache[value]
+
 class Value(object):
    @staticmethod
    def create(val, name_base, varset):
@@ -75,44 +97,14 @@ class Value(object):
       elif isinstance(val, (bool, int, long, float)):
          return Constant(val, name_base)
 
-   __template = mako.template.Template("""
-static const ${val.c_type} ${val.name} = {
-   { ${val.type_enum}, ${val.bit_size} },
-% if isinstance(val, Constant):
-   ${val.type()}, { ${hex(val)} /* ${val.value} */ },
-% elif isinstance(val, Variable):
-   ${val.index}, /* ${val.var_name} */
-   ${'true' if val.is_constant else 'false'},
-   ${val.type() or 'nir_type_invalid' },
-% elif isinstance(val, Expression):
-   ${'true' if val.inexact else 'false'},
-   nir_op_${val.opcode},
-   { ${', '.join(src.c_ptr for src in val.sources)} },
-% endif
-};""")
-
    def __init__(self, name, type_str):
       self.name = name
       self.type_str = type_str
 
    @property
-   def type_enum(self):
-      return "nir_search_value_" + self.type_str
-
-   @property
    def c_type(self):
       return "nir_search_" + self.type_str
 
-   @property
-   def c_ptr(self):
-      return "&{0}.value".format(self.name)
-
-   def render(self):
-      return self.__template.render(val=self,
-                                    Constant=Constant,
-                                    Variable=Variable,
-                                    Expression=Expression)
-
 _constant_re = re.compile(r"(?P<value>[^@]+)(?:@(?P<bits>\d+))?")
 
 class Constant(Value):
@@ -151,6 +143,10 @@ class Constant(Value):
       else:
          assert False
 
+   @property
+   def comment_str(self):
+      return "{0}({1})".format(type(self.value).__name__, self.value)
+
    def type(self):
       if isinstance(self.value, (bool)):
          return "nir_type_bool32"
@@ -159,6 +155,16 @@ class Constant(Value):
       elif isinstance(self.value, float):
          return "nir_type_float"
 
+   __union_template = mako.template.Template("""
+   { .constant = {
+      { nir_search_value_constant, ${val.bit_size} },
+      ${val.type()},
+      { .u = ${hex(val)} /* ${val.value} */ },
+   } },""")
+
+   def render_union(self, cache):
+      return self.__union_template.render(val=self)
+
 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
                           r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?")
 
@@ -199,6 +205,10 @@ class Variable(Value):
          and self.required_type == other.required_type \
          and self.bit_size == other.bit_size
 
+   @property
+   def comment_str(self):
+      return "variable"
+
    def type(self):
       if self.required_type == 'bool':
          return "nir_type_bool32"
@@ -207,6 +217,17 @@ class Variable(Value):
       elif self.required_type == 'float':
          return "nir_type_float"
 
+   __union_template = mako.template.Template("""
+   { .variable = {
+      { nir_search_value_variable, ${val.bit_size} },
+      ${val.index}, /* ${val.var_name} */
+      ${'true' if val.is_constant else 'false'},
+      ${val.type() or 'nir_type_invalid' },
+   } },""")
+
+   def render_union(self, cache):
+      return self.__union_template.render(val=self)
+
 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?")
 
 class Expression(Value):
@@ -217,6 +238,7 @@ class Expression(Value):
       m = _opcode_re.match(expr[0])
       assert m and m.group('opcode') is not None
 
+      self.expr = expr
       self.opcode = m.group('opcode')
       self.bit_size = int(m.group('bits')) if m.group('bits') else 0
       self.inexact = m.group('inexact') is not None
@@ -247,9 +269,20 @@ class Expression(Value):
       return self.inexact == other.inexact \
          and self.bit_size == other.bit_size
 
-   def render(self):
-      srcs = "\n".join(src.render() for src in self.sources)
-      return srcs + super(Expression, self).render()
+   @property
+   def comment_str(self):
+      return str(self.expr)
+
+   __union_template = mako.template.Template("""
+   { .expression = {
+      { nir_search_value_compact_expression, ${val.bit_size} },
+      ${'true' if val.inexact else 'false'},
+      nir_op_${val.opcode},
+      { ${', '.join(str(cache.get_index(src)) for src in val.sources)} },
+   } },""")
+
+   def render_union(self, cache):
+      return self.__union_template.render(val=self, cache=cache)
 
 class IntEquivalenceRelation(object):
    """A class representing an equivalence relation on integers.
@@ -548,22 +581,29 @@ _algebraic_pass_template = mako.template.Template("""
 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
 
 struct transform {
-   const nir_search_value *search;
-   const nir_search_value *replace;
-   unsigned condition_offset;
+   uint16_t search;
+   uint16_t replace;
+   uint16_t condition_offset;
 };
 
 #endif
 
-% for (opcode, xform_list) in xform_dict.iteritems():
-% for xform in xform_list:
-   ${xform.search.render()}
-   ${xform.replace.render()}
+nir_search_value_union ${pass_name}_values[] = {
+% for index, value in enumerate(cache):
+
+   /* ${index}: ${value.comment_str} */
+   ${value.render_union(cache)}
 % endfor
+};
 
+% for (opcode, xform_list) in xform_dict.iteritems():
 static const struct transform ${pass_name}_${opcode}_xforms[] = {
 % for xform in xform_list:
-   { ${xform.search.c_ptr}, ${xform.replace.c_ptr}, ${xform.condition_index} },
+   {
+      ${cache.get_index(xform.search)},
+      ${cache.get_index(xform.replace)},
+      ${xform.condition_index}
+   },
 % endfor
 };
 % endfor
@@ -588,8 +628,10 @@ ${pass_name}_block(nir_block *block, const bool *condition_flags,
          for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
             const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
             if (condition_flags[xform->condition_offset] &&
-                nir_replace_instr(alu, xform->search, xform->replace,
-                                  NULL, mem_ctx)) {
+                nir_replace_instr(alu,
+                                  &${pass_name}_values[xform->search].value,
+                                  &${pass_name}_values[xform->replace].value,
+                                  ${pass_name}_values, mem_ctx)) {
                progress = true;
                break;
             }
@@ -647,6 +689,7 @@ class AlgebraicPass(object):
    def __init__(self, pass_name, transforms):
       self.xform_dict = {}
       self.pass_name = pass_name
+      self.cache = ValueCache()
 
       error = False
 
@@ -662,6 +705,9 @@ class AlgebraicPass(object):
                error = True
                continue
 
+         self.cache.add_value(xform.search)
+         self.cache.add_value(xform.replace)
+
          if xform.search.opcode not in self.xform_dict:
             self.xform_dict[xform.search.opcode] = []
 
@@ -673,4 +719,5 @@ class AlgebraicPass(object):
    def render(self):
       return _algebraic_pass_template.render(pass_name=self.pass_name,
                                              xform_dict=self.xform_dict,
-                                             condition_list=condition_list)
+                                             condition_list=condition_list,
+                                             cache=self.cache)
diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h
index 1588628..7ce73fb 100644
--- a/src/compiler/nir/nir_search.h
+++ b/src/compiler/nir/nir_search.h
@@ -105,7 +105,7 @@ typedef struct {
     */
    bool inexact:1;
 
-   nir_op opcode:15;
+   unsigned opcode:15; /* enum nir_op */
    uint16_t srcs[4];
 } nir_search_compact_expression;
 
-- 
2.5.0.400.gff86faf



More information about the mesa-dev mailing list