Mesa (master): aco: optimize packed mul+add to v_pk_fma_f16

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Jan 13 18:03:29 UTC 2021


Module: Mesa
Branch: master
Commit: a9fd9187e830b6665984f2f9cf651465c266dc85
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=a9fd9187e830b6665984f2f9cf651465c266dc85

Author: Daniel Schürmann <daniel at schuermann.dev>
Date:   Thu Sep  3 12:02:55 2020 +0100

aco: optimize packed mul+add to v_pk_fma_f16

Reviewed-by: Rhys Perry <pendingchaos02 at gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6680>

---

 src/amd/compiler/aco_optimizer.cpp | 99 +++++++++++++++++++++++++++++++++++++-
 1 file changed, 97 insertions(+), 2 deletions(-)

diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp
index 967656c3c72..d4b95f50538 100644
--- a/src/amd/compiler/aco_optimizer.cpp
+++ b/src/amd/compiler/aco_optimizer.cpp
@@ -116,9 +116,10 @@ enum Label {
    label_b2i = 1 << 27,
    label_constant_16bit = 1 << 29,
    label_usedef = 1 << 30, /* generic label */
+   label_vop3p = 1 << 31,
 };
 
-static constexpr uint64_t instr_usedef_labels = label_vec | label_mul | label_mad | label_add_sub |
+static constexpr uint64_t instr_usedef_labels = label_vec | label_mul | label_mad | label_add_sub | label_vop3p |
                                                 label_bitwise | label_uniform_bitwise | label_minmax | label_vopc | label_usedef;
 static constexpr uint64_t instr_mod_labels = label_omod2 | label_omod4 | label_omod5 | label_clamp;
 
@@ -499,6 +500,17 @@ struct ssa_info {
    {
       return label & label_usedef;
    }
+
+   void set_vop3p(Instruction* vop3p_instr)
+   {
+      add_label(label_vop3p);
+      instr = vop3p_instr;
+   }
+
+   bool is_vop3p()
+   {
+      return label & label_vop3p;
+   }
 };
 
 struct opt_ctx {
@@ -1084,6 +1096,10 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
       ctx.info[instr->definitions[0].tempId()].set_vopc(instr.get());
       return;
    }
+   if (instr->format == Format::VOP3P) {
+      ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
+      return;
+   }
 
    switch (instr->opcode) {
    case aco_opcode::p_create_vector: {
@@ -2710,6 +2726,82 @@ bool combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
    return false;
 }
 
+void combine_vop3p(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
+{
+   // TODO: clamp, fneg?
+
+   if (instr->opcode == aco_opcode::v_pk_add_f16) {
+      if (instr->definitions[0].isPrecise())
+         return;
+
+      Instruction* mul_instr = nullptr;
+      unsigned add_op_idx = 0;
+      uint32_t uses = UINT32_MAX;
+
+      /* find the 'best' mul instruction to combine with the add */
+      for (unsigned i = 0; i < 2; i++) {
+         if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_vop3p())
+            continue;
+         ssa_info& info = ctx.info[instr->operands[i].tempId()];
+         if (info.instr->opcode != aco_opcode::v_pk_mul_f16 || info.instr->definitions[0].isPrecise())
+            continue;
+
+         Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]};
+         if (ctx.uses[instr->operands[i].tempId()] >= uses ||
+             !check_vop3_operands(ctx, 3, op))
+            continue;
+
+         /* opsel of mul needs to be .xy */
+         if (static_cast<VOP3P_instruction*>(instr.get())->opsel_lo & (1 << i) ||
+             !(static_cast<VOP3P_instruction*>(instr.get())->opsel_hi & (1 << i)))
+            continue;
+         /* no clamp allowed between mul and add */
+         if (static_cast<VOP3P_instruction*>(info.instr)->clamp)
+            continue;
+
+         mul_instr = info.instr;
+         add_op_idx = 1 - i;
+         uses = ctx.uses[instr->operands[i].tempId()];
+      }
+
+      if (!mul_instr)
+         return;
+
+      /* convert to mad */
+      Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1], instr->operands[add_op_idx]};
+      ctx.uses[mul_instr->definitions[0].tempId()]--;
+      if (ctx.uses[mul_instr->definitions[0].tempId()]) {
+         if (op[0].isTemp())
+            ctx.uses[op[0].tempId()]++;
+         if (op[1].isTemp())
+            ctx.uses[op[1].tempId()]++;
+      }
+
+      /* turn packed mul+add into v_pk_fma_f16 */
+      assert(mul_instr->format == Format::VOP3P);
+      aco_ptr<VOP3P_instruction> fma{create_instruction<VOP3P_instruction>(aco_opcode::v_pk_fma_f16, Format::VOP3P, 3, 1)};
+      VOP3P_instruction* mul = static_cast<VOP3P_instruction*>(mul_instr);
+      VOP3P_instruction* vop3p = static_cast<VOP3P_instruction*>(instr.get());
+      for (unsigned i = 0; i < 2; i++) {
+         fma->operands[i] = op[i];
+         fma->neg_lo[i] = mul->neg_lo[i];
+         fma->neg_hi[i] = mul->neg_hi[i];
+      }
+      fma->operands[2] = op[2];
+      fma->neg_lo[2] = vop3p->neg_lo[add_op_idx];
+      fma->neg_hi[2] = vop3p->neg_hi[add_op_idx];
+      fma->clamp = vop3p->clamp;
+      fma->opsel_lo = mul->opsel_lo | ((vop3p->opsel_lo << (2 - add_op_idx)) & 0x4);
+      fma->opsel_hi = mul->opsel_hi | ((vop3p->opsel_hi << (2 - add_op_idx)) & 0x4);
+      fma->neg_lo[1] = fma->neg_lo[1] ^ vop3p->neg_lo[1 - add_op_idx];
+      fma->neg_hi[1] = fma->neg_hi[1] ^ vop3p->neg_hi[1 - add_op_idx];
+      fma->definitions[0] = instr->definitions[0];
+      instr.reset(fma.release());
+      ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
+      return;
+   }
+}
+
 // TODO: we could possibly move the whole label_instruction pass to combine_instruction:
 // this would mean that we'd have to fix the instruction uses while value propagation
 
@@ -2724,6 +2816,9 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
       while (apply_omod_clamp(ctx, block, instr)) ;
    }
 
+   if (instr->format == Format::VOP3P)
+      return combine_vop3p(ctx, block, instr);
+
    if (ctx.info[instr->definitions[0].tempId()].is_vcc_hint()) {
       instr->definitions[0].setHint(vcc);
    }
@@ -2798,7 +2893,7 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
          return;
 
       Instruction* mul_instr = nullptr;
-      unsigned add_op_idx;
+      unsigned add_op_idx = 0;
       uint32_t uses = UINT32_MAX;
       /* find the 'best' mul instruction to combine with the add */
       for (unsigned i = 0; i < 2; i++) {



More information about the mesa-commit mailing list