Mesa (main): aco/ra: refactor subdword definition info

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Mon Aug 23 10:48:49 UTC 2021


Module: Mesa
Branch: main
Commit: c75138ed64054e6a8f6010cc1fd9ae19a59a62b6
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=c75138ed64054e6a8f6010cc1fd9ae19a59a62b6

Author: Daniel Schürmann <daniel at schuermann.dev>
Date:   Fri Aug 13 12:54:59 2021 +0200

aco/ra: refactor subdword definition info

Reviewed-by: Rhys Perry <pendingchaos02 at gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12364>

---

 src/amd/compiler/aco_register_allocation.cpp | 172 +++++++++++++++------------
 1 file changed, 98 insertions(+), 74 deletions(-)

diff --git a/src/amd/compiler/aco_register_allocation.cpp b/src/amd/compiler/aco_register_allocation.cpp
index 3e7e87e4ad2..6ba71eb34b6 100644
--- a/src/amd/compiler/aco_register_allocation.cpp
+++ b/src/amd/compiler/aco_register_allocation.cpp
@@ -43,8 +43,7 @@ void add_subdword_operand(ra_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx
                           RegClass rc);
 std::pair<unsigned, unsigned>
 get_subdword_definition_info(Program* program, const aco_ptr<Instruction>& instr, RegClass rc);
-void add_subdword_definition(Program* program, aco_ptr<Instruction>& instr, unsigned idx,
-                             PhysReg reg);
+void add_subdword_definition(Program* program, aco_ptr<Instruction>& instr, PhysReg reg);
 
 struct assignment {
    PhysReg reg;
@@ -565,99 +564,124 @@ get_subdword_definition_info(Program* program, const aco_ptr<Instruction>& instr
 {
    chip_class chip = program->chip_class;
 
-   if (instr->isPseudo() && chip >= GFX8)
-      return std::make_pair(rc.bytes() % 2 == 0 ? 2 : 1, rc.bytes());
-   else if (instr->isPseudo())
-      return std::make_pair(4, rc.size() * 4u);
-
-   unsigned bytes_written = chip >= GFX10 ? rc.bytes() : 4u;
-   switch (instr->opcode) {
-   case aco_opcode::v_mad_f16:
-   case aco_opcode::v_mad_u16:
-   case aco_opcode::v_mad_i16:
-   case aco_opcode::v_fma_f16:
-   case aco_opcode::v_div_fixup_f16:
-   case aco_opcode::v_interp_p2_f16: bytes_written = chip >= GFX9 ? rc.bytes() : 4u; break;
-   default: break;
+   if (instr->isPseudo()) {
+      if (chip >= GFX8)
+         return std::make_pair(rc.bytes() % 2 == 0 ? 2 : 1, rc.bytes());
+      else
+         return std::make_pair(4, rc.size() * 4u);
    }
-   bytes_written = bytes_written > 4 ? align(bytes_written, 4) : bytes_written;
-   bytes_written = MAX2(bytes_written, instr_info.definition_size[(int)instr->opcode] / 8u);
 
-   if (can_use_SDWA(chip, instr, false)) {
-      return std::make_pair(rc.bytes(), rc.bytes());
-   } else if (rc.bytes() == 2 && can_use_opsel(chip, instr->opcode, -1, 1)) {
-      return std::make_pair(2u, bytes_written);
+   if (instr->isVALU() || instr->isVINTRP()) {
+      assert(rc.bytes() <= 2);
+
+      if (can_use_SDWA(chip, instr, false))
+         return std::make_pair(rc.bytes(), rc.bytes());
+
+      unsigned bytes_written = 4u;
+      if (instr_is_16bit(chip, instr->opcode))
+         bytes_written = 2u;
+
+      unsigned stride = 4u;
+      if (instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
+          can_use_opsel(chip, instr->opcode, -1, true))
+         stride = 2u;
+
+      return std::make_pair(stride, bytes_written);
    }
 
    switch (instr->opcode) {
-   case aco_opcode::buffer_load_ubyte_d16:
-   case aco_opcode::buffer_load_short_d16:
+   case aco_opcode::ds_read_u8_d16:
+   case aco_opcode::ds_read_i8_d16:
+   case aco_opcode::ds_read_u16_d16:
    case aco_opcode::flat_load_ubyte_d16:
+   case aco_opcode::flat_load_sbyte_d16:
    case aco_opcode::flat_load_short_d16:
-   case aco_opcode::scratch_load_ubyte_d16:
-   case aco_opcode::scratch_load_short_d16:
    case aco_opcode::global_load_ubyte_d16:
+   case aco_opcode::global_load_sbyte_d16:
    case aco_opcode::global_load_short_d16:
-   case aco_opcode::ds_read_u8_d16:
-   case aco_opcode::ds_read_u16_d16:
-      if (chip >= GFX9 && !program->dev.sram_ecc_enabled)
+   case aco_opcode::scratch_load_ubyte_d16:
+   case aco_opcode::scratch_load_sbyte_d16:
+   case aco_opcode::scratch_load_short_d16:
+   case aco_opcode::buffer_load_ubyte_d16:
+   case aco_opcode::buffer_load_sbyte_d16:
+   case aco_opcode::buffer_load_short_d16: {
+      assert(chip >= GFX9);
+      if (!program->dev.sram_ecc_enabled)
          return std::make_pair(2u, 2u);
       else
          return std::make_pair(2u, 4u);
-   case aco_opcode::v_fma_mixlo_f16: return std::make_pair(2u, 2u);
-   default: break;
    }
 
-   return std::make_pair(4u, bytes_written);
+   default: return std::make_pair(4, rc.size() * 4u);
+   }
 }
 
 void
-add_subdword_definition(Program* program, aco_ptr<Instruction>& instr, unsigned idx, PhysReg reg)
+add_subdword_definition(Program* program, aco_ptr<Instruction>& instr, PhysReg reg)
 {
-   RegClass rc = instr->definitions[idx].regClass();
-   chip_class chip = program->chip_class;
-
-   if (instr->isPseudo()) {
-      return;
-   } else if (can_use_SDWA(chip, instr, false)) {
-      unsigned def_size = instr_info.definition_size[(int)instr->opcode];
-      if (reg.byte() || chip < GFX10 || def_size > rc.bytes() * 8u)
-         convert_to_SDWA(chip, instr);
-      return;
-   } else if (reg.byte() && rc.bytes() == 2 &&
-              can_use_opsel(chip, instr->opcode, -1, reg.byte() / 2)) {
-      VOP3_instruction& vop3 = instr->vop3();
-      if (reg.byte() == 2)
-         vop3.opsel |= (1 << 3); /* dst in high half */
+   if (instr->isPseudo())
       return;
-   }
 
-   if (reg.byte() == 2) {
-      if (instr->opcode == aco_opcode::v_fma_mixlo_f16)
+   if (instr->isVALU()) {
+      chip_class chip = program->chip_class;
+      assert(instr->definitions[0].bytes() <= 2);
+
+      if (reg.byte() == 0 && instr_is_16bit(chip, instr->opcode))
+         return;
+
+      /* check if we can use opsel */
+      if (instr->format == Format::VOP3) {
+         assert(reg.byte() == 2);
+         assert(can_use_opsel(chip, instr->opcode, -1, true));
+         instr->vop3().opsel |= (1 << 3); /* dst in high half */
+         return;
+      }
+
+      if (instr->opcode == aco_opcode::v_fma_mixlo_f16) {
          instr->opcode = aco_opcode::v_fma_mixhi_f16;
-      else if (instr->opcode == aco_opcode::buffer_load_ubyte_d16)
-         instr->opcode = aco_opcode::buffer_load_ubyte_d16_hi;
-      else if (instr->opcode == aco_opcode::buffer_load_short_d16)
-         instr->opcode = aco_opcode::buffer_load_short_d16_hi;
-      else if (instr->opcode == aco_opcode::flat_load_ubyte_d16)
-         instr->opcode = aco_opcode::flat_load_ubyte_d16_hi;
-      else if (instr->opcode == aco_opcode::flat_load_short_d16)
-         instr->opcode = aco_opcode::flat_load_short_d16_hi;
-      else if (instr->opcode == aco_opcode::scratch_load_ubyte_d16)
-         instr->opcode = aco_opcode::scratch_load_ubyte_d16_hi;
-      else if (instr->opcode == aco_opcode::scratch_load_short_d16)
-         instr->opcode = aco_opcode::scratch_load_short_d16_hi;
-      else if (instr->opcode == aco_opcode::global_load_ubyte_d16)
-         instr->opcode = aco_opcode::global_load_ubyte_d16_hi;
-      else if (instr->opcode == aco_opcode::global_load_short_d16)
-         instr->opcode = aco_opcode::global_load_short_d16_hi;
-      else if (instr->opcode == aco_opcode::ds_read_u8_d16)
-         instr->opcode = aco_opcode::ds_read_u8_d16_hi;
-      else if (instr->opcode == aco_opcode::ds_read_u16_d16)
-         instr->opcode = aco_opcode::ds_read_u16_d16_hi;
-      else
-         unreachable("Something went wrong: Impossible register assignment.");
+         return;
+      }
+
+      /* use SDWA */
+      assert(can_use_SDWA(chip, instr, false));
+      convert_to_SDWA(chip, instr);
+      return;
    }
+
+   if (reg.byte() == 0)
+      return;
+   else if (instr->opcode == aco_opcode::buffer_load_ubyte_d16)
+      instr->opcode = aco_opcode::buffer_load_ubyte_d16_hi;
+   else if (instr->opcode == aco_opcode::buffer_load_sbyte_d16)
+      instr->opcode = aco_opcode::buffer_load_sbyte_d16_hi;
+   else if (instr->opcode == aco_opcode::buffer_load_short_d16)
+      instr->opcode = aco_opcode::buffer_load_short_d16_hi;
+   else if (instr->opcode == aco_opcode::flat_load_ubyte_d16)
+      instr->opcode = aco_opcode::flat_load_ubyte_d16_hi;
+   else if (instr->opcode == aco_opcode::flat_load_sbyte_d16)
+      instr->opcode = aco_opcode::flat_load_sbyte_d16_hi;
+   else if (instr->opcode == aco_opcode::flat_load_short_d16)
+      instr->opcode = aco_opcode::flat_load_short_d16_hi;
+   else if (instr->opcode == aco_opcode::scratch_load_ubyte_d16)
+      instr->opcode = aco_opcode::scratch_load_ubyte_d16_hi;
+   else if (instr->opcode == aco_opcode::scratch_load_sbyte_d16)
+      instr->opcode = aco_opcode::scratch_load_sbyte_d16_hi;
+   else if (instr->opcode == aco_opcode::scratch_load_short_d16)
+      instr->opcode = aco_opcode::scratch_load_short_d16_hi;
+   else if (instr->opcode == aco_opcode::global_load_ubyte_d16)
+      instr->opcode = aco_opcode::global_load_ubyte_d16_hi;
+   else if (instr->opcode == aco_opcode::global_load_sbyte_d16)
+      instr->opcode = aco_opcode::global_load_sbyte_d16_hi;
+   else if (instr->opcode == aco_opcode::global_load_short_d16)
+      instr->opcode = aco_opcode::global_load_short_d16_hi;
+   else if (instr->opcode == aco_opcode::ds_read_u8_d16)
+      instr->opcode = aco_opcode::ds_read_u8_d16_hi;
+   else if (instr->opcode == aco_opcode::ds_read_i8_d16)
+      instr->opcode = aco_opcode::ds_read_i8_d16_hi;
+   else if (instr->opcode == aco_opcode::ds_read_u16_d16)
+      instr->opcode = aco_opcode::ds_read_u16_d16_hi;
+   else
+      unreachable("Something went wrong: Impossible register assignment.");
 }
 
 void
@@ -2576,7 +2600,7 @@ register_allocation(Program* program, std::vector<IDSet>& live_out_per_block, ra
                   PhysReg reg = get_reg(ctx, register_file, tmp, parallelcopy, instr);
                   definition->setFixed(reg);
                   if (reg.byte() || register_file.test(reg, 4)) {
-                     add_subdword_definition(program, instr, i, reg);
+                     add_subdword_definition(program, instr, reg);
                      definition = &instr->definitions[i]; /* add_subdword_definition can invalidate
                                                              the reference */
                   }



More information about the mesa-commit mailing list