Mesa (main): aco: Implement usub_sat.

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Jun 1 17:59:53 UTC 2022


Module: Mesa
Branch: main
Commit: faa2a894876a387c8945bb46f6ce71f495db1d44
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=faa2a894876a387c8945bb46f6ce71f495db1d44

Author: Georg Lehmann <dadschoorse at gmail.com>
Date:   Fri Nov 19 16:28:52 2021 +0100

aco: Implement usub_sat.

Signed-off-by: Georg Lehmann <dadschoorse at gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof at gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13895>

---

 src/amd/compiler/aco_instruction_selection.cpp | 96 ++++++++++++++++++++++++++
 1 file changed, 96 insertions(+)

diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp
index e09c2c281c9..b98fcaf85b6 100644
--- a/src/amd/compiler/aco_instruction_selection.cpp
+++ b/src/amd/compiler/aco_instruction_selection.cpp
@@ -1345,6 +1345,25 @@ uadd32_sat(Builder& bld, Definition dst, Temp src0, Temp src1)
    return dst.getTemp();
 }
 
+Temp
+usub32_sat(Builder& bld, Definition dst, Temp src0, Temp src1)
+{
+   if (bld.program->gfx_level < GFX8) {
+      Builder::Result sub = bld.vsub32(bld.def(v1), src0, src1, true);
+      return bld.vop2_e64(aco_opcode::v_cndmask_b32, dst, sub.def(0).getTemp(), Operand::c32(0u),
+                          sub.def(1).getTemp());
+   }
+
+   Builder::Result sub(NULL);
+   if (bld.program->gfx_level >= GFX9) {
+      sub = bld.vop2_e64(aco_opcode::v_sub_u32, dst, src0, src1);
+   } else {
+      sub = bld.vop2_e64(aco_opcode::v_sub_co_u32, dst, bld.def(bld.lm), src0, src1);
+   }
+   sub.instr->vop3().clamp = 1;
+   return dst.getTemp();
+}
+
 void
 visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
 {
@@ -2082,6 +2101,83 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
       }
       break;
    }
+   case nir_op_usub_sat: {
+      Temp src0 = get_alu_src(ctx, instr->src[0]);
+      Temp src1 = get_alu_src(ctx, instr->src[1]);
+      if (dst.regClass() == s1) {
+         Temp tmp = bld.tmp(s1), carry = bld.tmp(s1);
+         bld.sop2(aco_opcode::s_sub_u32, Definition(tmp), bld.scc(Definition(carry)), src0, src1);
+         bld.sop2(aco_opcode::s_cselect_b32, Definition(dst), Operand::c32(0), tmp, bld.scc(carry));
+         break;
+      } else if (dst.regClass() == v2b) {
+         Instruction* sub_instr;
+         if (ctx->program->gfx_level >= GFX10) {
+            sub_instr = bld.vop3(aco_opcode::v_sub_u16_e64, Definition(dst), src0, src1).instr;
+         } else {
+            aco_opcode op = aco_opcode::v_sub_u16;
+            if (src1.type() == RegType::sgpr) {
+               std::swap(src0, src1);
+               op = aco_opcode::v_subrev_u16;
+            }
+            sub_instr = bld.vop2_e64(op, Definition(dst), src0, as_vgpr(ctx, src1)).instr;
+         }
+         sub_instr->vop3().clamp = 1;
+         break;
+      } else if (dst.regClass() == v1) {
+         usub32_sat(bld, Definition(dst), src0, as_vgpr(ctx, src1));
+         break;
+      }
+
+      assert(src0.size() == 2 && src1.size() == 2);
+      Temp src00 = bld.tmp(src0.type(), 1);
+      Temp src01 = bld.tmp(src0.type(), 1);
+      bld.pseudo(aco_opcode::p_split_vector, Definition(src00), Definition(src01), src0);
+      Temp src10 = bld.tmp(src1.type(), 1);
+      Temp src11 = bld.tmp(src1.type(), 1);
+      bld.pseudo(aco_opcode::p_split_vector, Definition(src10), Definition(src11), src1);
+
+      if (dst.regClass() == s2) {
+         Temp carry0 = bld.tmp(s1);
+         Temp carry1 = bld.tmp(s1);
+
+         Temp no_sat0 =
+            bld.sop2(aco_opcode::s_sub_u32, bld.def(s1), bld.scc(Definition(carry0)), src00, src10);
+         Temp no_sat1 = bld.sop2(aco_opcode::s_subb_u32, bld.def(s1), bld.scc(Definition(carry1)),
+                                 src01, src11, bld.scc(carry0));
+
+         Temp no_sat = bld.pseudo(aco_opcode::p_create_vector, bld.def(s2), no_sat0, no_sat1);
+
+         bld.sop2(aco_opcode::s_cselect_b64, Definition(dst), Operand::c64(0ull), no_sat,
+                  bld.scc(carry1));
+      } else if (dst.regClass() == v2) {
+         Temp no_sat0 = bld.tmp(v1);
+         Temp dst0 = bld.tmp(v1);
+         Temp dst1 = bld.tmp(v1);
+
+         Temp carry0 = bld.vsub32(Definition(no_sat0), src00, src10, true).def(1).getTemp();
+         Temp carry1;
+
+         if (ctx->program->gfx_level >= GFX8) {
+            carry1 = bld.tmp(bld.lm);
+            bld.vop2_e64(aco_opcode::v_subb_co_u32, Definition(dst1), Definition(carry1),
+                         as_vgpr(ctx, src01), as_vgpr(ctx, src11), carry0)
+               .instr->vop3()
+               .clamp = 1;
+         } else {
+            Temp no_sat1 = bld.tmp(v1);
+            carry1 = bld.vsub32(Definition(no_sat1), src01, src11, true, carry0).def(1).getTemp();
+            bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst1), no_sat1, Operand::c32(0u),
+                         carry1);
+         }
+
+         bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst0), no_sat0, Operand::c32(0u),
+                      carry1);
+         bld.pseudo(aco_opcode::p_create_vector, Definition(dst), dst0, dst1);
+      } else {
+         isel_err(&instr->instr, "Unimplemented NIR instr bit size");
+      }
+      break;
+   }
    case nir_op_imul: {
       if (dst.bytes() <= 2 && ctx->program->gfx_level >= GFX10) {
          emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_lo_u16_e64, dst);



More information about the mesa-commit mailing list