[Mesa-dev] [PATCH 08/10] nir: add a loop unrolling pass

Timothy Arceri timothy.arceri at collabora.com
Thu Sep 15 07:03:19 UTC 2016


V2:
- tidy ups suggested by Connor.
- tidy up cloning logic and handle copy propagation
 based of suggestion by Connor.
- use nir_ssa_def_rewrite_uses to fix up lcssa phis
  suggested by Connor.
- add support for complex loop unrolling (two terminators)
- handle case were the ssa defs use outside the loop is already a phi
- support unrolling loops with multiple terminators when trip count
  is know for each terminator
---
 src/compiler/Makefile.sources          |   1 +
 src/compiler/nir/nir.h                 |   2 +
 src/compiler/nir/nir_opt_loop_unroll.c | 820 +++++++++++++++++++++++++++++++++
 3 files changed, 823 insertions(+)
 create mode 100644 src/compiler/nir/nir_opt_loop_unroll.c

diff --git a/src/compiler/Makefile.sources b/src/compiler/Makefile.sources
index 8ef6080..b3512bb 100644
--- a/src/compiler/Makefile.sources
+++ b/src/compiler/Makefile.sources
@@ -233,6 +233,7 @@ NIR_FILES = \
 	nir/nir_opt_dead_cf.c \
 	nir/nir_opt_gcm.c \
 	nir/nir_opt_global_to_local.c \
+	nir/nir_opt_loop_unroll.c \
 	nir/nir_opt_peephole_select.c \
 	nir/nir_opt_remove_phis.c \
 	nir/nir_opt_undef.c \
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 9887432..0513d81 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -2661,6 +2661,8 @@ bool nir_opt_dead_cf(nir_shader *shader);
 
 bool nir_opt_gcm(nir_shader *shader, bool value_number);
 
+bool nir_opt_loop_unroll(nir_shader *shader, nir_variable_mode indirect_mask);
+
 bool nir_opt_peephole_select(nir_shader *shader);
 
 bool nir_opt_remove_phis(nir_shader *shader);
diff --git a/src/compiler/nir/nir_opt_loop_unroll.c b/src/compiler/nir/nir_opt_loop_unroll.c
new file mode 100644
index 0000000..1de02f6
--- /dev/null
+++ b/src/compiler/nir/nir_opt_loop_unroll.c
@@ -0,0 +1,820 @@
+/*
+ * Copyright © 2016 Intel Corporation
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the "Software"),
+ * to deal in the Software without restriction, including without limitation
+ * the rights to use, copy, modify, merge, publish, distribute, sublicense,
+ * and/or sell copies of the Software, and to permit persons to whom the
+ * Software is furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice (including the next
+ * paragraph) shall be included in all copies or substantial portions of the
+ * Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
+ * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ * DEALINGS IN THE SOFTWARE.
+ */
+
+#include "nir.h"
+#include "nir_builder.h"
+#include "nir_control_flow.h"
+
+static void
+extract_loop_body(nir_cf_list *extracted, nir_cf_node *node)
+{
+   nir_cf_node *end = node;
+   while (!nir_cf_node_is_last(end))
+      end = nir_cf_node_next(end);
+
+   nir_cf_extract(extracted, nir_before_cf_node(node),
+                  nir_after_cf_node(end));
+}
+
+static void
+clone_list(nir_shader *ns, nir_loop *loop, nir_cf_list *src_cf_list,
+           nir_cf_list *cloned_cf_list, struct hash_table *remap_table)
+{
+   /* Dest list needs to at least have one block */
+   nir_block *nblk = nir_block_create(ns);
+   nblk->cf_node.parent = loop->cf_node.parent;
+   exec_list_push_tail(&cloned_cf_list->list, &nblk->cf_node.node);
+
+   nir_clone_loop_list(&cloned_cf_list->list, &src_cf_list->list,
+                       remap_table, ns);
+}
+
+static void
+move_cf_list_into_if(nir_cf_list *lst, nir_cf_node *if_node,
+                     nir_cf_node *last_node, bool continue_from_then_branch)
+{
+   nir_if *if_stmt = nir_cf_node_as_if(if_node);
+   if (continue_from_then_branch) {
+      /* Move the rest of the loop inside the then */
+      nir_cf_reinsert(lst, nir_after_cf_node(nir_if_last_then_node(if_stmt)));
+   } else {
+      /* Move the rest of the loop inside the else */
+      nir_cf_reinsert(lst, nir_after_cf_node(nir_if_last_else_node(if_stmt)));
+   }
+
+   /* Remove the break */
+   nir_instr_remove(nir_block_last_instr(nir_cf_node_as_block(last_node)));
+}
+
+static bool
+is_phi_src_phi_from_loop_header(nir_ssa_def *def, nir_ssa_def *src)
+{
+   return def->parent_instr->type == nir_instr_type_phi &&
+      src->parent_instr->type == nir_instr_type_phi &&
+      nir_instr_as_phi(def->parent_instr)->instr.block->index ==
+      nir_instr_as_phi(src->parent_instr)->instr.block->index;
+}
+
+static void
+get_table_of_lcssa_and_loop_term_phis(nir_cf_node *loop,
+                                      struct hash_table **lcssa_phis,
+                                      struct hash_table **loop_term_phis,
+                                      nir_if *loop_term_if)
+{
+   *lcssa_phis = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
+                                         _mesa_key_pointer_equal);
+   *loop_term_phis = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
+                                             _mesa_key_pointer_equal);
+
+   nir_cf_node *cf_node = nir_cf_node_next(loop);
+   nir_block *block = nir_cf_node_as_block(cf_node);
+   nir_foreach_instr(instr, block) {
+      if (instr->type == nir_instr_type_phi) {
+         nir_phi_instr *phi = nir_instr_as_phi(instr);
+
+         nir_foreach_phi_src(src, phi) {
+            nir_block *then_blk =
+               nir_cf_node_as_block(nir_if_last_then_node(loop_term_if));
+            nir_block *else_blk =
+               nir_cf_node_as_block(nir_if_last_else_node(loop_term_if));
+
+            if (src->pred == then_blk || src->pred == else_blk) {
+               _mesa_hash_table_insert(*loop_term_phis, phi, src->src.ssa);
+            } else {
+               _mesa_hash_table_insert(*lcssa_phis, phi, src->src.ssa);
+            }
+         }
+      } else {
+         /* There should be no more phis */
+         break;
+      }
+   }
+}
+
+static void
+create_remap_tables(nir_loop *loop, nir_block *loop_header_blk,
+                    struct hash_table **remap_table,
+                    struct hash_table **phi_remap,
+                    struct hash_table **src_before_loop,
+                    struct hash_table **src_after_loop)
+{
+   *remap_table = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
+                                          _mesa_key_pointer_equal);
+   *phi_remap = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
+                                        _mesa_key_pointer_equal);
+   *src_before_loop = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
+                                              _mesa_key_pointer_equal);
+   *src_after_loop = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
+                                             _mesa_key_pointer_equal);
+
+   /* Build hash tables used for remapping as we unroll. */
+   nir_foreach_instr(instr, loop_header_blk) {
+      if (instr->type != nir_instr_type_phi)
+         break;
+
+      nir_phi_instr *phi = nir_instr_as_phi(instr);
+      nir_foreach_phi_src(src, phi) {
+         /* Is the pred from the block itself? */
+         if (src->pred->index > phi->instr.block->index &&
+             src->pred->cf_node.parent == &loop->cf_node) {
+
+            _mesa_hash_table_insert(*phi_remap, &phi->dest.ssa, src->src.ssa);
+            _mesa_hash_table_insert(*src_after_loop, &phi->dest.ssa,
+                                    src->src.ssa);
+         } else {
+            _mesa_hash_table_insert(*remap_table, &phi->dest.ssa,
+                                    src->src.ssa);
+            _mesa_hash_table_insert(*src_before_loop, &phi->dest.ssa,
+                                    src->src.ssa);
+         }
+      }
+   }
+}
+
+static void
+update_remap_tables(bool is_first_iteration, struct hash_table *remap_table,
+                    struct hash_table *phi_remap,
+                    struct hash_table *src_before_loop,
+                    struct hash_table *src_after_loop)
+{
+   struct hash_entry *phi_hte;
+   hash_table_foreach(phi_remap, phi_hte) {
+      struct hash_entry *remap_hte =
+         _mesa_hash_table_search(remap_table, phi_hte->data);
+
+      nir_ssa_def *phi_def = (nir_ssa_def *) phi_hte->key;
+      nir_ssa_def *phi_src = (nir_ssa_def *) phi_hte->data;
+
+      if (!remap_hte && is_first_iteration) {
+         _mesa_hash_table_insert(remap_table, phi_hte->key, phi_hte->data);
+         continue;
+      }
+
+      if (is_phi_src_phi_from_loop_header(phi_def, phi_src)) {
+          /* After copy propagation we can end up with phis inside loops
+           * that look like this:
+           *
+           *    vec1 32 ssa_14 = phi block_0: ssa_9, block_4: ssa_13
+           *    vec1 32 ssa_13 = phi block_0: ssa_8, block_4: ssa_12
+           *    vec1 32 ssa_12 = phi block_0: ssa_7, block_4: ssa_11
+           *    vec1 32 ssa_11 = phi block_0: ssa_6, block_4: ssa_14
+           *
+           * For each iteration of the loop we need to update the phi and
+           * cloning remap tables so that we use the correct src for the
+           * next iteration.
+           */
+          struct hash_entry *sbl_hte =
+             _mesa_hash_table_search(src_before_loop, phi_hte->data);
+          _mesa_hash_table_insert(remap_table, phi_hte->key, sbl_hte->data);
+
+          struct hash_entry *sal_hte =
+             _mesa_hash_table_search(src_after_loop, phi_hte->data);
+          phi_hte->data = sal_hte->data;
+      } else if (remap_hte) {
+          _mesa_hash_table_insert(remap_table, phi_hte->key, remap_hte->data);
+      }
+   }
+}
+
+static void
+insert_phi_and_set_block_on_uses(nir_builder *b, nir_phi_instr *phi_instr)
+{
+   nir_instr_insert(b->cursor, &phi_instr->instr);
+
+   /* Now that we have inserted the phi fix up the block for its uses. */
+   nir_foreach_use_safe(use_src, &phi_instr->dest.ssa) {
+      nir_phi_instr *use_phi = nir_instr_as_phi(use_src->parent_instr);
+
+      foreach_list_typed(nir_phi_src, src, node, &use_phi->srcs) {
+         if (!src->pred)
+            src->pred = phi_instr->dest.ssa.parent_instr->block;
+      }
+   }
+}
+
+static nir_phi_instr *
+create_complex_unroll_phi(nir_shader *ns, nir_phi_instr *prev_phi_instr)
+{
+   nir_phi_instr *new_phi = nir_phi_instr_create(ns);
+   nir_ssa_dest_init(&new_phi->instr, &new_phi->dest, 1,
+                     prev_phi_instr->dest.ssa.bit_size, NULL);
+
+   /* Add the new phi as a src to the phi from the previous iteration */
+   nir_phi_src *new_src = ralloc(prev_phi_instr, nir_phi_src);
+   new_src->src = nir_src_for_ssa(&new_phi->dest.ssa);
+   new_src->src.parent_instr = &prev_phi_instr->instr;
+   exec_list_push_tail(&prev_phi_instr->srcs, &new_src->node);
+   list_addtail(&new_src->src.use_link, &new_src->src.ssa->uses);
+
+   return new_phi;
+}
+
+static void
+add_complex_unroll_phi_src(nir_ssa_def *phi_src, nir_phi_instr *phi_instr,
+                           struct hash_table *remap_table, nir_block *blk)
+{
+   struct hash_entry *hte =
+      _mesa_hash_table_search(remap_table, phi_src);
+
+   nir_phi_src *new_src = ralloc(phi_instr, nir_phi_src);
+   nir_ssa_def *ssa_def = hte ? (nir_ssa_def *) hte->data : phi_src;
+   new_src->pred = blk;
+   new_src->src = nir_src_for_ssa(ssa_def);
+   new_src->src.parent_instr = &phi_instr->instr;
+   list_addtail(&new_src->src.use_link, &new_src->src.ssa->uses);
+
+   exec_list_push_tail(&phi_instr->srcs, &new_src->node);
+}
+
+static void
+simple_loop_fix_lcssa_phis(nir_cf_node *loop, struct hash_table *remap_table)
+{
+   nir_block *prev_block = nir_cf_node_as_block(nir_cf_node_prev(loop));
+   nir_cf_node *cf_node = nir_cf_node_next(loop);
+   assert(cf_node->type == nir_cf_node_block);
+
+   nir_block *block = nir_cf_node_as_block(cf_node);
+   nir_foreach_instr_safe(instr, block) {
+      if (instr->type == nir_instr_type_phi) {
+         nir_phi_instr *phi = nir_instr_as_phi(instr);
+
+         nir_foreach_phi_src_safe(src, phi) {
+            /* Update predecessor */
+            src->pred = prev_block;
+
+            /* Update src */
+            struct hash_entry *hte =
+               _mesa_hash_table_search(remap_table, src->src.ssa);
+            assert(hte || !phi->is_lcssa_phi);
+            if (hte) {
+               nir_src new_src = nir_src_for_ssa((nir_ssa_def *) hte->data);
+               if (phi->is_lcssa_phi || exec_list_length(&phi->srcs) == 1) {
+                  nir_ssa_def_rewrite_uses(&phi->dest.ssa, new_src);
+               } else {
+                  nir_instr_rewrite_src(instr, &src->src, new_src);
+               }
+            } else {
+               /* If a non lcssa phi now only has 1 src rewrite its uses here.
+                * This avoids the src getting rewritten to an undefined def,
+                * which appears to be done in nir_cf_node_remove() when
+                * removing the loop.
+                */
+               if (exec_list_length(&phi->srcs) == 1) {
+                  struct exec_node *head = exec_list_get_head(&phi->srcs);
+                  nir_phi_src *phi_src = exec_node_data(nir_phi_src, head, node);
+                  nir_ssa_def_rewrite_uses(&phi->dest.ssa, phi_src->src);
+               }
+            }
+         }
+         if (phi->is_lcssa_phi || exec_list_length(&phi->srcs) == 1)
+            nir_instr_remove(&phi->instr);
+      } else {
+         /* There should be no more LCSSA-phis */
+         break;
+      }
+   }
+}
+
+static bool
+ends_in_break(nir_block *block)
+{
+   if (exec_list_is_empty(&block->instr_list))
+      return false;
+
+   nir_instr *instr = nir_block_last_instr(block);
+   return instr->type == nir_instr_type_jump &&
+      nir_instr_as_jump(instr)->type == nir_jump_break;
+}
+
+/**
+ * Unroll a loop which does not contain any jumps.  For example, if the input
+ * is:
+ *
+ *     (loop (...) ...instrs...)
+ *
+ * And the iteration count is 3, the output will be:
+ *
+ *     ...instrs... ...instrs... ...instrs...
+ */
+static void
+simple_unroll(nir_function *fn, nir_loop *loop, nir_builder *b)
+{
+   nir_shader *ns = fn->shader;
+
+   /* Get the loop header this contains a bunch of phis and the loops
+    * conditional.
+    */
+   nir_cf_node *lp_header_cf_node = nir_loop_first_cf_node(loop);
+   nir_block *loop_header_blk = nir_cf_node_as_block(lp_header_cf_node);
+
+   struct hash_table *remap_table;
+   struct hash_table *phi_remap;
+   struct hash_table *src_before_loop;
+   struct hash_table *src_after_loop;
+   create_remap_tables(loop, loop_header_blk, &remap_table, &phi_remap,
+                       &src_before_loop, &src_after_loop);
+
+   /* Skip over loop terminator and get the loop body. */
+   nir_cf_node *if_node = &loop->info->limiting_terminator->nif->cf_node;
+   list_for_each_entry(nir_loop_terminator, terminator,
+                       &loop->info->loop_terminator_list, loop_terminator_link) {
+       nir_cf_node *loop_node = &terminator->nif->cf_node;
+
+      /* Remove all but the limiting terminator as we know the other exit
+       * conditions can never be met.
+       */
+      if (loop_node != &loop->info->limiting_terminator->nif->cf_node) {
+         nir_cf_node_remove(loop_node);
+      }
+   }
+
+   nir_cf_node *cf_node = nir_cf_node_next(if_node);
+
+   /* Pluck out the loop header */
+   nir_cf_list lp_header;
+   nir_cf_extract(&lp_header, nir_before_cf_node(lp_header_cf_node),
+                  nir_before_cf_node(if_node));
+
+   /* Pluck out the loop body */
+   nir_cf_list loop_body;
+   extract_loop_body(&loop_body, cf_node);
+
+   /* Clone the loop header */
+   nir_cf_list cloned_header;
+   exec_list_make_empty(&cloned_header.list);
+   cloned_header.impl = loop_body.impl;
+
+   clone_list(ns, loop, &lp_header, &cloned_header, remap_table);
+
+   /* Insert cloned loop header before the loop */
+   b->cursor = nir_before_cf_node(&loop->cf_node);
+   nir_cf_reinsert(&cloned_header, b->cursor);
+
+   /* Create temp block to store the cloned loop body as we unroll */
+   nir_cf_list unrolled_lp_body;
+   exec_list_make_empty(&unrolled_lp_body.list);
+   unrolled_lp_body.impl = loop_body.impl;
+
+   /* Clone loop header and append to the loop body */
+   for (unsigned i = 0; i < loop->info->trip_count; i++) {
+      /* Clone loop body */
+      clone_list(ns, loop, &loop_body, &unrolled_lp_body, remap_table);
+
+      update_remap_tables(i == 0, remap_table, phi_remap, src_before_loop,
+                          src_after_loop);
+
+      /* Insert unrolled loop body before the loop */
+      b->cursor = nir_before_cf_node(&loop->cf_node);
+      nir_cf_reinsert(&unrolled_lp_body, b->cursor);
+
+      /* Clone loop header */
+      clone_list(ns, loop, &lp_header, &cloned_header, remap_table);
+
+      /* Insert loop header after loop body */
+      b->cursor = nir_before_cf_node(&loop->cf_node);
+      nir_cf_reinsert(&cloned_header, b->cursor);
+   }
+
+   /* The loop has been unrolled so remove it. */
+   simple_loop_fix_lcssa_phis(&loop->cf_node, remap_table);
+
+   /* Remove the loop */
+   nir_cf_node_remove(&loop->cf_node);
+
+   /* Delete the original loop body & header */
+   nir_cf_delete(&lp_header);
+   nir_cf_delete(&loop_body);
+
+   _mesa_hash_table_destroy(remap_table, NULL);
+   _mesa_hash_table_destroy(phi_remap, NULL);
+   _mesa_hash_table_destroy(src_before_loop, NULL);
+   _mesa_hash_table_destroy(src_after_loop, NULL);
+}
+
+/**
+ * Unroll a loop with two exists when the trip count of one of the exits is
+ * unknown.  If continue_from_then_branch is true, the loop is repeated only
+ * when the "then" branch of the if is taken; otherwise it is repeated only
+ * when the "else" branch of the if is taken.
+ *
+ * For example, if the input is:
+ *
+ *     (loop (...)
+ *      ...body...
+ *      (if (cond)
+ *          (...then_instrs...)
+ *        (...else_instrs...)))
+ *
+ * And the iteration count is 3, and \c continue_from_then_branch is true,
+ * then the output will be:
+ *
+ *     ...body...
+ *     (if (cond)
+ *         (...then_instrs...
+ *          ...body...
+ *          (if (cond)
+ *              (...then_instrs...
+ *               ...body...
+ *               (if (cond)
+ *                   (...then_instrs...)
+ *                 (...else_instrs...)))
+ *            (...else_instrs...)))
+ *       (...else_instrs))
+ */
+static void
+complex_unroll(nir_function *fn, nir_loop *loop, nir_builder *b,
+               nir_cf_node *if_node, nir_cf_node *last_node,
+               bool continue_from_then_branch, bool limiting_term_second)
+{
+   nir_cf_node *limiting_trm = &loop->info->limiting_terminator->nif->cf_node;
+   nir_cf_node *lp_header_cf_node = nir_loop_first_cf_node(loop);
+   nir_block *loop_header_blk = nir_cf_node_as_block(lp_header_cf_node);
+
+   struct hash_table *remap_table;
+   struct hash_table *phi_remap;
+   struct hash_table *src_before_loop;
+   struct hash_table *src_after_loop;
+   create_remap_tables(loop, loop_header_blk, &remap_table, &phi_remap,
+                       &src_before_loop, &src_after_loop);
+
+   struct hash_table *loop_phis;
+   struct hash_table *loop_term_phis;
+   get_table_of_lcssa_and_loop_term_phis(&loop->cf_node, &loop_phis,
+                                         &loop_term_phis,
+                                         loop->info->limiting_terminator->nif);
+
+   if (limiting_term_second) {
+      /* We need some special handling when its the second terminator causing
+       * us to exit the loop for example:
+       *
+       *   for (int i = 0; i < uniform_lp_count; i++) {
+       *      colour = vec4(0.0, 1.0, 0.0, 1.0);
+       *
+       *      if (i == 1)
+       *         break;
+       *      }
+       *      ... any further code is unreachable after i == 1 ...
+       *   }
+       *
+       * Bump the trip count by one so we actually clone something. Also
+       * extract everything after the limiting terminator and insert it into
+       * the branch we will continue from.
+       */
+      loop->info->trip_count++;
+
+      nir_cf_list after_lt;
+      extract_loop_body(&after_lt, nir_cf_node_next(limiting_trm));
+
+      nir_if *if_stmt = loop->info->limiting_terminator->nif;
+      nir_cf_node *last_then = nir_if_last_then_node(if_stmt);
+      if (last_then->type == nir_cf_node_block &&
+          ends_in_break(nir_cf_node_as_block(last_then))) {
+         move_cf_list_into_if(&after_lt, limiting_trm, last_then, false);
+      } else {
+         nir_cf_node *last_else = nir_if_last_else_node(if_stmt);
+         if (last_else->type == nir_cf_node_block &&
+             ends_in_break(nir_cf_node_as_block(last_else))) {
+            move_cf_list_into_if(&after_lt, limiting_trm, last_else, true);
+         }
+      }
+   } else {
+      /* Remove the limiting terminator.  Loop analysis will only find a
+       * terminator for trival if statments (then only contains break, else
+       * is empty) so its safe to remove the whole thing.
+       */
+      nir_cf_node_remove(limiting_trm);
+   }
+
+   nir_shader *ns = fn->shader;
+   struct hash_table *lcssa_phis =
+      _mesa_hash_table_create(NULL, _mesa_hash_pointer,
+                              _mesa_key_pointer_equal);
+
+   /* Create phis to be used post-if (replacements for the post-loop phis) */
+   struct hash_entry *phi_hte;
+   hash_table_foreach(loop_phis, phi_hte) {
+      nir_phi_instr *phi_instr = (nir_phi_instr *) phi_hte->key;
+      nir_phi_instr *new_phi = create_complex_unroll_phi(ns, phi_instr);
+
+      nir_ssa_def *ssa_def = (nir_ssa_def *) phi_hte->data;
+      _mesa_hash_table_insert(lcssa_phis, new_phi, ssa_def);
+
+      /* Update loop_phis to point to the replacement phi */
+      phi_hte->data = &new_phi->dest.ssa;
+
+      struct hash_entry *loop_term_hte =
+         _mesa_hash_table_search(loop_term_phis, phi_hte->key);
+      if (loop_term_hte) {
+         _mesa_hash_table_insert(loop_term_phis, new_phi, loop_term_hte->data);
+         _mesa_hash_table_remove(loop_term_phis, loop_term_hte);
+      }
+   }
+
+   /* Move everything after the terminator we don't have a trip count for
+    * inside the if.
+    */
+   nir_cf_list loop_end;
+   extract_loop_body(&loop_end, nir_cf_node_next(if_node));
+   nir_if *if_stmt = nir_cf_node_as_if(if_node);
+   move_cf_list_into_if(&loop_end, if_node, last_node,
+                        continue_from_then_branch);
+
+   /* Pluck out the loop body. Unlike the simple unroll pass there are no
+    * breaks remaining in the loop so we do not have the concept of a loop
+    * header and a loop body, instead we just extract everything.
+    */
+   nir_cf_list loop_body;
+   extract_loop_body(&loop_body, lp_header_cf_node);
+
+   /* Create temp block to store the cloned loop body as we unroll */
+   nir_cf_list unrolled_lp_body;
+   exec_list_make_empty(&unrolled_lp_body.list);
+   unrolled_lp_body.impl = loop_body.impl;
+
+   /* Set the cursor to before the loop */
+   b->cursor = nir_before_cf_node(&loop->cf_node);
+
+   nir_cf_node *continue_from_node = NULL;
+   for (unsigned i = 0; i < loop->info->trip_count; i++) {
+      /* Clone loop body */
+      clone_list(ns, loop, &loop_body, &unrolled_lp_body, remap_table);
+
+      nir_cf_node *last_node =
+         exec_node_data(nir_cf_node,
+                        exec_list_get_tail(&unrolled_lp_body.list), node);
+      assert(last_node->type == nir_cf_node_block &&
+             exec_list_is_empty(&nir_cf_node_as_block(last_node)->instr_list));
+
+      /* Insert unrolled loop body */
+      nir_cf_reinsert(&unrolled_lp_body, b->cursor);
+
+      nir_cf_node *if_node = nir_cf_node_prev(last_node);
+      assert(if_node->type == nir_cf_node_if);
+      if_stmt = nir_cf_node_as_if(if_node);
+
+      nir_cf_node *exit_from_node;
+      if (continue_from_then_branch) {
+         continue_from_node = nir_if_last_then_node(if_stmt);
+         exit_from_node = nir_if_last_else_node(if_stmt);
+      } else {
+         exit_from_node = nir_if_last_then_node(if_stmt);
+         continue_from_node = nir_if_last_else_node(if_stmt);
+      }
+
+      b->cursor = nir_after_cf_node(if_node);
+      if (i < loop->info->trip_count - 1) {
+         struct hash_table *tmp =
+            _mesa_hash_table_create(NULL, _mesa_hash_pointer,
+                                    _mesa_key_pointer_equal);
+
+         struct hash_entry *phi_hte;
+         hash_table_foreach(lcssa_phis, phi_hte) {
+            /* Insert phi created in previous iteration */
+            nir_phi_instr *phi_instr = (nir_phi_instr *) phi_hte->key;
+            insert_phi_and_set_block_on_uses(b, phi_instr);
+
+            nir_ssa_def *ssa_def = (nir_ssa_def *) phi_hte->data;
+            add_complex_unroll_phi_src(ssa_def, phi_instr, remap_table,
+                                       nir_cf_node_as_block(exit_from_node));
+
+            /* Create phi to be fixed up by next iteration */
+            nir_phi_instr *new_phi = create_complex_unroll_phi(ns, phi_instr);
+            _mesa_hash_table_insert(tmp, new_phi, ssa_def);
+
+            struct hash_entry *loop_term_hte =
+               _mesa_hash_table_search(loop_term_phis, phi_hte->key);
+            if (loop_term_hte) {
+               _mesa_hash_table_insert(loop_term_phis, new_phi,
+                                       loop_term_hte->data);
+               _mesa_hash_table_remove(loop_term_phis, loop_term_hte);
+            }
+         }
+
+         /* Now that the phis have been processed replace the table with the
+          * phis to be fixed up in the next iteration.
+          */
+         _mesa_hash_table_destroy(lcssa_phis, NULL);
+         lcssa_phis = tmp;
+      } else {
+         struct hash_entry *phi_hte;
+         hash_table_foreach(lcssa_phis, phi_hte) {
+            /* Insert phi created in previous iteration */
+            nir_phi_instr *phi_instr = (nir_phi_instr *) phi_hte->key;
+            insert_phi_and_set_block_on_uses(b, phi_instr);
+
+            nir_ssa_def *ssa_def = (nir_ssa_def *) phi_hte->data;
+            add_complex_unroll_phi_src(ssa_def, phi_instr, remap_table,
+                                       nir_cf_node_as_block(exit_from_node));
+         }
+      }
+
+      /* Ready the remap tables for the next iteration */
+      update_remap_tables(i == 0, remap_table, phi_remap, src_before_loop,
+                          src_after_loop);
+
+      /* Set the cursor to the last if in the loop body we just unrolled ready
+       * for the next iteration.
+       */
+      b->cursor = nir_after_cf_node(continue_from_node);
+   }
+
+   /* Now that the remap table is updated add the second src to the innermost
+    * phis.
+    */
+   hash_table_foreach(lcssa_phis, phi_hte) {
+      nir_phi_instr *phi_instr = (nir_phi_instr *) phi_hte->key;
+      nir_ssa_def *phi_src = (nir_ssa_def *) phi_hte->data;
+
+      assert(exec_list_length(&phi_instr->srcs) == 1);
+
+      /* Get the src for when exiting by the loop terminator */
+      struct hash_entry *loop_term_hte =
+         _mesa_hash_table_search(loop_term_phis, phi_instr);
+      if (loop_term_hte)
+         phi_src = (nir_ssa_def *) loop_term_hte->data;
+
+      add_complex_unroll_phi_src(phi_src, phi_instr, remap_table,
+                                 nir_cf_node_as_block(continue_from_node));
+   }
+
+   /* Rewrite the uses of the old loop phis */
+   hash_table_foreach(loop_phis, phi_hte) {
+      nir_phi_instr *phi_instr = (nir_phi_instr *) phi_hte->key;
+
+      nir_foreach_use_safe(use_src, &phi_instr->dest.ssa) {
+         nir_src new_src = nir_src_for_ssa((nir_ssa_def *) phi_hte->data);
+         nir_instr_rewrite_src(use_src->parent_instr, use_src, new_src);
+      }
+
+      nir_foreach_if_use_safe(use_src, &phi_instr->dest.ssa) {
+         nir_src new_src = nir_src_for_ssa((nir_ssa_def *) phi_hte->data);
+         nir_if_rewrite_condition(use_src->parent_if, new_src);
+      }
+   }
+
+   /* The loop has been unrolled so remove it. */
+   nir_cf_node_remove(&loop->cf_node);
+
+   /* Delete the original loop body */
+   nir_cf_delete(&loop_body);
+
+   _mesa_hash_table_destroy(loop_phis, NULL);
+   _mesa_hash_table_destroy(loop_term_phis, NULL);
+   _mesa_hash_table_destroy(lcssa_phis, NULL);
+   _mesa_hash_table_destroy(remap_table, NULL);
+   _mesa_hash_table_destroy(phi_remap, NULL);
+   _mesa_hash_table_destroy(src_before_loop, NULL);
+   _mesa_hash_table_destroy(src_after_loop, NULL);
+}
+
+static bool
+process_loops(nir_cf_node *cf_node, nir_builder *b, bool *innermost_loop)
+{
+   bool progress = false;
+   nir_loop *loop;
+
+   switch (cf_node->type) {
+   case nir_cf_node_block:
+      return progress;
+   case nir_cf_node_if: {
+      nir_if *if_stmt = nir_cf_node_as_if(cf_node);
+      foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->then_list)
+         progress |= process_loops(nested_node, b, innermost_loop);
+      foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->else_list)
+         progress |= process_loops(nested_node, b, innermost_loop);
+      return progress;
+   }
+   case nir_cf_node_loop: {
+      loop = nir_cf_node_as_loop(cf_node);
+      foreach_list_typed_safe(nir_cf_node, nested_node, node, &loop->body)
+         progress |= process_loops(nested_node, b, innermost_loop);
+      break;
+   }
+   default:
+      unreachable("unknown cf node type");
+   }
+
+   if (*innermost_loop) {
+      nir_function *fn = nir_cf_node_get_function(&loop->cf_node)->function;
+
+      /* Don't attempt to unroll outer loops or a second inner loop in
+       * this pass wait until the next pass as we have altered the cf.
+       */
+      *innermost_loop = false;
+
+      if (loop->info->limiting_terminator == NULL) {
+         return progress;
+      }
+
+      if (is_simple_loop(fn->shader, loop->info)) {
+         simple_unroll(fn, loop, b);
+         progress = true;
+      } else {
+         /* Attempt to unroll loops with two terminators. */
+         if (is_complex_loop(fn->shader, loop->info)) {
+            bool first_terminator = true;
+            list_for_each_entry(nir_loop_terminator, terminator,
+                                &loop->info->loop_terminator_list,
+                                loop_terminator_link) {
+
+               nir_cf_node *if_node = &terminator->nif->cf_node;
+
+               if (if_node == &loop->info->limiting_terminator->nif->cf_node) {
+                  first_terminator = false;
+                  continue;
+               }
+
+               /* If the first terminator has a trip count of zero just do a
+                * simple unroll as the second terminator can never be reached.
+                */
+               if (loop->info->trip_count == 0 && first_terminator) {
+                  simple_unroll(fn, loop, b);
+                  progress = true;
+                  break;
+               }
+
+               nir_if *if_stmt = nir_cf_node_as_if(if_node);
+
+               /* Determine which if-statement branch, if any, ends with a
+                * break. Note that since predicted_num_loop_jumps == 1, it is
+                * impossible for both branches to end with a break.
+                */
+               nir_cf_node *last_then = nir_if_last_then_node(if_stmt);
+               if (last_then->type == nir_cf_node_block &&
+                   ends_in_break(nir_cf_node_as_block(last_then))) {
+
+                  complex_unroll(fn, loop, b, if_node, last_then, false,
+                                 !first_terminator);
+
+                  progress = true;
+                  break;
+               } else {
+                  nir_cf_node *last_else = nir_if_last_else_node(if_stmt);
+                  if (last_else->type == nir_cf_node_block &&
+                      ends_in_break(nir_cf_node_as_block(last_else))) {
+
+                     complex_unroll(fn, loop, b, if_node, last_else, true,
+                                    !first_terminator);
+
+                     progress = true;
+                     break;
+                  }
+               }
+            }
+         }
+      }
+   }
+
+   return progress;
+}
+
+static bool
+nir_opt_loop_unroll_impl(nir_function_impl *impl,
+                         nir_variable_mode indirect_mask)
+{
+   bool progress = false;
+   nir_metadata_require(impl, nir_metadata_loop_analysis, indirect_mask);
+
+   nir_builder b;
+   nir_builder_init(&b, impl);
+
+   foreach_list_typed_safe(nir_cf_node, node, node, &impl->body) {
+      bool innermost_loop = true;
+      progress |= process_loops(node, &b, &innermost_loop);
+   }
+
+   return progress;
+}
+
+bool
+nir_opt_loop_unroll(nir_shader *shader, nir_variable_mode indirect_mask)
+{
+   bool progress = false;
+
+   nir_foreach_function(function, shader) {
+      if (function->impl) {
+         progress |= nir_opt_loop_unroll_impl(function->impl, indirect_mask);
+      }
+   }
+   return false;
+}
-- 
2.7.4



More information about the mesa-dev mailing list