[Mesa-dev] [PATCH 08/18] nir/spirv: Simplify matrix loads/stores

Jason Ekstrand jason at jlekstrand.net
Thu Jun 29 17:33:30 UTC 2017


Instead of handling all of the complexity at the end, we choose to
decorate types a bit more cleverly.  When we have a row-major matrix
type, we give it the stride of a single vector and give it's array
element type (which represents a column) the actual matrix stride.

Previously, we were using stop_at_matrix and handling everything from
matrix on down as special cases but now we walk the access chain all the
way to the end and then load.  Even though this looks like it may lead
to a significant functional change, it doesn't.  The reason why we
needed to do stop_at_matrix before was to handle row-major properly
since the offsets and strides would be all out-of-order.  Now that row
major matrix types have the small stride on the matrix and the large
stride on the vector, offsetting to a single column of a row-major
matrix works fine.  The load/store code simply picks up on the fact that
the stride isn't the type size and does multiple loads.  The generated
code from these methods should be the same.
---
 src/compiler/spirv/spirv_to_nir.c  |  29 ++++++-
 src/compiler/spirv/vtn_variables.c | 161 ++++++++++++++-----------------------
 2 files changed, 90 insertions(+), 100 deletions(-)

diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index 7a98843..72a8904 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -520,7 +520,7 @@ struct_member_decoration_cb(struct vtn_builder *b,
       ctx->type->offsets[member] = dec->literals[0];
       break;
    case SpvDecorationMatrixStride:
-      mutable_matrix_member(b, ctx->type, member)->stride = dec->literals[0];
+      /* Handled as a second pass */
       break;
    case SpvDecorationColMajor:
       break; /* Nothing to do here.  Column-major is the default. */
@@ -571,6 +571,32 @@ struct_member_decoration_cb(struct vtn_builder *b,
    }
 }
 
+/* Matrix strides are handled as a separate pass because we need to know
+ * whether the matrix is row-major or not first.
+ */
+static void
+struct_member_matrix_stride_cb(struct vtn_builder *b,
+                               struct vtn_value *val, int member,
+                               const struct vtn_decoration *dec,
+                               void *void_ctx)
+{
+   if (dec->decoration != SpvDecorationMatrixStride)
+      return;
+   assert(member >= 0);
+
+   struct member_decoration_ctx *ctx = void_ctx;
+
+   struct vtn_type *mat_type = mutable_matrix_member(b, ctx->type, member);
+   if (mat_type->row_major) {
+      mat_type->array_element = vtn_type_copy(b, mat_type->array_element);
+      mat_type->stride = mat_type->array_element->stride;
+      mat_type->array_element->stride = dec->literals[0];
+   } else {
+      assert(mat_type->array_element->stride > 0);
+      mat_type->stride = dec->literals[0];
+   }
+}
+
 static void
 type_decoration_cb(struct vtn_builder *b,
                    struct vtn_value *val, int member,
@@ -807,6 +833,7 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
       };
 
       vtn_foreach_decoration(b, val, struct_member_decoration_cb, &ctx);
+      vtn_foreach_decoration(b, val, struct_member_matrix_stride_cb, &ctx);
 
       const char *name = val->name ? val->name : "struct";
 
diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c
index 1854707..480bc36 100644
--- a/src/compiler/spirv/vtn_variables.c
+++ b/src/compiler/spirv/vtn_variables.c
@@ -530,114 +530,77 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
        * a vector, a scalar, or a matrix.
        */
       if (glsl_type_is_matrix(type->type)) {
-         if (chain == NULL) {
-            /* Loading the whole matrix */
-            struct vtn_ssa_value *transpose;
-            unsigned num_ops, vec_width;
-            if (type->row_major) {
-               num_ops = glsl_get_vector_elements(type->type);
-               vec_width = glsl_get_matrix_columns(type->type);
-               if (load) {
-                  const struct glsl_type *transpose_type =
-                     glsl_matrix_type(base_type, vec_width, num_ops);
-                  *inout = vtn_create_ssa_value(b, transpose_type);
-               } else {
-                  transpose = vtn_ssa_transpose(b, *inout);
-                  inout = &transpose;
-               }
+         /* Loading the whole matrix */
+         struct vtn_ssa_value *transpose;
+         unsigned num_ops, vec_width, col_stride;
+         if (type->row_major) {
+            num_ops = glsl_get_vector_elements(type->type);
+            vec_width = glsl_get_matrix_columns(type->type);
+            col_stride = type->array_element->stride;
+            if (load) {
+               const struct glsl_type *transpose_type =
+                  glsl_matrix_type(base_type, vec_width, num_ops);
+               *inout = vtn_create_ssa_value(b, transpose_type);
             } else {
-               num_ops = glsl_get_matrix_columns(type->type);
-               vec_width = glsl_get_vector_elements(type->type);
+               transpose = vtn_ssa_transpose(b, *inout);
+               inout = &transpose;
             }
+         } else {
+            num_ops = glsl_get_matrix_columns(type->type);
+            vec_width = glsl_get_vector_elements(type->type);
+            col_stride = type->stride;
+         }
 
-            for (unsigned i = 0; i < num_ops; i++) {
+         for (unsigned i = 0; i < num_ops; i++) {
+            nir_ssa_def *elem_offset =
+               nir_iadd(&b->nb, offset, nir_imm_int(&b->nb, i * col_stride));
+            _vtn_load_store_tail(b, op, load, index, elem_offset,
+                                 access_offset, access_size,
+                                 &(*inout)->elems[i],
+                                 glsl_vector_type(base_type, vec_width));
+         }
+
+         if (load && type->row_major)
+            *inout = vtn_ssa_transpose(b, *inout);
+      } else {
+         unsigned elems = glsl_get_vector_elements(type->type);
+         unsigned type_size = glsl_get_bit_size(type->type) / 8;
+         if (elems == 1 || type->stride == type_size) {
+            /* This is a tightly-packed normal scalar or vector load */
+            assert(glsl_type_is_vector_or_scalar(type->type));
+            _vtn_load_store_tail(b, op, load, index, offset,
+                                 access_offset, access_size,
+                                 inout, type->type);
+         } else {
+            /* This is a strided load.  We have to load N things separately.
+             * This is the single column of a row-major matrix case.
+             */
+            assert(type->stride > type_size);
+            assert(type->stride % type_size == 0);
+
+            nir_ssa_def *per_comp[4];
+            for (unsigned i = 0; i < elems; i++) {
                nir_ssa_def *elem_offset =
                   nir_iadd(&b->nb, offset,
-                           nir_imm_int(&b->nb, i * type->stride));
+                                   nir_imm_int(&b->nb, i * type->stride));
+               struct vtn_ssa_value *comp, temp_val;
+               if (!load) {
+                  temp_val.def = nir_channel(&b->nb, (*inout)->def, i);
+                  temp_val.type = glsl_scalar_type(base_type);
+               }
+               comp = &temp_val;
                _vtn_load_store_tail(b, op, load, index, elem_offset,
                                     access_offset, access_size,
-                                    &(*inout)->elems[i],
-                                    glsl_vector_type(base_type, vec_width));
+                                    &comp, glsl_scalar_type(base_type));
+               per_comp[i] = comp->def;
             }
 
-            if (load && type->row_major)
-               *inout = vtn_ssa_transpose(b, *inout);
-         } else if (type->row_major) {
-            /* Row-major but with an access chiain. */
-            nir_ssa_def *col_offset =
-               vtn_access_link_as_ssa(b, chain->link[chain_idx],
-                                      type->array_element->stride);
-            offset = nir_iadd(&b->nb, offset, col_offset);
-
-            if (chain_idx + 1 < chain->length) {
-               /* Picking off a single element */
-               nir_ssa_def *row_offset =
-                  vtn_access_link_as_ssa(b, chain->link[chain_idx + 1],
-                                         type->stride);
-               offset = nir_iadd(&b->nb, offset, row_offset);
-               if (load)
-                  *inout = vtn_create_ssa_value(b, glsl_scalar_type(base_type));
-               _vtn_load_store_tail(b, op, load, index, offset,
-                                    access_offset, access_size,
-                                    inout, glsl_scalar_type(base_type));
-            } else {
-               /* Grabbing a column; picking one element off each row */
-               unsigned num_comps = glsl_get_vector_elements(type->type);
-               const struct glsl_type *column_type =
-                  glsl_get_column_type(type->type);
-
-               nir_ssa_def *comps[4];
-               for (unsigned i = 0; i < num_comps; i++) {
-                  nir_ssa_def *elem_offset =
-                     nir_iadd(&b->nb, offset,
-                              nir_imm_int(&b->nb, i * type->stride));
-
-                  struct vtn_ssa_value *comp, temp_val;
-                  if (!load) {
-                     temp_val.def = nir_channel(&b->nb, (*inout)->def, i);
-                     temp_val.type = glsl_scalar_type(base_type);
-                  }
-                  comp = &temp_val;
-                  _vtn_load_store_tail(b, op, load, index, elem_offset,
-                                       access_offset, access_size,
-                                       &comp, glsl_scalar_type(base_type));
-                  comps[i] = comp->def;
-               }
-
-               if (load) {
-                  if (*inout == NULL)
-                     *inout = vtn_create_ssa_value(b, column_type);
-
-                  (*inout)->def = nir_vec(&b->nb, comps, num_comps);
-               }
+            if (load) {
+               if (*inout == NULL)
+                  *inout = vtn_create_ssa_value(b, type->type);
+               (*inout)->def = nir_vec(&b->nb, per_comp, elems);
             }
-         } else {
-            /* Column-major with a deref. Fall through to array case. */
-            nir_ssa_def *col_offset =
-               vtn_access_link_as_ssa(b, chain->link[chain_idx], type->stride);
-            offset = nir_iadd(&b->nb, offset, col_offset);
-
-            _vtn_block_load_store(b, op, load, index, offset,
-                                  access_offset, access_size,
-                                  chain, chain_idx + 1,
-                                  type->array_element, inout);
          }
-      } else if (chain == NULL) {
-         /* Single whole vector */
-         assert(glsl_type_is_vector_or_scalar(type->type));
-         _vtn_load_store_tail(b, op, load, index, offset,
-                              access_offset, access_size,
-                              inout, type->type);
-      } else {
-         /* Single component of a vector. Fall through to array case. */
-         nir_ssa_def *elem_offset =
-            vtn_access_link_as_ssa(b, chain->link[chain_idx], type->stride);
-         offset = nir_iadd(&b->nb, offset, elem_offset);
-
-         _vtn_block_load_store(b, op, load, index, offset,
-                               access_offset, access_size,
-                               NULL, 0,
-                               type->array_element, inout);
       }
       return;
 
@@ -696,7 +659,7 @@ vtn_block_load(struct vtn_builder *b, struct vtn_pointer *src)
    nir_ssa_def *offset, *index = NULL;
    unsigned chain_idx;
    struct vtn_type *type;
-   offset = vtn_pointer_to_offset(b, src, &index, &type, &chain_idx, true);
+   offset = vtn_pointer_to_offset(b, src, &index, &type, &chain_idx, false);
 
    struct vtn_ssa_value *value = NULL;
    _vtn_block_load_store(b, op, true, index, offset,
@@ -712,7 +675,7 @@ vtn_block_store(struct vtn_builder *b, struct vtn_ssa_value *src,
    nir_ssa_def *offset, *index = NULL;
    unsigned chain_idx;
    struct vtn_type *type;
-   offset = vtn_pointer_to_offset(b, dst, &index, &type, &chain_idx, true);
+   offset = vtn_pointer_to_offset(b, dst, &index, &type, &chain_idx, false);
 
    _vtn_block_load_store(b, nir_intrinsic_store_ssbo, false, index, offset,
                          0, 0, dst->chain, chain_idx, type, &src);
-- 
2.5.0.400.gff86faf



More information about the mesa-dev mailing list