[Mesa-dev] [PATCH 13/28] nir: take into account rounding modes in conversions
Samuel Iglesias Gonsálvez
siglesias at igalia.com
Wed Dec 5 15:55:28 UTC 2018
Signed-off-by: Samuel Iglesias Gonsálvez <siglesias at igalia.com>
---
src/compiler/nir/nir.h | 15 +++++++
src/compiler/nir/nir_constant_expressions.py | 46 +++++++++++++++++---
src/compiler/spirv/vtn_alu.c | 16 ++++++-
3 files changed, 71 insertions(+), 6 deletions(-)
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 65a1f60c3c6..f22ac13b2ac 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -866,6 +866,21 @@ nir_get_nir_type_for_glsl_type(const struct glsl_type *type)
nir_op nir_type_conversion_op(nir_alu_type src, nir_alu_type dst,
nir_rounding_mode rnd);
+static inline nir_rounding_mode
+nir_get_rounding_mode_from_float_controls(unsigned rounding_mode,
+ nir_alu_type type)
+{
+ if (nir_alu_type_get_base_type(type) != nir_type_float)
+ return nir_rounding_mode_undef;
+
+ if (rounding_mode & SHADER_ROUNDING_MODE_RTZ)
+ return nir_rounding_mode_rtz;
+ if (rounding_mode & SHADER_ROUNDING_MODE_RTE)
+ return nir_rounding_mode_rtne;
+
+ return nir_rounding_mode_undef;
+}
+
typedef enum {
NIR_OP_IS_COMMUTATIVE = (1 << 0),
NIR_OP_IS_ASSOCIATIVE = (1 << 1),
diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py
index 118af9f7818..a9af1bd233d 100644
--- a/src/compiler/nir/nir_constant_expressions.py
+++ b/src/compiler/nir/nir_constant_expressions.py
@@ -79,6 +79,7 @@ template = """\
#include <math.h>
#include "util/rounding.h" /* for _mesa_roundeven */
#include "util/half_float.h"
+#include "util/double.h"
#include "nir_constant_expressions.h"
/**
@@ -300,7 +301,15 @@ struct bool32_vec {
% elif input_types[j] == "float16":
_mesa_half_to_float(_src[${j}].u16[${k}]),
% else:
- _src[${j}].${get_const_field(input_types[j])}[${k}],
+ % if ("rtne" in op.name) and ("float" in input_types[j]) and ("int" in output_type):
+ % if "float32" in input_types[j]:
+ _mesa_roundevenf(_src[${j}].${get_const_field(input_types[j])}[${k}]),
+ % else:
+ _mesa_roundeven(_src[${j}].${get_const_field(input_types[j])}[${k}]),
+ % endif
+ % else:
+ _src[${j}].${get_const_field(input_types[j])}[${k}],
+ % endif
% endif
% endfor
% for k in range(op.input_sizes[j], 4):
@@ -328,8 +337,27 @@ struct bool32_vec {
const float src${j} =
_mesa_half_to_float(_src[${j}].u16[_i]);
% else:
- const ${input_types[j]}_t src${j} =
- _src[${j}].${get_const_field(input_types[j])}[_i];
+ % if ("rtne" in op.name) and ("float" in input_types[j]) and ("int" in output_type):
+ % if "float32" in input_types[j]:
+ const ${input_types[j]}_t src${j} =
+ _mesa_roundevenf(_src[${j}].${get_const_field(input_types[j])}[_i]);
+ % else:
+ const ${input_types[j]}_t src${j} =
+ _mesa_roundeven(_src[${j}].${get_const_field(input_types[j])}[_i]);
+
+ % endif
+ % elif ("float64" in input_types[j]) and ("float32" in output_type):
+ % if ("rtz" in op.name):
+ const ${input_types[j]}_t src${j} =
+ _mesa_double_to_float_rtz(_src[${j}].${get_const_field(input_types[j])}[_i]);
+ % else:
+ const ${input_types[j]}_t src${j} =
+ _mesa_double_to_float_rtne(_src[${j}].${get_const_field(input_types[j])}[_i]);
+ % endif
+ % else:
+ const ${input_types[j]}_t src${j} =
+ _src[${j}].${get_const_field(input_types[j])}[_i];
+ % endif
% endif
% endfor
@@ -350,7 +378,11 @@ struct bool32_vec {
## Sanitize the C value to a proper NIR bool
_dst_val.u32[_i] = dst ? NIR_TRUE : NIR_FALSE;
% elif output_type == "float16":
- _dst_val.u16[_i] = _mesa_float_to_half(dst);
+ % if "rtz" in op.name:
+ _dst_val.u16[_i] = _mesa_float_to_float16_rtz(dst);
+ % else:
+ _dst_val.u16[_i] = _mesa_float_to_float16_rtne(dst);
+ % endif
% else:
_dst_val.${get_const_field(output_type)}[_i] = dst;
% endif
@@ -379,7 +411,11 @@ struct bool32_vec {
## Sanitize the C value to a proper NIR bool
_dst_val.u32[${k}] = dst.${"xyzw"[k]} ? NIR_TRUE : NIR_FALSE;
% elif output_type == "float16":
- _dst_val.u16[${k}] = _mesa_float_to_half(dst.${"xyzw"[k]});
+ % if "rtz" in op.name:
+ _dst_val.u16[${k}] = _mesa_float_to_float16_rtz(dst.${"xyzw"[k]});
+ % else:
+ _dst_val.u16[${k}] = _mesa_float_to_float16_rtne(dst.${"xyzw"[k]});
+ % endif
% else:
_dst_val.${get_const_field(output_type)}[${k}] = dst.${"xyzw"[k]};
% endif
diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c
index 629b57560ca..fba2b031c09 100644
--- a/src/compiler/spirv/vtn_alu.c
+++ b/src/compiler/spirv/vtn_alu.c
@@ -329,7 +329,12 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
}
src_type |= src_bit_size;
dst_type |= dst_bit_size;
- return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
+ unsigned float_controls =
+ b->shader->info.shader_float_controls_execution_mode;
+ nir_rounding_mode rounding_mode =
+ nir_get_rounding_mode_from_float_controls(float_controls,
+ src_type);
+ return nir_type_conversion_op(src_type, dst_type, rounding_mode);
}
/* Derivatives: */
case SpvOpDPdx: return nir_op_fddx;
@@ -590,12 +595,21 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
glsl_get_bit_size(type));
break;
+ case SpvOpConvertFToS:
+ case SpvOpConvertFToU:
case SpvOpFConvert: {
nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
+ unsigned float_controls = b->shader->info.shader_float_controls_execution_mode;
vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode);
+
+ if (rounding_mode == nir_rounding_mode_undef && float_controls) {
+ rounding_mode =
+ nir_get_rounding_mode_from_float_controls(float_controls,
+ src_alu_type);
+ }
nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, rounding_mode);
val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL);
--
2.19.1
More information about the mesa-dev
mailing list