[Mesa-dev] [PATCH 16/22] nir+vtn: vec8+vec16 support

Karol Herbst kherbst at redhat.com
Tue Nov 13 15:48:20 UTC 2018


This introduces new vec8 and vec16 instructions (which are the only
instructions taking more than 4 sources), in order to construct 8 and 16
component vectors.

In order to avoid fixing up the non-autogenerated nir_build_alu() sites
and making them pass 16 src args for the benefit of the two instructions
that take more than 4 srcs (ie vec8 and vec16), nir_build_alu() is has
nir_build_alu_tail() split out and re-used by nir_build_alu2() (which is
used for the > 4 src args case).

Signed-off-by: Rob Clark <robdclark at gmail.com>
Signed-off-by: Karol Herbst <kherbst at redhat.com>
---
 src/compiler/nir/nir.h                       |  4 +-
 src/compiler/nir/nir_builder.h               | 58 +++++++++++++++-----
 src/compiler/nir/nir_builder_opcodes_h.py    |  5 +-
 src/compiler/nir/nir_constant_expressions.py | 33 +++++++++--
 src/compiler/nir/nir_lower_alu_to_scalar.c   |  2 +
 src/compiler/nir/nir_opcodes.py              | 39 ++++++++++++-
 src/compiler/nir/nir_print.c                 | 17 ++++--
 src/compiler/nir/nir_search.c                |  8 ++-
 src/compiler/spirv/spirv_to_nir.c            |  4 +-
 9 files changed, 140 insertions(+), 30 deletions(-)

diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 3855eb0b582..89c28e36618 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -57,8 +57,8 @@ extern "C" {
 
 #define NIR_FALSE 0u
 #define NIR_TRUE (~0u)
-#define NIR_MAX_VEC_COMPONENTS 4
-typedef uint8_t nir_component_mask_t;
+#define NIR_MAX_VEC_COMPONENTS 16
+typedef uint16_t nir_component_mask_t;
 
 /** Defines a cast function
  *
diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h
index 3271a480520..57f0a188c46 100644
--- a/src/compiler/nir/nir_builder.h
+++ b/src/compiler/nir/nir_builder.h
@@ -352,24 +352,12 @@ nir_imm_ivec4(nir_builder *build, int x, int y, int z, int w)
 }
 
 static inline nir_ssa_def *
-nir_build_alu(nir_builder *build, nir_op op, nir_ssa_def *src0,
-              nir_ssa_def *src1, nir_ssa_def *src2, nir_ssa_def *src3)
+nir_build_alu_tail(nir_builder *build, nir_alu_instr *instr)
 {
-   const nir_op_info *op_info = &nir_op_infos[op];
-   nir_alu_instr *instr = nir_alu_instr_create(build->shader, op);
-   if (!instr)
-      return NULL;
+   const nir_op_info *op_info = &nir_op_infos[instr->op];
 
    instr->exact = build->exact;
 
-   instr->src[0].src = nir_src_for_ssa(src0);
-   if (src1)
-      instr->src[1].src = nir_src_for_ssa(src1);
-   if (src2)
-      instr->src[2].src = nir_src_for_ssa(src2);
-   if (src3)
-      instr->src[3].src = nir_src_for_ssa(src3);
-
    /* Guess the number of components the destination temporary should have
     * based on our input sizes, if it's not fixed for the op.
     */
@@ -425,12 +413,54 @@ nir_build_alu(nir_builder *build, nir_op op, nir_ssa_def *src0,
    return &instr->dest.dest.ssa;
 }
 
+static inline nir_ssa_def *
+nir_build_alu(nir_builder *build, nir_op op, nir_ssa_def *src0,
+              nir_ssa_def *src1, nir_ssa_def *src2, nir_ssa_def *src3)
+{
+   nir_alu_instr *instr = nir_alu_instr_create(build->shader, op);
+   if (!instr)
+      return NULL;
+
+   instr->src[0].src = nir_src_for_ssa(src0);
+   if (src1)
+      instr->src[1].src = nir_src_for_ssa(src1);
+   if (src2)
+      instr->src[2].src = nir_src_for_ssa(src2);
+   if (src3)
+      instr->src[3].src = nir_src_for_ssa(src3);
+
+   return nir_build_alu_tail(build, instr);
+}
+
+/* for the couple special cases with more than 4 src args: */
+static inline nir_ssa_def *
+nir_build_alu2(nir_builder *build, nir_op op, nir_ssa_def **srcs)
+{
+   const nir_op_info *op_info = &nir_op_infos[op];
+   nir_alu_instr *instr = nir_alu_instr_create(build->shader, op);
+   if (!instr)
+      return NULL;
+
+   for (unsigned i = 0; i < op_info->num_inputs; i++)
+      instr->src[i].src = nir_src_for_ssa(srcs[i]);
+
+   return nir_build_alu_tail(build, instr);
+}
+
 #include "nir_builder_opcodes.h"
 
 static inline nir_ssa_def *
 nir_vec(nir_builder *build, nir_ssa_def **comp, unsigned num_components)
 {
    switch (num_components) {
+   case 16:
+      return nir_vec16(build, comp[0], comp[1], comp[2], comp[3],
+                       comp[4], comp[5], comp[6], comp[7],
+                       comp[8], comp[9], comp[10], comp[11],
+                       comp[12], comp[13], comp[14], comp[15]);
+   case 8:
+      return nir_vec8(build, comp[0], comp[1], comp[2], comp[3],
+                      comp[4], comp[5], comp[6], comp[7]);
    case 4:
       return nir_vec4(build, comp[0], comp[1], comp[2], comp[3]);
    case 3:
diff --git a/src/compiler/nir/nir_builder_opcodes_h.py b/src/compiler/nir/nir_builder_opcodes_h.py
index 84e5400958e..47edc02896c 100644
--- a/src/compiler/nir/nir_builder_opcodes_h.py
+++ b/src/compiler/nir/nir_builder_opcodes_h.py
@@ -31,14 +31,15 @@ def src_decl_list(num_srcs):
    return ', '.join('nir_ssa_def *src' + str(i) for i in range(num_srcs))
 
 def src_list(num_srcs):
-   return ', '.join('src' + str(i) if i < num_srcs else 'NULL' for i in range(4))
+   return ', '.join('src' + str(i) if i < num_srcs else 'NULL' for i in range(16))
 %>
 
 % for name, opcode in sorted(opcodes.items()):
 static inline nir_ssa_def *
 nir_${name}(nir_builder *build, ${src_decl_list(opcode.num_inputs)})
 {
-   return nir_build_alu(build, nir_op_${name}, ${src_list(opcode.num_inputs)});
+   nir_ssa_def *srcs[] = {${src_list(opcode.num_inputs)}};
+   return nir_build_alu2(build, nir_op_${name}, srcs);
 }
 % endfor
 
diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py
index 118af9f7818..fe54a8f710d 100644
--- a/src/compiler/nir/nir_constant_expressions.py
+++ b/src/compiler/nir/nir_constant_expressions.py
@@ -258,6 +258,7 @@ typedef float float16_t;
 typedef float float32_t;
 typedef double float64_t;
 typedef bool bool32_t;
+
 % for type in ["float", "int", "uint"]:
 % for width in type_sizes(type):
 struct ${type}${width}_vec {
@@ -265,6 +266,18 @@ struct ${type}${width}_vec {
    ${type}${width}_t y;
    ${type}${width}_t z;
    ${type}${width}_t w;
+   ${type}${width}_t e;
+   ${type}${width}_t f;
+   ${type}${width}_t g;
+   ${type}${width}_t h;
+   ${type}${width}_t i;
+   ${type}${width}_t j;
+   ${type}${width}_t k;
+   ${type}${width}_t l;
+   ${type}${width}_t m;
+   ${type}${width}_t n;
+   ${type}${width}_t o;
+   ${type}${width}_t p;
 };
 % endfor
 % endfor
@@ -274,6 +287,18 @@ struct bool32_vec {
     bool y;
     bool z;
     bool w;
+    bool e;
+    bool f;
+    bool g;
+    bool h;
+    bool i;
+    bool j;
+    bool k;
+    bool l;
+    bool m;
+    bool n;
+    bool o;
+    bool p;
 };
 
 <%def name="evaluate_op(op, bit_size)">
@@ -303,7 +328,7 @@ struct bool32_vec {
             _src[${j}].${get_const_field(input_types[j])}[${k}],
          % endif
       % endfor
-      % for k in range(op.input_sizes[j], 4):
+      % for k in range(op.input_sizes[j], 16):
          0,
       % endfor
       };
@@ -377,11 +402,11 @@ struct bool32_vec {
       % for k in range(op.output_size):
          % if output_type == "bool32":
             ## Sanitize the C value to a proper NIR bool
-            _dst_val.u32[${k}] = dst.${"xyzw"[k]} ? NIR_TRUE : NIR_FALSE;
+            _dst_val.u32[${k}] = dst.${"xyzwefghijklmnop"[k]} ? NIR_TRUE : NIR_FALSE;
          % elif output_type == "float16":
-            _dst_val.u16[${k}] = _mesa_float_to_half(dst.${"xyzw"[k]});
+            _dst_val.u16[${k}] = _mesa_float_to_half(dst.${"xyzwefghijklmnop"[k]});
          % else:
-            _dst_val.${get_const_field(output_type)}[${k}] = dst.${"xyzw"[k]};
+            _dst_val.${get_const_field(output_type)}[${k}] = dst.${"xyzwefghijklmnop"[k]};
          % endif
       % endfor
    % endif
diff --git a/src/compiler/nir/nir_lower_alu_to_scalar.c b/src/compiler/nir/nir_lower_alu_to_scalar.c
index 0be3aba9456..5e8b76426fb 100644
--- a/src/compiler/nir/nir_lower_alu_to_scalar.c
+++ b/src/compiler/nir/nir_lower_alu_to_scalar.c
@@ -93,6 +93,8 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
       return true;
 
    switch (instr->op) {
+   case nir_op_vec16:
+   case nir_op_vec8:
    case nir_op_vec4:
    case nir_op_vec3:
    case nir_op_vec2:
diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py
index 4ef4ecc6f22..bd212b7c2fb 100644
--- a/src/compiler/nir/nir_opcodes.py
+++ b/src/compiler/nir/nir_opcodes.py
@@ -72,7 +72,7 @@ class Opcode(object):
       assert isinstance(algebraic_properties, str)
       assert isinstance(const_expr, str)
       assert len(input_sizes) == len(input_types)
-      assert 0 <= output_size <= 4
+      assert (0 <= output_size <= 4) or (output_size == 8) or (output_size == 16)
       for size in input_sizes:
          assert 0 <= size <= 4
          if output_size != 0:
@@ -804,4 +804,41 @@ dst.z = src2.x;
 dst.w = src3.x;
 """)
 
+opcode("vec8", 8, tuint,
+        [1, 1, 1, 1, 1, 1, 1, 1],
+        [tuint, tuint, tuint, tuint, tuint, tuint, tuint, tuint],
+        "", """
+dst.x = src0.x;
+dst.y = src1.x;
+dst.z = src2.x;
+dst.w = src3.x;
+dst.e = src4.x;
+dst.f = src5.x;
+dst.g = src6.x;
+dst.h = src7.x;
+""")
+
+opcode("vec16", 16, tuint,
+        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
+        [tuint, tuint, tuint, tuint, tuint, tuint, tuint, tuint,
+         tuint, tuint, tuint, tuint, tuint, tuint, tuint, tuint],
+        "", """
+dst.x = src0.x;
+dst.y = src1.x;
+dst.z = src2.x;
+dst.w = src3.x;
+dst.e = src4.x;
+dst.f = src5.x;
+dst.g = src6.x;
+dst.h = src7.x;
+dst.i = src8.x;
+dst.j = src9.x;
+dst.k = src10.x;
+dst.l = src11.x;
+dst.m = src12.x;
+dst.n = src13.x;
+dst.o = src14.x;
+dst.p = src15.x;
+""")
+
 
diff --git a/src/compiler/nir/nir_print.c b/src/compiler/nir/nir_print.c
index ab3d5115688..0de82800f8c 100644
--- a/src/compiler/nir/nir_print.c
+++ b/src/compiler/nir/nir_print.c
@@ -173,6 +173,12 @@ print_dest(nir_dest *dest, print_state *state)
       print_reg_dest(&dest->reg, state);
 }
 
+static const char *
+wrmask_string(unsigned num_components)
+{
+   return (num_components > 4) ? "abcdefghijklmnop" : "xyzw";
+}
+
 static void
 print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state)
 {
@@ -208,7 +214,7 @@ print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state)
          if (!nir_alu_instr_channel_used(instr, src, i))
             continue;
 
-         fprintf(fp, "%c", "xyzw"[instr->src[src].swizzle[i]]);
+         fprintf(fp, "%c", wrmask_string(live_channels)[instr->src[src].swizzle[i]]);
       }
    }
 
@@ -226,10 +232,11 @@ print_alu_dest(nir_alu_dest *dest, print_state *state)
 
    if (!dest->dest.is_ssa &&
        dest->write_mask != (1 << dest->dest.reg.reg->num_components) - 1) {
+      unsigned live_channels = dest->dest.reg.reg->num_components;
       fprintf(fp, ".");
       for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
          if ((dest->write_mask >> i) & 1)
-            fprintf(fp, "%c", "xyzw"[i]);
+            fprintf(fp, "%c", wrmask_string(live_channels)[i]);
    }
 }
 
@@ -493,7 +500,7 @@ print_var_decl(nir_variable *var, print_state *state)
       case nir_var_shader_in:
       case nir_var_shader_out:
          if (num_components < 4 && num_components != 0) {
-            const char *xyzw = "xyzw";
+            const char *xyzw = wrmask_string(num_components);
             for (int i = 0; i < num_components; i++)
                components_local[i + 1] = xyzw[i + var->data.location_frac];
 
@@ -700,9 +707,9 @@ print_intrinsic_instr(nir_intrinsic_instr *instr, print_state *state)
          /* special case wrmask to show it as a writemask.. */
          unsigned wrmask = nir_intrinsic_write_mask(instr);
          fprintf(fp, " wrmask=");
-         for (unsigned i = 0; i < 4; i++)
+         for (unsigned i = 0; i < instr->num_components; i++)
             if ((wrmask >> i) & 1)
-               fprintf(fp, "%c", "xyzw"[i]);
+               fprintf(fp, "%c", wrmask_string(instr->num_components)[i]);
       } else if (idx == NIR_INTRINSIC_REDUCTION_OP) {
          nir_op reduction_op = nir_intrinsic_reduction_op(instr);
          fprintf(fp, " reduction_op=%s", nir_op_infos[reduction_op].name);
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index 0270302fd3d..642755f2a6a 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -42,7 +42,13 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
                  unsigned num_components, const uint8_t *swizzle,
                  struct match_state *state);
 
-static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
+static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] =
+{
+   0,  1,  2,  3,
+   4,  5,  6,  7,
+   8,  9, 10, 11,
+  12, 13, 14, 15
+};
 
 /**
  * Check if a source produces a value of the given type.
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index e597b2462cb..a350a95e27e 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -2838,6 +2838,8 @@ create_vec(struct vtn_builder *b, unsigned num_components, unsigned bit_size)
    case 2: op = nir_op_vec2; break;
    case 3: op = nir_op_vec3; break;
    case 4: op = nir_op_vec4; break;
+   case 8: op = nir_op_vec8; break;
+   case 16: op = nir_op_vec16; break;
    default: vtn_fail("bad vector size");
    }
 
@@ -3422,10 +3424,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
       case SpvCapabilityInputAttachment:
       case SpvCapabilityImageGatherExtended:
       case SpvCapabilityStorageImageExtendedFormats:
+      case SpvCapabilityVector16:
          break;
 
       case SpvCapabilityLinkage:
-      case SpvCapabilityVector16:
       case SpvCapabilityFloat16Buffer:
       case SpvCapabilityFloat16:
       case SpvCapabilityInt64Atomics:
-- 
2.19.1



More information about the mesa-dev mailing list