Mesa (main): aco: add p_extract/p_insert

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Tue Jun 8 09:19:34 UTC 2021


Module: Mesa
Branch: main
Commit: 2f94353735b5ddfe2a72499e7bf6c7bbc80b9a00
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=2f94353735b5ddfe2a72499e7bf6c7bbc80b9a00

Author: Rhys Perry <pendingchaos02 at gmail.com>
Date:   Wed Aug 12 14:35:15 2020 +0100

aco: add p_extract/p_insert

These will let us make the SDWA optimizer much simpler than if we were to
recognize combinations of shift/and/bfe.

Signed-off-by: Rhys Perry <pendingchaos02 at gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof at gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/3151>

---

 src/amd/compiler/aco_lower_to_hw_instr.cpp | 97 ++++++++++++++++++++++++++++++
 src/amd/compiler/aco_opcodes.py            |  8 +++
 src/amd/compiler/aco_optimizer.cpp         | 86 +++++++++++++++++++++++---
 src/amd/compiler/aco_validate.cpp          | 23 +++++++
 4 files changed, 207 insertions(+), 7 deletions(-)

diff --git a/src/amd/compiler/aco_lower_to_hw_instr.cpp b/src/amd/compiler/aco_lower_to_hw_instr.cpp
index c4b8120ed16..f9dfe0b2d29 100644
--- a/src/amd/compiler/aco_lower_to_hw_instr.cpp
+++ b/src/amd/compiler/aco_lower_to_hw_instr.cpp
@@ -1994,6 +1994,103 @@ void lower_to_hw_instr(Program* program)
                         Operand(reg.advance(4), s1), Operand(0u), Operand(scc, s1));
                break;
             }
+            case aco_opcode::p_extract:
+            {
+               assert(instr->operands[1].isConstant());
+               assert(instr->operands[2].isConstant());
+               assert(instr->operands[3].isConstant());
+               if (instr->definitions[0].regClass() == s1)
+                  assert(instr->definitions.size() >= 2 && instr->definitions[1].physReg() == scc);
+               Definition dst = instr->definitions[0];
+               Operand op = instr->operands[0];
+               unsigned bits = instr->operands[2].constantValue();
+               unsigned index = instr->operands[1].constantValue();
+               unsigned offset = index * bits;
+               bool signext = !instr->operands[3].constantEquals(0);
+
+               if (dst.regClass() == s1) {
+                  if (offset == (32 - bits)) {
+                     bld.sop2(signext ? aco_opcode::s_ashr_i32 : aco_opcode::s_lshr_b32,
+                              dst, bld.def(s1, scc), op, Operand(offset));
+                  } else if (offset == 0 && signext && (bits == 8 || bits == 16)) {
+                     bld.sop1(bits == 8 ? aco_opcode::s_sext_i32_i8 : aco_opcode::s_sext_i32_i16, dst, op);
+                  } else {
+                     bld.sop2(signext ? aco_opcode::s_bfe_i32 : aco_opcode::s_bfe_u32,
+                              dst, bld.def(s1, scc), op, Operand((bits << 16) | offset));
+                  }
+               } else if (dst.regClass() == v1 || ctx.program->chip_class <= GFX7) {
+                  assert(op.physReg().byte() == 0 && dst.physReg().byte() == 0);
+                  if (offset == (32 - bits) && op.regClass() != s1) {
+                     bld.vop2(signext ? aco_opcode::v_ashrrev_i32 : aco_opcode::v_lshrrev_b32,
+                              dst, Operand(offset), op);
+                  } else {
+                     bld.vop3(signext ? aco_opcode::v_bfe_i32 : aco_opcode::v_bfe_u32,
+                              dst, op, Operand(offset), Operand(bits));
+                  }
+               } else if (dst.regClass() == v2b) {
+                  aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(
+                     aco_opcode::v_mov_b32, (Format)((uint16_t)Format::VOP1|(uint16_t)Format::SDWA), 1, 1)};
+                  sdwa->operands[0] = Operand(op.physReg().advance(-op.physReg().byte()),
+                                              RegClass::get(op.regClass().type(), 4));
+                  sdwa->definitions[0] = dst;
+                  sdwa->sel[0] = sdwa_ubyte0 + op.physReg().byte() + index;
+                  if (signext)
+                     sdwa->sel[0] |= sdwa_sext;
+                  sdwa->dst_sel = sdwa_uword;
+                  bld.insert(std::move(sdwa));
+               }
+               break;
+            }
+            case aco_opcode::p_insert:
+            {
+               assert(instr->operands[1].isConstant());
+               assert(instr->operands[2].isConstant());
+               if (instr->definitions[0].regClass() == s1)
+                  assert(instr->definitions.size() >= 2 && instr->definitions[1].physReg() == scc);
+               Definition dst = instr->definitions[0];
+               Operand op = instr->operands[0];
+               unsigned bits = instr->operands[2].constantValue();
+               unsigned index = instr->operands[1].constantValue();
+               unsigned offset = index * bits;
+
+               if (dst.regClass() == s1) {
+                  if (offset == (32 - bits)) {
+                     bld.sop2(aco_opcode::s_lshl_b32, dst, bld.def(s1, scc), op, Operand(offset));
+                  } else if (offset == 0) {
+                     bld.sop2(aco_opcode::s_bfe_u32, dst, bld.def(s1, scc), op, Operand(bits << 16));
+                  } else {
+                     bld.sop2(aco_opcode::s_bfe_u32, dst, bld.def(s1, scc), op, Operand(bits << 16));
+                     bld.sop2(aco_opcode::s_lshl_b32, dst, bld.def(s1, scc), Operand(dst.physReg(), s1), Operand(offset));
+                  }
+               } else if (dst.regClass() == v1 || ctx.program->chip_class <= GFX7) {
+                  if (offset == (dst.bytes() * 8u - bits)) {
+                     bld.vop2(aco_opcode::v_lshlrev_b32, dst, Operand(offset), op);
+                  } else if (offset == 0) {
+                     bld.vop3(aco_opcode::v_bfe_u32, dst, op, Operand(0u), Operand(bits));
+                  } else if (program->chip_class >= GFX9 || (op.regClass() != s1 && program->chip_class >= GFX8)) {
+                     aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, (Format)((uint16_t)Format::VOP1|(uint16_t)Format::SDWA), 1, 1)};
+                     sdwa->operands[0] = op;
+                     sdwa->definitions[0] = dst;
+                     sdwa->sel[0] = sdwa_udword;
+                     sdwa->dst_sel = (bits == 8 ? sdwa_ubyte0 : sdwa_uword0) + (offset / bits);
+                     bld.insert(std::move(sdwa));
+                  } else {
+                     bld.vop3(aco_opcode::v_bfe_u32, dst, op, Operand(0u), Operand(bits));
+                     bld.vop2(aco_opcode::v_lshlrev_b32, dst, Operand(offset), Operand(dst.physReg(), v1));
+                  }
+               } else {
+                  assert(dst.regClass() == v2b);
+                  aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(
+                     aco_opcode::v_mov_b32, (Format)((uint16_t)Format::VOP1|(uint16_t)Format::SDWA), 1, 1)};
+                  sdwa->operands[0] = op;
+                  sdwa->definitions[0] = Definition(dst.physReg().advance(-dst.physReg().byte()), v1);
+                  sdwa->sel[0] = sdwa_uword;
+                  sdwa->dst_sel = sdwa_ubyte0 + dst.physReg().byte() + index;
+                  sdwa->dst_preserve = 1;
+                  bld.insert(std::move(sdwa));
+               }
+               break;
+            }
             default:
                break;
             }
diff --git a/src/amd/compiler/aco_opcodes.py b/src/amd/compiler/aco_opcodes.py
index a28f1d5d765..07ac9cf104c 100644
--- a/src/amd/compiler/aco_opcodes.py
+++ b/src/amd/compiler/aco_opcodes.py
@@ -320,6 +320,14 @@ opcode("p_bpermute")
 
 opcode("p_constaddr")
 
+# These don't have to be pseudo-ops, but it makes optimization easier to only
+# have to consider two instructions.
+# (src0 >> (index * bits)) & ((1 << bits) - 1) with optional sign extension
+opcode("p_extract") # src1=index, src2=bits, src3=signext
+# (src0 & ((1 << bits) - 1)) << (index * bits)
+opcode("p_insert") # src1=index, src2=bits
+
+
 # SOP2 instructions: 2 scalar inputs, 1 scalar output (+optional scc)
 SOP2 = {
   # GFX6, GFX7, GFX8, GFX9, GFX10, name
diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp
index 42f50cae8fc..b1fadd33c31 100644
--- a/src/amd/compiler/aco_optimizer.cpp
+++ b/src/amd/compiler/aco_optimizer.cpp
@@ -763,6 +763,8 @@ bool alu_can_accept_constant(aco_opcode opcode, unsigned operand)
    case aco_opcode::v_readlane_b32:
    case aco_opcode::v_readlane_b32_e64:
    case aco_opcode::v_readfirstlane_b32:
+   case aco_opcode::p_extract:
+   case aco_opcode::p_insert:
       return operand != 0;
    default:
       return true;
@@ -1610,6 +1612,16 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       if (instr->operands[0].constantEquals(0x3f800000u))
          ctx.info[instr->definitions[0].tempId()].set_canonicalized();
       break;
+   case aco_opcode::p_extract: {
+      if (instr->operands[0].isTemp())
+         ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
+      break;
+   }
+   case aco_opcode::p_insert: {
+      if (instr->operands[0].isTemp())
+         ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
+      break;
+   }
    default:
       break;
    }
@@ -2210,6 +2222,70 @@ bool combine_three_valu_op(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode
    return false;
 }
 
+/* creates v_lshl_add_u32, v_lshl_or_b32 or v_and_or_b32 */
+bool combine_add_or_then_and_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
+{
+   bool is_or = instr->opcode == aco_opcode::v_or_b32;
+   aco_opcode new_op_lshl = is_or ? aco_opcode::v_lshl_or_b32 : aco_opcode::v_lshl_add_u32;
+
+   if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2))
+      return true;
+   if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2))
+      return true;
+   if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, new_op_lshl, "120", 1 | 2))
+      return true;
+   if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, new_op_lshl, "210", 1 | 2))
+      return true;
+
+   if (instr->isSDWA())
+      return false;
+
+   /* v_or_b32(p_extract(a, 0, 8/16, 0), b) -> v_and_or_b32(a, 0xff/0xffff, b)
+    * v_or_b32(p_insert(a, 0, 8/16), b) -> v_and_or_b32(a, 0xff/0xffff, b)
+    * v_or_b32(p_insert(a, 24/16, 8/16), b) -> v_lshl_or_b32(a, 24/16, b)
+    * v_add_u32(p_insert(a, 24/16, 8/16), b) -> v_lshl_add_b32(a, 24/16, b)
+    */
+   for (unsigned i = 0; i < 2; i++) {
+      Instruction *extins = follow_operand(ctx, instr->operands[i]);
+      if (!extins)
+         continue;
+
+      aco_opcode op;
+      Operand operands[3];
+
+      if (extins->opcode == aco_opcode::p_insert &&
+          (extins->operands[1].constantValue() + 1) * extins->operands[2].constantValue() == 32) {
+         op = new_op_lshl;
+         operands[1] = Operand(extins->operands[1].constantValue() * extins->operands[2].constantValue());
+      } else if (is_or && (extins->opcode == aco_opcode::p_insert ||
+                           (extins->opcode == aco_opcode::p_extract && extins->operands[3].constantEquals(0))) &&
+                 extins->operands[1].constantEquals(0)) {
+         op = aco_opcode::v_and_or_b32;
+         operands[1] = Operand(extins->operands[2].constantEquals(8) ? 0xffu : 0xffffu);
+      } else {
+        continue;
+      }
+
+      operands[0] = extins->operands[0];
+      operands[2] = instr->operands[!i];
+
+      if (!check_vop3_operands(ctx, 3, operands))
+         continue;
+
+      bool neg[3] = {}, abs[3] = {};
+      uint8_t opsel = 0, omod = 0;
+      bool clamp = false;
+      if (instr->isVOP3())
+         clamp = instr->vop3().clamp;
+
+      ctx.uses[instr->operands[i].tempId()]--;
+      create_vop3_for_op3(ctx, op, instr, operands, neg, abs, opsel, clamp, omod);
+      return true;
+   }
+
+   return false;
+}
+
 bool combine_minmax(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode opposite, aco_opcode minmax3)
 {
    if (combine_three_valu_op(ctx, instr, instr->opcode, minmax3, "012", 1 | 2))
@@ -3198,10 +3274,7 @@ void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    } else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->chip_class >= GFX9) {
       if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012", 1 | 2)) ;
       else if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32, "012", 1 | 2)) ;
-      else if (combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2)) ;
-      else if (combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2)) ;
-      else if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, aco_opcode::v_lshl_or_b32, "120", 1 | 2)) ;
-      else combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, aco_opcode::v_lshl_or_b32, "210", 1 | 2);
+      else combine_add_or_then_and_lshl(ctx, instr) ;
    } else if (instr->opcode == aco_opcode::v_xor_b32 && ctx.program->chip_class >= GFX10) {
       if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xor3_b32, "012", 1 | 2)) ;
       else combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xor3_b32, "012", 1 | 2);
@@ -3215,9 +3288,8 @@ void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
          else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
          else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
-         else if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, aco_opcode::v_lshl_add_u32, "120", 1 | 2)) ;
-         else if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, aco_opcode::v_lshl_add_u32, "210", 1 | 2)) ;
-         else combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16, aco_opcode::v_mad_u32_u16, "120", 1 | 2) ;
+         else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16, aco_opcode::v_mad_u32_u16, "120", 1 | 2)) ;
+         else combine_add_or_then_and_lshl(ctx, instr) ;
       }
    } else if (instr->opcode == aco_opcode::v_add_co_u32 ||
               instr->opcode == aco_opcode::v_add_co_u32_e64) {
diff --git a/src/amd/compiler/aco_validate.cpp b/src/amd/compiler/aco_validate.cpp
index f5ba8ab4958..51ca2a35ae1 100644
--- a/src/amd/compiler/aco_validate.cpp
+++ b/src/amd/compiler/aco_validate.cpp
@@ -376,6 +376,29 @@ bool validate_ir(Program* program)
                   check(instr->definitions[0].size() == op.size(), "Operand sizes must match Definition size", instr.get());
                }
                check(instr->operands.size() == block.linear_preds.size(), "Number of Operands does not match number of predecessors", instr.get());
+            } else if (instr->opcode == aco_opcode::p_extract || instr->opcode == aco_opcode::p_insert) {
+               check(instr->operands[0].isTemp(),
+                     "Data operand must be temporary", instr.get());
+               check(instr->operands[1].isConstant(), "Index must be constant", instr.get());
+               if (instr->opcode == aco_opcode::p_extract)
+                  check(instr->operands[3].isConstant(), "Sign-extend flag must be constant", instr.get());
+
+               check(instr->definitions[0].getTemp().type() != RegType::sgpr ||
+                     instr->operands[0].getTemp().type() == RegType::sgpr,
+                     "Can't extract/insert VGPR to SGPR", instr.get());
+
+               if (instr->operands[0].getTemp().type() == RegType::vgpr)
+                  check(instr->operands[0].bytes() == instr->definitions[0].bytes(),
+                        "Sizes of operand and definition must match", instr.get());
+
+               if (instr->definitions[0].getTemp().type() == RegType::sgpr)
+                  check(instr->definitions.size() >= 2 && instr->definitions[1].isFixed() && instr->definitions[1].physReg() == scc, "SGPR extract/insert needs a SCC definition", instr.get());
+
+               check(instr->operands[2].constantEquals(8) || instr->operands[2].constantEquals(16), "Size must be 8 or 16", instr.get());
+               check(instr->operands[2].constantValue() < instr->operands[0].getTemp().bytes() * 8u, "Size must be smaller than source", instr.get());
+
+               unsigned comp = instr->operands[0].bytes() * 8u / MAX2(instr->operands[2].constantValue(), 1);
+               check(instr->operands[1].constantValue() < comp, "Index must be in-bounds", instr.get());
             }
             break;
          }



More information about the mesa-commit mailing list