[Mesa-dev] [PATCH A 12/15] FIXUP: Fix NIR producers and consumers to use unsized conversions
Jason Ekstrand
jason at jlekstrand.net
Fri Nov 9 03:45:13 UTC 2018
---
src/amd/common/ac_nir_to_llvm.c | 32 ++--
src/gallium/auxiliary/nir/tgsi_to_nir.c | 8 +-
.../drivers/freedreno/ir3/ir3_compiler_nir.c | 148 ++++++++----------
src/gallium/drivers/vc4/vc4_program.c | 8 +-
src/intel/compiler/brw_fs_nir.cpp | 78 +++------
src/intel/compiler/brw_vec4_nir.cpp | 39 ++---
6 files changed, 124 insertions(+), 189 deletions(-)
diff --git a/src/amd/common/ac_nir_to_llvm.c b/src/amd/common/ac_nir_to_llvm.c
index c950b81dca2..07c999edf5c 100644
--- a/src/amd/common/ac_nir_to_llvm.c
+++ b/src/amd/common/ac_nir_to_llvm.c
@@ -857,58 +857,44 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
src[i] = ac_to_integer(&ctx->ac, src[i]);
result = ac_build_gather_values(&ctx->ac, src, num_components);
break;
- case nir_op_f2i16:
- case nir_op_f2i32:
- case nir_op_f2i64:
+ case nir_op_f2i:
src[0] = ac_to_float(&ctx->ac, src[0]);
result = LLVMBuildFPToSI(ctx->ac.builder, src[0], def_type, "");
break;
- case nir_op_f2u16:
- case nir_op_f2u32:
- case nir_op_f2u64:
+ case nir_op_f2u:
src[0] = ac_to_float(&ctx->ac, src[0]);
result = LLVMBuildFPToUI(ctx->ac.builder, src[0], def_type, "");
break;
- case nir_op_i2f16:
- case nir_op_i2f32:
- case nir_op_i2f64:
+ case nir_op_i2f:
src[0] = ac_to_integer(&ctx->ac, src[0]);
result = LLVMBuildSIToFP(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
break;
- case nir_op_u2f16:
- case nir_op_u2f32:
- case nir_op_u2f64:
+ case nir_op_u2f:
src[0] = ac_to_integer(&ctx->ac, src[0]);
result = LLVMBuildUIToFP(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
break;
- case nir_op_f2f16_rtz:
+ case nir_op_f2f_rtz:
src[0] = ac_to_float(&ctx->ac, src[0]);
LLVMValueRef param[2] = { src[0], ctx->ac.f32_0 };
result = ac_build_cvt_pkrtz_f16(&ctx->ac, param);
result = LLVMBuildExtractElement(ctx->ac.builder, result, ctx->ac.i32_0, "");
break;
- case nir_op_f2f16_rtne:
- case nir_op_f2f16:
- case nir_op_f2f32:
- case nir_op_f2f64:
+ case nir_op_f2f_rtne:
+ case nir_op_f2f:
src[0] = ac_to_float(&ctx->ac, src[0]);
if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
result = LLVMBuildFPExt(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
else
result = LLVMBuildFPTrunc(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
break;
- case nir_op_u2u16:
- case nir_op_u2u32:
- case nir_op_u2u64:
+ case nir_op_u2u:
src[0] = ac_to_integer(&ctx->ac, src[0]);
if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
result = LLVMBuildZExt(ctx->ac.builder, src[0], def_type, "");
else
result = LLVMBuildTrunc(ctx->ac.builder, src[0], def_type, "");
break;
- case nir_op_i2i16:
- case nir_op_i2i32:
- case nir_op_i2i64:
+ case nir_op_i2i:
src[0] = ac_to_integer(&ctx->ac, src[0]);
if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
result = LLVMBuildSExt(ctx->ac.builder, src[0], def_type, "");
diff --git a/src/gallium/auxiliary/nir/tgsi_to_nir.c b/src/gallium/auxiliary/nir/tgsi_to_nir.c
index 0ad274b535a..508d7a96130 100644
--- a/src/gallium/auxiliary/nir/tgsi_to_nir.c
+++ b/src/gallium/auxiliary/nir/tgsi_to_nir.c
@@ -1449,7 +1449,7 @@ static const nir_op op_trans[TGSI_OPCODE_LAST] = {
[TGSI_OPCODE_DDY_FINE] = nir_op_fddy_fine,
[TGSI_OPCODE_CEIL] = nir_op_fceil,
- [TGSI_OPCODE_I2F] = nir_op_i2f32,
+ [TGSI_OPCODE_I2F] = nir_op_i2f,
[TGSI_OPCODE_NOT] = nir_op_inot,
[TGSI_OPCODE_TRUNC] = nir_op_ftrunc,
[TGSI_OPCODE_SHL] = nir_op_ishl,
@@ -1480,7 +1480,7 @@ static const nir_op op_trans[TGSI_OPCODE_LAST] = {
[TGSI_OPCODE_END] = 0,
- [TGSI_OPCODE_F2I] = nir_op_f2i32,
+ [TGSI_OPCODE_F2I] = nir_op_f2i,
[TGSI_OPCODE_IDIV] = nir_op_idiv,
[TGSI_OPCODE_IMAX] = nir_op_imax,
[TGSI_OPCODE_IMIN] = nir_op_imin,
@@ -1488,8 +1488,8 @@ static const nir_op op_trans[TGSI_OPCODE_LAST] = {
[TGSI_OPCODE_ISGE] = nir_op_ige,
[TGSI_OPCODE_ISHR] = nir_op_ishr,
[TGSI_OPCODE_ISLT] = nir_op_ilt,
- [TGSI_OPCODE_F2U] = nir_op_f2u32,
- [TGSI_OPCODE_U2F] = nir_op_u2f32,
+ [TGSI_OPCODE_F2U] = nir_op_f2u,
+ [TGSI_OPCODE_U2F] = nir_op_u2f,
[TGSI_OPCODE_UADD] = nir_op_iadd,
[TGSI_OPCODE_UDIV] = nir_op_udiv,
[TGSI_OPCODE_UMAD] = 0,
diff --git a/src/gallium/drivers/freedreno/ir3/ir3_compiler_nir.c b/src/gallium/drivers/freedreno/ir3/ir3_compiler_nir.c
index 0c7a722aa0c..e42a3f52a8b 100644
--- a/src/gallium/drivers/freedreno/ir3/ir3_compiler_nir.c
+++ b/src/gallium/drivers/freedreno/ir3/ir3_compiler_nir.c
@@ -881,21 +881,16 @@ ir3_n2b(struct ir3_block *block, struct ir3_instruction *instr)
static struct ir3_instruction *
create_cov(struct ir3_context *ctx, struct ir3_instruction *src,
- unsigned src_bitsize, nir_op op)
+ unsigned src_bitsize, unsigned dst_bitsize, nir_op op)
{
type_t src_type, dst_type;
switch (op) {
- case nir_op_f2f32:
- case nir_op_f2f16_rtne:
- case nir_op_f2f16_rtz:
- case nir_op_f2f16:
- case nir_op_f2i32:
- case nir_op_f2i16:
- case nir_op_f2i8:
- case nir_op_f2u32:
- case nir_op_f2u16:
- case nir_op_f2u8:
+ case nir_op_f2f:
+ case nir_op_f2f_rtne:
+ case nir_op_f2f_rtz:
+ case nir_op_f2i:
+ case nir_op_f2u:
switch (src_bitsize) {
case 32:
src_type = TYPE_F32;
@@ -908,11 +903,8 @@ create_cov(struct ir3_context *ctx, struct ir3_instruction *src,
}
break;
- case nir_op_i2f32:
- case nir_op_i2f16:
- case nir_op_i2i32:
- case nir_op_i2i16:
- case nir_op_i2i8:
+ case nir_op_i2f:
+ case nir_op_i2i:
switch (src_bitsize) {
case 32:
src_type = TYPE_S32;
@@ -928,11 +920,8 @@ create_cov(struct ir3_context *ctx, struct ir3_instruction *src,
}
break;
- case nir_op_u2f32:
- case nir_op_u2f16:
- case nir_op_u2u32:
- case nir_op_u2u16:
- case nir_op_u2u8:
+ case nir_op_u2f:
+ case nir_op_u2u:
switch (src_bitsize) {
case 32:
src_type = TYPE_U32;
@@ -953,49 +942,56 @@ create_cov(struct ir3_context *ctx, struct ir3_instruction *src,
}
switch (op) {
- case nir_op_f2f32:
- case nir_op_i2f32:
- case nir_op_u2f32:
- dst_type = TYPE_F32;
- break;
-
- case nir_op_f2f16_rtne:
- case nir_op_f2f16_rtz:
- case nir_op_f2f16:
+ case nir_op_f2f:
+ case nir_op_i2f:
+ case nir_op_u2f:
+ case nir_op_f2f_rtne:
+ case nir_op_f2f_rtz:
/* TODO how to handle rounding mode? */
- case nir_op_i2f16:
- case nir_op_u2f16:
- dst_type = TYPE_F16;
- break;
-
- case nir_op_f2i32:
- case nir_op_i2i32:
- dst_type = TYPE_S32;
- break;
-
- case nir_op_f2i16:
- case nir_op_i2i16:
- dst_type = TYPE_S16;
- break;
-
- case nir_op_f2i8:
- case nir_op_i2i8:
- dst_type = TYPE_S8;
- break;
-
- case nir_op_f2u32:
- case nir_op_u2u32:
- dst_type = TYPE_U32;
+ switch (dst_bitsize) {
+ case 32:
+ dst_type = TYPE_F32;
+ break;
+ case 16:
+ dst_type = TYPE_F16;
+ break;
+ default:
+ compile_error(ctx, "invalid dst bit size: %u", dst_bitsize);
+ }
break;
- case nir_op_f2u16:
- case nir_op_u2u16:
- dst_type = TYPE_U16;
+ case nir_op_f2i:
+ case nir_op_i2i:
+ switch (src_bitsize) {
+ case 32:
+ dst_type = TYPE_S32;
+ break;
+ case 16:
+ dst_type = TYPE_S16;
+ break;
+ case 8:
+ dst_type = TYPE_S8;
+ break;
+ default:
+ compile_error(ctx, "invalid dst bit size: %u", dst_bitsize);
+ }
break;
- case nir_op_f2u8:
- case nir_op_u2u8:
- dst_type = TYPE_U8;
+ case nir_op_f2u:
+ case nir_op_u2u:
+ switch (src_bitsize) {
+ case 32:
+ dst_type = TYPE_U32;
+ break;
+ case 16:
+ dst_type = TYPE_U16;
+ break;
+ case 8:
+ dst_type = TYPE_U8;
+ break;
+ default:
+ compile_error(ctx, "invalid dst bit size: %u", dst_bitsize);
+ }
break;
default:
@@ -1012,7 +1008,7 @@ emit_alu(struct ir3_context *ctx, nir_alu_instr *alu)
struct ir3_instruction **dst, *src[info->num_inputs];
unsigned bs[info->num_inputs]; /* bit size */
struct ir3_block *b = ctx->block;
- unsigned dst_sz, wrmask;
+ unsigned dst_sz, dst_bs, wrmask;
if (alu->dest.dest.is_ssa) {
dst_sz = alu->dest.dest.ssa.num_components;
@@ -1081,29 +1077,19 @@ emit_alu(struct ir3_context *ctx, nir_alu_instr *alu)
compile_assert(ctx, src[i]);
}
+ dst_bs = nir_dest_bit_size(alu->dest.dest);
switch (alu->op) {
- case nir_op_f2f32:
- case nir_op_f2f16_rtne:
- case nir_op_f2f16_rtz:
- case nir_op_f2f16:
- case nir_op_f2i32:
- case nir_op_f2i16:
- case nir_op_f2i8:
- case nir_op_f2u32:
- case nir_op_f2u16:
- case nir_op_f2u8:
- case nir_op_i2f32:
- case nir_op_i2f16:
- case nir_op_i2i32:
- case nir_op_i2i16:
- case nir_op_i2i8:
- case nir_op_u2f32:
- case nir_op_u2f16:
- case nir_op_u2u32:
- case nir_op_u2u16:
- case nir_op_u2u8:
- dst[0] = create_cov(ctx, src[0], bs[0], alu->op);
+ case nir_op_f2f:
+ case nir_op_f2f_rtne:
+ case nir_op_f2f_rtz:
+ case nir_op_f2i:
+ case nir_op_f2u:
+ case nir_op_i2f:
+ case nir_op_i2i:
+ case nir_op_u2f:
+ case nir_op_u2u:
+ dst[0] = create_cov(ctx, src[0], bs[0], dst_bs, alu->op);
break;
case nir_op_f2b:
dst[0] = ir3_CMPS_F(b, src[0], 0, create_immed(b, fui(0.0)), 0);
diff --git a/src/gallium/drivers/vc4/vc4_program.c b/src/gallium/drivers/vc4/vc4_program.c
index bc9bd76ae95..615124e3562 100644
--- a/src/gallium/drivers/vc4/vc4_program.c
+++ b/src/gallium/drivers/vc4/vc4_program.c
@@ -1200,12 +1200,12 @@ ntq_emit_alu(struct vc4_compile *c, nir_alu_instr *instr)
result = qir_FMAX(c, src[0], src[1]);
break;
- case nir_op_f2i32:
- case nir_op_f2u32:
+ case nir_op_f2i:
+ case nir_op_f2u:
result = qir_FTOI(c, src[0]);
break;
- case nir_op_i2f32:
- case nir_op_u2f32:
+ case nir_op_i2f:
+ case nir_op_u2f:
result = qir_ITOF(c, src[0]);
break;
case nir_op_b2f:
diff --git a/src/intel/compiler/brw_fs_nir.cpp b/src/intel/compiler/brw_fs_nir.cpp
index 2b36171136e..649b4dc545c 100644
--- a/src/intel/compiler/brw_fs_nir.cpp
+++ b/src/intel/compiler/brw_fs_nir.cpp
@@ -654,9 +654,9 @@ emit_find_msb_using_lzd(const fs_builder &bld,
static brw_rnd_mode
brw_rnd_mode_from_nir_op (const nir_op op) {
switch (op) {
- case nir_op_f2f16_rtz:
+ case nir_op_f2f_rtz:
return BRW_RND_MODE_RTZ;
- case nir_op_f2f16_rtne:
+ case nir_op_f2f_rtne:
return BRW_RND_MODE_RTNE;
default:
unreachable("Operation doesn't support rounding mode");
@@ -758,47 +758,33 @@ fs_visitor::nir_emit_alu(const fs_builder &bld, nir_alu_instr *instr)
}
switch (instr->op) {
- case nir_op_i2f32:
- case nir_op_u2f32:
- if (optimize_extract_to_float(instr, result))
+ case nir_op_i2f:
+ case nir_op_i2i:
+ case nir_op_u2f:
+ case nir_op_u2u:
+ case nir_op_f2i:
+ case nir_op_f2u:
+ case nir_op_f2f:
+ case nir_op_f2f_rtne:
+ case nir_op_f2f_rtz:
+ case nir_op_b2f:
+ case nir_op_b2i:
+ if ((instr->op == nir_op_i2f || instr->op == nir_op_u2f) &&
+ nir_dest_bit_size(instr->dest.dest) == 32 &&
+ optimize_extract_to_float(instr, result))
return;
- inst = bld.MOV(result, op[0]);
- inst->saturate = instr->dest.saturate;
- break;
-
- case nir_op_f2f16_rtne:
- case nir_op_f2f16_rtz:
- bld.emit(SHADER_OPCODE_RND_MODE, bld.null_reg_ud(),
- brw_imm_d(brw_rnd_mode_from_nir_op(instr->op)));
- /* fallthrough */
- /* In theory, it would be better to use BRW_OPCODE_F32TO16. Depending
- * on the HW gen, it is a special hw opcode or just a MOV, and
- * brw_F32TO16 (at brw_eu_emit) would do the work to chose.
- *
- * But if we want to use that opcode, we need to provide support on
- * different optimizations and lowerings. As right now HF support is
- * only for gen8+, it will be better to use directly the MOV, and use
- * BRW_OPCODE_F32TO16 when/if we work for HF support on gen7.
- */
+ if (instr->op == nir_op_f2f_rtne || instr->op == nir_op_f2f_rtz) {
+ assert(nir_dest_bit_size(instr->dest.dest) == 16);
+ bld.emit(SHADER_OPCODE_RND_MODE, bld.null_reg_ud(),
+ brw_imm_d(brw_rnd_mode_from_nir_op(instr->op)));
+ }
- case nir_op_f2f16:
- inst = bld.MOV(result, op[0]);
- inst->saturate = instr->dest.saturate;
- break;
+ if (nir_op_infos[instr->op].input_types[0] == nir_type_bool32) {
+ op[0].type = BRW_REGISTER_TYPE_D;
+ op[0].negate = !op[0].negate;
+ }
- case nir_op_b2i:
- case nir_op_b2f:
- op[0].type = BRW_REGISTER_TYPE_D;
- op[0].negate = !op[0].negate;
- /* fallthrough */
- case nir_op_f2f64:
- case nir_op_f2i64:
- case nir_op_f2u64:
- case nir_op_i2f64:
- case nir_op_i2i64:
- case nir_op_u2f64:
- case nir_op_u2u64:
/* CHV PRM, vol07, 3D Media GPGPU Engine, Register Region Restrictions:
*
* "When source or destination is 64b (...), regioning in Align1
@@ -822,20 +808,6 @@ fs_visitor::nir_emit_alu(const fs_builder &bld, nir_alu_instr *instr)
inst->saturate = instr->dest.saturate;
break;
}
- /* fallthrough */
- case nir_op_f2f32:
- case nir_op_f2i32:
- case nir_op_f2u32:
- case nir_op_f2i16:
- case nir_op_f2u16:
- case nir_op_i2i32:
- case nir_op_u2u32:
- case nir_op_i2i16:
- case nir_op_u2u16:
- case nir_op_i2f16:
- case nir_op_u2f16:
- case nir_op_i2i8:
- case nir_op_u2u8:
inst = bld.MOV(result, op[0]);
inst->saturate = instr->dest.saturate;
break;
diff --git a/src/intel/compiler/brw_vec4_nir.cpp b/src/intel/compiler/brw_vec4_nir.cpp
index 564be7e5eee..6799dff03bc 100644
--- a/src/intel/compiler/brw_vec4_nir.cpp
+++ b/src/intel/compiler/brw_vec4_nir.cpp
@@ -1154,27 +1154,28 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
case nir_op_vec4:
unreachable("not reached: should be handled by lower_vec_to_movs()");
- case nir_op_i2f32:
- case nir_op_u2f32:
- inst = emit(MOV(dst, op[0]));
- inst->saturate = instr->dest.saturate;
- break;
+ case nir_op_i2f:
+ case nir_op_i2i:
+ case nir_op_u2f:
+ case nir_op_u2u:
+ case nir_op_f2f:
+ case nir_op_f2i:
+ case nir_op_f2u:
+ case nir_op_b2i:
+ case nir_op_b2f:
+ if (nir_op_infos[instr->op].input_types[0] == nir_type_bool32) {
+ assert(op[0].type == BRW_REGISTER_TYPE_D);
+ op[0].negate = true;
+ }
- case nir_op_f2f32:
- case nir_op_f2i32:
- case nir_op_f2u32:
if (nir_src_bit_size(instr->src[0].src) == 64)
emit_conversion_from_double(dst, op[0], instr->dest.saturate);
+ else if (nir_dest_bit_size(instr->dest.dest) == 64)
+ emit_conversion_to_double(dst, op[0], instr->dest.saturate);
else
inst = emit(MOV(dst, op[0]));
break;
- case nir_op_f2f64:
- case nir_op_i2f64:
- case nir_op_u2f64:
- emit_conversion_to_double(dst, op[0], instr->dest.saturate);
- break;
-
case nir_op_iadd:
assert(nir_dest_bit_size(instr->dest.dest) < 64);
/* fall through */
@@ -1538,16 +1539,6 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
emit(AND(dst, op[0], op[1]));
break;
- case nir_op_b2i:
- case nir_op_b2f:
- if (nir_dest_bit_size(instr->dest.dest) > 32) {
- assert(dst.type == BRW_REGISTER_TYPE_DF);
- emit_conversion_to_double(dst, negate(op[0]), false);
- } else {
- emit(MOV(dst, negate(op[0])));
- }
- break;
-
case nir_op_f2b:
if (nir_src_bit_size(instr->src[0].src) == 64) {
/* We use a MOV with conditional_mod to check if the provided value is
--
2.19.1
More information about the mesa-dev
mailing list