[Mesa-dev] [PATCH v3 14/44] nir: add new fdot* opcodes taking into account rounding mode
Samuel Iglesias Gonsálvez
siglesias at igalia.com
Wed Feb 6 10:44:43 UTC 2019
According to Vulkan spec, the new execution modes affect only
correctly rounded SPIR-V instructions, which includes FaceForward.
FaceForward is lowered into fdot* instructions.
Signed-off-by: Samuel Iglesias Gonsálvez <siglesias at igalia.com>
---
src/compiler/nir/nir_builder.h | 32 +++++++++++++++
src/compiler/nir/nir_lower_alu_to_scalar.c | 9 +++++
src/compiler/nir/nir_opcodes.py | 45 +++++++++++++---------
3 files changed, 67 insertions(+), 19 deletions(-)
diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h
index 2a36eb3c91b..4c5f044aaf2 100644
--- a/src/compiler/nir/nir_builder.h
+++ b/src/compiler/nir/nir_builder.h
@@ -521,6 +521,38 @@ nir_fdot(nir_builder *build, nir_ssa_def *src0, nir_ssa_def *src1)
return NULL;
}
+static inline nir_ssa_def *
+nir_fdot_rtne(nir_builder *build, nir_ssa_def *src0, nir_ssa_def *src1)
+{
+ assert(src0->num_components == src1->num_components);
+ switch (src0->num_components) {
+ case 1: return nir_fmul_rtne(build, src0, src1);
+ case 2: return nir_fdot2_rtne(build, src0, src1);
+ case 3: return nir_fdot3_rtne(build, src0, src1);
+ case 4: return nir_fdot4_rtne(build, src0, src1);
+ default:
+ unreachable("bad component size");
+ }
+
+ return NULL;
+}
+
+static inline nir_ssa_def *
+nir_fdot_rtz(nir_builder *build, nir_ssa_def *src0, nir_ssa_def *src1)
+{
+ assert(src0->num_components == src1->num_components);
+ switch (src0->num_components) {
+ case 1: return nir_fmul_rtz(build, src0, src1);
+ case 2: return nir_fdot2_rtz(build, src0, src1);
+ case 3: return nir_fdot3_rtz(build, src0, src1);
+ case 4: return nir_fdot4_rtz(build, src0, src1);
+ default:
+ unreachable("bad component size");
+ }
+
+ return NULL;
+}
+
static inline nir_ssa_def *
nir_bany_inequal(nir_builder *b, nir_ssa_def *src0, nir_ssa_def *src1)
{
diff --git a/src/compiler/nir/nir_lower_alu_to_scalar.c b/src/compiler/nir/nir_lower_alu_to_scalar.c
index 9b175878c15..2ed4098d59b 100644
--- a/src/compiler/nir/nir_lower_alu_to_scalar.c
+++ b/src/compiler/nir/nir_lower_alu_to_scalar.c
@@ -91,6 +91,12 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
case name##4: \
lower_reduction(instr, chan, merge, b); \
return true;
+#define LOWER_REDUCTION_ROUNDING(name, chan, merge, rounding) \
+ case name##2_##rounding: \
+ case name##3_##rounding: \
+ case name##4_##rounding: \
+ lower_reduction(instr, chan, merge, b); \
+ return true;
switch (instr->op) {
case nir_op_vec4:
@@ -198,6 +204,9 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
return false;
LOWER_REDUCTION(nir_op_fdot, nir_op_fmul, nir_op_fadd);
+ LOWER_REDUCTION_ROUNDING(nir_op_fdot, nir_op_fmul_rtne, nir_op_fadd_rtne, rtne);
+ LOWER_REDUCTION_ROUNDING(nir_op_fdot, nir_op_fmul_rtz, nir_op_fadd_rtz, rtz);
+
LOWER_REDUCTION(nir_op_ball_fequal, nir_op_feq, nir_op_iand);
LOWER_REDUCTION(nir_op_ball_iequal, nir_op_ieq, nir_op_iand);
LOWER_REDUCTION(nir_op_bany_fnequal, nir_op_fne, nir_op_ior);
diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py
index d0087d350a8..21f6569c888 100644
--- a/src/compiler/nir/nir_opcodes.py
+++ b/src/compiler/nir/nir_opcodes.py
@@ -474,7 +474,7 @@ def binop_horiz(name, out_size, out_type, src1_size, src1_type, src2_size,
False, "", const_expr, "")
def binop_reduce(name, output_size, output_type, src_type, prereduce_expr,
- reduce_expr, final_expr):
+ reduce_expr, final_expr, rounding_mode):
def final(src):
return final_expr.format(src= "(" + src + ")")
def reduce_(src0, src1):
@@ -485,15 +485,15 @@ def binop_reduce(name, output_size, output_type, src_type, prereduce_expr,
src1 = prereduce("src0.y", "src1.y")
src2 = prereduce("src0.z", "src1.z")
src3 = prereduce("src0.w", "src1.w")
- opcode(name + "2", output_size, output_type,
+ opcode(name + "2" + rounding_mode, output_size, output_type,
[2, 2], [src_type, src_type], False, commutative,
- final(reduce_(src0, src1)), "")
- opcode(name + "3", output_size, output_type,
+ final(reduce_(src0, src1)), rounding_mode)
+ opcode(name + "3"+ rounding_mode, output_size, output_type,
[3, 3], [src_type, src_type], False, commutative,
- final(reduce_(reduce_(src0, src1), src2)), "")
- opcode(name + "4", output_size, output_type,
+ final(reduce_(reduce_(src0, src1), src2)), rounding_mode)
+ opcode(name + "4" + rounding_mode, output_size, output_type,
[4, 4], [src_type, src_type], False, commutative,
- final(reduce_(reduce_(src0, src1), reduce_(src2, src3))), "")
+ final(reduce_(reduce_(src0, src1), reduce_(src2, src3))), rounding_mode)
binop("fadd", tfloat, commutative + associative, "src0 + src1")
binop_rounding_mode("fadd_rtne", tfloat, commutative + associative,
@@ -620,29 +620,29 @@ binop_compare32("uge32", tuint, "", "src0 >= src1")
# integer-aware GLSL-style comparisons that compare floats and ints
binop_reduce("ball_fequal", 1, tbool1, tfloat, "{src0} == {src1}",
- "{src0} && {src1}", "{src}")
+ "{src0} && {src1}", "{src}", "")
binop_reduce("bany_fnequal", 1, tbool1, tfloat, "{src0} != {src1}",
- "{src0} || {src1}", "{src}")
+ "{src0} || {src1}", "{src}", "")
binop_reduce("ball_iequal", 1, tbool1, tint, "{src0} == {src1}",
- "{src0} && {src1}", "{src}")
+ "{src0} && {src1}", "{src}", "")
binop_reduce("bany_inequal", 1, tbool1, tint, "{src0} != {src1}",
- "{src0} || {src1}", "{src}")
+ "{src0} || {src1}", "{src}", "")
binop_reduce("b32all_fequal", 1, tbool32, tfloat, "{src0} == {src1}",
- "{src0} && {src1}", "{src}")
+ "{src0} && {src1}", "{src}", "")
binop_reduce("b32any_fnequal", 1, tbool32, tfloat, "{src0} != {src1}",
- "{src0} || {src1}", "{src}")
+ "{src0} || {src1}", "{src}", "")
binop_reduce("b32all_iequal", 1, tbool32, tint, "{src0} == {src1}",
- "{src0} && {src1}", "{src}")
+ "{src0} && {src1}", "{src}", "")
binop_reduce("b32any_inequal", 1, tbool32, tint, "{src0} != {src1}",
- "{src0} || {src1}", "{src}")
+ "{src0} || {src1}", "{src}", "")
# non-integer-aware GLSL-style comparisons that return 0.0 or 1.0
binop_reduce("fall_equal", 1, tfloat32, tfloat32, "{src0} == {src1}",
- "{src0} && {src1}", "{src} ? 1.0f : 0.0f")
+ "{src0} && {src1}", "{src} ? 1.0f : 0.0f", "")
binop_reduce("fany_nequal", 1, tfloat32, tfloat32, "{src0} != {src1}",
- "{src0} || {src1}", "{src} ? 1.0f : 0.0f")
+ "{src0} || {src1}", "{src} ? 1.0f : 0.0f", "")
# These comparisons for integer-less hardware return 1.0 and 0.0 for true
# and false respectively
@@ -681,10 +681,17 @@ binop("fxor", tfloat32, commutative,
"(src0 != 0.0f && src1 == 0.0f) || (src0 == 0.0f && src1 != 0.0f) ? 1.0f : 0.0f")
binop_reduce("fdot", 1, tfloat, tfloat, "{src0} * {src1}", "{src0} + {src1}",
- "{src}")
+ "{src}", "")
+
+# Add fdot_rtne and fdot_rtz
+binop_reduce("fdot", 1, tfloat, tfloat, "bit_size == 32 ? _mesa_roundevenf({src0} * {src1}) : _mesa_roundeven({src0} * {src1})",
+ "bit_size == 32 ? _mesa_roundevenf({src0} + {src1}) : _mesa_roundeven({src0} + {src1})",
+ "{src}", "_rtne")
+binop_reduce("fdot", 1, tfloat, tfloat, "{src0} * {src1}", "{src0} + {src1}",
+ "{src}", "_rtz")
binop_reduce("fdot_replicated", 4, tfloat, tfloat,
- "{src0} * {src1}", "{src0} + {src1}", "{src}")
+ "{src0} * {src1}", "{src0} + {src1}", "{src}", "")
opcode("fdph", 1, tfloat, [3, 4], [tfloat, tfloat], False, "",
"src0.x * src1.x + src0.y * src1.y + src0.z * src1.z + src1.w", "")
--
2.19.1
More information about the mesa-dev
mailing list