Mesa (main): nir/algebraic: Replace relocations for nir_search values with a table.

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Tue Dec 7 07:49:39 UTC 2021


Module: Mesa
Branch: main
Commit: 5d82c61a30b390c0216ab6c70082268dc21e7630
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=5d82c61a30b390c0216ab6c70082268dc21e7630

Author: Emma Anholt <emma at anholt.net>
Date:   Mon Nov 29 15:24:47 2021 -0800

nir/algebraic: Replace relocations for nir_search values with a table.

Even with packing all 3 types into a 40-byte union (nir_search_constant
being 24 bytes and nir_search_expression having formerly been 32), and
having a single array of them, this cuts 1.7MB from each of
libvulkan_intel.so and libgallium_dri.so.

Reviewed-by: Adam Jackson <ajax at redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13987>

---

 src/compiler/nir/nir_algebraic.py | 73 +++++++++++++++++----------------------
 src/compiler/nir/nir_search.c     | 25 ++++++++------
 src/compiler/nir/nir_search.h     | 18 +++++++---
 3 files changed, 60 insertions(+), 56 deletions(-)

diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
index 1b73fc019b7..95af2471dff 100644
--- a/src/compiler/nir/nir_algebraic.py
+++ b/src/compiler/nir/nir_algebraic.py
@@ -148,22 +148,6 @@ class Value(object):
    def type_enum(self):
       return "nir_search_value_" + self.type_str
 
-   @property
-   def c_type(self):
-      return "nir_search_" + self.type_str
-
-   def __c_name(self, cache):
-      if self.name in cache:
-         return cache[self.name]
-      else:
-         return self.name
-
-   def c_value_ptr(self, cache):
-      return "&{0}.value".format(self.__c_name(cache))
-
-   def c_ptr(self, cache):
-      return "&{0}".format(self.__c_name(cache))
-
    @property
    def c_bit_size(self):
       bit_size = self.get_bit_size()
@@ -179,40 +163,42 @@ class Value(object):
          # We represent these cases with a 0 bit-size.
          return 0
 
-   __template = mako.template.Template("""{
-   { ${val.type_enum}, ${val.c_bit_size} },
+   __template = mako.template.Template("""   { .${val.type_str} = {
+      { ${val.type_enum}, ${val.c_bit_size} },
 % if isinstance(val, Constant):
-   ${val.type()}, { ${val.hex()} /* ${val.value} */ },
+      ${val.type()}, { ${val.hex()} /* ${val.value} */ },
 % elif isinstance(val, Variable):
-   ${val.index}, /* ${val.var_name} */
-   ${'true' if val.is_constant else 'false'},
-   ${val.type() or 'nir_type_invalid' },
-   ${val.cond if val.cond else 'NULL'},
-   ${val.swizzle()},
+      ${val.index}, /* ${val.var_name} */
+      ${'true' if val.is_constant else 'false'},
+      ${val.type() or 'nir_type_invalid' },
+      ${val.cond if val.cond else 'NULL'},
+      ${val.swizzle()},
 % elif isinstance(val, Expression):
-   ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
-   ${val.comm_expr_idx}, ${val.comm_exprs},
-   ${val.c_opcode()},
-   { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
-   ${val.cond if val.cond else 'NULL'},
+      ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
+      ${val.comm_expr_idx}, ${val.comm_exprs},
+      ${val.c_opcode()},
+      { ${', '.join(src.array_index for src in val.sources)} },
+      ${val.cond if val.cond else 'NULL'},
 % endif
-};""")
+   } },
+""")
 
    def render(self, cache):
-      struct_init = self.__template.render(val=self, cache=cache,
+      struct_init = self.__template.render(val=self,
                                            Constant=Constant,
                                            Variable=Variable,
                                            Expression=Expression)
       if struct_init in cache:
          # If it's in the cache, register a name remap in the cache and render
          # only a comment saying it's been remapped
-         cache[self.name] = cache[struct_init]
-         return "/* {} -> {} in the cache */\n".format(self.name,
+         self.array_index = cache[struct_init]
+         return "   /* {} -> {} in the cache */\n".format(self.name,
                                                        cache[struct_init])
       else:
-         cache[struct_init] = self.name
-         return "static const {} {} = {}\n".format(self.c_type, self.name,
-                                                   struct_init)
+         self.array_index = str(cache["next_index"])
+         cache[struct_init] = self.array_index
+         cache["next_index"] += 1
+         return struct_init
 
 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
 
@@ -433,7 +419,7 @@ class Expression(Value):
       return get_c_opcode(self.opcode)
 
    def render(self, cache):
-      srcs = "\n".join(src.render(cache) for src in self.sources)
+      srcs = "".join(src.render(cache) for src in self.sources)
       return srcs + super(Expression, self).render(cache)
 
 class BitSizeValidator(object):
@@ -1045,17 +1031,20 @@ _algebraic_pass_template = mako.template.Template("""
 % endfor
  */
 
-<% cache = {} %>
+<% cache = {"next_index": 0} %>
+static const nir_search_value_union ${pass_name}_values[] = {
 % for xform in xforms:
-   ${xform.search.render(cache)}
-   ${xform.replace.render(cache)}
+   /* ${xform.search} => ${xform.replace} */
+${xform.search.render(cache)}
+${xform.replace.render(cache)}
 % endfor
+};
 
 % for state_id, state_xforms in enumerate(automaton.state_patterns):
 % if state_xforms: # avoid emitting a 0-length array for MSVC
 static const struct transform ${pass_name}_state${state_id}_xforms[] = {
 % for i in state_xforms:
-  { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
+  { ${xforms[i].search.array_index}, ${xforms[i].replace.array_index}, ${xforms[i].condition_index} },
 % endfor
 };
 % endif
@@ -1109,6 +1098,7 @@ static const nir_algebraic_table ${pass_name}_table = {
    .transforms = ${pass_name}_transforms,
    .transform_counts = ${pass_name}_transform_counts,
    .pass_op_table = ${pass_name}_pass_op_table,
+   .values = ${pass_name}_values,
 };
 
 bool
@@ -1121,6 +1111,7 @@ ${pass_name}(nir_shader *shader)
    (void) options;
    (void) info;
 
+   STATIC_ASSERT(${str(cache["next_index"])} == ARRAY_SIZE(${pass_name}_values));
    % for index, condition in enumerate(condition_list):
    condition_flags[${index}] = ${condition};
    % endfor
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index e76853f38dd..9d1061cafb0 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -43,6 +43,7 @@ struct match_state {
    /* Used for running the automaton on newly-constructed instructions. */
    struct util_dynarray *states;
    const struct per_op_table *pass_op_table;
+   const nir_algebraic_table *table;
 
    nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
    struct hash_table *range_ht;
@@ -440,7 +441,7 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
       /* 2src_commutative instructions that have 3 sources are only commutative
        * in the first two sources.  Source 2 is always source 2.
        */
-      if (!match_value(expr->srcs[i], instr,
+      if (!match_value(&state->table->values[expr->srcs[i]].value, instr,
                        i < 2 ? i ^ comm_op_flip : i,
                        num_components, swizzle, state)) {
          matched = false;
@@ -498,7 +499,7 @@ construct_value(nir_builder *build,
          if (nir_op_infos[alu->op].input_sizes[i] != 0)
             num_components = nir_op_infos[alu->op].input_sizes[i];
 
-         alu->src[i] = construct_value(build, expr->srcs[i],
+         alu->src[i] = construct_value(build, &state->table->values[expr->srcs[i]].value,
                                        num_components, search_bitsize,
                                        state, instr);
       }
@@ -576,7 +577,7 @@ construct_value(nir_builder *build,
    }
 }
 
-UNUSED static void dump_value(const nir_search_value *val)
+UNUSED static void dump_value(const nir_algebraic_table *table, const nir_search_value *val)
 {
    switch (val->type) {
    case nir_search_value_constant: {
@@ -634,7 +635,7 @@ UNUSED static void dump_value(const nir_search_value *val)
 
       for (unsigned i = 0; i < num_srcs; i++) {
          fprintf(stderr, " ");
-         dump_value(expr->srcs[i]);
+         dump_value(table, &table->values[expr->srcs[i]].value);
       }
 
       fprintf(stderr, ")");
@@ -687,7 +688,7 @@ nir_ssa_def *
 nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
                   struct hash_table *range_ht,
                   struct util_dynarray *states,
-                  const struct per_op_table *pass_op_table,
+                  const nir_algebraic_table *table,
                   const nir_search_expression *search,
                   const nir_search_value *replace,
                   nir_instr_worklist *algebraic_worklist)
@@ -703,7 +704,8 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
    state.inexact_match = false;
    state.has_exact_alu = false;
    state.range_ht = range_ht;
-   state.pass_op_table = pass_op_table;
+   state.pass_op_table = table->pass_op_table;
+   state.table = table;
 
    STATIC_ASSERT(sizeof(state.comm_op_direction) * 8 >= NIR_SEARCH_MAX_COMM_OPS);
 
@@ -786,7 +788,7 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
       nir_mov_alu(build, val, instr->dest.dest.ssa.num_components);
    if (ssa_val->index == util_dynarray_num_elements(states, uint16_t)) {
       util_dynarray_append(states, uint16_t, 0);
-      nir_algebraic_automaton(ssa_val->parent_instr, states, pass_op_table);
+      nir_algebraic_automaton(ssa_val->parent_instr, states, table->pass_op_table);
    }
 
    /* Rewrite the uses of the old SSA value to the new one, and recurse
@@ -794,7 +796,7 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
     */
    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, ssa_val);
    nir_algebraic_update_automaton(ssa_val->parent_instr, algebraic_worklist,
-                                  states, pass_op_table);
+                                  states, table->pass_op_table);
 
    /* Nothing uses the instr any more, so drop it out of the program.  Note
     * that the instr may be in the worklist still, so we can't free it
@@ -883,9 +885,10 @@ nir_algebraic_instr(nir_builder *build, nir_instr *instr,
    for (uint16_t i = 0; i < table->transform_counts[xform_idx]; i++) {
       const struct transform *xform = &table->transforms[xform_idx][i];
       if (condition_flags[xform->condition_offset] &&
-          !(xform->search->inexact && ignore_inexact) &&
-          nir_replace_instr(build, alu, range_ht, states, table->pass_op_table,
-                            xform->search, xform->replace, worklist)) {
+          !(table->values[xform->search].expression.inexact && ignore_inexact) &&
+          nir_replace_instr(build, alu, range_ht, states, table,
+                            &table->values[xform->search].expression,
+                            &table->values[xform->replace].value, worklist)) {
          _mesa_hash_table_clear(range_ht, NULL);
          return true;
       }
diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h
index 3699e215fb5..88f38e9d663 100644
--- a/src/compiler/nir/nir_search.h
+++ b/src/compiler/nir/nir_search.h
@@ -157,7 +157,8 @@ typedef struct {
 
    /* One of nir_op or nir_search_op */
    uint16_t opcode;
-   const nir_search_value *srcs[4];
+   /* Index in table->values[] for the expression operands */
+   uint16_t srcs[4];
 
    /** Optional condition fxn ptr
     *
@@ -175,16 +176,25 @@ struct per_op_table {
 };
 
 struct transform {
-   const nir_search_expression *search;
-   const nir_search_value *replace;
+   uint16_t search; /* Index in table->values[] for the search expression. */
+   uint16_t replace; /* Index in table->values[] for the replace value. */
    unsigned condition_offset;
 };
 
+typedef union {
+   nir_search_value value; /* base type of the union, first element of each variant struct */
+
+   nir_search_constant constant;
+   nir_search_variable variable;
+   nir_search_expression expression;
+} nir_search_value_union;
+
 /* Generated data table for an algebraic optimization pass. */
 typedef struct {
    const struct transform **transforms;
    const uint16_t *transform_counts;
    const struct per_op_table *pass_op_table;
+   const nir_search_value_union *values;
 } nir_algebraic_table;
 
 /* Note: these must match the start states created in
@@ -208,7 +218,7 @@ nir_ssa_def *
 nir_replace_instr(struct nir_builder *b, nir_alu_instr *instr,
                   struct hash_table *range_ht,
                   struct util_dynarray *states,
-                  const struct per_op_table *pass_op_table,
+                  const nir_algebraic_table *table,
                   const nir_search_expression *search,
                   const nir_search_value *replace,
                   nir_instr_worklist *algebraic_worklist);



More information about the mesa-commit mailing list