Mesa (main): nir/algebraic: Move relocations for expression conds to a table.

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Tue Dec 7 07:49:39 UTC 2021


Module: Mesa
Branch: main
Commit: 8485a789776554548ca19084d638a4496805df77
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=8485a789776554548ca19084d638a4496805df77

Author: Emma Anholt <emma at anholt.net>
Date:   Tue Nov 30 14:23:39 2021 -0800

nir/algebraic: Move relocations for expression conds to a table.

This helps concentrate the dirty pages from the relocations, reduces how
many relocations there are, and reduces the size of each expression
(assuming expressions mostly don't have conditions or the conditions are
mostly reused).  Reduces libvulkan_intel.so size by 8.7kb.

Reviewed-by: Adam Jackson <ajax at redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13987>

---

 src/compiler/nir/nir_algebraic.py               | 44 ++++++++++++++++++++-----
 src/compiler/nir/nir_search.c                   | 15 +++++----
 src/compiler/nir/nir_search.h                   | 12 +++++--
 src/compiler/nir/tests/algebraic_parser_test.py |  6 ++--
 4 files changed, 57 insertions(+), 20 deletions(-)

diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
index 92114482bc8..57887947912 100644
--- a/src/compiler/nir/nir_algebraic.py
+++ b/src/compiler/nir/nir_algebraic.py
@@ -53,6 +53,17 @@ conv_opcode_types = {
     'f2b' : 'bool',
 }
 
+def get_cond_index(conds, cond):
+    if cond:
+        if cond in conds:
+            return conds[cond]
+        else:
+            cond_index = len(conds)
+            conds[cond] = cond_index
+            return cond_index
+    else:
+        return -1
+
 def get_c_opcode(op):
       if op in conv_opcode_types:
          return 'nir_search_op_' + op
@@ -89,12 +100,12 @@ class VarSet(object):
 
 class Value(object):
    @staticmethod
-   def create(val, name_base, varset):
+   def create(val, name_base, varset, algebraic_pass):
       if isinstance(val, bytes):
          val = val.decode('utf-8')
 
       if isinstance(val, tuple):
-         return Expression(val, name_base, varset)
+         return Expression(val, name_base, varset, algebraic_pass)
       elif isinstance(val, Expression):
          return val
       elif isinstance(val, str):
@@ -178,7 +189,7 @@ class Value(object):
       ${val.comm_expr_idx}, ${val.comm_exprs},
       ${val.c_opcode()},
       { ${', '.join(src.array_index for src in val.sources)} },
-      ${val.cond if val.cond else 'NULL'},
+      ${val.cond_index},
 % endif
    } },
 """)
@@ -326,7 +337,7 @@ _opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bit
                         r"(?P<cond>\([^\)]+\))?")
 
 class Expression(Value):
-   def __init__(self, expr, name_base, varset):
+   def __init__(self, expr, name_base, varset, algebraic_pass):
       Value.__init__(self, expr, name_base, "expression")
       assert isinstance(expr, tuple)
 
@@ -356,7 +367,11 @@ class Expression(Value):
          self.cond = c[0] if c else None
          self.many_commutative_expressions = True
 
-      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
+      # Deduplicate references to the condition functions for the expressions
+      # and save the index for the order they were added.
+      self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond)
+
+      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass)
                        for (i, src) in enumerate(expr[1:]) ]
 
       # nir_search_expression::srcs is hard-coded to 4
@@ -730,7 +745,7 @@ _optimization_ids = itertools.count()
 condition_list = ['true']
 
 class SearchAndReplace(object):
-   def __init__(self, transform):
+   def __init__(self, transform, algebraic_pass):
       self.id = next(_optimization_ids)
 
       search = transform[0]
@@ -748,14 +763,14 @@ class SearchAndReplace(object):
       if isinstance(search, Expression):
          self.search = search
       else:
-         self.search = Expression(search, "search{0}".format(self.id), varset)
+         self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass)
 
       varset.lock()
 
       if isinstance(replace, Value):
          self.replace = replace
       else:
-         self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
+         self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass)
 
       BitSizeValidator(varset).validate(self.search, self.replace)
 
@@ -1041,6 +1056,14 @@ ${xform.replace.render(cache)}
 % endfor
 };
 
+% if expression_cond:
+static const nir_search_expression_cond ${pass_name}_expression_cond[] = {
+% for cond in expression_cond:
+   ${cond[0]},
+% endfor
+};
+% endif
+
 % for state_id, state_xforms in enumerate(automaton.state_patterns):
 % if state_xforms: # avoid emitting a 0-length array for MSVC
 static const struct transform ${pass_name}_state${state_id}_xforms[] = {
@@ -1100,6 +1123,7 @@ static const nir_algebraic_table ${pass_name}_table = {
    .transform_counts = ${pass_name}_transform_counts,
    .pass_op_table = ${pass_name}_pass_op_table,
    .values = ${pass_name}_values,
+   .expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" },
 };
 
 bool
@@ -1134,13 +1158,14 @@ class AlgebraicPass(object):
       self.xforms = []
       self.opcode_xforms = defaultdict(lambda : [])
       self.pass_name = pass_name
+      self.expression_cond = {}
 
       error = False
 
       for xform in transforms:
          if not isinstance(xform, SearchAndReplace):
             try:
-               xform = SearchAndReplace(xform)
+               xform = SearchAndReplace(xform, self)
             except:
                print("Failed to parse transformation:", file=sys.stderr)
                print("  " + str(xform), file=sys.stderr)
@@ -1196,5 +1221,6 @@ class AlgebraicPass(object):
                                              opcode_xforms=self.opcode_xforms,
                                              condition_list=condition_list,
                                              automaton=self.automaton,
+                                             expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]),
                                              get_c_opcode=get_c_opcode,
                                              itertools=itertools)
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index 9d1061cafb0..cbe7ca22a7c 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -50,7 +50,7 @@ struct match_state {
 };
 
 static bool
-match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
+match_expression(const nir_algebraic_table *table, const nir_search_expression *expr, nir_alu_instr *instr,
                  unsigned num_components, const uint8_t *swizzle,
                  struct match_state *state);
 static bool
@@ -253,7 +253,8 @@ nir_op_for_search_op(uint16_t sop, unsigned bit_size)
 }
 
 static bool
-match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
+match_value(const nir_algebraic_table *table,
+            const nir_search_value *value, nir_alu_instr *instr, unsigned src,
             unsigned num_components, const uint8_t *swizzle,
             struct match_state *state)
 {
@@ -289,7 +290,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
       if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
          return false;
 
-      return match_expression(nir_search_value_as_expression(value),
+      return match_expression(table, nir_search_value_as_expression(value),
                               nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
                               num_components, new_swizzle, state);
 
@@ -390,11 +391,11 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
 }
 
 static bool
-match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
+match_expression(const nir_algebraic_table *table, const nir_search_expression *expr, nir_alu_instr *instr,
                  unsigned num_components, const uint8_t *swizzle,
                  struct match_state *state)
 {
-   if (expr->cond && !expr->cond(instr))
+   if (expr->cond_index != -1 && !table->expression_cond[expr->cond_index](instr))
       return false;
 
    if (!nir_op_matches_search_op(instr->op, expr->opcode))
@@ -441,7 +442,7 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
       /* 2src_commutative instructions that have 3 sources are only commutative
        * in the first two sources.  Source 2 is always source 2.
        */
-      if (!match_value(&state->table->values[expr->srcs[i]].value, instr,
+      if (!match_value(table, &state->table->values[expr->srcs[i]].value, instr,
                        i < 2 ? i ^ comm_op_flip : i,
                        num_components, swizzle, state)) {
          matched = false;
@@ -720,7 +721,7 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
       state.comm_op_direction = comb;
       state.variables_seen = 0;
 
-      if (match_expression(search, instr,
+      if (match_expression(table, search, instr,
                            instr->dest.dest.ssa.num_components,
                            swizzle, &state)) {
          found = true;
diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h
index 88f38e9d663..c0bb6129bb6 100644
--- a/src/compiler/nir/nir_search.h
+++ b/src/compiler/nir/nir_search.h
@@ -160,13 +160,13 @@ typedef struct {
    /* Index in table->values[] for the expression operands */
    uint16_t srcs[4];
 
-   /** Optional condition fxn ptr
+   /** Optional table->expression_cond[] fxn ptr index
     *
     * This allows additional constraints on expression matching, it is
     * typically used to match an expressions uses such as the number of times
     * the expression is used, and whether its used by an if.
     */
-   bool (*cond)(nir_alu_instr *instr);
+   int16_t cond_index;
 } nir_search_expression;
 
 struct per_op_table {
@@ -189,12 +189,20 @@ typedef union {
    nir_search_expression expression;
 } nir_search_value_union;
 
+typedef bool (*nir_search_expression_cond)(nir_alu_instr *instr);
+
 /* Generated data table for an algebraic optimization pass. */
 typedef struct {
    const struct transform **transforms;
    const uint16_t *transform_counts;
    const struct per_op_table *pass_op_table;
    const nir_search_value_union *values;
+
+   /**
+    * Array of condition functions for expressions, referenced by
+    * nir_search_expression->cond.
+    */
+   const nir_search_expression_cond *expression_cond;
 } nir_algebraic_table;
 
 /* Note: these must match the start states created in
diff --git a/src/compiler/nir/tests/algebraic_parser_test.py b/src/compiler/nir/tests/algebraic_parser_test.py
index d96da7db519..39b5d691a3a 100644
--- a/src/compiler/nir/tests/algebraic_parser_test.py
+++ b/src/compiler/nir/tests/algebraic_parser_test.py
@@ -26,7 +26,7 @@ import sys
 import os
 sys.path.insert(1, os.path.join(sys.path[0], '..'))
 
-from nir_algebraic import SearchAndReplace
+from nir_algebraic import SearchAndReplace, AlgebraicPass
 
 # These tests check that the bitsize validator correctly rejects various
 # different kinds of malformed expressions, and documents what the error
@@ -40,9 +40,11 @@ class ValidatorTests(unittest.TestCase):
     pattern = ()
     message = ''
 
+    algebraic_pass = AlgebraicPass("test", [])
+
     def common(self, pattern, message):
         with self.assertRaises(AssertionError) as context:
-            SearchAndReplace(pattern)
+            SearchAndReplace(pattern, self.algebraic_pass)
 
         self.assertEqual(message, str(context.exception))
 



More information about the mesa-commit mailing list