Mesa (staging/21.0): aco: simplify multiply-add combining

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Jan 13 19:31:07 UTC 2021


Module: Mesa
Branch: staging/21.0
Commit: 7f40dc9760eea25f6bc4936f26cec60481c8fc7d
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=7f40dc9760eea25f6bc4936f26cec60481c8fc7d

Author: Daniel Schürmann <daniel at schuermann.dev>
Date:   Wed Sep  2 15:19:21 2020 +0100

aco: simplify multiply-add combining

When both operands of a v_sub (same apply for v_add) are mul and one
already uses clamp/omod, pick the other operand to get a chance to
combine to a MAD.

No fossils-db changes.

Co-authored-by: Samuel Pitoiset <samuel.pitoiset at gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02 at gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6680>
(cherry picked from commit 01134b0bfe407f43d8089551301ffedaeeb459ff)

---

 .pick_status.json                  |  2 +-
 src/amd/compiler/aco_optimizer.cpp | 84 ++++++++++++++++----------------------
 2 files changed, 37 insertions(+), 49 deletions(-)

diff --git a/.pick_status.json b/.pick_status.json
index 0e0301fa208..5ec58098d3e 100644
--- a/.pick_status.json
+++ b/.pick_status.json
@@ -328,7 +328,7 @@
         "description": "aco: simplify multiply-add combining",
         "nominated": false,
         "nomination_type": null,
-        "resolution": 4,
+        "resolution": 1,
         "master_sha": null,
         "because_sha": null
     },
diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp
index c986c6f69e7..967656c3c72 100644
--- a/src/amd/compiler/aco_optimizer.cpp
+++ b/src/amd/compiler/aco_optimizer.cpp
@@ -2797,49 +2797,50 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
       if (need_fma && mad32 && !ctx.program->has_fast_fma32)
          return;
 
-      uint32_t uses_src0 = UINT32_MAX;
-      uint32_t uses_src1 = UINT32_MAX;
       Instruction* mul_instr = nullptr;
       unsigned add_op_idx;
-      /* check if any of the operands is a multiplication */
-      ssa_info *op0_info = instr->operands[0].isTemp() ? &ctx.info[instr->operands[0].tempId()] : NULL;
-      ssa_info *op1_info = instr->operands[1].isTemp() ? &ctx.info[instr->operands[1].tempId()] : NULL;
-      if (op0_info && op0_info->is_mul() && (!need_fma || !op0_info->instr->definitions[0].isPrecise()))
-         uses_src0 = ctx.uses[instr->operands[0].tempId()];
-      if (op1_info && op1_info->is_mul() && (!need_fma || !op1_info->instr->definitions[0].isPrecise()))
-         uses_src1 = ctx.uses[instr->operands[1].tempId()];
-
+      uint32_t uses = UINT32_MAX;
       /* find the 'best' mul instruction to combine with the add */
-      if (uses_src0 < uses_src1) {
-         mul_instr = op0_info->instr;
-         add_op_idx = 1;
-      } else if (uses_src1 < uses_src0) {
-         mul_instr = op1_info->instr;
-         add_op_idx = 0;
-      } else if (uses_src0 != UINT32_MAX) {
-         /* tiebreaker: quite random what to pick */
-         if (op0_info->instr->operands[0].isLiteral()) {
-            mul_instr = op1_info->instr;
-            add_op_idx = 0;
-         } else {
-            mul_instr = op0_info->instr;
-            add_op_idx = 1;
-         }
+      for (unsigned i = 0; i < 2; i++) {
+         if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul())
+            continue;
+         /* check precision requirements */
+         ssa_info& info = ctx.info[instr->operands[i].tempId()];
+         if (need_fma && info.instr->definitions[0].isPrecise())
+            continue;
+
+         /* no clamp/omod allowed between mul and add */
+         if (info.instr->isVOP3() &&
+             (static_cast<VOP3A_instruction*>(info.instr)->clamp ||
+              static_cast<VOP3A_instruction*>(info.instr)->omod))
+            continue;
+
+         Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]};
+         if (info.instr->isSDWA() ||
+             !check_vop3_operands(ctx, 3, op) ||
+             ctx.uses[instr->operands[i].tempId()] >= uses)
+            continue;
+
+         mul_instr = info.instr;
+         add_op_idx = 1 - i;
+         uses = ctx.uses[instr->operands[i].tempId()];
       }
+
       if (mul_instr) {
-         Operand op[3] = {Operand(v1), Operand(v1), Operand(v1)};
+         /* turn mul+add into v_mad/v_fma */
+         Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1], instr->operands[add_op_idx]};
+         ctx.uses[mul_instr->definitions[0].tempId()]--;
+         if (ctx.uses[mul_instr->definitions[0].tempId()]) {
+            if (op[0].isTemp())
+               ctx.uses[op[0].tempId()]++;
+            if (op[1].isTemp())
+               ctx.uses[op[1].tempId()]++;
+         }
+
          bool neg[3] = {false, false, false};
          bool abs[3] = {false, false, false};
          unsigned omod = 0;
          bool clamp = false;
-         op[0] = mul_instr->operands[0];
-         op[1] = mul_instr->operands[1];
-         op[2] = instr->operands[add_op_idx];
-         // TODO: would be better to check this before selecting a mul instr?
-         if (!check_vop3_operands(ctx, 3, op))
-            return;
-         if (mul_instr->isSDWA())
-            return;
 
          if (mul_instr->isVOP3()) {
             VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*> (mul_instr);
@@ -2847,18 +2848,6 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
             neg[1] = vop3->neg[1];
             abs[0] = vop3->abs[0];
             abs[1] = vop3->abs[1];
-            /* we cannot use these modifiers between mul and add */
-            if (vop3->clamp || vop3->omod)
-               return;
-         }
-
-         /* convert to mad */
-         ctx.uses[mul_instr->definitions[0].tempId()]--;
-         if (ctx.uses[mul_instr->definitions[0].tempId()]) {
-            if (op[0].isTemp())
-               ctx.uses[op[0].tempId()]++;
-            if (op[1].isTemp())
-               ctx.uses[op[1].tempId()]++;
          }
 
          if (instr->isVOP3()) {
@@ -2888,8 +2877,7 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
                                 (ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_f16 : aco_opcode::v_mad_f16);
 
          aco_ptr<VOP3A_instruction> mad{create_instruction<VOP3A_instruction>(mad_op, Format::VOP3A, 3, 1)};
-         for (unsigned i = 0; i < 3; i++)
-         {
+         for (unsigned i = 0; i < 3; i++) {
             mad->operands[i] = op[i];
             mad->neg[i] = neg[i];
             mad->abs[i] = abs[i];



More information about the mesa-commit mailing list