Mesa (master): nir/loop_analyze: Properly handle swizzles in loop conditions

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Jul 10 00:30:19 UTC 2019


Module: Mesa
Branch: master
Commit: ff972c7a3a7e80a426b72f285902d35f6ca3b820
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=ff972c7a3a7e80a426b72f285902d35f6ca3b820

Author: Jason Ekstrand <jason at jlekstrand.net>
Date:   Fri Jun 21 09:18:16 2019 -0500

nir/loop_analyze: Properly handle swizzles in loop conditions

This commit re-plumbs all of nir_loop_analyze to use nir_ssa_scalar for
all intermediate values so that we can properly handle swizzles.  Even
though if conditions are required to be scalars, they may still consume
swizzles so you could have ((a.yzw < b.zzx).xz && c.xx).y == 0 as your
loop termination condition.  The old code would just bail the moment it
saw its first non-zero swizzle but we can now properly chase the scalar
from the if condition to all the way to a, b, and c.

Shader-db results on Kaby Lake:

    total loops in shared programs: 4388 -> 4364 (-0.55%)
    loops in affected programs: 29 -> 5 (-82.76%)
    helped: 29
    HURT: 5

Shader-db results on Haswell:

    total loops in shared programs: 4370 -> 4373 (0.07%)
    loops in affected programs: 2 -> 5 (150.00%)
    helped: 2
    HURT: 5

Reviewed-by: Timothy Arceri <tarceri at itsqueeze.com>

---

 src/compiler/nir/nir_loop_analyze.c | 289 +++++++++++++++++++-----------------
 1 file changed, 149 insertions(+), 140 deletions(-)

diff --git a/src/compiler/nir/nir_loop_analyze.c b/src/compiler/nir/nir_loop_analyze.c
index c64314aa378..587cf08fa02 100644
--- a/src/compiler/nir/nir_loop_analyze.c
+++ b/src/compiler/nir/nir_loop_analyze.c
@@ -32,7 +32,10 @@ typedef enum {
    basic_induction
 } nir_loop_variable_type;
 
-struct nir_basic_induction_var;
+typedef struct nir_basic_induction_var {
+   nir_alu_instr *alu;                      /* The def of the alu-operation */
+   nir_ssa_def *def_outside_loop;           /* The phi-src outside the loop */
+} nir_basic_induction_var;
 
 typedef struct {
    /* A link for the work list */
@@ -57,13 +60,6 @@ typedef struct {
 
 } nir_loop_variable;
 
-typedef struct nir_basic_induction_var {
-   nir_op alu_op;                           /* The type of alu-operation    */
-   nir_loop_variable *alu_def;              /* The def of the alu-operation */
-   nir_loop_variable *invariant;            /* The invariant alu-operand    */
-   nir_loop_variable *def_outside_loop;     /* The phi-src outside the loop */
-} nir_basic_induction_var;
-
 typedef struct {
    /* The loop we store information for */
    nir_loop *loop;
@@ -300,6 +296,19 @@ phi_instr_as_alu(nir_phi_instr *phi)
 }
 
 static bool
+alu_src_has_identity_swizzle(nir_alu_instr *alu, unsigned src_idx)
+{
+   assert(nir_op_infos[alu->op].input_sizes[src_idx] == 0);
+   assert(alu->dest.dest.is_ssa);
+   for (unsigned i = 0; i < alu->dest.dest.ssa.num_components; i++) {
+      if (alu->src[src_idx].swizzle[i] != i)
+         return false;
+   }
+
+   return true;
+}
+
+static bool
 compute_induction_information(loop_info_state *state)
 {
    bool found_induction_var = false;
@@ -320,15 +329,10 @@ compute_induction_information(loop_info_state *state)
       if (!is_var_phi(var))
          continue;
 
-      /* We only handle scalars because none of the rest of the loop analysis
-       * code can properly handle swizzles.
-       */
-      if (var->def->num_components > 1)
-         continue;
-
       nir_phi_instr *phi = nir_instr_as_phi(var->def->parent_instr);
       nir_basic_induction_var *biv = rzalloc(state, nir_basic_induction_var);
 
+      nir_loop_variable *alu_src_var = NULL;
       nir_foreach_phi_src(src, phi) {
          nir_loop_variable *src_var = get_loop_var(src->src.ssa, state);
 
@@ -352,32 +356,36 @@ compute_induction_information(loop_info_state *state)
             }
          }
 
-         if (!src_var->in_loop) {
-            biv->def_outside_loop = src_var;
-         } else if (is_var_alu(src_var)) {
+         if (!src_var->in_loop && !biv->def_outside_loop) {
+            biv->def_outside_loop = src_var->def;
+         } else if (is_var_alu(src_var) && !biv->alu) {
+            alu_src_var = src_var;
             nir_alu_instr *alu = nir_instr_as_alu(src_var->def->parent_instr);
 
             if (nir_op_infos[alu->op].num_inputs == 2) {
-               biv->alu_def = src_var;
-               biv->alu_op = alu->op;
-
                for (unsigned i = 0; i < 2; i++) {
-                  /* Is one of the operands const, and the other the phi */
-                  if (alu->src[i].src.ssa->parent_instr->type == nir_instr_type_load_const &&
-                      alu->src[i].swizzle[0] == 0 &&
-                      alu->src[1-i].src.ssa == &phi->dest.ssa)
-                     assert(alu->src[1-i].swizzle[0] == 0);
-                     biv->invariant = get_loop_var(alu->src[i].src.ssa, state);
+                  /* Is one of the operands const, and the other the phi.  The
+                   * phi source can't be swizzled in any way.
+                   */
+                  if (nir_src_is_const(alu->src[i].src) &&
+                      alu->src[1-i].src.ssa == &phi->dest.ssa &&
+                      alu_src_has_identity_swizzle(alu, 1 - i))
+                     biv->alu = alu;
                }
             }
+
+            if (!biv->alu)
+               break;
+         } else {
+            biv->alu = NULL;
+            break;
          }
       }
 
-      if (biv->alu_def && biv->def_outside_loop && biv->invariant &&
-          is_var_constant(biv->def_outside_loop)) {
-         assert(is_var_constant(biv->invariant));
-         biv->alu_def->type = basic_induction;
-         biv->alu_def->ind = biv;
+      if (biv->alu && biv->def_outside_loop &&
+          biv->def_outside_loop->parent_instr->type == nir_instr_type_load_const) {
+         alu_src_var->type = basic_induction;
+         alu_src_var->ind = biv;
          var->type = basic_induction;
          var->ind = biv;
 
@@ -504,7 +512,7 @@ find_array_access_via_induction(loop_info_state *state,
 
 static bool
 guess_loop_limit(loop_info_state *state, nir_const_value *limit_val,
-                 nir_loop_variable *basic_ind)
+                 nir_ssa_scalar basic_ind)
 {
    unsigned min_array_size = 0;
 
@@ -525,8 +533,10 @@ guess_loop_limit(loop_info_state *state, nir_const_value *limit_val,
                find_array_access_via_induction(state,
                                                nir_src_as_deref(intrin->src[0]),
                                                &array_idx);
-            if (basic_ind == array_idx &&
+            if (array_idx && basic_ind.def == array_idx->def &&
                 (min_array_size == 0 || min_array_size > array_size)) {
+               /* Array indices are scalars */
+               assert(basic_ind.def->num_components == 1);
                min_array_size = array_size;
             }
 
@@ -537,8 +547,10 @@ guess_loop_limit(loop_info_state *state, nir_const_value *limit_val,
                find_array_access_via_induction(state,
                                                nir_src_as_deref(intrin->src[1]),
                                                &array_idx);
-            if (basic_ind == array_idx &&
+            if (array_idx && basic_ind.def == array_idx->def &&
                 (min_array_size == 0 || min_array_size > array_size)) {
+               /* Array indices are scalars */
+               assert(basic_ind.def->num_components == 1);
                min_array_size = array_size;
             }
          }
@@ -547,7 +559,7 @@ guess_loop_limit(loop_info_state *state, nir_const_value *limit_val,
 
    if (min_array_size) {
       *limit_val = nir_const_value_for_uint(min_array_size,
-                                            basic_ind->def->bit_size);
+                                            basic_ind.def->bit_size);
       return true;
    }
 
@@ -555,33 +567,22 @@ guess_loop_limit(loop_info_state *state, nir_const_value *limit_val,
 }
 
 static bool
-try_find_limit_of_alu(nir_loop_variable *limit, nir_const_value *limit_val,
+try_find_limit_of_alu(nir_ssa_scalar limit, nir_const_value *limit_val,
                       nir_loop_terminator *terminator, loop_info_state *state)
 {
-   if(!is_var_alu(limit))
+   if (!nir_ssa_scalar_is_alu(limit))
       return false;
 
-   nir_alu_instr *limit_alu = nir_instr_as_alu(limit->def->parent_instr);
-
-   if (limit_alu->op == nir_op_imin ||
-       limit_alu->op == nir_op_fmin) {
-      /* We don't handle swizzles here */
-      if (limit_alu->src[0].swizzle[0] > 0 || limit_alu->src[1].swizzle[0] > 0)
-         return false;
-
-      limit = get_loop_var(limit_alu->src[0].src.ssa, state);
-
-      if (!is_var_constant(limit))
-         limit = get_loop_var(limit_alu->src[1].src.ssa, state);
-
-      if (!is_var_constant(limit))
-         return false;
-
-      *limit_val = nir_instr_as_load_const(limit->def->parent_instr)->value[0];
-
-      terminator->exact_trip_count_unknown = true;
-
-      return true;
+   nir_op limit_op = nir_ssa_scalar_alu_op(limit);
+   if (limit_op == nir_op_imin || limit_op == nir_op_fmin) {
+      for (unsigned i = 0; i < 2; i++) {
+         nir_ssa_scalar src = nir_ssa_scalar_chase_alu_src(limit, i);
+         if (nir_ssa_scalar_is_const(src)) {
+            *limit_val = nir_ssa_scalar_as_const_value(src);
+            terminator->exact_trip_count_unknown = true;
+            return true;
+         }
+      }
    }
 
    return false;
@@ -696,14 +697,12 @@ test_iterations(int32_t iter_int, nir_const_value *step,
 
 static int
 calculate_iterations(nir_const_value *initial, nir_const_value *step,
-                     nir_const_value *limit, nir_loop_variable *alu_def,
-                     nir_alu_instr *cond_alu, nir_op alu_op, bool limit_rhs,
+                     nir_const_value *limit, nir_alu_instr *alu,
+                     nir_ssa_scalar cond, nir_op alu_op, bool limit_rhs,
                      bool invert_cond)
 {
    assert(initial != NULL && step != NULL && limit != NULL);
 
-   nir_alu_instr *alu = nir_instr_as_alu(alu_def->def->parent_instr);
-
    /* nir_op_isub should have been lowered away by this point */
    assert(alu->op != nir_op_isub);
 
@@ -735,8 +734,9 @@ calculate_iterations(nir_const_value *initial, nir_const_value *step,
     * condition and if so we assume we need to step the initial value.
     */
    unsigned trip_offset = 0;
-   if (cond_alu->src[0].src.ssa == alu_def->def ||
-       cond_alu->src[1].src.ssa == alu_def->def) {
+   nir_alu_instr *cond_alu = nir_instr_as_alu(cond.def->parent_instr);
+   if (cond_alu->src[0].src.ssa == &alu->dest.dest.ssa ||
+       cond_alu->src[1].src.ssa == &alu->dest.dest.ssa) {
       trip_offset = 1;
    }
 
@@ -774,9 +774,9 @@ calculate_iterations(nir_const_value *initial, nir_const_value *step,
 }
 
 static nir_op
-inverse_comparison(nir_alu_instr *alu)
+inverse_comparison(nir_op alu_op)
 {
-   switch (alu->op) {
+   switch (alu_op) {
    case nir_op_fge:
       return nir_op_flt;
    case nir_op_ige:
@@ -803,29 +803,33 @@ inverse_comparison(nir_alu_instr *alu)
 }
 
 static bool
-is_supported_terminator_condition(nir_alu_instr *alu)
+is_supported_terminator_condition(nir_ssa_scalar cond)
 {
+   if (!nir_ssa_scalar_is_alu(cond))
+      return false;
+
+   nir_alu_instr *alu = nir_instr_as_alu(cond.def->parent_instr);
    return nir_alu_instr_is_comparison(alu) &&
           nir_op_infos[alu->op].num_inputs == 2;
 }
 
 static bool
-get_induction_and_limit_vars(nir_alu_instr *alu,
-                             nir_loop_variable **ind,
-                             nir_loop_variable **limit,
+get_induction_and_limit_vars(nir_ssa_scalar cond,
+                             nir_ssa_scalar *ind,
+                             nir_ssa_scalar *limit,
                              bool *limit_rhs,
                              loop_info_state *state)
 {
-   nir_loop_variable *rhs, *lhs;
-   lhs = get_loop_var(alu->src[0].src.ssa, state);
-   rhs = get_loop_var(alu->src[1].src.ssa, state);
+   nir_ssa_scalar rhs, lhs;
+   lhs = nir_ssa_scalar_chase_alu_src(cond, 0);
+   rhs = nir_ssa_scalar_chase_alu_src(cond, 1);
 
-   if (lhs->type == basic_induction) {
+   if (get_loop_var(lhs.def, state)->type == basic_induction) {
       *ind = lhs;
       *limit = rhs;
       *limit_rhs = true;
       return true;
-   } else if (rhs->type == basic_induction) {
+   } else if (get_loop_var(rhs.def, state)->type == basic_induction) {
       *ind = rhs;
       *limit = lhs;
       *limit_rhs = false;
@@ -836,53 +840,40 @@ get_induction_and_limit_vars(nir_alu_instr *alu,
 }
 
 static bool
-try_find_trip_count_vars_in_iand(nir_alu_instr **alu,
-                                 nir_loop_variable **ind,
-                                 nir_loop_variable **limit,
+try_find_trip_count_vars_in_iand(nir_ssa_scalar *cond,
+                                 nir_ssa_scalar *ind,
+                                 nir_ssa_scalar *limit,
                                  bool *limit_rhs,
                                  loop_info_state *state)
 {
-   assert((*alu)->op == nir_op_ieq || (*alu)->op == nir_op_inot);
-
-   nir_ssa_def *iand_def = (*alu)->src[0].src.ssa;
-   /* This is used directly in an if condition so it must be a scalar */
-   assert(iand_def->num_components == 1);
+   const nir_op alu_op = nir_ssa_scalar_alu_op(*cond);
+   assert(alu_op == nir_op_ieq || alu_op == nir_op_inot);
 
-   if ((*alu)->op == nir_op_ieq) {
-      nir_ssa_def *zero_def = (*alu)->src[1].src.ssa;
+   nir_ssa_scalar iand = nir_ssa_scalar_chase_alu_src(*cond, 0);
 
-      /* We don't handle swizzles here */
-      if ((*alu)->src[0].swizzle[0] > 0 || (*alu)->src[1].swizzle[0] > 0)
-         return false;
-
-      if (iand_def->parent_instr->type != nir_instr_type_alu ||
-          zero_def->parent_instr->type != nir_instr_type_load_const) {
+   if (alu_op == nir_op_ieq) {
+      nir_ssa_scalar zero = nir_ssa_scalar_chase_alu_src(*cond, 1);
 
+      if (!nir_ssa_scalar_is_alu(iand) || !nir_ssa_scalar_is_const(zero)) {
          /* Maybe we had it the wrong way, flip things around */
-         iand_def = (*alu)->src[1].src.ssa;
-         zero_def = (*alu)->src[0].src.ssa;
+         nir_ssa_scalar tmp = zero;
+         zero = iand;
+         iand = tmp;
 
          /* If we still didn't find what we need then return */
-         if (zero_def->parent_instr->type != nir_instr_type_load_const)
+         if (!nir_ssa_scalar_is_const(zero))
             return false;
       }
 
       /* If the loop is not breaking on (x && y) == 0 then return */
-      nir_const_value *zero =
-         nir_instr_as_load_const(zero_def->parent_instr)->value;
-      if (zero[0].i32 != 0)
+      if (nir_ssa_scalar_as_uint(zero) != 0)
          return false;
    }
 
-   if (iand_def->parent_instr->type != nir_instr_type_alu)
-      return false;
-
-   nir_alu_instr *iand = nir_instr_as_alu(iand_def->parent_instr);
-   if (iand->op != nir_op_iand)
+   if (!nir_ssa_scalar_is_alu(iand))
       return false;
 
-   /* We don't handle swizzles here */
-   if ((*alu)->src[0].swizzle[0] > 0 || (*alu)->src[1].swizzle[0] > 0)
+   if (nir_ssa_scalar_alu_op(iand) != nir_op_iand)
       return false;
 
    /* Check if iand src is a terminator condition and try get induction var
@@ -890,19 +881,15 @@ try_find_trip_count_vars_in_iand(nir_alu_instr **alu,
     */
    bool found_induction_var = false;
    for (unsigned i = 0; i < 2; i++) {
-      nir_ssa_def *src = iand->src[i].src.ssa;
-      if (src->parent_instr->type == nir_instr_type_alu) {
-         nir_alu_instr *src_alu = nir_instr_as_alu(src->parent_instr);
-         if (is_supported_terminator_condition(src_alu) &&
-             get_induction_and_limit_vars(src_alu, ind, limit,
-                                          limit_rhs, state)) {
-            *alu = src_alu;
-            found_induction_var = true;
-
-            /* If we've found one with a constant limit, stop. */
-            if (is_var_constant(*limit))
-               return true;
-         }
+      nir_ssa_scalar src = nir_ssa_scalar_chase_alu_src(iand, i);
+      if (is_supported_terminator_condition(src) &&
+          get_induction_and_limit_vars(src, ind, limit, limit_rhs, state)) {
+         *cond = src;
+         found_induction_var = true;
+
+         /* If we've found one with a constant limit, stop. */
+         if (nir_ssa_scalar_is_const(*limit))
+            return true;
       }
    }
 
@@ -926,8 +913,10 @@ find_trip_count(loop_info_state *state)
    list_for_each_entry(nir_loop_terminator, terminator,
                        &state->loop->info->loop_terminator_list,
                        loop_terminator_link) {
+      assert(terminator->nif->condition.is_ssa);
+      nir_ssa_scalar cond = { terminator->nif->condition.ssa, 0 };
 
-      if (terminator->conditional_instr->type != nir_instr_type_alu) {
+      if (!nir_ssa_scalar_is_alu(cond)) {
          /* If we get here the loop is dead and will get cleaned up by the
           * nir_opt_dead_cf pass.
           */
@@ -935,27 +924,27 @@ find_trip_count(loop_info_state *state)
          continue;
       }
 
-      nir_alu_instr *alu = nir_instr_as_alu(terminator->conditional_instr);
-      nir_op alu_op = alu->op;
+      nir_op alu_op = nir_ssa_scalar_alu_op(cond);
 
       bool limit_rhs;
-      nir_loop_variable *basic_ind = NULL;
-      nir_loop_variable *limit;
-      if ((alu->op == nir_op_inot || alu->op == nir_op_ieq) &&
-          try_find_trip_count_vars_in_iand(&alu, &basic_ind, &limit,
+      nir_ssa_scalar basic_ind = { NULL, 0 };
+      nir_ssa_scalar limit;
+      if ((alu_op == nir_op_inot || alu_op == nir_op_ieq) &&
+          try_find_trip_count_vars_in_iand(&cond, &basic_ind, &limit,
                                            &limit_rhs, state)) {
+
          /* The loop is exiting on (x && y) == 0 so we need to get the
           * inverse of x or y (i.e. which ever contained the induction var) in
           * order to compute the trip count.
           */
-         alu_op = inverse_comparison(alu);
+         alu_op = inverse_comparison(nir_ssa_scalar_alu_op(cond));
          trip_count_known = false;
          terminator->exact_trip_count_unknown = true;
       }
 
-      if (!basic_ind) {
-         if (is_supported_terminator_condition(alu)) {
-            get_induction_and_limit_vars(alu, &basic_ind,
+      if (!basic_ind.def) {
+         if (is_supported_terminator_condition(cond)) {
+            get_induction_and_limit_vars(cond, &basic_ind,
                                          &limit, &limit_rhs, state);
          }
       }
@@ -963,7 +952,7 @@ find_trip_count(loop_info_state *state)
       /* The comparison has to have a basic induction variable for us to be
        * able to find trip counts.
        */
-      if (!basic_ind) {
+      if (!basic_ind.def) {
          trip_count_known = false;
          continue;
       }
@@ -972,9 +961,8 @@ find_trip_count(loop_info_state *state)
 
       /* Attempt to find a constant limit for the loop */
       nir_const_value limit_val;
-      if (is_var_constant(limit)) {
-         limit_val =
-            nir_instr_as_load_const(limit->def->parent_instr)->value[0];
+      if (nir_ssa_scalar_is_const(limit)) {
+         limit_val = nir_ssa_scalar_as_const_value(limit);
       } else {
          trip_count_known = false;
 
@@ -996,17 +984,38 @@ find_trip_count(loop_info_state *state)
        * Thats all thats needed to calculate the trip-count
        */
 
-      nir_const_value *initial_val =
-         nir_instr_as_load_const(basic_ind->ind->def_outside_loop->
-                                    def->parent_instr)->value;
+      nir_basic_induction_var *ind_var =
+         get_loop_var(basic_ind.def, state)->ind;
 
-      nir_const_value *step_val =
-         nir_instr_as_load_const(basic_ind->ind->invariant->def->
-                                    parent_instr)->value;
+      /* The basic induction var might be a vector but, because we guarantee
+       * earlier that the phi source has a scalar swizzle, we can take the
+       * component from basic_ind.
+       */
+      nir_ssa_scalar initial_s = { ind_var->def_outside_loop, basic_ind.comp };
+      nir_ssa_scalar alu_s = { &ind_var->alu->dest.dest.ssa, basic_ind.comp };
+
+      nir_const_value initial_val = nir_ssa_scalar_as_const_value(initial_s);
+
+      /* We are guaranteed by earlier code that at least one of these sources
+       * is a constant but we don't know which.
+       */
+      nir_const_value step_val;
+      memset(&step_val, 0, sizeof(step_val));
+      UNUSED bool found_step_value = false;
+      assert(nir_op_infos[ind_var->alu->op].num_inputs == 2);
+      for (unsigned i = 0; i < 2; i++) {
+         nir_ssa_scalar alu_src = nir_ssa_scalar_chase_alu_src(alu_s, i);
+         if (nir_ssa_scalar_is_const(alu_src)) {
+            found_step_value = true;
+            step_val = nir_ssa_scalar_as_const_value(alu_src);
+            break;
+         }
+      }
+      assert(found_step_value);
 
-      int iterations = calculate_iterations(initial_val, step_val,
+      int iterations = calculate_iterations(&initial_val, &step_val,
                                             &limit_val,
-                                            basic_ind->ind->alu_def, alu,
+                                            ind_var->alu, cond,
                                             alu_op, limit_rhs,
                                             terminator->continue_from_then);
 




More information about the mesa-commit mailing list