[Mesa-dev] [PATCH A 12/15] FIXUP: Fix NIR producers and consumers to use unsized conversions

Jason Ekstrand jason at jlekstrand.net
Fri Nov 9 03:45:13 UTC 2018


---
 src/amd/common/ac_nir_to_llvm.c               |  32 ++--
 src/gallium/auxiliary/nir/tgsi_to_nir.c       |   8 +-
 .../drivers/freedreno/ir3/ir3_compiler_nir.c  | 148 ++++++++----------
 src/gallium/drivers/vc4/vc4_program.c         |   8 +-
 src/intel/compiler/brw_fs_nir.cpp             |  78 +++------
 src/intel/compiler/brw_vec4_nir.cpp           |  39 ++---
 6 files changed, 124 insertions(+), 189 deletions(-)

diff --git a/src/amd/common/ac_nir_to_llvm.c b/src/amd/common/ac_nir_to_llvm.c
index c950b81dca2..07c999edf5c 100644
--- a/src/amd/common/ac_nir_to_llvm.c
+++ b/src/amd/common/ac_nir_to_llvm.c
@@ -857,58 +857,44 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
 			src[i] = ac_to_integer(&ctx->ac, src[i]);
 		result = ac_build_gather_values(&ctx->ac, src, num_components);
 		break;
-	case nir_op_f2i16:
-	case nir_op_f2i32:
-	case nir_op_f2i64:
+	case nir_op_f2i:
 		src[0] = ac_to_float(&ctx->ac, src[0]);
 		result = LLVMBuildFPToSI(ctx->ac.builder, src[0], def_type, "");
 		break;
-	case nir_op_f2u16:
-	case nir_op_f2u32:
-	case nir_op_f2u64:
+	case nir_op_f2u:
 		src[0] = ac_to_float(&ctx->ac, src[0]);
 		result = LLVMBuildFPToUI(ctx->ac.builder, src[0], def_type, "");
 		break;
-	case nir_op_i2f16:
-	case nir_op_i2f32:
-	case nir_op_i2f64:
+	case nir_op_i2f:
 		src[0] = ac_to_integer(&ctx->ac, src[0]);
 		result = LLVMBuildSIToFP(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
 		break;
-	case nir_op_u2f16:
-	case nir_op_u2f32:
-	case nir_op_u2f64:
+	case nir_op_u2f:
 		src[0] = ac_to_integer(&ctx->ac, src[0]);
 		result = LLVMBuildUIToFP(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
 		break;
-	case nir_op_f2f16_rtz:
+	case nir_op_f2f_rtz:
 		src[0] = ac_to_float(&ctx->ac, src[0]);
 		LLVMValueRef param[2] = { src[0], ctx->ac.f32_0 };
 		result = ac_build_cvt_pkrtz_f16(&ctx->ac, param);
 		result = LLVMBuildExtractElement(ctx->ac.builder, result, ctx->ac.i32_0, "");
 		break;
-	case nir_op_f2f16_rtne:
-	case nir_op_f2f16:
-	case nir_op_f2f32:
-	case nir_op_f2f64:
+	case nir_op_f2f_rtne:
+	case nir_op_f2f:
 		src[0] = ac_to_float(&ctx->ac, src[0]);
 		if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
 			result = LLVMBuildFPExt(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
 		else
 			result = LLVMBuildFPTrunc(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
 		break;
-	case nir_op_u2u16:
-	case nir_op_u2u32:
-	case nir_op_u2u64:
+	case nir_op_u2u:
 		src[0] = ac_to_integer(&ctx->ac, src[0]);
 		if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
 			result = LLVMBuildZExt(ctx->ac.builder, src[0], def_type, "");
 		else
 			result = LLVMBuildTrunc(ctx->ac.builder, src[0], def_type, "");
 		break;
-	case nir_op_i2i16:
-	case nir_op_i2i32:
-	case nir_op_i2i64:
+	case nir_op_i2i:
 		src[0] = ac_to_integer(&ctx->ac, src[0]);
 		if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
 			result = LLVMBuildSExt(ctx->ac.builder, src[0], def_type, "");
diff --git a/src/gallium/auxiliary/nir/tgsi_to_nir.c b/src/gallium/auxiliary/nir/tgsi_to_nir.c
index 0ad274b535a..508d7a96130 100644
--- a/src/gallium/auxiliary/nir/tgsi_to_nir.c
+++ b/src/gallium/auxiliary/nir/tgsi_to_nir.c
@@ -1449,7 +1449,7 @@ static const nir_op op_trans[TGSI_OPCODE_LAST] = {
    [TGSI_OPCODE_DDY_FINE] = nir_op_fddy_fine,
 
    [TGSI_OPCODE_CEIL] = nir_op_fceil,
-   [TGSI_OPCODE_I2F] = nir_op_i2f32,
+   [TGSI_OPCODE_I2F] = nir_op_i2f,
    [TGSI_OPCODE_NOT] = nir_op_inot,
    [TGSI_OPCODE_TRUNC] = nir_op_ftrunc,
    [TGSI_OPCODE_SHL] = nir_op_ishl,
@@ -1480,7 +1480,7 @@ static const nir_op op_trans[TGSI_OPCODE_LAST] = {
 
    [TGSI_OPCODE_END] = 0,
 
-   [TGSI_OPCODE_F2I] = nir_op_f2i32,
+   [TGSI_OPCODE_F2I] = nir_op_f2i,
    [TGSI_OPCODE_IDIV] = nir_op_idiv,
    [TGSI_OPCODE_IMAX] = nir_op_imax,
    [TGSI_OPCODE_IMIN] = nir_op_imin,
@@ -1488,8 +1488,8 @@ static const nir_op op_trans[TGSI_OPCODE_LAST] = {
    [TGSI_OPCODE_ISGE] = nir_op_ige,
    [TGSI_OPCODE_ISHR] = nir_op_ishr,
    [TGSI_OPCODE_ISLT] = nir_op_ilt,
-   [TGSI_OPCODE_F2U] = nir_op_f2u32,
-   [TGSI_OPCODE_U2F] = nir_op_u2f32,
+   [TGSI_OPCODE_F2U] = nir_op_f2u,
+   [TGSI_OPCODE_U2F] = nir_op_u2f,
    [TGSI_OPCODE_UADD] = nir_op_iadd,
    [TGSI_OPCODE_UDIV] = nir_op_udiv,
    [TGSI_OPCODE_UMAD] = 0,
diff --git a/src/gallium/drivers/freedreno/ir3/ir3_compiler_nir.c b/src/gallium/drivers/freedreno/ir3/ir3_compiler_nir.c
index 0c7a722aa0c..e42a3f52a8b 100644
--- a/src/gallium/drivers/freedreno/ir3/ir3_compiler_nir.c
+++ b/src/gallium/drivers/freedreno/ir3/ir3_compiler_nir.c
@@ -881,21 +881,16 @@ ir3_n2b(struct ir3_block *block, struct ir3_instruction *instr)
 
 static struct ir3_instruction *
 create_cov(struct ir3_context *ctx, struct ir3_instruction *src,
-		unsigned src_bitsize, nir_op op)
+		unsigned src_bitsize, unsigned dst_bitsize, nir_op op)
 {
 	type_t src_type, dst_type;
 
 	switch (op) {
-	case nir_op_f2f32:
-	case nir_op_f2f16_rtne:
-	case nir_op_f2f16_rtz:
-	case nir_op_f2f16:
-	case nir_op_f2i32:
-	case nir_op_f2i16:
-	case nir_op_f2i8:
-	case nir_op_f2u32:
-	case nir_op_f2u16:
-	case nir_op_f2u8:
+	case nir_op_f2f:
+	case nir_op_f2f_rtne:
+	case nir_op_f2f_rtz:
+	case nir_op_f2i:
+	case nir_op_f2u:
 		switch (src_bitsize) {
 		case 32:
 			src_type = TYPE_F32;
@@ -908,11 +903,8 @@ create_cov(struct ir3_context *ctx, struct ir3_instruction *src,
 		}
 		break;
 
-	case nir_op_i2f32:
-	case nir_op_i2f16:
-	case nir_op_i2i32:
-	case nir_op_i2i16:
-	case nir_op_i2i8:
+	case nir_op_i2f:
+	case nir_op_i2i:
 		switch (src_bitsize) {
 		case 32:
 			src_type = TYPE_S32;
@@ -928,11 +920,8 @@ create_cov(struct ir3_context *ctx, struct ir3_instruction *src,
 		}
 		break;
 
-	case nir_op_u2f32:
-	case nir_op_u2f16:
-	case nir_op_u2u32:
-	case nir_op_u2u16:
-	case nir_op_u2u8:
+	case nir_op_u2f:
+	case nir_op_u2u:
 		switch (src_bitsize) {
 		case 32:
 			src_type = TYPE_U32;
@@ -953,49 +942,56 @@ create_cov(struct ir3_context *ctx, struct ir3_instruction *src,
 	}
 
 	switch (op) {
-	case nir_op_f2f32:
-	case nir_op_i2f32:
-	case nir_op_u2f32:
-		dst_type = TYPE_F32;
-		break;
-
-	case nir_op_f2f16_rtne:
-	case nir_op_f2f16_rtz:
-	case nir_op_f2f16:
+	case nir_op_f2f:
+	case nir_op_i2f:
+	case nir_op_u2f:
+	case nir_op_f2f_rtne:
+	case nir_op_f2f_rtz:
 		/* TODO how to handle rounding mode? */
-	case nir_op_i2f16:
-	case nir_op_u2f16:
-		dst_type = TYPE_F16;
-		break;
-
-	case nir_op_f2i32:
-	case nir_op_i2i32:
-		dst_type = TYPE_S32;
-		break;
-
-	case nir_op_f2i16:
-	case nir_op_i2i16:
-		dst_type = TYPE_S16;
-		break;
-
-	case nir_op_f2i8:
-	case nir_op_i2i8:
-		dst_type = TYPE_S8;
-		break;
-
-	case nir_op_f2u32:
-	case nir_op_u2u32:
-		dst_type = TYPE_U32;
+		switch (dst_bitsize) {
+		case 32:
+			dst_type = TYPE_F32;
+			break;
+		case 16:
+			dst_type = TYPE_F16;
+			break;
+		default:
+			compile_error(ctx, "invalid dst bit size: %u", dst_bitsize);
+		}
 		break;
 
-	case nir_op_f2u16:
-	case nir_op_u2u16:
-		dst_type = TYPE_U16;
+	case nir_op_f2i:
+	case nir_op_i2i:
+		switch (src_bitsize) {
+		case 32:
+			dst_type = TYPE_S32;
+			break;
+		case 16:
+			dst_type = TYPE_S16;
+			break;
+		case 8:
+			dst_type = TYPE_S8;
+			break;
+		default:
+			compile_error(ctx, "invalid dst bit size: %u", dst_bitsize);
+		}
 		break;
 
-	case nir_op_f2u8:
-	case nir_op_u2u8:
-		dst_type = TYPE_U8;
+	case nir_op_f2u:
+	case nir_op_u2u:
+		switch (src_bitsize) {
+		case 32:
+			dst_type = TYPE_U32;
+			break;
+		case 16:
+			dst_type = TYPE_U16;
+			break;
+		case 8:
+			dst_type = TYPE_U8;
+			break;
+		default:
+			compile_error(ctx, "invalid dst bit size: %u", dst_bitsize);
+		}
 		break;
 
 	default:
@@ -1012,7 +1008,7 @@ emit_alu(struct ir3_context *ctx, nir_alu_instr *alu)
 	struct ir3_instruction **dst, *src[info->num_inputs];
 	unsigned bs[info->num_inputs];     /* bit size */
 	struct ir3_block *b = ctx->block;
-	unsigned dst_sz, wrmask;
+	unsigned dst_sz, dst_bs, wrmask;
 
 	if (alu->dest.dest.is_ssa) {
 		dst_sz = alu->dest.dest.ssa.num_components;
@@ -1081,29 +1077,19 @@ emit_alu(struct ir3_context *ctx, nir_alu_instr *alu)
 
 		compile_assert(ctx, src[i]);
 	}
+        dst_bs = nir_dest_bit_size(alu->dest.dest);
 
 	switch (alu->op) {
-	case nir_op_f2f32:
-	case nir_op_f2f16_rtne:
-	case nir_op_f2f16_rtz:
-	case nir_op_f2f16:
-	case nir_op_f2i32:
-	case nir_op_f2i16:
-	case nir_op_f2i8:
-	case nir_op_f2u32:
-	case nir_op_f2u16:
-	case nir_op_f2u8:
-	case nir_op_i2f32:
-	case nir_op_i2f16:
-	case nir_op_i2i32:
-	case nir_op_i2i16:
-	case nir_op_i2i8:
-	case nir_op_u2f32:
-	case nir_op_u2f16:
-	case nir_op_u2u32:
-	case nir_op_u2u16:
-	case nir_op_u2u8:
-		dst[0] = create_cov(ctx, src[0], bs[0], alu->op);
+	case nir_op_f2f:
+	case nir_op_f2f_rtne:
+	case nir_op_f2f_rtz:
+	case nir_op_f2i:
+	case nir_op_f2u:
+	case nir_op_i2f:
+	case nir_op_i2i:
+	case nir_op_u2f:
+	case nir_op_u2u:
+		dst[0] = create_cov(ctx, src[0], bs[0], dst_bs, alu->op);
 		break;
 	case nir_op_f2b:
 		dst[0] = ir3_CMPS_F(b, src[0], 0, create_immed(b, fui(0.0)), 0);
diff --git a/src/gallium/drivers/vc4/vc4_program.c b/src/gallium/drivers/vc4/vc4_program.c
index bc9bd76ae95..615124e3562 100644
--- a/src/gallium/drivers/vc4/vc4_program.c
+++ b/src/gallium/drivers/vc4/vc4_program.c
@@ -1200,12 +1200,12 @@ ntq_emit_alu(struct vc4_compile *c, nir_alu_instr *instr)
                 result = qir_FMAX(c, src[0], src[1]);
                 break;
 
-        case nir_op_f2i32:
-        case nir_op_f2u32:
+        case nir_op_f2i:
+        case nir_op_f2u:
                 result = qir_FTOI(c, src[0]);
                 break;
-        case nir_op_i2f32:
-        case nir_op_u2f32:
+        case nir_op_i2f:
+        case nir_op_u2f:
                 result = qir_ITOF(c, src[0]);
                 break;
         case nir_op_b2f:
diff --git a/src/intel/compiler/brw_fs_nir.cpp b/src/intel/compiler/brw_fs_nir.cpp
index 2b36171136e..649b4dc545c 100644
--- a/src/intel/compiler/brw_fs_nir.cpp
+++ b/src/intel/compiler/brw_fs_nir.cpp
@@ -654,9 +654,9 @@ emit_find_msb_using_lzd(const fs_builder &bld,
 static brw_rnd_mode
 brw_rnd_mode_from_nir_op (const nir_op op) {
    switch (op) {
-   case nir_op_f2f16_rtz:
+   case nir_op_f2f_rtz:
       return BRW_RND_MODE_RTZ;
-   case nir_op_f2f16_rtne:
+   case nir_op_f2f_rtne:
       return BRW_RND_MODE_RTNE;
    default:
       unreachable("Operation doesn't support rounding mode");
@@ -758,47 +758,33 @@ fs_visitor::nir_emit_alu(const fs_builder &bld, nir_alu_instr *instr)
    }
 
    switch (instr->op) {
-   case nir_op_i2f32:
-   case nir_op_u2f32:
-      if (optimize_extract_to_float(instr, result))
+   case nir_op_i2f:
+   case nir_op_i2i:
+   case nir_op_u2f:
+   case nir_op_u2u:
+   case nir_op_f2i:
+   case nir_op_f2u:
+   case nir_op_f2f:
+   case nir_op_f2f_rtne:
+   case nir_op_f2f_rtz:
+   case nir_op_b2f:
+   case nir_op_b2i:
+      if ((instr->op == nir_op_i2f || instr->op == nir_op_u2f) &&
+           nir_dest_bit_size(instr->dest.dest) == 32 &&
+          optimize_extract_to_float(instr, result))
          return;
-      inst = bld.MOV(result, op[0]);
-      inst->saturate = instr->dest.saturate;
-      break;
-
-   case nir_op_f2f16_rtne:
-   case nir_op_f2f16_rtz:
-      bld.emit(SHADER_OPCODE_RND_MODE, bld.null_reg_ud(),
-               brw_imm_d(brw_rnd_mode_from_nir_op(instr->op)));
-      /* fallthrough */
 
-      /* In theory, it would be better to use BRW_OPCODE_F32TO16. Depending
-       * on the HW gen, it is a special hw opcode or just a MOV, and
-       * brw_F32TO16 (at brw_eu_emit) would do the work to chose.
-       *
-       * But if we want to use that opcode, we need to provide support on
-       * different optimizations and lowerings. As right now HF support is
-       * only for gen8+, it will be better to use directly the MOV, and use
-       * BRW_OPCODE_F32TO16 when/if we work for HF support on gen7.
-       */
+      if (instr->op == nir_op_f2f_rtne || instr->op == nir_op_f2f_rtz) {
+         assert(nir_dest_bit_size(instr->dest.dest) == 16);
+         bld.emit(SHADER_OPCODE_RND_MODE, bld.null_reg_ud(),
+                  brw_imm_d(brw_rnd_mode_from_nir_op(instr->op)));
+      }
 
-   case nir_op_f2f16:
-      inst = bld.MOV(result, op[0]);
-      inst->saturate = instr->dest.saturate;
-      break;
+      if (nir_op_infos[instr->op].input_types[0] == nir_type_bool32) {
+         op[0].type = BRW_REGISTER_TYPE_D;
+         op[0].negate = !op[0].negate;
+      }
 
-   case nir_op_b2i:
-   case nir_op_b2f:
-      op[0].type = BRW_REGISTER_TYPE_D;
-      op[0].negate = !op[0].negate;
-      /* fallthrough */
-   case nir_op_f2f64:
-   case nir_op_f2i64:
-   case nir_op_f2u64:
-   case nir_op_i2f64:
-   case nir_op_i2i64:
-   case nir_op_u2f64:
-   case nir_op_u2u64:
       /* CHV PRM, vol07, 3D Media GPGPU Engine, Register Region Restrictions:
        *
        *    "When source or destination is 64b (...), regioning in Align1
@@ -822,20 +808,6 @@ fs_visitor::nir_emit_alu(const fs_builder &bld, nir_alu_instr *instr)
          inst->saturate = instr->dest.saturate;
          break;
       }
-      /* fallthrough */
-   case nir_op_f2f32:
-   case nir_op_f2i32:
-   case nir_op_f2u32:
-   case nir_op_f2i16:
-   case nir_op_f2u16:
-   case nir_op_i2i32:
-   case nir_op_u2u32:
-   case nir_op_i2i16:
-   case nir_op_u2u16:
-   case nir_op_i2f16:
-   case nir_op_u2f16:
-   case nir_op_i2i8:
-   case nir_op_u2u8:
       inst = bld.MOV(result, op[0]);
       inst->saturate = instr->dest.saturate;
       break;
diff --git a/src/intel/compiler/brw_vec4_nir.cpp b/src/intel/compiler/brw_vec4_nir.cpp
index 564be7e5eee..6799dff03bc 100644
--- a/src/intel/compiler/brw_vec4_nir.cpp
+++ b/src/intel/compiler/brw_vec4_nir.cpp
@@ -1154,27 +1154,28 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
    case nir_op_vec4:
       unreachable("not reached: should be handled by lower_vec_to_movs()");
 
-   case nir_op_i2f32:
-   case nir_op_u2f32:
-      inst = emit(MOV(dst, op[0]));
-      inst->saturate = instr->dest.saturate;
-      break;
+   case nir_op_i2f:
+   case nir_op_i2i:
+   case nir_op_u2f:
+   case nir_op_u2u:
+   case nir_op_f2f:
+   case nir_op_f2i:
+   case nir_op_f2u:
+   case nir_op_b2i:
+   case nir_op_b2f:
+      if (nir_op_infos[instr->op].input_types[0] == nir_type_bool32) {
+         assert(op[0].type == BRW_REGISTER_TYPE_D);
+         op[0].negate = true;
+      }
 
-   case nir_op_f2f32:
-   case nir_op_f2i32:
-   case nir_op_f2u32:
       if (nir_src_bit_size(instr->src[0].src) == 64)
          emit_conversion_from_double(dst, op[0], instr->dest.saturate);
+      else if (nir_dest_bit_size(instr->dest.dest) == 64)
+         emit_conversion_to_double(dst, op[0], instr->dest.saturate);
       else
          inst = emit(MOV(dst, op[0]));
       break;
 
-   case nir_op_f2f64:
-   case nir_op_i2f64:
-   case nir_op_u2f64:
-      emit_conversion_to_double(dst, op[0], instr->dest.saturate);
-      break;
-
    case nir_op_iadd:
       assert(nir_dest_bit_size(instr->dest.dest) < 64);
       /* fall through */
@@ -1538,16 +1539,6 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
       emit(AND(dst, op[0], op[1]));
       break;
 
-   case nir_op_b2i:
-   case nir_op_b2f:
-      if (nir_dest_bit_size(instr->dest.dest) > 32) {
-         assert(dst.type == BRW_REGISTER_TYPE_DF);
-         emit_conversion_to_double(dst, negate(op[0]), false);
-      } else {
-         emit(MOV(dst, negate(op[0])));
-      }
-      break;
-
    case nir_op_f2b:
       if (nir_src_bit_size(instr->src[0].src) == 64) {
          /* We use a MOV with conditional_mod to check if the provided value is
-- 
2.19.1



More information about the mesa-dev mailing list