[Mesa-dev] [PATCH v7 18/35] nvir/nir: implement nir_alu_instr handling

Karol Herbst kherbst at redhat.com
Mon Apr 16 13:25:58 UTC 2018


v2: user bitfield_insert instead of bfi
    rework switch helper macros
    remove some lowering code (LoweringHelper is now used for this)
v3: add pack_half_2x16_split
    add unpack_half_2x16_split_x/y
v5: replace first argument with nullptr in loadImm calls
    prefer getSSA over getScratch

Signed-off-by: Karol Herbst <kherbst at redhat.com>
---
 .../drivers/nouveau/codegen/nv50_ir_from_nir.cpp   | 489 ++++++++++++++++++++-
 1 file changed, 488 insertions(+), 1 deletion(-)

diff --git a/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp b/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp
index 8a474eb1a8c..8368bbcc015 100644
--- a/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp
+++ b/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp
@@ -34,6 +34,31 @@
 #include <unordered_map>
 #include <vector>
 
+#define CASE_OPFI(ni) \
+   case nir_op_f ## ni : \
+   case nir_op_i ## ni
+#define CASE_OPFIU(ni) \
+   case nir_op_f ## ni : \
+   case nir_op_i ## ni : \
+   case nir_op_u ## ni
+#define CASE_OPIU(ni) \
+   case nir_op_i ## ni : \
+   case nir_op_u ## ni
+
+#define CASE_OPFI_RET(ni, val) \
+   case nir_op_f ## ni : \
+   case nir_op_i ## ni : \
+      return val
+#define CASE_OPFIU_RET(ni, val) \
+   case nir_op_f ## ni : \
+   case nir_op_i ## ni : \
+   case nir_op_u ## ni : \
+      return val
+#define CASE_OPIU_RET(ni, val) \
+   case nir_op_i ## ni : \
+   case nir_op_u ## ni : \
+      return val
+
 static int
 type_size(const struct glsl_type *type)
 {
@@ -97,9 +122,17 @@ private:
    std::vector<DataType> getSTypes(nir_alu_instr*);
    DataType getSType(nir_src&, bool isFloat, bool isSigned);
 
+   operation getOperation(nir_op);
+   operation preOperationNeeded(nir_op);
+
+   int getSubOp(nir_op);
+
+   CondCode getCondCode(nir_op);
+
    bool assignSlots();
    bool parseNIR();
 
+   bool visit(nir_alu_instr *);
    bool visit(nir_block *);
    bool visit(nir_cf_node *);
    bool visit(nir_function *);
@@ -118,6 +151,7 @@ private:
    unsigned int curLoopDepth;
 
    BasicBlock *exit;
+   Value *zero;
 
    union {
       struct {
@@ -129,7 +163,10 @@ private:
 Converter::Converter(Program *prog, nir_shader *nir, nv50_ir_prog_info *info)
    : ConverterCommon(prog, info),
      nir(nir),
-     curLoopDepth(0) {}
+     curLoopDepth(0)
+{
+   zero = mkImm((uint32_t)0);
+}
 
 BasicBlock *
 Converter::convert(nir_block *block)
@@ -246,6 +283,137 @@ Converter::getSType(nir_src &src, bool isFloat, bool isSigned)
    return typeOfSize(bitSize / 8, isFloat, isSigned);
 }
 
+operation
+Converter::getOperation(nir_op op)
+{
+   switch (op) {
+   // basic ops with float and int variants
+   CASE_OPFI_RET(abs, OP_ABS);
+   CASE_OPFI_RET(add, OP_ADD);
+   CASE_OPFI_RET(and, OP_AND);
+   CASE_OPFIU_RET(div, OP_DIV);
+   CASE_OPIU_RET(find_msb, OP_BFIND);
+   CASE_OPFIU_RET(max, OP_MAX);
+   CASE_OPFIU_RET(min, OP_MIN);
+   CASE_OPFIU_RET(mod, OP_MOD);
+   CASE_OPFI_RET(rem, OP_MOD);
+   CASE_OPFI_RET(mul, OP_MUL);
+   CASE_OPIU_RET(mul_high, OP_MUL);
+   CASE_OPFI_RET(neg, OP_NEG);
+   CASE_OPFI_RET(not, OP_NOT);
+   CASE_OPFI_RET(or, OP_OR);
+   CASE_OPFI_RET(eq, OP_SET);
+   CASE_OPFIU_RET(ge, OP_SET);
+   CASE_OPFIU_RET(lt, OP_SET);
+   CASE_OPFI_RET(ne, OP_SET);
+   CASE_OPIU_RET(shr, OP_SHR);
+   CASE_OPFI_RET(sub, OP_SUB);
+   CASE_OPFI_RET(xor, OP_XOR);
+   case nir_op_fceil:
+      return OP_CEIL;
+   case nir_op_fcos:
+      return OP_COS;
+   case nir_op_f2f32:
+   case nir_op_f2f64:
+   case nir_op_f2i32:
+   case nir_op_f2i64:
+   case nir_op_f2u32:
+   case nir_op_f2u64:
+   case nir_op_i2f32:
+   case nir_op_i2f64:
+   case nir_op_i2i32:
+   case nir_op_i2i64:
+   case nir_op_u2f32:
+   case nir_op_u2f64:
+   case nir_op_u2u32:
+   case nir_op_u2u64:
+      return OP_CVT;
+   case nir_op_fddx:
+   case nir_op_fddx_coarse:
+   case nir_op_fddx_fine:
+      return OP_DFDX;
+   case nir_op_fddy:
+   case nir_op_fddy_coarse:
+   case nir_op_fddy_fine:
+      return OP_DFDY;
+   case nir_op_fexp2:
+      return OP_EX2;
+   case nir_op_ffloor:
+      return OP_FLOOR;
+   case nir_op_ffma:
+      return OP_FMA;
+   case nir_op_flog2:
+      return OP_LG2;
+   case nir_op_pack_64_2x32_split:
+      return OP_MERGE;
+   case nir_op_frcp:
+      return OP_RCP;
+   case nir_op_frsq:
+      return OP_RSQ;
+   case nir_op_fsat:
+      return OP_SAT;
+   case nir_op_ishl:
+      return OP_SHL;
+   case nir_op_fsin:
+      return OP_SIN;
+   case nir_op_fsqrt:
+      return OP_SQRT;
+   case nir_op_ftrunc:
+      return OP_TRUNC;
+   default:
+      ERROR("couldn't get operation for op %s\n", nir_op_infos[op].name);
+      assert(false);
+      return OP_NOP;
+   }
+}
+
+operation
+Converter::preOperationNeeded(nir_op op)
+{
+   switch (op) {
+   case nir_op_fcos:
+   case nir_op_fsin:
+      return OP_PRESIN;
+   default:
+      return OP_NOP;
+   }
+}
+
+int
+Converter::getSubOp(nir_op op)
+{
+   switch (op) {
+   CASE_OPIU_RET(mul_high, NV50_IR_SUBOP_MUL_HIGH);
+   default:
+      return 0;
+   }
+}
+
+CondCode
+Converter::getCondCode(nir_op op)
+{
+   switch (op) {
+   CASE_OPFI(eq):
+      return CC_EQ;
+   CASE_OPFIU(ge):
+      return CC_GE;
+   CASE_OPFIU(lt):
+      return CC_LT;
+   CASE_OPFI(ne):
+      return CC_NEU;
+   default:
+      ERROR("couldn't get CondCode for op %s\n", nir_op_infos[op].name);
+      assert(false);
+      return CC_FL;
+   }
+}
+
+Converter::LValues&
+Converter::convert(nir_alu_dest *dest)
+{
+   return convert(&dest->dest);
+}
+
 Converter::LValues&
 Converter::convert(nir_dest *dest)
 {
@@ -1278,6 +1446,8 @@ bool
 Converter::visit(nir_instr *insn)
 {
    switch (insn->type) {
+   case nir_instr_type_alu:
+      return visit(nir_instr_as_alu(insn));
    case nir_instr_type_intrinsic:
       return visit(nir_instr_as_intrinsic(insn));
    case nir_instr_type_jump:
@@ -1347,6 +1517,323 @@ Converter::visit(nir_load_const_instr *insn)
    return true;
 }
 
+#define DEFAULT_CHECKS \
+      if (insn->dest.dest.ssa.num_components > 1) { \
+         ERROR("nir_alu_instr only supported with 1 component!\n"); \
+         return false; \
+      } \
+      if (insn->dest.write_mask != 1) { \
+         ERROR("nir_alu_instr only with write_mask of 1 supported!\n"); \
+         return false; \
+      }
+bool
+Converter::visit(nir_alu_instr *insn)
+{
+   const nir_op op = insn->op;
+   const nir_op_info &info = nir_op_infos[op];
+   DataType dType = getDType(insn);
+   const std::vector<DataType> sTypes = getSTypes(insn);
+
+   Instruction *oldPos = this->bb->getExit();
+
+   switch (op) {
+   CASE_OPFI(abs):
+   CASE_OPFI(add):
+   CASE_OPFI(and):
+   case nir_op_fceil:
+   case nir_op_fcos:
+   case nir_op_fddx:
+   case nir_op_fddx_coarse:
+   case nir_op_fddx_fine:
+   case nir_op_fddy:
+   case nir_op_fddy_coarse:
+   case nir_op_fddy_fine:
+   CASE_OPFIU(div):
+   case nir_op_fexp2:
+   case nir_op_ffloor:
+   case nir_op_ffma:
+   case nir_op_flog2:
+   CASE_OPFIU(max):
+   CASE_OPFIU(min):
+   CASE_OPFIU(mod):
+   CASE_OPFI(mul):
+   CASE_OPIU(mul_high):
+   CASE_OPFI(neg):
+   CASE_OPFI(not):
+   CASE_OPFI(or):
+   case nir_op_pack_64_2x32_split:
+   case nir_op_frcp:
+   CASE_OPFI(rem):
+   case nir_op_frsq:
+   case nir_op_fsat:
+   CASE_OPIU(shr):
+   case nir_op_fsin:
+   case nir_op_fsqrt:
+   CASE_OPFI(sub):
+   case nir_op_ftrunc:
+   case nir_op_ishl:
+   CASE_OPFI(xor): {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      operation preOp = preOperationNeeded(op);
+      if (preOp != OP_NOP) {
+         assert(info.num_inputs < 2);
+         Value *tmp = getSSA(typeSizeof(dType));
+         Instruction *i0 = mkOp(preOp, dType, tmp);
+         Instruction *i1 = mkOp(getOperation(op), dType, newDefs[0]);
+         if (info.num_inputs) {
+            i0->setSrc(0, getSrc(&insn->src[0]));
+            i1->setSrc(0, tmp);
+         }
+         i1->subOp = getSubOp(op);
+      } else {
+         Instruction *i = mkOp(getOperation(op), dType, newDefs[0]);
+         for (auto s = 0u; s < info.num_inputs; ++s) {
+            i->setSrc(s, getSrc(&insn->src[s]));
+         }
+         i->subOp = getSubOp(op);
+      }
+      break;
+   }
+   CASE_OPIU(find_msb): {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      dType = sTypes[0];
+      mkOp1(getOperation(op), dType, newDefs[0], getSrc(&insn->src[0]));
+      break;
+   }
+   case nir_op_fround_even: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      mkCvt(OP_CVT, dType, newDefs[0], dType, getSrc(&insn->src[0]))->rnd = ROUND_NI;
+      break;
+   }
+   /* convert instructions */
+   CASE_OPFIU(2f32):
+   CASE_OPFIU(2f64):
+   case nir_op_f2i32:
+   case nir_op_f2i64:
+   case nir_op_f2u32:
+   case nir_op_f2u64:
+   case nir_op_i2i32:
+   case nir_op_i2i64:
+   case nir_op_u2u32:
+   case nir_op_u2u64: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      Instruction *i = mkOp1(getOperation(op), dType, newDefs[0], getSrc(&insn->src[0]));
+      if (op == nir_op_f2i32 || op == nir_op_f2i64 || op == nir_op_f2u32 || op == nir_op_f2u64)
+         i->rnd = ROUND_Z;
+      i->sType = sTypes[0];
+      break;
+   }
+   /* compare instructions */
+   CASE_OPFI(eq):
+   CASE_OPFIU(ge):
+   CASE_OPFIU(lt):
+   CASE_OPFI(ne): {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      Instruction *i = mkCmp(getOperation(op),
+                             getCondCode(op),
+                             dType,
+                             newDefs[0],
+                             dType,
+                             getSrc(&insn->src[0]),
+                             getSrc(&insn->src[1]));
+      if (info.num_inputs == 3)
+         i->setSrc(2, getSrc(&insn->src[2]));
+      i->sType = sTypes[0];
+      break;
+   }
+   /* those are weird ALU ops and need special handling, because
+    *   1. they are always componend based
+    *   2. they basically just merge multiple values into one data type
+    */
+   CASE_OPFI(mov):
+   case nir_op_vec2:
+   case nir_op_vec3:
+   case nir_op_vec4: {
+      LValues &newDefs = convert(&insn->dest);
+      for (LValues::size_type c = 0u; c < newDefs.size(); ++c) {
+         mkMov(newDefs[c], getSrc(&insn->src[c]), dType);
+      }
+      break;
+   }
+   /* (un)pack */
+   case nir_op_pack_64_2x32: {
+      LValues &newDefs = convert(&insn->dest);
+      Instruction *merge = mkOp(OP_MERGE, dType, newDefs[0]);
+      merge->setSrc(0, getSrc(&insn->src[0], 0));
+      merge->setSrc(1, getSrc(&insn->src[0], 1));
+      break;
+   }
+   case nir_op_pack_half_2x16_split: {
+      LValues &newDefs = convert(&insn->dest);
+      Value *tmpH = getSSA();
+      Value *tmpL = getSSA();
+
+      mkCvt(OP_CVT, TYPE_F16, tmpL, TYPE_F32, getSrc(&insn->src[0]));
+      mkCvt(OP_CVT, TYPE_F16, tmpH, TYPE_F32, getSrc(&insn->src[1]));
+      mkOp3(OP_INSBF, TYPE_U32, newDefs[0], tmpH, mkImm(0x1010), tmpL);
+      break;
+   }
+   case nir_op_unpack_half_2x16_split_x:
+   case nir_op_unpack_half_2x16_split_y: {
+      LValues &newDefs = convert(&insn->dest);
+      Instruction *cvt = mkCvt(OP_CVT, TYPE_F32, newDefs[0], TYPE_F16, getSrc(&insn->src[0]));
+      if (op == nir_op_unpack_half_2x16_split_y)
+         cvt->subOp = 1;
+      break;
+   }
+   case nir_op_unpack_64_2x32: {
+      LValues &newDefs = convert(&insn->dest);
+      mkOp1(OP_SPLIT, dType, newDefs[0], getSrc(&insn->src[0]))->setDef(1, newDefs[1]);
+      break;
+   }
+   case nir_op_unpack_64_2x32_split_x: {
+      LValues &newDefs = convert(&insn->dest);
+      mkOp1(OP_SPLIT, dType, newDefs[0], getSrc(&insn->src[0]))->setDef(1, getSSA());
+      break;
+   }
+   case nir_op_unpack_64_2x32_split_y: {
+      LValues &newDefs = convert(&insn->dest);
+      mkOp1(OP_SPLIT, dType, getSSA(), getSrc(&insn->src[0]))->setDef(1, newDefs[0]);
+      break;
+   }
+   /* special instructions */
+   CASE_OPFI(sign): {
+      DEFAULT_CHECKS;
+      DataType iType;
+      if (::isFloatType(dType))
+         iType = TYPE_F32;
+      else
+         iType = TYPE_S32;
+
+      LValues &newDefs = convert(&insn->dest);
+      LValue *val0 = getScratch();
+      LValue *val1 = getScratch();
+      mkCmp(OP_SET, CC_GT, iType, val0, dType, getSrc(&insn->src[0]), zero);
+      mkCmp(OP_SET, CC_LT, iType, val1, dType, getSrc(&insn->src[0]), zero);
+
+      if (dType == TYPE_F64) {
+         mkOp2(OP_SUB, iType, val0, val0, val1);
+         mkCvt(OP_CVT, TYPE_F64, newDefs[0], iType, val0);
+      } else if (dType == TYPE_S64 || dType == TYPE_U64) {
+         mkOp2(OP_SUB, iType, val0, val1, val0);
+         mkOp2(OP_SHR, iType, val1, val0, loadImm(nullptr, 31));
+         mkOp2(OP_MERGE, dType, newDefs[0], val0, val1);
+      } else if (::isFloatType(dType))
+         mkOp2(OP_SUB, iType, newDefs[0], val0, val1);
+      else
+         mkOp2(OP_SUB, iType, newDefs[0], val1, val0);
+      break;
+   }
+   case nir_op_fcsel:
+   case nir_op_bcsel: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      mkCmp(OP_SLCT, CC_NE, dType, newDefs[0], sTypes[0], getSrc(&insn->src[1]), getSrc(&insn->src[2]), getSrc(&insn->src[0]));
+      break;
+   }
+   CASE_OPIU(bfe):
+   CASE_OPIU(bitfield_extract): {
+      DEFAULT_CHECKS;
+      Value *tmp = getSSA();
+      LValues &newDefs = convert(&insn->dest);
+      mkOp3(OP_INSBF, dType, tmp, getSrc(&insn->src[2]), loadImm(nullptr, 0x808), getSrc(&insn->src[1]));
+      mkOp2(OP_EXTBF, dType, newDefs[0], getSrc(&insn->src[0]), tmp);
+      break;
+   }
+   case nir_op_bfm: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      mkOp3(OP_INSBF, dType, newDefs[0], getSrc(&insn->src[0]), loadImm(nullptr, 0x808), getSrc(&insn->src[1]));
+      break;
+   }
+   case nir_op_bitfield_insert: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      LValue *temp = getSSA();
+      mkOp3(OP_INSBF, TYPE_U32, temp, getSrc(&insn->src[3]), mkImm(0x808), getSrc(&insn->src[2]));
+      mkOp3(OP_INSBF, dType, newDefs[0], getSrc(&insn->src[1]), temp, getSrc(&insn->src[0]));
+      break;
+   }
+   case nir_op_bit_count: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      mkOp2(OP_POPCNT, dType, newDefs[0], getSrc(&insn->src[0]), getSrc(&insn->src[0]));
+      break;
+   }
+   case nir_op_bitfield_reverse: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      mkOp2(OP_EXTBF, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), mkImm(0x2000))->subOp = NV50_IR_SUBOP_EXTBF_REV;
+      break;
+   }
+   case nir_op_find_lsb: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      Value *tmp = getSSA();
+      mkOp2(OP_EXTBF, TYPE_U32, tmp, getSrc(&insn->src[0]), mkImm(0x2000))->subOp = NV50_IR_SUBOP_EXTBF_REV;
+      mkOp1(OP_BFIND, TYPE_U32, newDefs[0], tmp)->subOp = NV50_IR_SUBOP_BFIND_SAMT;
+      break;
+   }
+   /* boolean conversions */
+   case nir_op_b2f: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      mkOp2(OP_AND, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), loadImm(nullptr, 1.0f));
+      break;
+   }
+   CASE_OPFI(2b): {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      Value *src1;
+      if (typeSizeof(sTypes[0]) == 8) {
+         src1 = loadImm(getSSA(8), 0.0);
+      } else {
+         src1 = zero;
+      }
+      mkCmp(OP_SET, CC_NEU, TYPE_U32, newDefs[0], sTypes[0], getSrc(&insn->src[0]), src1);
+      break;
+   }
+   case nir_op_b2i: {
+      DEFAULT_CHECKS;
+      LValues &newDefs = convert(&insn->dest);
+      LValue *def;
+      if (typeSizeof(dType) == 8)
+         def = getScratch();
+      else
+         def = newDefs[0];
+
+      /* bools are always 32bit values */
+      mkOp2(OP_AND, TYPE_U32, def, getSrc(&insn->src[0]), loadImm(nullptr, 1));
+      if (typeSizeof(dType) == 8)
+         mkOp2(OP_MERGE, TYPE_S64, newDefs[0], def, loadImm(nullptr, 0));
+
+      break;
+   }
+   default:
+      ERROR("unknown nir_op %s\n", info.name);
+      return false;
+   }
+
+   if (!oldPos) {
+      oldPos = this->bb->getExit();
+      oldPos->precise = insn->exact;
+   }
+
+   while (oldPos->next) {
+      oldPos = oldPos->next;
+      oldPos->precise = insn->exact;
+   }
+   oldPos->saturate = insn->dest.saturate;
+
+   return true;
+}
+#undef DEFAULT_CHECKS
+
 bool
 Converter::run()
 {
-- 
2.14.3



More information about the mesa-dev mailing list