[Mesa-dev] [PATCH 2/2] nir: Add support for 8 and 16-bit types

Jason Ekstrand jason at jlekstrand.net
Thu Mar 9 18:23:57 UTC 2017


---
 src/compiler/nir/nir.h                       |  4 ++++
 src/compiler/nir/nir_constant_expressions.py | 16 +++++++++++++++-
 src/compiler/nir/nir_opcodes.py              |  6 +++++-
 3 files changed, 24 insertions(+), 2 deletions(-)

diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 57b8be3..eaa103d 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -105,6 +105,10 @@ typedef enum {
 typedef union {
    float f32[4];
    double f64[4];
+   int8_t i8[4];
+   uint8_t u8[4];
+   int16_t i16[4];
+   uint16_t u16[4];
    int32_t i32[4];
    uint32_t u32[4];
    int64_t i64[4];
diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py
index aecca8b..cbda4b1 100644
--- a/src/compiler/nir/nir_constant_expressions.py
+++ b/src/compiler/nir/nir_constant_expressions.py
@@ -14,8 +14,10 @@ def type_size(type_):
 def type_sizes(type_):
     if type_has_size(type_):
         return [type_size(type_)]
+    elif type_ == 'float':
+        return [16, 32, 64]
     else:
-        return [32, 64]
+        return [8, 16, 32, 64]
 
 def type_add_size(type_, size):
     if type_has_size(type_):
@@ -34,6 +36,8 @@ def op_bit_sizes(op):
 def get_const_field(type_):
     if type_ == "bool32":
         return "u32"
+    elif type_ == "float16":
+        return "u16"
     else:
         m = type_split_re.match(type_)
         if not m:
@@ -246,6 +250,7 @@ unpack_half_1x16(uint16_t u)
 }
 
 /* Some typed vector structures to make things like src0.y work */
+typedef float float16_t;
 typedef float float32_t;
 typedef double float64_t;
 typedef bool bool32_t;
@@ -297,6 +302,8 @@ evaluate_${name}(MAYBE_UNUSED unsigned num_components, unsigned bit_size,
          % for k in range(op.input_sizes[j]):
             % if input_types[j] == "bool32":
                _src[${j}].u32[${k}] != 0,
+            % elif input_types[j] == "float16":
+               _mesa_half_to_float(_src[${j}].u16[${k}]),
             % else:
                _src[${j}].${get_const_field(input_types[j])}[${k}],
             % endif
@@ -322,6 +329,9 @@ evaluate_${name}(MAYBE_UNUSED unsigned num_components, unsigned bit_size,
                   <% continue %>
                % elif input_types[j] == "bool32":
                   const bool src${j} = _src[${j}].u32[_i] != 0;
+               % elif input_types[j] == "float16":
+                  const float src${j} =
+                     _mesa_half_to_float(_src[${j}].u16[_i]);
                % else:
                   const ${input_types[j]}_t src${j} =
                      _src[${j}].${get_const_field(input_types[j])}[_i];
@@ -344,6 +354,8 @@ evaluate_${name}(MAYBE_UNUSED unsigned num_components, unsigned bit_size,
             % if output_type == "bool32":
                ## Sanitize the C value to a proper NIR bool
                _dst_val.u32[_i] = dst ? NIR_TRUE : NIR_FALSE;
+            % elif output_type == "float16":
+               _dst_val.u16[_i] = _mesa_float_to_half(dst);
             % else:
                _dst_val.${get_const_field(output_type)}[_i] = dst;
             % endif
@@ -371,6 +383,8 @@ evaluate_${name}(MAYBE_UNUSED unsigned num_components, unsigned bit_size,
             % if output_type == "bool32":
                ## Sanitize the C value to a proper NIR bool
                _dst_val.u32[${k}] = dst.${"xyzw"[k]} ? NIR_TRUE : NIR_FALSE;
+            % elif output_type == "float16":
+               _dst_val.u16[${k}] = _mesa_float_to_half(dst.${"xyzw"[k]});
             % else:
                _dst_val.${get_const_field(output_type)}[${k}] = dst.${"xyzw"[k]};
             % endif
diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py
index 53e9aff..37c655b 100644
--- a/src/compiler/nir/nir_opcodes.py
+++ b/src/compiler/nir/nir_opcodes.py
@@ -175,7 +175,11 @@ for src_t in [tint, tuint, tfloat]:
       dst_types = [tint, tuint, tfloat]
 
    for dst_t in dst_types:
-      for bit_size in [32, 64]:
+      if dst_t == tfloat:
+         bit_sizes = [16, 32, 64]
+      else:
+         bit_sizes = [8, 16, 32, 64]
+      for bit_size in bit_sizes:
          unop_convert("{}2{}{}".format(src_t[0], dst_t[0], bit_size),
                       dst_t + str(bit_size), src_t, "src0")
 
-- 
2.5.0.400.gff86faf



More information about the mesa-dev mailing list