Mesa (master): spirv,nir: Move the SPIR-V vector insert code to NIR

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Fri Apr 17 19:38:47 UTC 2020


Module: Mesa
Branch: master
Commit: f5deed138a0b4765438135367248f1d8f0649975
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=f5deed138a0b4765438135367248f1d8f0649975

Author: Jason Ekstrand <jason at jlekstrand.net>
Date:   Thu Apr  9 17:09:10 2020 -0500

spirv,nir: Move the SPIR-V vector insert code to NIR

This also makes spirv_to_nir a bit simpler because the new
nir_vector_insert helper automatically handles a constant component
selector like nir_vector_extract does.

Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira at intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4495>

---

 src/compiler/nir/nir_builder.h     | 55 ++++++++++++++++++++++++++++++++++++++
 src/compiler/spirv/spirv_to_nir.c  | 46 +++----------------------------
 src/compiler/spirv/vtn_private.h   |  5 ----
 src/compiler/spirv/vtn_variables.c |  8 ++----
 4 files changed, 61 insertions(+), 53 deletions(-)

diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h
index 481ea6382bf..52fcf9e2250 100644
--- a/src/compiler/nir/nir_builder.h
+++ b/src/compiler/nir/nir_builder.h
@@ -601,6 +601,61 @@ nir_vector_extract(nir_builder *b, nir_ssa_def *vec, nir_ssa_def *c)
    }
 }
 
+/** Replaces the component of `vec` specified by `c` with `scalar` */
+static inline nir_ssa_def *
+nir_vector_insert_imm(nir_builder *b, nir_ssa_def *vec,
+                      nir_ssa_def *scalar, unsigned c)
+{
+   assert(scalar->num_components == 1);
+   assert(c < vec->num_components);
+
+   nir_op vec_op = nir_op_vec(vec->num_components);
+   nir_alu_instr *vec_instr = nir_alu_instr_create(b->shader, vec_op);
+
+   for (unsigned i = 0; i < vec->num_components; i++) {
+      if (i == c) {
+         vec_instr->src[i].src = nir_src_for_ssa(scalar);
+         vec_instr->src[i].swizzle[0] = 0;
+      } else {
+         vec_instr->src[i].src = nir_src_for_ssa(vec);
+         vec_instr->src[i].swizzle[0] = i;
+      }
+   }
+
+   return nir_builder_alu_instr_finish_and_insert(b, vec_instr);
+}
+
+/** Replaces the component of `vec` specified by `c` with `scalar` */
+static inline nir_ssa_def *
+nir_vector_insert(nir_builder *b, nir_ssa_def *vec, nir_ssa_def *scalar,
+                  nir_ssa_def *c)
+{
+   assert(scalar->num_components == 1);
+   assert(c->num_components == 1);
+
+   nir_src c_src = nir_src_for_ssa(c);
+   if (nir_src_is_const(c_src)) {
+      uint64_t c_const = nir_src_as_uint(c_src);
+      if (c_const < vec->num_components)
+         return nir_vector_insert_imm(b, vec, scalar, c_const);
+      else
+         return vec;
+   } else {
+      nir_const_value per_comp_idx_const[NIR_MAX_VEC_COMPONENTS];
+      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
+         per_comp_idx_const[i] = nir_const_value_for_int(i, c->bit_size);
+      nir_ssa_def *per_comp_idx =
+         nir_build_imm(b, vec->num_components,
+                       c->bit_size, per_comp_idx_const);
+
+      /* nir_builder will automatically splat out scalars to vectors so an
+       * insert is as simple as "if I'm the channel, replace me with the
+       * scalar."
+       */
+      return nir_bcsel(b, nir_ieq(b, c, per_comp_idx), scalar, vec);
+   }
+}
+
 static inline nir_ssa_def *
 nir_i2i(nir_builder *build, nir_ssa_def *x, unsigned dest_bit_size)
 {
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index 2cc8f2570c7..3cac23433f2 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -3311,44 +3311,6 @@ vtn_ssa_transpose(struct vtn_builder *b, struct vtn_ssa_value *src)
    return dest;
 }
 
-nir_ssa_def *
-vtn_vector_insert(struct vtn_builder *b, nir_ssa_def *src, nir_ssa_def *insert,
-                  unsigned index)
-{
-   nir_alu_instr *vec = create_vec(b, src->num_components,
-                                   src->bit_size);
-
-   for (unsigned i = 0; i < src->num_components; i++) {
-      if (i == index) {
-         vec->src[i].src = nir_src_for_ssa(insert);
-      } else {
-         vec->src[i].src = nir_src_for_ssa(src);
-         vec->src[i].swizzle[0] = i;
-      }
-   }
-
-   nir_builder_instr_insert(&b->nb, &vec->instr);
-
-   return &vec->dest.dest.ssa;
-}
-
-nir_ssa_def *
-vtn_vector_insert_dynamic(struct vtn_builder *b, nir_ssa_def *src,
-                          nir_ssa_def *insert, nir_ssa_def *index)
-{
-   nir_const_value per_comp_idx_const[NIR_MAX_VEC_COMPONENTS];
-   for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
-      per_comp_idx_const[i] = nir_const_value_for_int(i, index->bit_size);
-   nir_ssa_def *per_comp_idx =
-      nir_build_imm(&b->nb, src->num_components,
-                    index->bit_size, per_comp_idx_const);
-
-   /* nir_builder will automatically splat out scalars to vectors so an insert
-    * is as simple as "if I'm the channel, replace me with the scalar."
-    */
-   return nir_bcsel(&b->nb, nir_ieq(&b->nb, index, per_comp_idx), insert, src);
-}
-
 static nir_ssa_def *
 vtn_vector_shuffle(struct vtn_builder *b, unsigned num_components,
                    nir_ssa_def *src0, nir_ssa_def *src1,
@@ -3462,7 +3424,7 @@ vtn_composite_insert(struct vtn_builder *b, struct vtn_ssa_value *src,
        * the index to insert the scalar into the vector.
        */
 
-      cur->def = vtn_vector_insert(b, cur->def, insert->def, indices[i]);
+      cur->def = nir_vector_insert_imm(&b->nb, cur->def, insert->def, indices[i]);
    } else {
       vtn_fail_if(indices[i] >= glsl_get_length(cur->type),
                   "All indices in an OpCompositeInsert must be in-bounds");
@@ -3516,9 +3478,9 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode,
       break;
 
    case SpvOpVectorInsertDynamic:
-      ssa->def = vtn_vector_insert_dynamic(b, vtn_ssa_value(b, w[3])->def,
-                                           vtn_ssa_value(b, w[4])->def,
-                                           vtn_ssa_value(b, w[5])->def);
+      ssa->def = nir_vector_insert(&b->nb, vtn_ssa_value(b, w[3])->def,
+                                   vtn_ssa_value(b, w[4])->def,
+                                   vtn_ssa_value(b, w[5])->def);
       break;
 
    case SpvOpVectorShuffle:
diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h
index 3624ac3fa7e..f4e6201febe 100644
--- a/src/compiler/spirv/vtn_private.h
+++ b/src/compiler/spirv/vtn_private.h
@@ -797,11 +797,6 @@ struct vtn_ssa_value *vtn_create_ssa_value(struct vtn_builder *b,
 struct vtn_ssa_value *vtn_ssa_transpose(struct vtn_builder *b,
                                         struct vtn_ssa_value *src);
 
-nir_ssa_def *vtn_vector_insert(struct vtn_builder *b, nir_ssa_def *src,
-                               nir_ssa_def *insert, unsigned index);
-nir_ssa_def *vtn_vector_insert_dynamic(struct vtn_builder *b, nir_ssa_def *src,
-                                       nir_ssa_def *insert, nir_ssa_def *index);
-
 nir_deref_instr *vtn_nir_deref(struct vtn_builder *b, uint32_t id);
 
 struct vtn_pointer *vtn_pointer_for_variable(struct vtn_builder *b,
diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c
index 8bb00a5dd40..9dc2c755ca5 100644
--- a/src/compiler/spirv/vtn_variables.c
+++ b/src/compiler/spirv/vtn_variables.c
@@ -751,12 +751,8 @@ vtn_local_store(struct vtn_builder *b, struct vtn_ssa_value *src,
       struct vtn_ssa_value *val = vtn_create_ssa_value(b, dest_tail->type);
       _vtn_local_load_store(b, true, dest_tail, val, access);
 
-      if (nir_src_is_const(dest->arr.index))
-         val->def = vtn_vector_insert(b, val->def, src->def,
-                                      nir_src_as_uint(dest->arr.index));
-      else
-         val->def = vtn_vector_insert_dynamic(b, val->def, src->def,
-                                              dest->arr.index.ssa);
+      val->def = nir_vector_insert(&b->nb, val->def, src->def,
+                                   dest->arr.index.ssa);
       _vtn_local_load_store(b, false, dest_tail, val, access);
    } else {
       _vtn_local_load_store(b, false, dest_tail, src, access);



More information about the mesa-commit mailing list