[Mesa-dev] [PATCH v5 17/34] nvir/nir: implement nir_alu_instr handling
Karol Herbst
kherbst at redhat.com
Tue Feb 20 21:02:35 UTC 2018
Signed-off-by: Karol Herbst <kherbst at redhat.com>
v2: user bitfield_insert instead of bfi
rework switch helper macros
remove some lowering code (LoweringHelper is now used for this)
v3: add pack_half_2x16_split
add unpack_half_2x16_split_x/y
v5: replace first argument with nullptr in loadImm calls
prefer getSSA over getScratch
Signed-off-by: Karol Herbst <kherbst at redhat.com>
---
.../drivers/nouveau/codegen/nv50_ir_from_nir.cpp | 487 ++++++++++++++++++++-
1 file changed, 486 insertions(+), 1 deletion(-)
diff --git a/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp b/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp
index 544d2cc778..7425cf8874 100644
--- a/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp
+++ b/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp
@@ -32,6 +32,31 @@
#include <unordered_map>
#include <vector>
+#define CASE_OPFI(ni) \
+ case nir_op_f ## ni : \
+ case nir_op_i ## ni
+#define CASE_OPFIU(ni) \
+ case nir_op_f ## ni : \
+ case nir_op_i ## ni : \
+ case nir_op_u ## ni
+#define CASE_OPIU(ni) \
+ case nir_op_i ## ni : \
+ case nir_op_u ## ni
+
+#define CASE_OPFI_RET(ni, val) \
+ case nir_op_f ## ni : \
+ case nir_op_i ## ni : \
+ return val
+#define CASE_OPFIU_RET(ni, val) \
+ case nir_op_f ## ni : \
+ case nir_op_i ## ni : \
+ case nir_op_u ## ni : \
+ return val
+#define CASE_OPIU_RET(ni, val) \
+ case nir_op_i ## ni : \
+ case nir_op_u ## ni : \
+ return val
+
static int
type_size(const struct glsl_type *type)
{
@@ -78,6 +103,7 @@ public:
Instruction *loadFrom(DataFile, uint8_t, DataType, Value *def, uint32_t base, uint8_t c, Value *indirect0 = nullptr, Value *indirect1 = nullptr, bool patch = false);
void storeTo(nir_intrinsic_instr *, DataFile, operation, DataType, Value *src, uint8_t idx, uint8_t c, Value *indirect0 = nullptr, Value *indirect1 = nullptr);
+ bool visit(nir_alu_instr *);
bool visit(nir_block *);
bool visit(nir_cf_node *);
bool visit(nir_function *);
@@ -100,6 +126,10 @@ public:
std::vector<DataType> getSTypes(nir_alu_instr*);
DataType getSType(nir_src&, bool isFloat, bool isSigned);
+ operation getOperation(nir_op);
+ operation preOperationNeeded(nir_op);
+ int getSubOp(nir_op);
+ CondCode getCondCode(nir_op);
private:
nir_shader *nir;
@@ -109,6 +139,7 @@ private:
unsigned int curLoopDepth;
BasicBlock *exit;
+ Value *zero;
union {
struct {
@@ -120,7 +151,10 @@ private:
Converter::Converter(Program *prog, nir_shader *nir, nv50_ir_prog_info *info)
: ConverterCommon(prog, info),
nir(nir),
- curLoopDepth(0) {}
+ curLoopDepth(0)
+{
+ zero = mkImm((uint32_t)0);
+}
BasicBlock *
Converter::convert(nir_block *block)
@@ -239,6 +273,136 @@ Converter::getSType(nir_src &src, bool isFloat, bool isSigned)
return typeOfSize(bitSize / 8, isFloat, isSigned);
}
+operation
+Converter::getOperation(nir_op op)
+{
+ switch (op) {
+ // basic ops with float and int variants
+ CASE_OPFI_RET(abs, OP_ABS);
+ CASE_OPFI_RET(add, OP_ADD);
+ CASE_OPFI_RET(and, OP_AND);
+ CASE_OPFIU_RET(div, OP_DIV);
+ CASE_OPIU_RET(find_msb, OP_BFIND);
+ CASE_OPFIU_RET(max, OP_MAX);
+ CASE_OPFIU_RET(min, OP_MIN);
+ CASE_OPFIU_RET(mod, OP_MOD);
+ CASE_OPFI_RET(mul, OP_MUL);
+ CASE_OPIU_RET(mul_high, OP_MUL);
+ CASE_OPFI_RET(neg, OP_NEG);
+ CASE_OPFI_RET(not, OP_NOT);
+ CASE_OPFI_RET(or, OP_OR);
+ CASE_OPFI_RET(eq, OP_SET);
+ CASE_OPFIU_RET(ge, OP_SET);
+ CASE_OPFIU_RET(lt, OP_SET);
+ CASE_OPFI_RET(ne, OP_SET);
+ CASE_OPIU_RET(shr, OP_SHR);
+ CASE_OPFI_RET(sub, OP_SUB);
+ CASE_OPFI_RET(xor, OP_XOR);
+ case nir_op_fceil:
+ return OP_CEIL;
+ case nir_op_fcos:
+ return OP_COS;
+ case nir_op_f2f32:
+ case nir_op_f2f64:
+ case nir_op_f2i32:
+ case nir_op_f2i64:
+ case nir_op_f2u32:
+ case nir_op_f2u64:
+ case nir_op_i2f32:
+ case nir_op_i2f64:
+ case nir_op_i2i32:
+ case nir_op_i2i64:
+ case nir_op_u2f32:
+ case nir_op_u2f64:
+ case nir_op_u2u32:
+ case nir_op_u2u64:
+ return OP_CVT;
+ case nir_op_fddx:
+ case nir_op_fddx_coarse:
+ case nir_op_fddx_fine:
+ return OP_DFDX;
+ case nir_op_fddy:
+ case nir_op_fddy_coarse:
+ case nir_op_fddy_fine:
+ return OP_DFDY;
+ case nir_op_fexp2:
+ return OP_EX2;
+ case nir_op_ffloor:
+ return OP_FLOOR;
+ case nir_op_ffma:
+ return OP_FMA;
+ case nir_op_flog2:
+ return OP_LG2;
+ case nir_op_pack_64_2x32_split:
+ return OP_MERGE;
+ case nir_op_frcp:
+ return OP_RCP;
+ case nir_op_frsq:
+ return OP_RSQ;
+ case nir_op_fsat:
+ return OP_SAT;
+ case nir_op_ishl:
+ return OP_SHL;
+ case nir_op_fsin:
+ return OP_SIN;
+ case nir_op_fsqrt:
+ return OP_SQRT;
+ case nir_op_ftrunc:
+ return OP_TRUNC;
+ default:
+ ERROR("couldn't get operation for op %s\n", nir_op_infos[op].name);
+ assert(false);
+ return OP_NOP;
+ }
+}
+
+operation
+Converter::preOperationNeeded(nir_op op)
+{
+ switch (op) {
+ case nir_op_fcos:
+ case nir_op_fsin:
+ return OP_PRESIN;
+ default:
+ return OP_NOP;
+ }
+}
+
+int
+Converter::getSubOp(nir_op op)
+{
+ switch (op) {
+ CASE_OPIU_RET(mul_high, NV50_IR_SUBOP_MUL_HIGH);
+ default:
+ return 0;
+ }
+}
+
+CondCode
+Converter::getCondCode(nir_op op)
+{
+ switch (op) {
+ CASE_OPFI(eq):
+ return CC_EQ;
+ CASE_OPFIU(ge):
+ return CC_GE;
+ CASE_OPFIU(lt):
+ return CC_LT;
+ CASE_OPFI(ne):
+ return CC_NEU;
+ default:
+ ERROR("couldn't get CondCode for op %s\n", nir_op_infos[op].name);
+ assert(false);
+ return CC_FL;
+ }
+}
+
+Converter::LValues&
+Converter::convert(nir_alu_dest *dest)
+{
+ return convert(&dest->dest);
+}
+
Converter::LValues&
Converter::convert(nir_dest *dest)
{
@@ -1213,6 +1377,10 @@ bool
Converter::visit(nir_instr *insn)
{
switch (insn->type) {
+ case nir_instr_type_alu:
+ if (!visit(nir_instr_as_alu(insn)))
+ return false;
+ break;
case nir_instr_type_intrinsic:
if (!visit(nir_instr_as_intrinsic(insn)))
return false;
@@ -1288,6 +1456,323 @@ Converter::visit(nir_intrinsic_instr *insn)
return true;
}
+#define DEFAULT_CHECKS \
+ if (insn->dest.dest.ssa.num_components > 1) { \
+ ERROR("nir_alu_instr only supported with 1 component!\n"); \
+ return false; \
+ } \
+ if (insn->dest.write_mask != 1) { \
+ ERROR("nir_alu_instr only with write_mask of 1 supported!\n"); \
+ return false; \
+ }
+bool
+Converter::visit(nir_alu_instr *insn)
+{
+ // some helper variables
+ const nir_op op = insn->op;
+ const nir_op_info &info = nir_op_infos[op];
+ DataType dType = getDType(insn);
+ const std::vector<DataType> sTypes = getSTypes(insn);
+ // save last instruction
+ Instruction *oldPos = this->bb->getExit();
+
+ switch (op) {
+ CASE_OPFI(abs):
+ CASE_OPFI(add):
+ CASE_OPFI(and):
+ case nir_op_fceil:
+ case nir_op_fcos:
+ case nir_op_fddx:
+ case nir_op_fddx_coarse:
+ case nir_op_fddx_fine:
+ case nir_op_fddy:
+ case nir_op_fddy_coarse:
+ case nir_op_fddy_fine:
+ CASE_OPFIU(div):
+ case nir_op_fexp2:
+ case nir_op_ffloor:
+ case nir_op_ffma:
+ case nir_op_flog2:
+ CASE_OPFIU(max):
+ CASE_OPFIU(min):
+ CASE_OPFIU(mod):
+ CASE_OPFI(mul):
+ CASE_OPIU(mul_high):
+ CASE_OPFI(neg):
+ CASE_OPFI(not):
+ CASE_OPFI(or):
+ case nir_op_pack_64_2x32_split:
+ case nir_op_frcp:
+ case nir_op_frsq:
+ case nir_op_fsat:
+ CASE_OPIU(shr):
+ case nir_op_fsin:
+ case nir_op_fsqrt:
+ CASE_OPFI(sub):
+ case nir_op_ftrunc:
+ case nir_op_ishl:
+ CASE_OPFI(xor): {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ operation preOp = preOperationNeeded(op);
+ if (preOp != OP_NOP) {
+ assert(info.num_inputs < 2);
+ Value *tmp = getSSA(typeSizeof(dType));
+ Instruction *i0 = mkOp(preOp, dType, tmp);
+ Instruction *i1 = mkOp(getOperation(op), dType, newDefs[0]);
+ if (info.num_inputs) {
+ i0->setSrc(0, getSrc(&insn->src[0]));
+ i1->setSrc(0, tmp);
+ }
+ i1->subOp = getSubOp(op);
+ } else {
+ Instruction *i = mkOp(getOperation(op), dType, newDefs[0]);
+ for (auto s = 0u; s < info.num_inputs; ++s) {
+ i->setSrc(s, getSrc(&insn->src[s]));
+ }
+ i->subOp = getSubOp(op);
+ }
+ break;
+ }
+ CASE_OPIU(find_msb): {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ dType = sTypes[0];
+ mkOp1(getOperation(op), dType, newDefs[0], getSrc(&insn->src[0]));
+ break;
+ }
+ case nir_op_fround_even: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ mkCvt(OP_CVT, dType, newDefs[0], dType, getSrc(&insn->src[0]))->rnd = ROUND_NI;
+ break;
+ }
+ // convert instructions
+ CASE_OPFIU(2f32):
+ CASE_OPFIU(2f64):
+ case nir_op_f2i32:
+ case nir_op_f2i64:
+ case nir_op_f2u32:
+ case nir_op_f2u64:
+ case nir_op_i2i32:
+ case nir_op_i2i64:
+ case nir_op_u2u32:
+ case nir_op_u2u64: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ Instruction *i = mkOp1(getOperation(op), dType, newDefs[0], getSrc(&insn->src[0]));
+ if (op == nir_op_f2i32 || op == nir_op_f2i64 || op == nir_op_f2u32 || op == nir_op_f2u64)
+ i->rnd = ROUND_Z;
+ i->sType = sTypes[0];
+ break;
+ }
+ // compare instructions
+ CASE_OPFI(eq):
+ CASE_OPFIU(ge):
+ CASE_OPFIU(lt):
+ CASE_OPFI(ne): {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ Instruction *i = mkCmp(getOperation(op),
+ getCondCode(op),
+ dType,
+ newDefs[0],
+ dType,
+ getSrc(&insn->src[0]),
+ getSrc(&insn->src[1]));
+ if (info.num_inputs == 3)
+ i->setSrc(2, getSrc(&insn->src[2]));
+ i->sType = sTypes[0];
+ break;
+ }
+ /* those are weird ALU ops and need special handling, because
+ * 1. they are always componend based
+ * 2. they basically just merge multiple values into one data type
+ */
+ CASE_OPFI(mov):
+ case nir_op_vec2:
+ case nir_op_vec3:
+ case nir_op_vec4: {
+ LValues &newDefs = convert(&insn->dest);
+ for (LValues::size_type c = 0u; c < newDefs.size(); ++c) {
+ mkMov(newDefs[c], getSrc(&insn->src[c]), dType);
+ }
+ break;
+ }
+ // (un)pack
+ case nir_op_pack_64_2x32: {
+ LValues &newDefs = convert(&insn->dest);
+ Instruction *merge = mkOp(OP_MERGE, dType, newDefs[0]);
+ merge->setSrc(0, getSrc(&insn->src[0], 0));
+ merge->setSrc(1, getSrc(&insn->src[0], 1));
+ break;
+ }
+ case nir_op_pack_half_2x16_split: {
+ LValues &newDefs = convert(&insn->dest);
+ Value *tmpH = getSSA();
+ Value *tmpL = getSSA();
+
+ mkCvt(OP_CVT, TYPE_F16, tmpL, TYPE_F32, getSrc(&insn->src[0]));
+ mkCvt(OP_CVT, TYPE_F16, tmpH, TYPE_F32, getSrc(&insn->src[1]));
+ mkOp3(OP_INSBF, TYPE_U32, newDefs[0], tmpH, mkImm(0x1010), tmpL);
+ break;
+ }
+ case nir_op_unpack_half_2x16_split_x:
+ case nir_op_unpack_half_2x16_split_y: {
+ LValues &newDefs = convert(&insn->dest);
+ Instruction *cvt = mkCvt(OP_CVT, TYPE_F32, newDefs[0], TYPE_F16, getSrc(&insn->src[0]));
+ if (op == nir_op_unpack_half_2x16_split_y)
+ cvt->subOp = 1;
+ break;
+ }
+ case nir_op_unpack_64_2x32: {
+ LValues &newDefs = convert(&insn->dest);
+ mkOp1(OP_SPLIT, dType, newDefs[0], getSrc(&insn->src[0]))->setDef(1, newDefs[1]);
+ break;
+ }
+ case nir_op_unpack_64_2x32_split_x: {
+ LValues &newDefs = convert(&insn->dest);
+ mkOp1(OP_SPLIT, dType, newDefs[0], getSrc(&insn->src[0]))->setDef(1, getSSA());
+ break;
+ }
+ case nir_op_unpack_64_2x32_split_y: {
+ LValues &newDefs = convert(&insn->dest);
+ mkOp1(OP_SPLIT, dType, getSSA(), getSrc(&insn->src[0]))->setDef(1, newDefs[0]);
+ break;
+ }
+ // special instructions
+ CASE_OPFI(sign): {
+ DEFAULT_CHECKS;
+ DataType iType;
+ if (::isFloatType(dType))
+ iType = TYPE_F32;
+ else
+ iType = TYPE_S32;
+
+ LValues &newDefs = convert(&insn->dest);
+ LValue *val0 = getScratch();
+ LValue *val1 = getScratch();
+ mkCmp(OP_SET, CC_GT, iType, val0, dType, getSrc(&insn->src[0]), zero);
+ mkCmp(OP_SET, CC_LT, iType, val1, dType, getSrc(&insn->src[0]), zero);
+
+ if (dType == TYPE_F64) {
+ mkOp2(OP_SUB, iType, val0, val0, val1);
+ mkCvt(OP_CVT, TYPE_F64, newDefs[0], iType, val0);
+ } else if (dType == TYPE_S64 || dType == TYPE_U64) {
+ mkOp2(OP_SUB, iType, val0, val1, val0);
+ mkOp2(OP_SHR, iType, val1, val0, loadImm(nullptr, 31));
+ mkOp2(OP_MERGE, dType, newDefs[0], val0, val1);
+ } else if (::isFloatType(dType))
+ mkOp2(OP_SUB, iType, newDefs[0], val0, val1);
+ else
+ mkOp2(OP_SUB, iType, newDefs[0], val1, val0);
+ break;
+ }
+ case nir_op_fcsel:
+ case nir_op_bcsel: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ mkCmp(OP_SLCT, CC_NE, dType, newDefs[0], sTypes[0], getSrc(&insn->src[1]), getSrc(&insn->src[2]), getSrc(&insn->src[0]));
+ break;
+ }
+ CASE_OPIU(bfe):
+ CASE_OPIU(bitfield_extract): {
+ DEFAULT_CHECKS;
+ Value *tmp = getSSA();
+ LValues &newDefs = convert(&insn->dest);
+ mkOp3(OP_INSBF, dType, tmp, getSrc(&insn->src[2]), loadImm(nullptr, 0x808), getSrc(&insn->src[1]));
+ mkOp2(OP_EXTBF, dType, newDefs[0], getSrc(&insn->src[0]), tmp);
+ break;
+ }
+ case nir_op_bfm: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ mkOp3(OP_INSBF, dType, newDefs[0], getSrc(&insn->src[0]), loadImm(nullptr, 0x808), getSrc(&insn->src[1]));
+ break;
+ }
+ case nir_op_bitfield_insert: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ LValue *temp = getSSA();
+ mkOp3(OP_INSBF, TYPE_U32, temp, getSrc(&insn->src[3]), mkImm(0x808), getSrc(&insn->src[2]));
+ mkOp3(OP_INSBF, dType, newDefs[0], getSrc(&insn->src[1]), temp, getSrc(&insn->src[0]));
+ break;
+ }
+ case nir_op_bit_count: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ mkOp2(OP_POPCNT, dType, newDefs[0], getSrc(&insn->src[0]), getSrc(&insn->src[0]));
+ break;
+ }
+ case nir_op_bitfield_reverse: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ mkOp2(OP_EXTBF, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), mkImm(0x2000))->subOp = NV50_IR_SUBOP_EXTBF_REV;
+ break;
+ }
+ case nir_op_find_lsb: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ Value *tmp = getSSA();
+ mkOp2(OP_EXTBF, TYPE_U32, tmp, getSrc(&insn->src[0]), mkImm(0x2000))->subOp = NV50_IR_SUBOP_EXTBF_REV;
+ mkOp1(OP_BFIND, TYPE_U32, newDefs[0], tmp)->subOp = NV50_IR_SUBOP_BFIND_SAMT;
+ break;
+ }
+ // boolean conversions
+ case nir_op_b2f: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ mkOp2(OP_AND, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), loadImm(nullptr, 1.0f));
+ break;
+ }
+ CASE_OPFI(2b): {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ Value *src1;
+ if (typeSizeof(sTypes[0]) == 8) {
+ src1 = loadImm(getSSA(8), 0.0);
+ } else {
+ src1 = zero;
+ }
+ mkCmp(OP_SET, CC_NEU, TYPE_U32, newDefs[0], sTypes[0], getSrc(&insn->src[0]), src1);
+ break;
+ }
+ case nir_op_b2i: {
+ DEFAULT_CHECKS;
+ LValues &newDefs = convert(&insn->dest);
+ LValue *def;
+ if (typeSizeof(dType) == 8)
+ def = getScratch();
+ else
+ def = newDefs[0];
+
+ // bools are always 32bit values
+ mkOp2(OP_AND, TYPE_U32, def, getSrc(&insn->src[0]), loadImm(nullptr, 1));
+ if (typeSizeof(dType) == 8)
+ mkOp2(OP_MERGE, TYPE_S64, newDefs[0], def, loadImm(nullptr, 0));
+
+ break;
+ }
+ default:
+ ERROR("unknown nir_op %s\n", info.name);
+ return false;
+ }
+
+ if (!oldPos) {
+ oldPos = this->bb->getExit();
+ oldPos->precise = insn->exact;
+ }
+
+ while (oldPos->next) {
+ oldPos = oldPos->next;
+ oldPos->precise = insn->exact;
+ }
+ oldPos->saturate = insn->dest.saturate;
+
+ return true;
+}
+#undef DEFAULT_CHECKS
+
bool
Converter::run()
{
--
2.14.3
More information about the mesa-dev
mailing list