Mesa (main): aco: add more D16 load/store instructions to RA and validator

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Mon Nov 15 18:57:32 UTC 2021


Module: Mesa
Branch: main
Commit: 8f1483cd5ca83b78843b2c46ea3b593a61c24f4e
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=8f1483cd5ca83b78843b2c46ea3b593a61c24f4e

Author: Daniel Schürmann <daniel at schuermann.dev>
Date:   Mon Oct 25 14:26:05 2021 +0200

aco: add more D16 load/store instructions to RA and validator

This enables correct handling for
buffer_load/store_format_d16_x and
D16 Image instructions.

Reviewed-by: Rhys Perry <pendingchaos02 at gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13592>

---

 src/amd/compiler/aco_register_allocation.cpp | 26 +++++++++++-
 src/amd/compiler/aco_validate.cpp            | 63 +++++++++++++++++++++++++---
 2 files changed, 81 insertions(+), 8 deletions(-)

diff --git a/src/amd/compiler/aco_register_allocation.cpp b/src/amd/compiler/aco_register_allocation.cpp
index a037745aaaf..ae96d02bdcb 100644
--- a/src/amd/compiler/aco_register_allocation.cpp
+++ b/src/amd/compiler/aco_register_allocation.cpp
@@ -495,6 +495,7 @@ get_subdword_operand_stride(chip_class chip, const aco_ptr<Instruction>& instr,
    case aco_opcode::ds_write_b16: return chip >= GFX9 ? 2 : 4;
    case aco_opcode::buffer_store_byte:
    case aco_opcode::buffer_store_short:
+   case aco_opcode::buffer_store_format_d16_x:
    case aco_opcode::flat_store_byte:
    case aco_opcode::flat_store_short:
    case aco_opcode::scratch_store_byte:
@@ -552,6 +553,8 @@ add_subdword_operand(ra_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, uns
       instr->opcode = aco_opcode::buffer_store_byte_d16_hi;
    else if (instr->opcode == aco_opcode::buffer_store_short)
       instr->opcode = aco_opcode::buffer_store_short_d16_hi;
+   else if (instr->opcode == aco_opcode::buffer_store_format_d16_x)
+      instr->opcode = aco_opcode::buffer_store_format_d16_hi_x;
    else if (instr->opcode == aco_opcode::flat_store_byte)
       instr->opcode = aco_opcode::flat_store_byte_d16_hi;
    else if (instr->opcode == aco_opcode::flat_store_short)
@@ -601,6 +604,7 @@ get_subdword_definition_info(Program* program, const aco_ptr<Instruction>& instr
    }
 
    switch (instr->opcode) {
+   /* D16 loads with _hi version */
    case aco_opcode::ds_read_u8_d16:
    case aco_opcode::ds_read_i8_d16:
    case aco_opcode::ds_read_u16_d16:
@@ -615,16 +619,32 @@ get_subdword_definition_info(Program* program, const aco_ptr<Instruction>& instr
    case aco_opcode::scratch_load_short_d16:
    case aco_opcode::buffer_load_ubyte_d16:
    case aco_opcode::buffer_load_sbyte_d16:
-   case aco_opcode::buffer_load_short_d16: {
+   case aco_opcode::buffer_load_short_d16:
+   case aco_opcode::buffer_load_format_d16_x: {
       assert(chip >= GFX9);
       if (!program->dev.sram_ecc_enabled)
          return std::make_pair(2u, 2u);
       else
          return std::make_pair(2u, 4u);
    }
+   /* 3-component D16 loads */
+   case aco_opcode::buffer_load_format_d16_xyz:
+   case aco_opcode::tbuffer_load_format_d16_xyz: {
+      assert(chip >= GFX9);
+      if (!program->dev.sram_ecc_enabled)
+         return std::make_pair(4u, 6u);
+      break;
+   }
 
-   default: return std::make_pair(4, rc.size() * 4u);
+   default: break;
    }
+
+   if (instr->isMIMG() && instr->mimg().d16 && !program->dev.sram_ecc_enabled) {
+      assert(chip >= GFX9);
+      return std::make_pair(4u, rc.bytes());
+   }
+
+   return std::make_pair(4, rc.size() * 4u);
 }
 
 void
@@ -667,6 +687,8 @@ add_subdword_definition(Program* program, aco_ptr<Instruction>& instr, PhysReg r
       instr->opcode = aco_opcode::buffer_load_sbyte_d16_hi;
    else if (instr->opcode == aco_opcode::buffer_load_short_d16)
       instr->opcode = aco_opcode::buffer_load_short_d16_hi;
+   else if (instr->opcode == aco_opcode::buffer_load_format_d16_x)
+      instr->opcode = aco_opcode::buffer_load_format_d16_hi_x;
    else if (instr->opcode == aco_opcode::flat_load_ubyte_d16)
       instr->opcode = aco_opcode::flat_load_ubyte_d16_hi;
    else if (instr->opcode == aco_opcode::flat_load_sbyte_d16)
diff --git a/src/amd/compiler/aco_validate.cpp b/src/amd/compiler/aco_validate.cpp
index 8c5a84ac3f9..445b4cd4918 100644
--- a/src/amd/compiler/aco_validate.cpp
+++ b/src/amd/compiler/aco_validate.cpp
@@ -256,8 +256,8 @@ validate_ir(Program* program)
          /* check subdword definitions */
          for (unsigned i = 0; i < instr->definitions.size(); i++) {
             if (instr->definitions[i].regClass().is_subdword())
-               check(instr->isPseudo() || instr->definitions[i].bytes() <= 4,
-                     "Only Pseudo instructions can write subdword registers larger than 4 bytes",
+               check(instr->definitions[i].bytes() <= 4 || instr->isPseudo() || instr->isVMEM(),
+                     "Only Pseudo and VMEM instructions can write subdword registers > 4 bytes",
                      instr.get());
          }
 
@@ -542,6 +542,36 @@ validate_ir(Program* program)
                      (instr->operands[3].isTemp() &&
                       instr->operands[3].regClass().type() == RegType::vgpr),
                   "VMEM write data must be vgpr", instr.get());
+
+            const bool d16 = instr->opcode == aco_opcode::buffer_load_dword || // FIXME: used to spill subdword variables
+                             instr->opcode == aco_opcode::buffer_load_ubyte ||
+                             instr->opcode == aco_opcode::buffer_load_sbyte ||
+                             instr->opcode == aco_opcode::buffer_load_ushort ||
+                             instr->opcode == aco_opcode::buffer_load_sshort ||
+                             instr->opcode == aco_opcode::buffer_load_ubyte_d16 ||
+                             instr->opcode == aco_opcode::buffer_load_ubyte_d16_hi ||
+                             instr->opcode == aco_opcode::buffer_load_sbyte_d16 ||
+                             instr->opcode == aco_opcode::buffer_load_sbyte_d16_hi ||
+                             instr->opcode == aco_opcode::buffer_load_short_d16 ||
+                             instr->opcode == aco_opcode::buffer_load_short_d16_hi ||
+                             instr->opcode == aco_opcode::buffer_load_format_d16_x ||
+                             instr->opcode == aco_opcode::buffer_load_format_d16_hi_x ||
+                             instr->opcode == aco_opcode::buffer_load_format_d16_xy ||
+                             instr->opcode == aco_opcode::buffer_load_format_d16_xyz ||
+                             instr->opcode == aco_opcode::buffer_load_format_d16_xyzw ||
+                             instr->opcode == aco_opcode::tbuffer_load_format_d16_x ||
+                             instr->opcode == aco_opcode::tbuffer_load_format_d16_xy ||
+                             instr->opcode == aco_opcode::tbuffer_load_format_d16_xyz ||
+                             instr->opcode == aco_opcode::tbuffer_load_format_d16_xyzw;
+            if (instr->definitions.size()) {
+               check(instr->definitions[0].isTemp() &&
+                        instr->definitions[0].regClass().type() == RegType::vgpr,
+                     "VMEM definitions[0] (VDATA) must be VGPR", instr.get());
+               check(d16 || !instr->definitions[0].regClass().is_subdword(),
+                     "Only D16 opcodes can load subdword values.", instr.get());
+               check(instr->definitions[0].bytes() <= 8 || !d16,
+                     "D16 opcodes can only load up to 8 bytes.", instr.get());
+            }
             break;
          }
          case Format::MIMG: {
@@ -575,10 +605,16 @@ validate_ir(Program* program)
                         instr.get());
                }
             }
-            check(instr->definitions.empty() ||
-                     (instr->definitions[0].isTemp() &&
-                      instr->definitions[0].regClass().type() == RegType::vgpr),
-                  "MIMG definitions[0] (VDATA) must be VGPR", instr.get());
+
+            if (instr->definitions.size()) {
+               check(instr->definitions[0].isTemp() &&
+                        instr->definitions[0].regClass().type() == RegType::vgpr,
+                     "MIMG definitions[0] (VDATA) must be VGPR", instr.get());
+               check(instr->mimg().d16 || !instr->definitions[0].regClass().is_subdword(),
+                     "Only D16 MIMG instructions can load subdword values.", instr.get());
+               check(instr->definitions[0].bytes() <= 8 || !instr->mimg().d16,
+                     "D16 MIMG instructions can only load up to 8 bytes.", instr.get());
+            }
             break;
          }
          case Format::DS: {
@@ -744,6 +780,7 @@ validate_subdword_operand(chip_class chip, const aco_ptr<Instruction>& instr, un
       break;
    case aco_opcode::buffer_store_byte_d16_hi:
    case aco_opcode::buffer_store_short_d16_hi:
+   case aco_opcode::buffer_store_format_d16_hi_x:
       if (byte == 2 && index == 3)
          return true;
       break;
@@ -778,7 +815,9 @@ validate_subdword_definition(chip_class chip, const aco_ptr<Instruction>& instr)
 
    switch (instr->opcode) {
    case aco_opcode::buffer_load_ubyte_d16_hi:
+   case aco_opcode::buffer_load_sbyte_d16_hi:
    case aco_opcode::buffer_load_short_d16_hi:
+   case aco_opcode::buffer_load_format_d16_hi_x:
    case aco_opcode::flat_load_ubyte_d16_hi:
    case aco_opcode::flat_load_short_d16_hi:
    case aco_opcode::scratch_load_ubyte_d16_hi:
@@ -812,9 +851,17 @@ get_subdword_bytes_written(Program* program, const aco_ptr<Instruction>& instr,
       return 4;
    }
 
+   if (instr->isMIMG()) {
+      assert(instr->mimg().d16);
+      return program->dev.sram_ecc_enabled ? def.size() * 4u : def.bytes();
+   }
+
    switch (instr->opcode) {
    case aco_opcode::buffer_load_ubyte_d16:
+   case aco_opcode::buffer_load_sbyte_d16:
    case aco_opcode::buffer_load_short_d16:
+   case aco_opcode::buffer_load_format_d16_x:
+   case aco_opcode::tbuffer_load_format_d16_x:
    case aco_opcode::flat_load_ubyte_d16:
    case aco_opcode::flat_load_short_d16:
    case aco_opcode::scratch_load_ubyte_d16:
@@ -824,7 +871,9 @@ get_subdword_bytes_written(Program* program, const aco_ptr<Instruction>& instr,
    case aco_opcode::ds_read_u8_d16:
    case aco_opcode::ds_read_u16_d16:
    case aco_opcode::buffer_load_ubyte_d16_hi:
+   case aco_opcode::buffer_load_sbyte_d16_hi:
    case aco_opcode::buffer_load_short_d16_hi:
+   case aco_opcode::buffer_load_format_d16_hi_x:
    case aco_opcode::flat_load_ubyte_d16_hi:
    case aco_opcode::flat_load_short_d16_hi:
    case aco_opcode::scratch_load_ubyte_d16_hi:
@@ -833,6 +882,8 @@ get_subdword_bytes_written(Program* program, const aco_ptr<Instruction>& instr,
    case aco_opcode::global_load_short_d16_hi:
    case aco_opcode::ds_read_u8_d16_hi:
    case aco_opcode::ds_read_u16_d16_hi: return program->dev.sram_ecc_enabled ? 4 : 2;
+   case aco_opcode::buffer_load_format_d16_xyz:
+   case aco_opcode::tbuffer_load_format_d16_xyz: return program->dev.sram_ecc_enabled ? 8 : 6;
    default: return def.size() * 4;
    }
 }



More information about the mesa-commit mailing list