[Mesa-dev] [PATCH 2/2] nir: Add support for 8 and 16-bit types
Iago Toral
itoral at igalia.com
Fri Mar 10 09:56:08 UTC 2017
On Thu, 2017-03-09 at 14:05 -0800, Jason Ekstrand wrote:
> ---
> src/compiler/nir/nir.h | 4 ++++
> src/compiler/nir/nir_constant_expressions.py | 16 +++++++++++++++-
> src/compiler/nir/nir_opcodes.py | 6 +++++-
> 3 files changed, 24 insertions(+), 2 deletions(-)
>
> diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
> index 57b8be3..eaa103d 100644
> --- a/src/compiler/nir/nir.h
> +++ b/src/compiler/nir/nir.h
> @@ -105,6 +105,10 @@ typedef enum {
> typedef union {
> float f32[4];
> double f64[4];
> + int8_t i8[4];
> + uint8_t u8[4];
> + int16_t i16[4];
> + uint16_t u16[4];
> int32_t i32[4];
> uint32_t u32[4];
> int64_t i64[4];
>
> diff --git a/src/compiler/nir/nir_constant_expressions.py
> b/src/compiler/nir/nir_constant_expressions.py
> index aecca8b..cbda4b1 100644
> --- a/src/compiler/nir/nir_constant_expressions.py
> +++ b/src/compiler/nir/nir_constant_expressions.py
> @@ -14,8 +14,10 @@ def type_size(type_):
> def type_sizes(type_):
> if type_has_size(type_):
> return [type_size(type_)]
> + elif type_ == 'float':
> + return [16, 32, 64]
> else:
> - return [32, 64]
> + return [8, 16, 32, 64]
>
> def type_add_size(type_, size):
> if type_has_size(type_):
> @@ -34,6 +36,8 @@ def op_bit_sizes(op):
> def get_const_field(type_):
> if type_ == "bool32":
> return "u32"
> + elif type_ == "float16":
> + return "u16"
I was wondering why not have a f16 field with type float16_t
in nir_const_value instead, but reading the rest of the patch it seems
we really want to work with 32-bit floats and then encode results back to 16-bit values? Is it because there are some operations that we can't implement directly with half-floats?
Iago
> else:
> m = type_split_re.match(type_)
> if not m:
> @@ -246,6 +250,7 @@ unpack_half_1x16(uint16_t u)
> }
>
> /* Some typed vector structures to make things like src0.y work */
> +typedef float float16_t;
> typedef float float32_t;
> typedef double float64_t;
> typedef bool bool32_t;
> @@ -297,6 +302,8 @@ evaluate_${name}(MAYBE_UNUSED unsigned
> num_components, unsigned bit_size,
> % for k in range(op.input_sizes[j]):
> % if input_types[j] == "bool32":
> _src[${j}].u32[${k}] != 0,
> + % elif input_types[j] == "float16":
> + _mesa_half_to_float(_src[${j}].u16[${k}]),
> % else:
> _src[${j}].${get_const_field(input_types[j])}[${k}],
> % endif
> @@ -322,6 +329,9 @@ evaluate_${name}(MAYBE_UNUSED unsigned
> num_components, unsigned bit_size,
> <% continue %>
> % elif input_types[j] == "bool32":
> const bool src${j} = _src[${j}].u32[_i] != 0;
> + % elif input_types[j] == "float16":
> + const float src${j} =
> + _mesa_half_to_float(_src[${j}].u16[_i]);
> % else:
> const ${input_types[j]}_t src${j} =
> _src[${j}].${get_const_field(input_types[j])}[_
> i];
> @@ -344,6 +354,8 @@ evaluate_${name}(MAYBE_UNUSED unsigned
> num_components, unsigned bit_size,
> % if output_type == "bool32":
> ## Sanitize the C value to a proper NIR bool
> _dst_val.u32[_i] = dst ? NIR_TRUE : NIR_FALSE;
> + % elif output_type == "float16":
> + _dst_val.u16[_i] = _mesa_float_to_half(dst);
> % else:
> _dst_val.${get_const_field(output_type)}[_i] = dst;
> % endif
> @@ -371,6 +383,8 @@ evaluate_${name}(MAYBE_UNUSED unsigned
> num_components, unsigned bit_size,
> % if output_type == "bool32":
> ## Sanitize the C value to a proper NIR bool
> _dst_val.u32[${k}] = dst.${"xyzw"[k]} ? NIR_TRUE :
> NIR_FALSE;
> + % elif output_type == "float16":
> + _dst_val.u16[${k}] =
> _mesa_float_to_half(dst.${"xyzw"[k]});
> % else:
> _dst_val.${get_const_field(output_type)}[${k}] =
> dst.${"xyzw"[k]};
> % endif
> diff --git a/src/compiler/nir/nir_opcodes.py
> b/src/compiler/nir/nir_opcodes.py
> index 53e9aff..37c655b 100644
> --- a/src/compiler/nir/nir_opcodes.py
> +++ b/src/compiler/nir/nir_opcodes.py
> @@ -175,7 +175,11 @@ for src_t in [tint, tuint, tfloat]:
> dst_types = [tint, tuint, tfloat]
>
> for dst_t in dst_types:
> - for bit_size in [32, 64]:
> + if dst_t == tfloat:
> + bit_sizes = [16, 32, 64]
> + else:
> + bit_sizes = [8, 16, 32, 64]
> + for bit_size in bit_sizes:
> unop_convert("{}2{}{}".format(src_t[0], dst_t[0],
> bit_size),
> dst_t + str(bit_size), src_t, "src0")
>
More information about the mesa-dev
mailing list