Mesa (master): aco: emit packed 16bit instructions

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Jan 13 18:03:29 UTC 2021


Module: Mesa
Branch: master
Commit: 454bbf8f230e44e54b1dfc04e87dff353fa3fd1f
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=454bbf8f230e44e54b1dfc04e87dff353fa3fd1f

Author: Daniel Schürmann <daniel at schuermann.dev>
Date:   Mon Aug 24 20:36:06 2020 +0100

aco: emit packed 16bit instructions

Reviewed-by: Rhys Perry <pendingchaos02 at gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6680>

---

 src/amd/compiler/aco_instruction_selection.cpp     | 53 ++++++++++++++++++++++
 .../compiler/aco_instruction_selection_setup.cpp   | 17 +++++--
 2 files changed, 66 insertions(+), 4 deletions(-)

diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp
index 0dad702e419..1c47c45eacb 100644
--- a/src/amd/compiler/aco_instruction_selection.cpp
+++ b/src/amd/compiler/aco_instruction_selection.cpp
@@ -1406,6 +1406,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
          emit_vop3a_instruction(ctx, instr, aco_opcode::v_max_i16_e64, dst);
       } else if (dst.regClass() == v2b) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_max_i16, dst, true);
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_max_i16, dst);
       } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_max_i32, dst, true);
       } else if (dst.regClass() == s1) {
@@ -1420,6 +1422,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
          emit_vop3a_instruction(ctx, instr, aco_opcode::v_max_u16_e64, dst);
       } else if (dst.regClass() == v2b) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_max_u16, dst, true);
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_max_u16, dst);
       } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_max_u32, dst, true);
       } else if (dst.regClass() == s1) {
@@ -1434,6 +1438,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
          emit_vop3a_instruction(ctx, instr, aco_opcode::v_min_i16_e64, dst);
       } else if (dst.regClass() == v2b) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_min_i16, dst, true);
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_min_i16, dst);
       } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_min_i32, dst, true);
       } else if (dst.regClass() == s1) {
@@ -1448,6 +1454,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
          emit_vop3a_instruction(ctx, instr, aco_opcode::v_min_u16_e64, dst);
       } else if (dst.regClass() == v2b) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_min_u16, dst, true);
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_min_u16, dst);
       } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_min_u32, dst, true);
       } else if (dst.regClass() == s1) {
@@ -1510,6 +1518,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
          emit_vop3a_instruction(ctx, instr, aco_opcode::v_lshrrev_b16_e64, dst, false, 2, true);
       } else if (dst.regClass() == v2b) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_lshrrev_b16, dst, false, true);
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_lshrrev_b16, dst, true);
       } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_lshrrev_b32, dst, false, true);
       } else if (dst.regClass() == v2 && ctx->program->chip_class >= GFX8) {
@@ -1531,6 +1541,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
          emit_vop3a_instruction(ctx, instr, aco_opcode::v_lshlrev_b16_e64, dst, false, 2, true);
       } else if (dst.regClass() == v2b) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_lshlrev_b16, dst, false, true);
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_lshlrev_b16, dst, true);
       } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_lshlrev_b32, dst, false, true, false, false, 1);
       } else if (dst.regClass() == v2 && ctx->program->chip_class >= GFX8) {
@@ -1552,6 +1564,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
          emit_vop3a_instruction(ctx, instr, aco_opcode::v_ashrrev_i16_e64, dst, false, 2, true);
       } else if (dst.regClass() == v2b) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_ashrrev_i16, dst, false, true);
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_ashrrev_i16, dst, true);
       } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_ashrrev_i32, dst, false, true);
       } else if (dst.regClass() == v2 && ctx->program->chip_class >= GFX8) {
@@ -1628,6 +1642,9 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       } else if (dst.bytes() <= 2 && ctx->program->chip_class >= GFX8) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_add_u16, dst, true);
          break;
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_add_u16, dst);
+         break;
       }
 
       Temp src0 = get_alu_src(ctx, instr->src[0]);
@@ -1737,6 +1754,9 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       if (dst.regClass() == s1) {
          emit_sop2_instruction(ctx, instr, aco_opcode::s_sub_i32, dst, true);
          break;
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_sub_u16, dst);
+         break;
       }
 
       Temp src0 = get_alu_src(ctx, instr->src[0]);
@@ -1815,6 +1835,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
          emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_lo_u16_e64, dst);
       } else if (dst.bytes() <= 2 && ctx->program->chip_class >= GFX8) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_lo_u16, dst, true);
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_mul_lo_u16, dst);
       } else if (dst.type() == RegType::vgpr) {
          Temp src0 = get_alu_src(ctx, instr->src[0]);
          Temp src1 = get_alu_src(ctx, instr->src[1]);
@@ -1896,6 +1918,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    case nir_op_fmul: {
       if (dst.regClass() == v2b) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f16, dst, true);
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_mul_f16, dst);
       } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f32, dst, true);
       } else if (dst.regClass() == v2) {
@@ -1908,6 +1932,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    case nir_op_fadd: {
       if (dst.regClass() == v2b) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f16, dst, true);
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_add_f16, dst);
       } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f32, dst, true);
       } else if (dst.regClass() == v2) {
@@ -1918,6 +1944,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_fsub: {
+      if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         Instruction* add = emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_add_f16, dst);
+         VOP3P_instruction* sub = static_cast<VOP3P_instruction*>(add);
+         sub->neg_lo[1] = true;
+         sub->neg_hi[1] = true;
+         break;
+      }
+
       Temp src0 = get_alu_src(ctx, instr->src[0]);
       Temp src1 = get_alu_src(ctx, instr->src[1]);
       if (dst.regClass() == v2b) {
@@ -1944,6 +1978,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       if (dst.regClass() == v2b) {
          // TODO: check fp_mode.must_flush_denorms16_64
          emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f16, dst, true);
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_max_f16, dst);
       } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32);
       } else if (dst.regClass() == v2) {
@@ -1957,6 +1993,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       if (dst.regClass() == v2b) {
          // TODO: check fp_mode.must_flush_denorms16_64
          emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f16, dst, true);
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_min_f16, dst, true);
       } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32);
       } else if (dst.regClass() == v2) {
@@ -2011,6 +2049,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_fneg: {
+      if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         Temp src = get_alu_src_vop3p(ctx, instr->src[0]);
+         bld.vop3p(aco_opcode::v_pk_mul_f16, Definition(dst), src, Operand(uint16_t(0xBC00)),
+                   instr->src[0].swizzle[0] & 1, instr->src[0].swizzle[1] & 1);
+         emit_split_vector(ctx, dst, 2);
+         break;
+      }
       Temp src = get_alu_src(ctx, instr->src[0]);
       if (dst.regClass() == v2b) {
          if (ctx->block->fp_mode.must_flush_denorms16_64)
@@ -2055,6 +2100,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_fsat: {
+      if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         Temp src = get_alu_src_vop3p(ctx, instr->src[0]);
+         Instruction* vop3p = bld.vop3p(aco_opcode::v_pk_mul_f16, Definition(dst), src, Operand(uint16_t(0x3C00)),
+                                        instr->src[0].swizzle[0] & 1, instr->src[0].swizzle[1] & 1);
+         static_cast<VOP3P_instruction*>(vop3p)->clamp = true;
+         emit_split_vector(ctx, dst, 2);
+         break;
+      }
       Temp src = get_alu_src(ctx, instr->src[0]);
       if (dst.regClass() == v2b) {
          bld.vop3(aco_opcode::v_med3_f16, Definition(dst), Operand((uint16_t)0u), Operand((uint16_t)0x3c00), src);
diff --git a/src/amd/compiler/aco_instruction_selection_setup.cpp b/src/amd/compiler/aco_instruction_selection_setup.cpp
index e0c2b93d2b8..04aa7be7f3d 100644
--- a/src/amd/compiler/aco_instruction_selection_setup.cpp
+++ b/src/amd/compiler/aco_instruction_selection_setup.cpp
@@ -680,7 +680,7 @@ void init_context(isel_context *ctx, nir_shader *shader)
             switch(instr->type) {
             case nir_instr_type_alu: {
                nir_alu_instr *alu_instr = nir_instr_as_alu(instr);
-               RegType type = RegType::sgpr;
+               RegType type = nir_dest_is_divergent(alu_instr->dest.dest) ? RegType::vgpr : RegType::sgpr;
                switch(alu_instr->op) {
                   case nir_op_fmul:
                   case nir_op_fadd:
@@ -745,10 +745,19 @@ void init_context(isel_context *ctx, nir_shader *shader)
                   case nir_op_b2f16:
                   case nir_op_b2f32:
                   case nir_op_mov:
-                     type = nir_dest_is_divergent(alu_instr->dest.dest) ? RegType::vgpr : RegType::sgpr;
                      break;
-                  case nir_op_bcsel:
-                     type = nir_dest_is_divergent(alu_instr->dest.dest) ? RegType::vgpr : RegType::sgpr;
+                  case nir_op_iadd:
+                  case nir_op_isub:
+                  case nir_op_imul:
+                  case nir_op_imin:
+                  case nir_op_imax:
+                  case nir_op_umin:
+                  case nir_op_umax:
+                  case nir_op_ishl:
+                  case nir_op_ishr:
+                  case nir_op_ushr:
+                     /* packed 16bit instructions have to be VGPR */
+                     type = alu_instr->dest.dest.ssa.num_components == 2 ? RegType::vgpr : type;
                      FALLTHROUGH;
                   default:
                      for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) {



More information about the mesa-commit mailing list