[Mesa-dev] [PATCH 11/22] nvir/nir: implement nir_alu_instr handling
Karol Herbst
kherbst at redhat.com
Thu Dec 21 15:51:29 UTC 2017
TODO: move lowering code somewhere else. We do the same thing as from_tgsi for
a few ops and we could move that down a bit so the input IR doesn't have to
deal with a few things, like slct and min/max with 64bit dest types.
TODO: move DEFAULT_HANDLER into its own function
TODO: check if some code duplication can be eliminated through templates
Signed-off-by: Karol Herbst <kherbst at redhat.com>
---
.../drivers/nouveau/codegen/nv50_ir_from_nir.cpp | 524 ++++++++++++++++++++-
1 file changed, 523 insertions(+), 1 deletion(-)
diff --git a/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp b/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp
index fe11280537..d2b2236c17 100644
--- a/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp
+++ b/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp
@@ -64,6 +64,7 @@ public:
Value* getSrc(nir_src *, uint8_t);
Value* getSrc(nir_ssa_def *, uint8_t);
+ bool visit(nir_alu_instr *);
bool visit(nir_block *);
bool visit(nir_cf_node *);
bool visit(nir_function *);
@@ -86,6 +87,10 @@ public:
std::vector<DataType> getSTypes(nir_alu_instr*);
DataType getSType(nir_src&, bool isFloat, bool isSigned);
+ operation getOperation(nir_op);
+ operation preOperationNeeded(nir_op);
+ int getSubOp(nir_op);
+ CondCode getCondCode(nir_op);
private:
nir_shader *nir;
@@ -95,6 +100,7 @@ private:
unsigned int curLoopDepth;
BasicBlock *exit;
+ Value *zero;
union {
struct {
@@ -106,7 +112,10 @@ private:
Converter::Converter(Program *prog, nir_shader *nir, nv50_ir_prog_info *info)
: ConverterCommon(prog, info),
nir(nir),
- curLoopDepth(0) {}
+ curLoopDepth(0)
+{
+ zero = mkImm((uint32_t)0);
+}
BasicBlock *
Converter::convert(nir_block *block)
@@ -224,6 +233,157 @@ Converter::getSType(nir_src &src, bool isFloat, bool isSigned)
return typeOfSize(bitSize / 8, isFloat, isSigned);
}
+#define CASE_OP(ni, no) \
+ case nir_op_f ## ni : \
+ case nir_op_i ## ni : \
+ return OP_ ## no
+#define CASE_OP3(ni, no) \
+ case nir_op_f ## ni : \
+ case nir_op_i ## ni : \
+ case nir_op_u ## ni : \
+ return OP_ ## no
+#define CASE_OPIU(ni, no) \
+ case nir_op_i ## ni : \
+ case nir_op_u ## ni : \
+ return OP_ ## no
+
+operation
+Converter::getOperation(nir_op op)
+{
+ switch (op) {
+ // basic ops with float and int variants
+ CASE_OP(abs, ABS);
+ CASE_OP(add, ADD);
+ CASE_OP(and, AND);
+ CASE_OP3(div, DIV);
+ CASE_OPIU(find_msb, BFIND);
+ CASE_OP3(max, MAX);
+ CASE_OP3(min, MIN);
+ CASE_OP3(mod, MOD);
+ CASE_OP(mul, MUL);
+ CASE_OPIU(mul_high, MUL);
+ CASE_OP(neg, NEG);
+ CASE_OP(not, NOT);
+ CASE_OP(or, OR);
+ CASE_OP(eq, SET);
+ CASE_OP3(ge, SET);
+ CASE_OP3(lt, SET);
+ CASE_OP(ne, SET);
+ CASE_OPIU(shr, SHR);
+ CASE_OP(sub, SUB);
+ CASE_OP(xor, XOR);
+ case nir_op_fceil:
+ return OP_CEIL;
+ case nir_op_fcos:
+ return OP_COS;
+ case nir_op_f2f32:
+ case nir_op_f2f64:
+ case nir_op_f2i32:
+ case nir_op_f2i64:
+ case nir_op_f2u32:
+ case nir_op_f2u64:
+ case nir_op_i2f32:
+ case nir_op_i2f64:
+ case nir_op_u2f32:
+ case nir_op_u2f64:
+ return OP_CVT;
+ case nir_op_fddx:
+ return OP_DFDX;
+ case nir_op_fddy:
+ return OP_DFDY;
+ case nir_op_fexp2:
+ return OP_EX2;
+ case nir_op_ffloor:
+ return OP_FLOOR;
+ case nir_op_ffma:
+ return OP_FMA;
+ case nir_op_flog2:
+ return OP_LG2;
+ case nir_op_frcp:
+ return OP_RCP;
+ case nir_op_frsq:
+ return OP_RSQ;
+ case nir_op_fsat:
+ return OP_SAT;
+ case nir_op_ishl:
+ return OP_SHL;
+ case nir_op_fsin:
+ return OP_SIN;
+ case nir_op_fsqrt:
+ return OP_SQRT;
+ case nir_op_ftrunc:
+ return OP_TRUNC;
+ default:
+ ERROR("couldn't get operation for op %s\n", nir_op_infos[op].name);
+ assert(false);
+ return OP_NOP;
+ }
+}
+#undef CASE_OP
+#undef CASE_OP3
+#undef CASE_OPIU
+
+operation
+Converter::preOperationNeeded(nir_op op)
+{
+ switch (op) {
+ case nir_op_fcos:
+ case nir_op_fsin:
+ return OP_PRESIN;
+ default:
+ return OP_NOP;
+ }
+}
+
+#define CASE_OPIU(ni, no) \
+ case nir_op_i ## ni : \
+ case nir_op_u ## ni : \
+ return NV50_IR_SUBOP_ ## no
+int
+Converter::getSubOp(nir_op op)
+{
+ switch (op) {
+ CASE_OPIU(mul_high, MUL_HIGH);
+ default:
+ return 0;
+ }
+}
+#undef CASE_OPIU
+
+#define CASE_OP(ni) \
+ case nir_op_f ## ni : \
+ case nir_op_i ## ni
+#define CASE_OP3(ni) \
+ case nir_op_f ## ni : \
+ case nir_op_i ## ni : \
+ case nir_op_u ## ni
+CondCode
+Converter::getCondCode(nir_op op)
+{
+ switch (op) {
+ CASE_OP(eq):
+ return CC_EQ;
+ CASE_OP3(ge):
+ return CC_GE;
+ CASE_OP3(lt):
+ return CC_LT;
+ CASE_OP(ne):
+ return CC_NEU;
+ default:
+ ERROR("couldn't get CondCode for op %s\n", nir_op_infos[op].name);
+ assert(false);
+ return CC_FL;
+ }
+}
+#undef CASE_OP
+#undef CASE_OP3
+
+Converter::LValues&
+Converter::convert(nir_alu_dest *dest)
+{
+ return convert(&dest->dest);
+}
+
Converter::LValues&
Converter::convert(nir_dest *dest)
{
@@ -486,6 +646,10 @@ bool
Converter::visit(nir_instr *insn)
{
switch (insn->type) {
+ case nir_instr_type_alu:
+ if (!visit(nir_instr_as_alu(insn)))
+ return false;
+ break;
case nir_instr_type_intrinsic:
if (!visit(nir_instr_as_intrinsic(insn)))
return false;
@@ -559,6 +723,364 @@ Converter::visit(nir_intrinsic_instr *insn)
return true;
}
+#define CASE_OP(ni) \
+ case nir_op_f ## ni : \
+ case nir_op_i ## ni
+#define CASE_OP3(ni) \
+ case nir_op_f ## ni : \
+ case nir_op_i ## ni : \
+ case nir_op_u ## ni
+#define CASE_OPIU(ni) \
+ case nir_op_i ## ni : \
+ case nir_op_u ## ni
+#define DEFAULT_CHECKS \
+ if (insn->dest.dest.ssa.num_components > 1) { \
+ ERROR("nir_alu_instr only supported with 1 component!\n"); \
+ return false; \
+ } \
+ if (insn->dest.write_mask != 1) { \
+ ERROR("nir_alu_instr only with write_mask of 1 supported!\n"); \
+ return false; \
+ }
+#define DEFAULT_HANDLER \
+ do { \
+ LValues &newDefs = convert(&insn->dest); \
+ operation preOp = preOperationNeeded(op); \
+ if (preOp != OP_NOP) { \
+ assert(info.num_inputs < 2); \
+ Instruction *i0 = mkOp(preOp, dType, newDefs[0]); \
+ Instruction *i1 = mkOp(getOperation(op), dType, newDefs[0]); \
+ if (info.num_inputs) { \
+ i0->setSrc(0, getSrc(&insn->src[0])); \
+ i1->setSrc(0, newDefs[0]); \
+ } \
+ i1->subOp = getSubOp(op); \
+ } else { \
+ Instruction *i = mkOp(getOperation(op), dType, newDefs[0]); \
+ for (auto s = 0u; s < info.num_inputs; ++s) { \
+ i->setSrc(s, getSrc(&insn->src[s])); \
+ } \
+ i->subOp = getSubOp(op); \
+ } \
+ } while (false)
+
+bool
+Converter::visit(nir_alu_instr *insn)
+{
+ // some helper variables
+ const nir_op op = insn->op;
+ const nir_op_info &info = nir_op_infos[op];
+ DataType dType = getDType(insn);
+ const std::vector<DataType> sTypes = getSTypes(insn);
+ // save last instruction
+ Instruction *oldPos = this->bb->getExit();
+
+ switch (op) {
+ CASE_OP(abs):
+ CASE_OP(add):
+ CASE_OP(and):
+ case nir_op_fceil:
+ case nir_op_fcos:
+ case nir_op_fddx:
+ case nir_op_fddy:
+ CASE_OP3(div):
+ case nir_op_fexp2:
+ case nir_op_ffloor:
+ case nir_op_ffma:
+ case nir_op_flog2:
+ CASE_OP3(mod):
+ CASE_OP(mul):
+ CASE_OPIU(mul_high):
+ CASE_OP(neg):
+ CASE_OP(not):
+ CASE_OP(or):
+ case nir_op_frcp:
+ case nir_op_frsq:
+ case nir_op_fsat:
+ CASE_OPIU(shr):
+ case nir_op_fsin:
+ case nir_op_fsqrt:
+ CASE_OP(sub):
+ case nir_op_ftrunc:
+ case nir_op_ishl:
+ CASE_OP(xor): {
+ DEFAULT_CHECKS;
+ DEFAULT_HANDLER;
+ break;
+ }
+ CASE_OPIU(find_msb): {
+ DEFAULT_CHECKS;
+ dType = sTypes[0];
+ DEFAULT_HANDLER;
+ break;
+ }
+ CASE_OP3(max):
+ CASE_OP3(min): {
+ DEFAULT_CHECKS;
+ if (dType == TYPE_U64 || dType == TYPE_S64) {
+ operation op = getOperation(insn->op);
+ LValues &newDefs = convert(&insn->dest);
+ DataType sdType = typeOfSize(4, false, isSignedIntType(dType));
+ Value *flag = getSSA(1, FILE_FLAGS);
+
+ Value *split0[2];
+ Value *split1[2];
+ Value *merge[2];
+
+ merge[0] = getScratch();
+ merge[1] = getScratch();
+
+ mkSplit(split0, 4, getSrc(&insn->src[0]));
+ mkSplit(split1, 4, getSrc(&insn->src[1]));
+
+ Instruction *hi = mkOp2(op, sdType, merge[1], split0[1], split1[1]);
+ hi->subOp = NV50_IR_SUBOP_MINMAX_HIGH;
+ hi->setFlagsDef(1, flag);
+
+ Instruction *low = mkOp2(op, sdType, merge[0], split0[0], split1[0]);
+ low->subOp = NV50_IR_SUBOP_MINMAX_LOW;
+ low->setFlagsSrc(2, flag);
+
+ mkOp2(OP_MERGE, dType, newDefs[0], merge[0], merge[1]);
+ } else
+ DEFAULT_HANDLER;
+ break;
+ }
+ case nir_op_fround_even: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ mkCvt(OP_CVT, dType, newDefs[0], dType, getSrc(&insn->src[0]))->rnd = ROUND_NI;
+ break;
+ }
+ // convert instructions
+ CASE_OP3(2f32):
+ CASE_OP3(2f64):
+ case nir_op_f2i32:
+ case nir_op_f2i64:
+ case nir_op_f2u32:
+ case nir_op_f2u64: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ Instruction *i = mkOp1(getOperation(op), dType, newDefs[0], getSrc(&insn->src[0]));
+ if (op == nir_op_f2i32 || op == nir_op_f2i64 || op == nir_op_f2u32 || op == nir_op_f2u64)
+ i->rnd = ROUND_Z;
+ i->sType = sTypes[0];
+ break;
+ }
+ // compare instructions
+ CASE_OP(eq):
+ CASE_OP3(ge):
+ CASE_OP3(lt):
+ CASE_OP(ne): {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ Instruction *i = mkCmp(getOperation(op),
+ getCondCode(op),
+ dType,
+ newDefs[0],
+ dType,
+ getSrc(&insn->src[0]),
+ getSrc(&insn->src[1]));
+ if (info.num_inputs == 3)
+ i->setSrc(2, getSrc(&insn->src[2]));
+ i->sType = sTypes[0];
+ break;
+ }
+ /* those are weird ALU ops and need special handling, because
+ * 1. they are always componend based
+ * 2. they basically just merge multiple values into one data type
+ */
+ CASE_OP(mov):
+ case nir_op_vec2:
+ case nir_op_vec3:
+ case nir_op_vec4: {
+ LValues &newDefs = convert(&insn->dest);
+ for (LValues::size_type c = 0u; c < newDefs.size(); ++c) {
+ mkMov(newDefs[c], getSrc(&insn->src[c]), dType);
+ }
+ break;
+ }
+ // (un)pack
+ case nir_op_pack_64_2x32: {
+ LValues &newDefs = convert(&insn->dest);
+ Instruction *merge = mkOp(OP_MERGE, dType, newDefs[0]);
+ merge->setSrc(0, getSrc(&insn->src[0], 0));
+ merge->setSrc(1, getSrc(&insn->src[0], 1));
+ break;
+ }
+ case nir_op_unpack_64_2x32: {
+ LValues &newDefs = convert(&insn->dest);
+ mkOp1(OP_SPLIT, dType, newDefs[0], getSrc(&insn->src[0]))->setDef(1, newDefs[1]);
+ break;
+ }
+ // special instructions
+ CASE_OP(sign): {
+ DEFAULT_CHECKS;
+ DataType iType;
+ if (::isFloatType(dType))
+ iType = TYPE_F32;
+ else
+ iType = TYPE_S32;
+
+ LValues &newDefs = convert(&insn->dest);
+ LValue *val0 = getScratch();
+ LValue *val1 = getScratch();
+ mkCmp(OP_SET, CC_GT, iType, val0, dType, getSrc(&insn->src[0]), zero);
+ mkCmp(OP_SET, CC_LT, iType, val1, dType, getSrc(&insn->src[0]), zero);
+
+ if (dType == TYPE_F64) {
+ mkOp2(OP_SUB, iType, val0, val0, val1);
+ mkCvt(OP_CVT, TYPE_F64, newDefs[0], iType, val0);
+ } else if (dType == TYPE_S64 || dType == TYPE_U64) {
+ mkOp2(OP_SUB, iType, val0, val1, val0);
+ mkOp2(OP_SHR, iType, val1, val0, loadImm(nullptr, 31));
+ mkOp2(OP_MERGE, dType, newDefs[0], val0, val1);
+ } else if (::isFloatType(dType))
+ mkOp2(OP_SUB, iType, newDefs[0], val0, val1);
+ else
+ mkOp2(OP_SUB, iType, newDefs[0], val1, val0);
+ break;
+ }
+ case nir_op_fcsel:
+ case nir_op_bcsel: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ if (typeSizeof(dType) > 4) {
+ Value *split0[2];
+ Value *split1[2];
+ Value *merge[2];
+ merge[0] = getScratch();
+ merge[1] = getScratch();
+ mkSplit(split0, 4, getSrc(&insn->src[1]));
+ mkSplit(split1, 4, getSrc(&insn->src[2]));
+ mkCmp(OP_SLCT, CC_NE, typeOfSize(4, ::isFloatType(dType)), merge[0], sTypes[0], split0[0], split1[0], getSrc(&insn->src[0]));
+ mkCmp(OP_SLCT, CC_NE, typeOfSize(4, ::isFloatType(dType)), merge[1], sTypes[0], split0[1], split1[1], getSrc(&insn->src[0]));
+ mkOp2(OP_MERGE, dType, newDefs[0], merge[0], merge[1]);
+ } else
+ mkCmp(OP_SLCT, CC_NE, dType, newDefs[0], sTypes[0], getSrc(&insn->src[1]), getSrc(&insn->src[2]), getSrc(&insn->src[0]));
+ break;
+ }
+ CASE_OPIU(bfe):
+ CASE_OPIU(bitfield_extract): {
+ DEFAULT_CHECKS;
+ Value *tmp = getScratch();
+ LValues &newDefs = convert(&insn->dest);
+ mkOp3(OP_INSBF, dType, tmp, getSrc(&insn->src[2]), loadImm(getScratch(), 0x808), getSrc(&insn->src[1]));
+ mkOp2(OP_EXTBF, dType, newDefs[0], getSrc(&insn->src[0]), tmp);
+ break;
+ }
+ case nir_op_bfm: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ mkOp3(OP_INSBF, dType, newDefs[0], getSrc(&insn->src[0]), loadImm(getScratch(), 0x808), getSrc(&insn->src[1]));
+ break;
+ }
+ case nir_op_bfi: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ mkOp3(OP_INSBF, dType, newDefs[0], getSrc(&insn->src[1]), getSrc(&insn->src[0]), getSrc(&insn->src[2]));
+ break;
+ }
+ case nir_op_bit_count: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ mkOp2(OP_POPCNT, dType, newDefs[0], getSrc(&insn->src[0]), getSrc(&insn->src[0]));
+ break;
+ }
+ case nir_op_bitfield_reverse: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ mkOp2(OP_EXTBF, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), mkImm(0x2000))->subOp = NV50_IR_SUBOP_EXTBF_REV;
+ break;
+ }
+ case nir_op_find_lsb: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ Value *tmp = getScratch();
+ mkOp2(OP_EXTBF, TYPE_U32, tmp, getSrc(&insn->src[0]), mkImm(0x2000))->subOp = NV50_IR_SUBOP_EXTBF_REV;
+ mkOp1(OP_BFIND, TYPE_U32, newDefs[0], tmp)->subOp = NV50_IR_SUBOP_BFIND_SAMT;
+ break;
+ }
+ // boolean conversions
+ case nir_op_b2f: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ mkOp2(OP_AND, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), loadImm(getScratch(), 1.0f));
+ break;
+ }
+ CASE_OP(2b): {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ Value *src1;
+ if (typeSizeof(sTypes[0]) == 8) {
+ src1 = loadImm(getScratch(8), 0.0);
+ } else {
+ src1 = zero;
+ }
+ mkCmp(OP_SET, CC_NEU, TYPE_U32, newDefs[0], sTypes[0], getSrc(&insn->src[0]), src1);
+ break;
+ }
+ case nir_op_b2i: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ LValue *def;
+ if (typeSizeof(dType) == 8)
+ def = getScratch();
+ else
+ def = newDefs[0];
+
+ // bools are always 32bit values
+ mkOp2(OP_AND, TYPE_U32, def, getSrc(&insn->src[0]), loadImm(getScratch(), 1));
+ if (typeSizeof(dType) == 8)
+ mkOp2(OP_MERGE, TYPE_S64, newDefs[0], def, loadImm(getScratch(), 0));
+
+ break;
+ }
+ case nir_op_i2i32:
+ case nir_op_u2u32: {
+ DEFAULT_CHECKS;
+ Value *src[2];
+ LValues &newDefs = convert(&insn->dest);
+ mkSplit(src, 4, getSrc(&insn->src[0]));
+ mkMov(newDefs[0], src[0]);
+ break;
+ }
+ case nir_op_i2i64: {
+ LValues &newDefs = convert(&insn->dest);
+ Value *dst0 = getSrc(&insn->src[0]);
+ Value *dst1 = getScratch();
+ mkOp2(OP_SHR, TYPE_S32, dst1, dst0, loadImm(NULL, 31));
+ mkOp2(OP_MERGE, TYPE_S64, newDefs[0], dst0, dst1);
+ break;
+ }
+ case nir_op_u2u64: {
+ LValues &newDefs = convert(&insn->dest);
+ mkOp2(OP_MERGE, TYPE_U64, newDefs[0], getSrc(&insn->src[0]), loadImm(getScratch(), 0));
+ break;
+ }
+ default:
+ ERROR("unknown nir_op %s\n", info.name);
+ return false;
+ }
+
+ if (!oldPos) {
+ oldPos = this->bb->getExit();
+ oldPos->precise = insn->exact;
+ }
+
+ while (oldPos->next) {
+ oldPos = oldPos->next;
+ oldPos->precise = insn->exact;
+ }
+ oldPos->saturate = insn->dest.saturate;
+
+ return true;
+}
+#undef CASE_OP
+#undef CASE_OP3
+#undef CASE_OPIU
+#undef DEFAULT_CHECKS
+
bool
Converter::run()
{
--
2.14.3
More information about the mesa-dev
mailing list