[Mesa-dev] [PATCH 11/22] nvir/nir: implement nir_alu_instr handling

Karol Herbst kherbst at redhat.com
Thu Dec 21 15:51:29 UTC 2017

TODO: move lowering code somewhere else. We do the same thing as from_tgsi for
a few ops and we could move that down a bit so the input IR doesn't have to
deal with a few things, like slct and min/max with 64bit dest types.

TODO: move DEFAULT_HANDLER into its own function

TODO: check if some code duplication can be eliminated through templates
Signed-off-by: Karol Herbst <kherbst at redhat.com>
 .../drivers/nouveau/codegen/nv50_ir_from_nir.cpp   | 524 ++++++++++++++++++++-
 1 file changed, 523 insertions(+), 1 deletion(-)

diff --git a/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp b/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp
index fe11280537..d2b2236c17 100644
--- a/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp
+++ b/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp
@@ -64,6 +64,7 @@ public:
    Value* getSrc(nir_src *, uint8_t);
    Value* getSrc(nir_ssa_def *, uint8_t);
+   bool visit(nir_alu_instr *);
    bool visit(nir_block *);
    bool visit(nir_cf_node *);
    bool visit(nir_function *);
@@ -86,6 +87,10 @@ public:
    std::vector<DataType> getSTypes(nir_alu_instr*);
    DataType getSType(nir_src&, bool isFloat, bool isSigned);
+   operation getOperation(nir_op);
+   operation preOperationNeeded(nir_op);
+   int getSubOp(nir_op);
+   CondCode getCondCode(nir_op);
    nir_shader *nir;
@@ -95,6 +100,7 @@ private:
    unsigned int curLoopDepth;
    BasicBlock *exit;
+   Value *zero;
    union {
       struct {
@@ -106,7 +112,10 @@ private:
 Converter::Converter(Program *prog, nir_shader *nir, nv50_ir_prog_info *info)
    : ConverterCommon(prog, info),
-     curLoopDepth(0) {}
+     curLoopDepth(0)
+   zero = mkImm((uint32_t)0);
 BasicBlock *
 Converter::convert(nir_block *block)
@@ -224,6 +233,157 @@ Converter::getSType(nir_src &src, bool isFloat, bool isSigned)
    return typeOfSize(bitSize / 8, isFloat, isSigned);
+#define CASE_OP(ni, no) \
+   case nir_op_f ## ni : \
+   case nir_op_i ## ni : \
+      return OP_ ## no
+#define CASE_OP3(ni, no) \
+   case nir_op_f ## ni : \
+   case nir_op_i ## ni : \
+   case nir_op_u ## ni : \
+      return OP_ ## no
+#define CASE_OPIU(ni, no) \
+   case nir_op_i ## ni : \
+   case nir_op_u ## ni : \
+      return OP_ ## no
+Converter::getOperation(nir_op op)
+   switch (op) {
+   // basic ops with float and int variants
+   CASE_OP(abs, ABS);
+   CASE_OP(add, ADD);
+   CASE_OP(and, AND);
+   CASE_OP3(div, DIV);
+   CASE_OPIU(find_msb, BFIND);
+   CASE_OP3(max, MAX);
+   CASE_OP3(min, MIN);
+   CASE_OP3(mod, MOD);
+   CASE_OP(mul, MUL);
+   CASE_OPIU(mul_high, MUL);
+   CASE_OP(neg, NEG);
+   CASE_OP(not, NOT);
+   CASE_OP(or, OR);
+   CASE_OP(eq, SET);
+   CASE_OP3(ge, SET);
+   CASE_OP3(lt, SET);
+   CASE_OP(ne, SET);
+   CASE_OPIU(shr, SHR);
+   CASE_OP(sub, SUB);
+   CASE_OP(xor, XOR);
+   case nir_op_fceil:
+      return OP_CEIL;
+   case nir_op_fcos:
+      return OP_COS;
+   case nir_op_f2f32:
+   case nir_op_f2f64:
+   case nir_op_f2i32:
+   case nir_op_f2i64:
+   case nir_op_f2u32:
+   case nir_op_f2u64:
+   case nir_op_i2f32:
+   case nir_op_i2f64:
+   case nir_op_u2f32:
+   case nir_op_u2f64:
+      return OP_CVT;
+   case nir_op_fddx:
+      return OP_DFDX;
+   case nir_op_fddy:
+      return OP_DFDY;
+   case nir_op_fexp2:
+      return OP_EX2;
+   case nir_op_ffloor:
+      return OP_FLOOR;
+   case nir_op_ffma:
+      return OP_FMA;
+   case nir_op_flog2:
+      return OP_LG2;
+   case nir_op_frcp:
+      return OP_RCP;
+   case nir_op_frsq:
+      return OP_RSQ;
+   case nir_op_fsat:
+      return OP_SAT;
+   case nir_op_ishl:
+      return OP_SHL;
+   case nir_op_fsin:
+      return OP_SIN;
+   case nir_op_fsqrt:
+      return OP_SQRT;
+   case nir_op_ftrunc:
+      return OP_TRUNC;
+   default:
+      ERROR("couldn't get operation for op %s\n", nir_op_infos[op].name);
+      assert(false);
+      return OP_NOP;
+   }
+#undef CASE_OP
+#undef CASE_OP3
+#undef CASE_OPIU
+Converter::preOperationNeeded(nir_op op)
+   switch (op) {
+   case nir_op_fcos:
+   case nir_op_fsin:
+      return OP_PRESIN;
+   default:
+      return OP_NOP;
+   }
+#define CASE_OPIU(ni, no) \
+   case nir_op_i ## ni : \
+   case nir_op_u ## ni : \
+      return NV50_IR_SUBOP_ ## no
+Converter::getSubOp(nir_op op)
+   switch (op) {
+   CASE_OPIU(mul_high, MUL_HIGH);
+   default:
+      return 0;
+   }
+#undef CASE_OPIU
+#define CASE_OP(ni) \
+   case nir_op_f ## ni : \
+   case nir_op_i ## ni
+#define CASE_OP3(ni) \
+   case nir_op_f ## ni : \
+   case nir_op_i ## ni : \
+   case nir_op_u ## ni
+Converter::getCondCode(nir_op op)
+   switch (op) {
+   CASE_OP(eq):
+      return CC_EQ;
+   CASE_OP3(ge):
+      return CC_GE;
+   CASE_OP3(lt):
+      return CC_LT;
+   CASE_OP(ne):
+      return CC_NEU;
+   default:
+      ERROR("couldn't get CondCode for op %s\n", nir_op_infos[op].name);
+      assert(false);
+      return CC_FL;
+   }
+#undef CASE_OP
+#undef CASE_OP3
+Converter::convert(nir_alu_dest *dest)
+   return convert(&dest->dest);
 Converter::convert(nir_dest *dest)
@@ -486,6 +646,10 @@ bool
 Converter::visit(nir_instr *insn)
    switch (insn->type) {
+   case nir_instr_type_alu:
+      if (!visit(nir_instr_as_alu(insn)))
+         return false;
+      break;
    case nir_instr_type_intrinsic:
       if (!visit(nir_instr_as_intrinsic(insn)))
          return false;
@@ -559,6 +723,364 @@ Converter::visit(nir_intrinsic_instr *insn)
    return true;
+#define CASE_OP(ni) \
+   case nir_op_f ## ni : \
+   case nir_op_i ## ni
+#define CASE_OP3(ni) \
+   case nir_op_f ## ni : \
+   case nir_op_i ## ni : \
+   case nir_op_u ## ni
+#define CASE_OPIU(ni) \
+   case nir_op_i ## ni : \
+   case nir_op_u ## ni
+      if (insn->dest.dest.ssa.num_components > 1) { \
+         ERROR("nir_alu_instr only supported with 1 component!\n"); \
+         return false; \
+      } \
+      if (insn->dest.write_mask != 1) { \
+         ERROR("nir_alu_instr only with write_mask of 1 supported!\n"); \
+         return false; \
+      }
+   do { \
+      LValues &newDefs = convert(&insn->dest); \
+      operation preOp = preOperationNeeded(op); \
+      if (preOp != OP_NOP) { \
+         assert(info.num_inputs < 2); \
+         Instruction *i0 = mkOp(preOp, dType, newDefs[0]); \
+         Instruction *i1 = mkOp(getOperation(op), dType, newDefs[0]); \
+         if (info.num_inputs) { \
+            i0->setSrc(0, getSrc(&insn->src[0])); \
+            i1->setSrc(0, newDefs[0]); \
+         } \
+         i1->subOp = getSubOp(op); \
+      } else { \
+         Instruction *i = mkOp(getOperation(op), dType, newDefs[0]); \
+         for (auto s = 0u; s < info.num_inputs; ++s) { \
+            i->setSrc(s, getSrc(&insn->src[s])); \
+         } \
+         i->subOp = getSubOp(op); \
+      } \
+   } while (false)
+Converter::visit(nir_alu_instr *insn)
+   // some helper variables
+   const nir_op op = insn->op;
+   const nir_op_info &info = nir_op_infos[op];
+   DataType dType = getDType(insn);
+   const std::vector<DataType> sTypes = getSTypes(insn);
+   // save last instruction
+   Instruction *oldPos = this->bb->getExit();
+   switch (op) {
+   CASE_OP(abs):
+   CASE_OP(add):
+   CASE_OP(and):
+   case nir_op_fceil:
+   case nir_op_fcos:
+   case nir_op_fddx:
+   case nir_op_fddy:
+   CASE_OP3(div):
+   case nir_op_fexp2:
+   case nir_op_ffloor:
+   case nir_op_ffma:
+   case nir_op_flog2:
+   CASE_OP3(mod):
+   CASE_OP(mul):
+   CASE_OPIU(mul_high):
+   CASE_OP(neg):
+   CASE_OP(not):
+   CASE_OP(or):
+   case nir_op_frcp:
+   case nir_op_frsq:
+   case nir_op_fsat:
+   CASE_OPIU(shr):
+   case nir_op_fsin:
+   case nir_op_fsqrt:
+   CASE_OP(sub):
+   case nir_op_ftrunc:
+   case nir_op_ishl:
+   CASE_OP(xor): {
+      break;
+   }
+   CASE_OPIU(find_msb): {
+      dType = sTypes[0];
+      break;
+   }
+   CASE_OP3(max):
+   CASE_OP3(min): {
+      if (dType == TYPE_U64 || dType == TYPE_S64) {
+         operation op = getOperation(insn->op);
+         LValues &newDefs = convert(&insn->dest);
+         DataType sdType = typeOfSize(4, false, isSignedIntType(dType));
+         Value *flag = getSSA(1, FILE_FLAGS);
+         Value *split0[2];
+         Value *split1[2];
+         Value *merge[2];
+         merge[0] = getScratch();
+         merge[1] = getScratch();
+         mkSplit(split0, 4, getSrc(&insn->src[0]));
+         mkSplit(split1, 4, getSrc(&insn->src[1]));
+         Instruction *hi = mkOp2(op, sdType, merge[1], split0[1], split1[1]);
+         hi->subOp = NV50_IR_SUBOP_MINMAX_HIGH;
+         hi->setFlagsDef(1, flag);
+         Instruction *low = mkOp2(op, sdType, merge[0], split0[0], split1[0]);
+         low->subOp = NV50_IR_SUBOP_MINMAX_LOW;
+         low->setFlagsSrc(2, flag);
+         mkOp2(OP_MERGE, dType, newDefs[0], merge[0], merge[1]);
+     } else
+      break;
+   }
+   case nir_op_fround_even: {
+      LValues &newDefs = convert(&insn->dest);
+      mkCvt(OP_CVT, dType, newDefs[0], dType, getSrc(&insn->src[0]))->rnd = ROUND_NI;
+      break;
+   }
+   // convert instructions
+   CASE_OP3(2f32):
+   CASE_OP3(2f64):
+   case nir_op_f2i32:
+   case nir_op_f2i64:
+   case nir_op_f2u32:
+   case nir_op_f2u64: {
+      LValues &newDefs = convert(&insn->dest);
+      Instruction *i = mkOp1(getOperation(op), dType, newDefs[0], getSrc(&insn->src[0]));
+      if (op == nir_op_f2i32 || op == nir_op_f2i64 || op == nir_op_f2u32 || op == nir_op_f2u64)
+         i->rnd = ROUND_Z;
+      i->sType = sTypes[0];
+      break;
+   }
+   // compare instructions
+   CASE_OP(eq):
+   CASE_OP3(ge):
+   CASE_OP3(lt):
+   CASE_OP(ne): {
+      LValues &newDefs = convert(&insn->dest);
+      Instruction *i = mkCmp(getOperation(op),
+                             getCondCode(op),
+                             dType,
+                             newDefs[0],
+                             dType,
+                             getSrc(&insn->src[0]),
+                             getSrc(&insn->src[1]));
+      if (info.num_inputs == 3)
+         i->setSrc(2, getSrc(&insn->src[2]));
+      i->sType = sTypes[0];
+      break;
+   }
+   /* those are weird ALU ops and need special handling, because
+    *   1. they are always componend based
+    *   2. they basically just merge multiple values into one data type
+    */
+   CASE_OP(mov):
+   case nir_op_vec2:
+   case nir_op_vec3:
+   case nir_op_vec4: {
+      LValues &newDefs = convert(&insn->dest);
+      for (LValues::size_type c = 0u; c < newDefs.size(); ++c) {
+         mkMov(newDefs[c], getSrc(&insn->src[c]), dType);
+      }
+      break;
+   }
+   // (un)pack
+   case nir_op_pack_64_2x32: {
+      LValues &newDefs = convert(&insn->dest);
+      Instruction *merge = mkOp(OP_MERGE, dType, newDefs[0]);
+      merge->setSrc(0, getSrc(&insn->src[0], 0));
+      merge->setSrc(1, getSrc(&insn->src[0], 1));
+      break;
+   }
+   case nir_op_unpack_64_2x32: {
+      LValues &newDefs = convert(&insn->dest);
+      mkOp1(OP_SPLIT, dType, newDefs[0], getSrc(&insn->src[0]))->setDef(1, newDefs[1]);
+      break;
+   }
+   // special instructions
+   CASE_OP(sign): {
+      DataType iType;
+      if (::isFloatType(dType))
+         iType = TYPE_F32;
+      else
+         iType = TYPE_S32;
+      LValues &newDefs = convert(&insn->dest);
+      LValue *val0 = getScratch();
+      LValue *val1 = getScratch();
+      mkCmp(OP_SET, CC_GT, iType, val0, dType, getSrc(&insn->src[0]), zero);
+      mkCmp(OP_SET, CC_LT, iType, val1, dType, getSrc(&insn->src[0]), zero);
+      if (dType == TYPE_F64) {
+         mkOp2(OP_SUB, iType, val0, val0, val1);
+         mkCvt(OP_CVT, TYPE_F64, newDefs[0], iType, val0);
+      } else if (dType == TYPE_S64 || dType == TYPE_U64) {
+         mkOp2(OP_SUB, iType, val0, val1, val0);
+         mkOp2(OP_SHR, iType, val1, val0, loadImm(nullptr, 31));
+         mkOp2(OP_MERGE, dType, newDefs[0], val0, val1);
+      } else if (::isFloatType(dType))
+         mkOp2(OP_SUB, iType, newDefs[0], val0, val1);
+      else
+         mkOp2(OP_SUB, iType, newDefs[0], val1, val0);
+      break;
+   }
+   case nir_op_fcsel:
+   case nir_op_bcsel: {
+      LValues &newDefs = convert(&insn->dest);
+      if (typeSizeof(dType) > 4) {
+         Value *split0[2];
+         Value *split1[2];
+         Value *merge[2];
+         merge[0] = getScratch();
+         merge[1] = getScratch();
+         mkSplit(split0, 4, getSrc(&insn->src[1]));
+         mkSplit(split1, 4, getSrc(&insn->src[2]));
+         mkCmp(OP_SLCT, CC_NE, typeOfSize(4, ::isFloatType(dType)), merge[0], sTypes[0], split0[0], split1[0], getSrc(&insn->src[0]));
+         mkCmp(OP_SLCT, CC_NE, typeOfSize(4, ::isFloatType(dType)), merge[1], sTypes[0], split0[1], split1[1], getSrc(&insn->src[0]));
+         mkOp2(OP_MERGE, dType, newDefs[0], merge[0], merge[1]);
+      } else
+         mkCmp(OP_SLCT, CC_NE, dType, newDefs[0], sTypes[0], getSrc(&insn->src[1]), getSrc(&insn->src[2]), getSrc(&insn->src[0]));
+      break;
+   }
+   CASE_OPIU(bfe):
+   CASE_OPIU(bitfield_extract): {
+      Value *tmp = getScratch();
+      LValues &newDefs = convert(&insn->dest);
+      mkOp3(OP_INSBF, dType, tmp, getSrc(&insn->src[2]), loadImm(getScratch(), 0x808), getSrc(&insn->src[1]));
+      mkOp2(OP_EXTBF, dType, newDefs[0], getSrc(&insn->src[0]), tmp);
+      break;
+   }
+   case nir_op_bfm: {
+      LValues &newDefs = convert(&insn->dest);
+      mkOp3(OP_INSBF, dType, newDefs[0], getSrc(&insn->src[0]), loadImm(getScratch(), 0x808), getSrc(&insn->src[1]));
+      break;
+   }
+   case nir_op_bfi: {
+      LValues &newDefs = convert(&insn->dest);
+      mkOp3(OP_INSBF, dType, newDefs[0], getSrc(&insn->src[1]), getSrc(&insn->src[0]), getSrc(&insn->src[2]));
+      break;
+   }
+   case nir_op_bit_count: {
+      LValues &newDefs = convert(&insn->dest);
+      mkOp2(OP_POPCNT, dType, newDefs[0], getSrc(&insn->src[0]), getSrc(&insn->src[0]));
+      break;
+   }
+   case nir_op_bitfield_reverse: {
+      LValues &newDefs = convert(&insn->dest);
+      mkOp2(OP_EXTBF, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), mkImm(0x2000))->subOp = NV50_IR_SUBOP_EXTBF_REV;
+      break;
+   }
+   case nir_op_find_lsb: {
+      LValues &newDefs = convert(&insn->dest);
+      Value *tmp = getScratch();
+      mkOp2(OP_EXTBF, TYPE_U32, tmp, getSrc(&insn->src[0]), mkImm(0x2000))->subOp = NV50_IR_SUBOP_EXTBF_REV;
+      mkOp1(OP_BFIND, TYPE_U32, newDefs[0], tmp)->subOp = NV50_IR_SUBOP_BFIND_SAMT;
+      break;
+   }
+   // boolean conversions
+   case nir_op_b2f: {
+      LValues &newDefs = convert(&insn->dest);
+      mkOp2(OP_AND, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), loadImm(getScratch(), 1.0f));
+      break;
+   }
+   CASE_OP(2b): {
+      LValues &newDefs = convert(&insn->dest);
+      Value *src1;
+      if (typeSizeof(sTypes[0]) == 8) {
+         src1 = loadImm(getScratch(8), 0.0);
+      } else {
+         src1 = zero;
+      }
+      mkCmp(OP_SET, CC_NEU, TYPE_U32, newDefs[0], sTypes[0], getSrc(&insn->src[0]), src1);
+      break;
+   }
+   case nir_op_b2i: {
+      LValues &newDefs = convert(&insn->dest);
+      LValue *def;
+      if (typeSizeof(dType) == 8)
+         def = getScratch();
+      else
+         def = newDefs[0];
+      // bools are always 32bit values
+      mkOp2(OP_AND, TYPE_U32, def, getSrc(&insn->src[0]), loadImm(getScratch(), 1));
+      if (typeSizeof(dType) == 8)
+         mkOp2(OP_MERGE, TYPE_S64, newDefs[0], def, loadImm(getScratch(), 0));
+      break;
+   }
+   case nir_op_i2i32:
+   case nir_op_u2u32: {
+      Value *src[2];
+      LValues &newDefs = convert(&insn->dest);
+      mkSplit(src, 4, getSrc(&insn->src[0]));
+      mkMov(newDefs[0], src[0]);
+      break;
+   }
+   case nir_op_i2i64: {
+      LValues &newDefs = convert(&insn->dest);
+      Value *dst0 = getSrc(&insn->src[0]);
+      Value *dst1 = getScratch();
+      mkOp2(OP_SHR, TYPE_S32, dst1, dst0, loadImm(NULL, 31));
+      mkOp2(OP_MERGE, TYPE_S64, newDefs[0], dst0, dst1);
+      break;
+   }
+   case nir_op_u2u64: {
+      LValues &newDefs = convert(&insn->dest);
+      mkOp2(OP_MERGE, TYPE_U64, newDefs[0], getSrc(&insn->src[0]), loadImm(getScratch(), 0));
+      break;
+   }
+   default:
+      ERROR("unknown nir_op %s\n", info.name);
+      return false;
+   }
+   if (!oldPos) {
+      oldPos = this->bb->getExit();
+      oldPos->precise = insn->exact;
+   }
+   while (oldPos->next) {
+      oldPos = oldPos->next;
+      oldPos->precise = insn->exact;
+   }
+   oldPos->saturate = insn->dest.saturate;
+   return true;
+#undef CASE_OP
+#undef CASE_OP3
+#undef CASE_OPIU

More information about the mesa-dev mailing list