[Mesa-dev] [PATCH 3/4] nir: add all combinations of conversions with rounding and saturation
Jason Ekstrand
jason at jlekstrand.net
Sat Apr 28 15:29:52 UTC 2018
On Sat, Apr 28, 2018 at 4:14 AM, Karol Herbst <kherbst at redhat.com> wrote:
> OpenCL has explicit casts where one can specify the rounding mode and put a
> sat modifier:
>
> https://www.khronos.org/registry/OpenCL/sdk/2.1/docs/
> man/xhtml/convert_T.html
>
> _sat is valid for all conversions to an integer type and rounding modes are
> valid for all conversions involving floats.
>
> Allthough the FPRoundingMode modifier is allowed without any restrictions
> in
> capabilities, it can only be used together with fp16 in GLSL. Additionally
> it
> can be used for conversions to/from floating points in OpenCL.
>
> The SaturatedConversion modifier, OpSatConvertUToS and OpSatConvertSToU are
> only supported for Kernels, so current drivers are safe.
>
> Signed-off-by: Karol Herbst <kherbst at redhat.com>
> ---
> src/compiler/glsl/glsl_to_nir.cpp | 2 +-
> src/compiler/nir/nir.h | 2 +-
> src/compiler/nir/nir_opcodes.py | 28 +++++-----
> src/compiler/nir/nir_opcodes_c.py | 26 +++++----
> src/compiler/spirv/spirv_to_nir.c | 4 +-
> src/compiler/spirv/vtn_alu.c | 108 ++++++++++++++++++++++++------
> --------
> src/compiler/spirv/vtn_glsl450.c | 2 +-
> src/compiler/spirv/vtn_private.h | 2 +-
> 8 files changed, 107 insertions(+), 67 deletions(-)
>
> diff --git a/src/compiler/glsl/glsl_to_nir.cpp
> b/src/compiler/glsl/glsl_to_nir.cpp
> index 8e5e9c34912..fcb6ef27e47 100644
> --- a/src/compiler/glsl/glsl_to_nir.cpp
> +++ b/src/compiler/glsl/glsl_to_nir.cpp
> @@ -1589,7 +1589,7 @@ nir_visitor::visit(ir_expression *ir)
> nir_alu_type src_type = nir_get_nir_type_for_glsl_
> base_type(types[0]);
> nir_alu_type dst_type = nir_get_nir_type_for_glsl_
> base_type(out_type);
> result = nir_build_alu(&b, nir_type_conversion_op(src_type,
> dst_type,
> - nir_rounding_mode_undef),
> + nir_rounding_mode_undef, false),
> srcs[0], NULL, NULL, NULL);
> /* b2i and b2f don't have fixed bit-size versions so the builder
> will
> * just assume 32 and we have to fix it up here.
> diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
> index f3326e6df94..f32e5bd8bb2 100644
> --- a/src/compiler/nir/nir.h
> +++ b/src/compiler/nir/nir.h
> @@ -784,7 +784,7 @@ nir_get_nir_type_for_glsl_type(const struct glsl_type
> *type)
> }
>
> nir_op nir_type_conversion_op(nir_alu_type src, nir_alu_type dst,
> - nir_rounding_mode rnd);
> + nir_rounding_mode rnd, bool saturation);
>
> typedef enum {
> NIR_OP_IS_COMMUTATIVE = (1 << 0),
> diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_
> opcodes.py
> index f4cd175bc6a..9c51f77bf1b 100644
> --- a/src/compiler/nir/nir_opcodes.py
> +++ b/src/compiler/nir/nir_opcodes.py
> @@ -168,26 +168,28 @@ unop("flog2", tfloat, "log2f(src0)")
>
> # Generate all of the numeric conversion opcodes
> for src_t in [tint, tuint, tfloat]:
> - if src_t in (tint, tuint):
> - dst_types = [tfloat, src_t]
> - elif src_t == tfloat:
> - dst_types = [tint, tuint, tfloat]
> -
> - for dst_t in dst_types:
> + for dst_t in [tint, tuint, tfloat]:
> if dst_t == tfloat:
> bit_sizes = [16, 32, 64]
> + sat_modes = ['']
> else:
> bit_sizes = [8, 16, 32, 64]
> + if src_t != tfloat and dst_t != src_t:
> + sat_modes = ['_sat']
> + else:
> + sat_modes = ['_sat', '']
> for bit_size in bit_sizes:
> - if dst_t == tfloat and src_t == tfloat:
> - rnd_modes = ['_rtne', '_rtz', '']
> - for rnd_mode in rnd_modes:
> + for sat_mode in sat_modes:
> + if src_t == tfloat or dst_t == tfloat:
> + for rnd_mode in ['_rtne', '_rtz', '_ru', '_rd', '']:
> + unop_convert("{0}2{1}{2}{3}{4}".format(src_t[0],
> dst_t[0],
> + bit_size,
> rnd_mode,
> + sat_mode),
> + dst_t + str(bit_size), src_t, "src0")
> + else:
> unop_convert("{0}2{1}{2}{3}".format(src_t[0], dst_t[0],
> - bit_size,
> rnd_mode),
> + bit_size, sat_mode),
> dst_t + str(bit_size), src_t, "src0")
> - else:
> - unop_convert("{0}2{1}{2}".format(src_t[0], dst_t[0],
> bit_size),
> - dst_t + str(bit_size), src_t, "src0")
>
As I mentioned on IRC, we need proper constant folding. Getting rounding
modes on f32->f16 wrong isn't good and I probably shouldn't have let it
through. Let's not make the problem worse. Not correctly handling _sat is
especially bad.
>
> # We'll hand-code the to/from bool conversion opcodes. Because bool
> doesn't
> # have multiple bit-sizes, we can always infer the size from the other
> type.
> diff --git a/src/compiler/nir/nir_opcodes_c.py b/src/compiler/nir/nir_
> opcodes_c.py
> index 19079f86e7b..9b8642f0cc1 100644
> --- a/src/compiler/nir/nir_opcodes_c.py
> +++ b/src/compiler/nir/nir_opcodes_c.py
> @@ -30,7 +30,8 @@ template = Template("""
> #include "nir.h"
>
> nir_op
> -nir_type_conversion_op(nir_alu_type src, nir_alu_type dst,
> nir_rounding_mode rnd)
> +nir_type_conversion_op(nir_alu_type src, nir_alu_type dst,
> nir_rounding_mode rnd,
> + bool saturate)
> {
> nir_alu_type src_base = (nir_alu_type) nir_alu_type_get_base_type(
> src);
> nir_alu_type dst_base = (nir_alu_type) nir_alu_type_get_base_type(
> dst);
> @@ -41,7 +42,8 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type
> dst, nir_rounding_mode rnd
> return nir_op_fmov;
> } else if ((src_base == nir_type_int || src_base == nir_type_uint) &&
> (dst_base == nir_type_int || dst_base == nir_type_uint) &&
> - src_bit_size == dst_bit_size) {
> + src_bit_size == dst_bit_size &&
> + (src_base == dst_base || !saturate)) {
> /* Integer <-> integer conversions with the same bit-size on both
> * ends are just no-op moves.
> */
> @@ -54,12 +56,9 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type
> dst, nir_rounding_mode rnd
> switch (dst_base) {
> % for dst_t in ['int', 'uint', 'float']:
> case nir_type_${dst_t}:
> +<% orig_dst_t = dst_t %>
> % if src_t in ['int', 'uint'] and dst_t in ['int', 'uint']:
> -% if dst_t == 'int':
> -<% continue %>
> -% else:
> -<% dst_t = src_t %>
> -% endif
> +<% dst_t = src_t %>
> % endif
> switch (dst_bit_size) {
> % if dst_t == 'float':
> @@ -69,18 +68,25 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type
> dst, nir_rounding_mode rnd
> % endif
> % for dst_bits in bit_sizes:
> case ${dst_bits}:
> -% if src_t == 'float' and dst_t == 'float':
> +% if src_t == 'float' or dst_t == 'float':
> switch(rnd) {
> -% for rnd_t in [('rtne', '_rtne'), ('rtz', '_rtz'),
> ('undef', '')]:
> +% for rnd_t in [('rtne', '_rtne'), ('rtz', '_rtz'),
> ('ru', '_ru'), ('rd', '_rd'), ('undef', '')]:
> case nir_rounding_mode_${rnd_t[0]}:
> +% if dst_t != 'float':
> + if (saturate)
> + return ${'nir_op_{0}2{1}{2}{3}_sat'.format(src_t[0],
> dst_t[0],
> +
> dst_bits, rnd_t[1])};
> +% endif
> return ${'nir_op_{0}2{1}{2}{3}'.format(src_t[0],
> dst_t[0],
>
> dst_bits, rnd_t[1])};
> % endfor
> default:
> - unreachable("Invalid 16-bit nir rounding
> mode");
> + unreachable("Invalid float nir rounding mode");
> }
> % else:
> assert(rnd == nir_rounding_mode_undef);
> + if (saturate)
> + return ${'nir_op_{0}2{1}{2}_sat'.format(src_t[0],
> orig_dst_t[0], dst_bits)};
> return ${'nir_op_{0}2{1}{2}'.format(src_t[0],
> dst_t[0], dst_bits)};
> % endif
> % endfor
> diff --git a/src/compiler/spirv/spirv_to_nir.c
> b/src/compiler/spirv/spirv_to_nir.c
> index 2a835f047e4..6f1a1871b38 100644
> --- a/src/compiler/spirv/spirv_to_nir.c
> +++ b/src/compiler/spirv/spirv_to_nir.c
> @@ -1726,7 +1726,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp
> opcode,
> bit_size = glsl_get_bit_size(val->type->type);
> };
>
> - nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
> + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode,
> &swap,
>
> nir_alu_type_get_type_size(src_alu_type),
>
> nir_alu_type_get_type_size(dst_alu_type));
> nir_const_value src[4];
> @@ -3839,6 +3839,8 @@ vtn_handle_body_instruction(struct vtn_builder *b,
> SpvOp opcode,
> case SpvOpUConvert:
> case SpvOpSConvert:
> case SpvOpFConvert:
> + case SpvOpSatConvertUToS:
> + case SpvOpSatConvertSToU:
> case SpvOpQuantizeToF16:
> case SpvOpConvertPtrToU:
> case SpvOpConvertUToPtr:
> diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c
> index 3134849ba90..b96f7d688fb 100644
> --- a/src/compiler/spirv/vtn_alu.c
> +++ b/src/compiler/spirv/vtn_alu.c
> @@ -273,8 +273,46 @@ vtn_handle_bitcast(struct vtn_builder *b, struct
> vtn_ssa_value *dest,
> dest->def = nir_vec(&b->nb, dest_chan, dest_components);
> }
>
> +static void
> +handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int
> member,
> + const struct vtn_decoration *dec, void
> *_out_rounding_mode)
> +{
> + nir_rounding_mode *out_rounding_mode = _out_rounding_mode;
> + assert(dec->scope == VTN_DEC_DECORATION);
> + if (dec->decoration != SpvDecorationFPRoundingMode)
> + return;
> + switch (dec->literals[0]) {
> + case SpvFPRoundingModeRTE:
> + *out_rounding_mode = nir_rounding_mode_rtne;
> + break;
> + case SpvFPRoundingModeRTZ:
> + *out_rounding_mode = nir_rounding_mode_rtz;
> + break;
> + case SpvFPRoundingModeRTP:
> + *out_rounding_mode = nir_rounding_mode_ru;
> + break;
> + case SpvFPRoundingModeRTN:
> + *out_rounding_mode = nir_rounding_mode_rd;
> + break;
> + default:
> + unreachable("Not supported rounding mode");
> + break;
> + }
> +}
> +
> +static void
> +handle_saturation(struct vtn_builder *b, struct vtn_value *val, int
> member,
> + const struct vtn_decoration *dec, void *_out_saturation)
> +{
> + bool *out_saturation = _out_saturation;
> + assert(dec->scope == VTN_DEC_DECORATION);
> + if (dec->decoration != SpvDecorationSaturatedConversion)
> + return;
> + *out_saturation = true;
> +}
> +
> nir_op
> -vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
> +vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, struct vtn_value
> *val,
> SpvOp opcode, bool *swap,
> unsigned src_bit_size, unsigned
> dst_bit_size)
> {
> @@ -356,42 +394,67 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder
> *b,
> case SpvOpConvertSToF:
> case SpvOpConvertUToF:
> case SpvOpSConvert:
> - case SpvOpFConvert: {
> + case SpvOpFConvert:
> + case SpvOpSatConvertUToS:
> + case SpvOpSatConvertSToU: {
> nir_alu_type src_type;
> nir_alu_type dst_type;
>
> + nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
> + bool saturation = false;
> +
> switch (opcode) {
> case SpvOpConvertFToS:
> src_type = nir_type_float;
> dst_type = nir_type_int;
> + vtn_foreach_decoration(b, val, handle_rounding_mode,
> &rounding_mode);
> + vtn_foreach_decoration(b, val, handle_saturation, &saturation);
> break;
> case SpvOpConvertFToU:
> src_type = nir_type_float;
> dst_type = nir_type_uint;
> + vtn_foreach_decoration(b, val, handle_rounding_mode,
> &rounding_mode);
> + vtn_foreach_decoration(b, val, handle_saturation, &saturation);
> break;
> case SpvOpFConvert:
> src_type = dst_type = nir_type_float;
> + vtn_foreach_decoration(b, val, handle_rounding_mode,
> &rounding_mode);
> break;
> case SpvOpConvertSToF:
> src_type = nir_type_int;
> dst_type = nir_type_float;
> + vtn_foreach_decoration(b, val, handle_rounding_mode,
> &rounding_mode);
> break;
> case SpvOpSConvert:
> src_type = dst_type = nir_type_int;
> + vtn_foreach_decoration(b, val, handle_saturation, &saturation);
> break;
> case SpvOpConvertUToF:
> src_type = nir_type_uint;
> dst_type = nir_type_float;
> + vtn_foreach_decoration(b, val, handle_rounding_mode,
> &rounding_mode);
> break;
> case SpvOpUConvert:
> src_type = dst_type = nir_type_uint;
> + vtn_foreach_decoration(b, val, handle_saturation, &saturation);
> + break;
> + case SpvOpSatConvertUToS:
> + src_type = nir_type_uint;
> + dst_type = nir_type_int;
> + saturation = true;
> + break;
> + case SpvOpSatConvertSToU:
> + src_type = nir_type_int;
> + dst_type = nir_type_uint;
> + saturation = true;
> break;
> default:
> unreachable("Invalid opcode");
> }
> src_type |= src_bit_size;
> dst_type |= dst_bit_size;
> - return nir_type_conversion_op(src_type, dst_type,
> nir_rounding_mode_undef);
> +
> + return nir_type_conversion_op(src_type, dst_type, rounding_mode,
> saturation);
> }
> /* Derivatives: */
> case SpvOpDPdx: return nir_op_fddx;
> @@ -417,27 +480,6 @@ handle_no_contraction(struct vtn_builder *b, struct
> vtn_value *val, int member,
> b->nb.exact = true;
> }
>
> -static void
> -handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int
> member,
> - const struct vtn_decoration *dec, void
> *_out_rounding_mode)
> -{
> - nir_rounding_mode *out_rounding_mode = _out_rounding_mode;
> - assert(dec->scope == VTN_DEC_DECORATION);
> - if (dec->decoration != SpvDecorationFPRoundingMode)
> - return;
> - switch (dec->literals[0]) {
> - case SpvFPRoundingModeRTE:
> - *out_rounding_mode = nir_rounding_mode_rtne;
> - break;
> - case SpvFPRoundingModeRTZ:
> - *out_rounding_mode = nir_rounding_mode_rtz;
> - break;
> - default:
> - unreachable("Not supported rounding mode");
> - break;
> - }
> -}
> -
> void
> vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
> const uint32_t *w, unsigned count)
> @@ -579,7 +621,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
> bool swap;
> unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
> unsigned dst_bit_size = glsl_get_bit_size(type);
> - nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
> + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode, &swap,
> src_bit_size,
> dst_bit_size);
>
> if (swap) {
> @@ -605,7 +647,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
> bool swap;
> unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
> unsigned dst_bit_size = glsl_get_bit_size(type);
> - nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
> + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode, &swap,
> src_bit_size,
> dst_bit_size);
>
> assert(!swap);
> @@ -623,23 +665,11 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
> vtn_handle_bitcast(b, val->ssa, src[0]);
> break;
>
> - case SpvOpFConvert: {
> - nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_
> type(vtn_src[0]->type);
> - nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
> - nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
> -
> - vtn_foreach_decoration(b, val, handle_rounding_mode,
> &rounding_mode);
> - nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type,
> rounding_mode);
> -
> - val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL,
> NULL);
> - break;
> - }
> -
> default: {
> bool swap;
> unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
> unsigned dst_bit_size = glsl_get_bit_size(type);
> - nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
> + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode, &swap,
> src_bit_size,
> dst_bit_size);
>
> if (swap) {
> diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_
> glsl450.c
> index 6fa759b1bba..284371446b5 100644
> --- a/src/compiler/spirv/vtn_glsl450.c
> +++ b/src/compiler/spirv/vtn_glsl450.c
> @@ -659,7 +659,7 @@ handle_glsl450_alu(struct vtn_builder *b, enum
> GLSLstd450 entrypoint,
> nir_op conversion_op =
> nir_type_conversion_op(nir_type_float | eta->bit_size,
> nir_type_float | I->bit_size,
> - nir_rounding_mode_undef);
> + nir_rounding_mode_undef, false);
> eta = nir_build_alu(nb, conversion_op, eta, NULL, NULL, NULL);
> }
> /* k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I)) */
> diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_
> private.h
> index b501bbf9b4a..0895c865fbb 100644
> --- a/src/compiler/spirv/vtn_private.h
> +++ b/src/compiler/spirv/vtn_private.h
> @@ -708,7 +708,7 @@ typedef void (*vtn_execution_mode_foreach_cb)(struct
> vtn_builder *,
> void vtn_foreach_execution_mode(struct vtn_builder *b, struct vtn_value
> *value,
> vtn_execution_mode_foreach_cb cb, void
> *data);
>
> -nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
> +nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, struct
> vtn_value *val,
> SpvOp opcode, bool *swap,
> unsigned src_bit_size, unsigned
> dst_bit_size);
>
> --
> 2.14.3
>
> _______________________________________________
> mesa-dev mailing list
> mesa-dev at lists.freedesktop.org
> https://lists.freedesktop.org/mailman/listinfo/mesa-dev
>
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://lists.freedesktop.org/archives/mesa-dev/attachments/20180428/e74fc0df/attachment-0001.html>
More information about the mesa-dev
mailing list