Mesa (main): zink: implement indirect buffer indexing

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Tue May 10 06:06:43 UTC 2022


Module: Mesa
Branch: main
Commit: a7327c7cac9b49ac1c7679f91c4727da0d60f501
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=a7327c7cac9b49ac1c7679f91c4727da0d60f501

Author: Mike Blumenkrantz <michael.blumenkrantz at gmail.com>
Date:   Tue Apr 12 17:26:53 2022 -0400

zink: implement indirect buffer indexing

this compacts all buffers in the shader into an array that can be
used in a single descriptor, thus handling the case of indirect indexing
while also turning constant indexing into indirect (with const offsets)
since there's no sane way to distinguish

a "proper" implementation of this would be to skip gl_nir_lower_buffers
and nir_lower_explicit_io altogether and retain the derefs, but that would
require a ton of legwork of other nir passes which only operate on the
explicit io intrinsics

Reviewed-by: Dave Airlie <airlied at redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15906>

---

 .../drivers/zink/nir_to_spirv/nir_to_spirv.c       | 123 ++++++----
 src/gallium/drivers/zink/zink_compiler.c           | 271 +++++++++++++--------
 src/gallium/drivers/zink/zink_descriptors_lazy.c   |  37 +--
 3 files changed, 252 insertions(+), 179 deletions(-)

diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
index d975eff4857..3dd102ffbf4 100644
--- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
+++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
@@ -46,18 +46,19 @@ struct ntv_context {
    nir_shader *nir;
 
    struct hash_table *glsl_types;
-   struct hash_table *bo_types;
+   struct hash_table *bo_struct_types;
+   struct hash_table *bo_array_types;
 
    SpvId GLSL_std_450;
 
    gl_shader_stage stage;
    const struct zink_shader_info *sinfo;
 
-   SpvId ubos[PIPE_MAX_CONSTANT_BUFFERS][5]; //8, 16, 32, unused, 64
-   nir_variable *ubo_vars[PIPE_MAX_CONSTANT_BUFFERS];
+   SpvId ubos[2][5]; //8, 16, 32, unused, 64
+   nir_variable *ubo_vars[2];
 
-   SpvId ssbos[PIPE_MAX_SHADER_BUFFERS][5]; //8, 16, 32, unused, 64
-   nir_variable *ssbo_vars[PIPE_MAX_SHADER_BUFFERS];
+   SpvId ssbos[5]; //8, 16, 32, unused, 64
+   nir_variable *ssbo_vars;
    SpvId image_types[PIPE_MAX_SAMPLERS];
    SpvId images[PIPE_MAX_SAMPLERS];
    SpvId sampler_types[PIPE_MAX_SAMPLERS];
@@ -961,23 +962,22 @@ get_sized_uint_array_type(struct ntv_context *ctx, unsigned array_size, unsigned
    return array_type;
 }
 
+/* get array<struct(array_type <--this one)> */
 static SpvId
 get_bo_array_type(struct ntv_context *ctx, struct nir_variable *var)
 {
-   struct hash_entry *he = _mesa_hash_table_search(ctx->bo_types, var);
+   struct hash_entry *he = _mesa_hash_table_search(ctx->bo_array_types, var);
    if (he)
       return (SpvId)(uintptr_t)he->data;
-   unsigned bitsize = glsl_get_bit_size(glsl_get_array_element(glsl_get_struct_field(var->type, 0)));
+   unsigned bitsize = glsl_get_bit_size(glsl_get_array_element(glsl_get_struct_field(glsl_without_array(var->type), 0)));
    assert(bitsize);
    SpvId array_type;
-   const struct glsl_type *type = var->type;
-   if (!glsl_type_is_unsized_array(type)) {
-      type = glsl_get_struct_field(var->interface_type, 0);
-      if (!glsl_type_is_unsized_array(type)) {
-         uint32_t array_size = glsl_get_length(type) * (bitsize / 4);
-         assert(array_size);
-         return get_sized_uint_array_type(ctx, array_size, bitsize);
-      }
+   const struct glsl_type *type = glsl_without_array(var->type);
+   const struct glsl_type *first_type = glsl_get_struct_field(type, 0);
+   if (!glsl_type_is_unsized_array(first_type)) {
+      uint32_t array_size = glsl_get_length(first_type);
+      assert(array_size);
+      return get_sized_uint_array_type(ctx, array_size, bitsize);
    }
    SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bitsize);
    array_type = spirv_builder_type_runtime_array(&ctx->builder, uint_type);
@@ -985,18 +985,23 @@ get_bo_array_type(struct ntv_context *ctx, struct nir_variable *var)
    return array_type;
 }
 
+/* get array<struct(array_type) <--this one> */
 static SpvId
 get_bo_struct_type(struct ntv_context *ctx, struct nir_variable *var)
 {
-   unsigned bitsize = glsl_get_bit_size(glsl_get_array_element(glsl_get_struct_field(var->type, 0)));
+   struct hash_entry *he = _mesa_hash_table_search(ctx->bo_struct_types, var);
+   if (he)
+      return (SpvId)(uintptr_t)he->data;
+   const struct glsl_type *bare_type = glsl_without_array(var->type);
+   unsigned bitsize = glsl_get_bit_size(glsl_get_array_element(glsl_get_struct_field(bare_type, 0)));
    SpvId array_type = get_bo_array_type(ctx, var);
-   _mesa_hash_table_insert(ctx->bo_types, var, (void *)(uintptr_t)array_type);
+   _mesa_hash_table_insert(ctx->bo_array_types, var, (void *)(uintptr_t)array_type);
    bool ssbo = var->data.mode == nir_var_mem_ssbo;
 
    // wrap UBO-array in a struct
    SpvId runtime_array = 0;
-   if (ssbo && glsl_get_length(var->interface_type) > 1) {
-       const struct glsl_type *last_member = glsl_get_struct_field(var->interface_type, glsl_get_length(var->interface_type) - 1);
+   if (ssbo && glsl_get_length(bare_type) > 1) {
+       const struct glsl_type *last_member = glsl_get_struct_field(bare_type, glsl_get_length(bare_type) - 1);
        if (glsl_type_is_unsized_array(last_member)) {
           bool is_64bit = glsl_type_is_64bit(glsl_without_array(last_member));
           runtime_array = spirv_builder_type_runtime_array(&ctx->builder, get_uvec_type(ctx, is_64bit ? 64 : bitsize, 1));
@@ -1014,35 +1019,35 @@ get_bo_struct_type(struct ntv_context *ctx, struct nir_variable *var)
    spirv_builder_emit_decoration(&ctx->builder, struct_type,
                                  SpvDecorationBlock);
    spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
-   if (runtime_array) {
-      spirv_builder_emit_member_offset(&ctx->builder, struct_type, 1,
-                                      glsl_get_struct_field_offset(var->interface_type,
-                                                                   glsl_get_length(var->interface_type) - 1));
-   }
+   if (runtime_array)
+      spirv_builder_emit_member_offset(&ctx->builder, struct_type, 1, 0);
 
-   return spirv_builder_type_pointer(&ctx->builder,
-                                                   ssbo ? SpvStorageClassStorageBuffer : SpvStorageClassUniform,
-                                                   struct_type);
+   return struct_type;
 }
 
 static void
 emit_bo(struct ntv_context *ctx, struct nir_variable *var)
 {
-   unsigned bitsize = glsl_get_bit_size(glsl_get_array_element(glsl_get_struct_field(var->type, 0)));
+   unsigned bitsize = glsl_get_bit_size(glsl_get_array_element(glsl_get_struct_field(glsl_without_array(var->type), 0)));
    bool ssbo = var->data.mode == nir_var_mem_ssbo;
-   SpvId pointer_type = get_bo_struct_type(ctx, var);
-
+   SpvId struct_type = get_bo_struct_type(ctx, var);
+   _mesa_hash_table_insert(ctx->bo_struct_types, var, (void *)(uintptr_t)struct_type);
+   SpvId array_length = emit_uint_const(ctx, 32, glsl_get_length(var->type));
+   SpvId array_type = spirv_builder_type_array(&ctx->builder, struct_type, array_length);
+   SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
+                                                   ssbo ? SpvStorageClassStorageBuffer : SpvStorageClassUniform,
+                                                   array_type);
    SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
                                          ssbo ? SpvStorageClassStorageBuffer : SpvStorageClassUniform);
    if (var->name)
       spirv_builder_emit_name(&ctx->builder, var_id, var->name);
 
    unsigned idx = bitsize >> 4;
-   assert(idx < ARRAY_SIZE(ctx->ssbos[0]));
+   assert(idx < ARRAY_SIZE(ctx->ssbos));
    if (ssbo) {
-      assert(!ctx->ssbos[var->data.driver_location][idx]);
-      ctx->ssbos[var->data.driver_location][idx] = var_id;
-      ctx->ssbo_vars[var->data.driver_location] = var;
+      assert(!ctx->ssbos[idx]);
+      ctx->ssbos[idx] = var_id;
+      ctx->ssbo_vars = var;
    } else {
       assert(!ctx->ubos[var->data.driver_location][idx]);
       ctx->ubos[var->data.driver_location][idx] = var_id;
@@ -2494,13 +2499,10 @@ emit_ssbo_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    SpvId param;
    SpvId dest_type = get_dest_type(ctx, &intr->dest, nir_type_uint32);
 
-   nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
-   assert(const_block_index); // no dynamic indexing for now
    unsigned bit_size = MIN2(nir_src_bit_size(intr->src[0]), 32);
    unsigned idx = bit_size >> 4;
-   assert(idx < ARRAY_SIZE(ctx->ssbos[0]));
-   assert(ctx->ssbos[const_block_index->u32][idx]);
-   ssbo = ctx->ssbos[const_block_index->u32][idx];
+   assert(idx < ARRAY_SIZE(ctx->ssbos));
+   ssbo = ctx->ssbos[idx];
    param = get_src(ctx, &intr->src[2]);
 
    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
@@ -2509,10 +2511,11 @@ emit_ssbo_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    SpvId uint_type = get_uvec_type(ctx, 32, 1);
    /* an id of the array stride in bytes */
    SpvId uint_size = emit_uint_const(ctx, 32, bit_size / 8);
+   SpvId bo = get_src(ctx, &intr->src[0]);
    SpvId member = emit_uint_const(ctx, 32, 0);
    SpvId offset = get_src(ctx, &intr->src[1]);
    SpvId vec_offset = emit_binop(ctx, SpvOpUDiv, uint_type, offset, uint_size);
-   SpvId indices[] = { member, vec_offset };
+   SpvId indices[] = { bo, member, vec_offset };
    SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
                                                ssbo, indices,
                                                ARRAY_SIZE(indices));
@@ -2550,25 +2553,32 @@ static void
 emit_get_ssbo_size(struct ntv_context *ctx, nir_intrinsic_instr *intr)
 {
    SpvId uint_type = get_uvec_type(ctx, 32, 1);
-   nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
-   assert(const_block_index); // no dynamic indexing for now
-   nir_variable *var = ctx->ssbo_vars[const_block_index->u32];
-   unsigned last_member_idx = glsl_get_length(var->interface_type) - 1;
+   nir_variable *var = ctx->ssbo_vars;
+   const struct glsl_type *bare_type = glsl_without_array(var->type);
+   unsigned last_member_idx = glsl_get_length(bare_type) - 1;
+   SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
+                                                   SpvStorageClassStorageBuffer,
+                                                   get_bo_struct_type(ctx, var));
+   SpvId bo = get_src(ctx, &intr->src[0]);
+   SpvId indices[] = { bo };
+   SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
+                                               ctx->ssbos[2], indices,
+                                               ARRAY_SIZE(indices));
    SpvId result = spirv_builder_emit_binop(&ctx->builder, SpvOpArrayLength, uint_type,
-                                             ctx->ssbos[const_block_index->u32][2], last_member_idx);
+                                           ptr, last_member_idx);
    /* this is going to be converted by nir to:
 
       length = (buffer_size - offset) / stride
 
       * so we need to un-convert it to avoid having the calculation performed twice
       */
-   const struct glsl_type *last_member = glsl_get_struct_field(var->interface_type, last_member_idx);
+   const struct glsl_type *last_member = glsl_get_struct_field(bare_type, last_member_idx);
    /* multiply by stride */
    result = emit_binop(ctx, SpvOpIMul, uint_type, result, emit_uint_const(ctx, 32, glsl_get_explicit_stride(last_member)));
    /* get total ssbo size by adding offset */
    result = emit_binop(ctx, SpvOpIAdd, uint_type, result,
                         emit_uint_const(ctx, 32,
-                                       glsl_get_struct_field_offset(var->interface_type, last_member_idx)));
+                                       glsl_get_struct_field_offset(bare_type, last_member_idx)));
    store_dest(ctx, &intr->dest, result, nir_type_uint);
 }
 
@@ -3535,11 +3545,21 @@ emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
    SpvStorageClass storage_class = get_storage_class(var);
    SpvId base, type;
    switch (var->data.mode) {
+
+   case nir_var_mem_ubo:
+   case nir_var_mem_ssbo:
+      base = get_src(ctx, &deref->parent);
+      /* this is either the array<buffers> deref or the array<uint> deref */
+      if (glsl_type_is_struct_or_ifc(deref->type)) {
+         /* array<buffers> */
+         type = get_bo_struct_type(ctx, var);
+         break;
+      }
+      /* array<uint> */
+      FALLTHROUGH;
    case nir_var_function_temp:
    case nir_var_shader_in:
    case nir_var_shader_out:
-   case nir_var_mem_ubo:
-   case nir_var_mem_ssbo:
       base = get_src(ctx, &deref->parent);
       type = get_glsl_type(ctx, deref->type);
       break;
@@ -3898,8 +3918,9 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_
    ctx.spirv_1_4_interfaces = spirv_version >= SPIRV_VERSION(1, 4);
 
    ctx.glsl_types = _mesa_pointer_hash_table_create(ctx.mem_ctx);
-   ctx.bo_types = _mesa_pointer_hash_table_create(ctx.mem_ctx);
-   if (!ctx.glsl_types || !ctx.bo_types)
+   ctx.bo_array_types = _mesa_pointer_hash_table_create(ctx.mem_ctx);
+   ctx.bo_struct_types = _mesa_pointer_hash_table_create(ctx.mem_ctx);
+   if (!ctx.glsl_types || !ctx.bo_array_types || !ctx.bo_struct_types)
       goto fail;
 
    spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c
index 56b0fc56141..cd850fc6010 100644
--- a/src/gallium/drivers/zink/zink_compiler.c
+++ b/src/gallium/drivers/zink/zink_compiler.c
@@ -878,6 +878,7 @@ rewrite_bo_access_instr(nir_builder *b, nir_instr *instr, void *data)
    case nir_intrinsic_load_ubo: {
       /* ubo0 can have unaligned 64bit loads, particularly for bindless texture ids */
       bool force_2x32 = intr->intrinsic == nir_intrinsic_load_ubo &&
+                        nir_src_is_const(intr->src[0]) &&
                         nir_src_as_uint(intr->src[0]) == 0 &&
                         nir_dest_bit_size(intr->dest) == 64 &&
                         nir_intrinsic_align_offset(intr) % 8 != 0;
@@ -958,36 +959,52 @@ rewrite_bo_access(nir_shader *shader, struct zink_screen *screen)
 }
 
 struct bo_vars {
-   nir_variable *ubo[PIPE_MAX_CONSTANT_BUFFERS][5];
-   nir_variable *ssbo[PIPE_MAX_CONSTANT_BUFFERS][5];
+   nir_variable *ubo[2][5];
+   nir_variable *ssbo[5];
+   uint32_t first_ubo;
+   uint32_t first_ssbo;
 };
 
 static nir_variable *
-get_bo_var(nir_shader *shader, struct bo_vars *bo, bool ssbo, unsigned idx, unsigned bit_size)
+get_bo_var(nir_shader *shader, struct bo_vars *bo, bool ssbo, nir_src *src, unsigned bit_size)
 {
-   nir_variable *var;
-   nir_variable **arr = (nir_variable**)(ssbo ? bo->ssbo : bo->ubo);
+   nir_variable *var, **ptr;
+   nir_variable **arr = (nir_variable**)bo->ubo;
+   unsigned idx = ssbo || (nir_src_is_const(*src) && !nir_src_as_uint(*src)) ? 0 : 1;
 
-   var = arr[idx * 5 + (bit_size >> 4)];
+   if (ssbo)
+      ptr = &bo->ssbo[bit_size >> 4];
+   else
+      ptr = &arr[idx * 5 + (bit_size >> 4)];
+   var = *ptr;
    if (!var) {
-      arr[idx * 5 + (bit_size >> 4)] = var = nir_variable_clone(arr[idx * 5 + (32 >> 4)], shader);
+      if (ssbo)
+         var = bo->ssbo[32 >> 4];
+      else
+         var = arr[idx * 5 + (32 >> 4)];
+      var = nir_variable_clone(var, shader);
+      *ptr = var;
       nir_shader_add_variable(shader, var);
 
       struct glsl_struct_field *fields = rzalloc_array(shader, struct glsl_struct_field, 2);
       fields[0].name = ralloc_strdup(shader, "base");
       fields[1].name = ralloc_strdup(shader, "unsized");
-      const struct glsl_type *array_type = glsl_get_struct_field(var->type, 0);
+      unsigned array_size = glsl_get_length(var->type);
+      const struct glsl_type *bare_type = glsl_without_array(var->type);
+      const struct glsl_type *array_type = glsl_get_struct_field(bare_type, 0);
+      unsigned length = glsl_get_length(array_type);
       const struct glsl_type *type;
       const struct glsl_type *unsized = glsl_array_type(glsl_uintN_t_type(bit_size), 0, bit_size / 8);
       if (bit_size > 32) {
          assert(bit_size == 64);
-         type = glsl_array_type(glsl_uintN_t_type(bit_size), glsl_get_length(array_type) / 2, bit_size / 8);
+         type = glsl_array_type(glsl_uintN_t_type(bit_size), length / 2, bit_size / 8);
       } else {
-         type = glsl_array_type(glsl_uintN_t_type(bit_size), glsl_get_length(array_type) * (32 / bit_size), bit_size / 8);
+         type = glsl_array_type(glsl_uintN_t_type(bit_size), length * (32 / bit_size), bit_size / 8);
       }
       fields[0].type = type;
       fields[1].type = unsized;
-      var->type = glsl_struct_type(fields, glsl_get_length(var->type), "struct", false);
+      var->type = glsl_array_type(glsl_struct_type(fields, glsl_get_length(bare_type), "struct", false), array_size, 0);
+      var->data.driver_location = idx;
    }
    return var;
 }
@@ -1003,31 +1020,57 @@ remove_bo_access_instr(nir_builder *b, nir_instr *instr, void *data)
    nir_ssa_def *offset = NULL;
    bool is_load = true;
    b->cursor = nir_before_instr(instr);
+   nir_src *src;
+   bool ssbo = true;
    switch (intr->intrinsic) {
+   /* TODO: these should all be rewritten to use deref intrinsics */
+   case nir_intrinsic_ssbo_atomic_add:
+   case nir_intrinsic_ssbo_atomic_umin:
+   case nir_intrinsic_ssbo_atomic_imin:
+   case nir_intrinsic_ssbo_atomic_umax:
+   case nir_intrinsic_ssbo_atomic_imax:
+   case nir_intrinsic_ssbo_atomic_and:
+   case nir_intrinsic_ssbo_atomic_or:
+   case nir_intrinsic_ssbo_atomic_xor:
+   case nir_intrinsic_ssbo_atomic_exchange:
+   case nir_intrinsic_ssbo_atomic_comp_swap:
+      nir_instr_rewrite_src_ssa(instr, &intr->src[0], nir_iadd_imm(b, intr->src[0].ssa, -bo->first_ssbo));
+      return true;
    case nir_intrinsic_store_ssbo:
-      var = get_bo_var(b->shader, bo, true, nir_src_as_uint(intr->src[1]), nir_src_bit_size(intr->src[0]));
+      src = &intr->src[1];
+      var = get_bo_var(b->shader, bo, true, src, nir_src_bit_size(intr->src[0]));
       offset = intr->src[2].ssa;
       is_load = false;
       break;
    case nir_intrinsic_load_ssbo:
-      var = get_bo_var(b->shader, bo, true, nir_src_as_uint(intr->src[0]), nir_dest_bit_size(intr->dest));
+      src = &intr->src[0];
+      var = get_bo_var(b->shader, bo, true, src, nir_dest_bit_size(intr->dest));
       offset = intr->src[1].ssa;
       break;
    case nir_intrinsic_load_ubo:
-      var = get_bo_var(b->shader, bo, false, nir_src_as_uint(intr->src[0]), nir_dest_bit_size(intr->dest));
+      src = &intr->src[0];
+      var = get_bo_var(b->shader, bo, false, src, nir_dest_bit_size(intr->dest));
       offset = intr->src[1].ssa;
+      ssbo = false;
       break;
    default:
       return false;
    }
    assert(var);
    assert(offset);
-   nir_deref_instr *deref_var = nir_build_deref_struct(b, nir_build_deref_var(b, var), 0);
+   nir_deref_instr *deref_var = nir_build_deref_var(b, var);
+   nir_ssa_def *idx = !ssbo && var->data.driver_location ? nir_iadd_imm(b, src->ssa, -1) : src->ssa;
+   if (!ssbo && bo->first_ubo && var->data.driver_location)
+      idx = nir_iadd_imm(b, idx, -bo->first_ubo);
+   else if (ssbo && bo->first_ssbo)
+      idx = nir_iadd_imm(b, idx, -bo->first_ssbo);
+   nir_deref_instr *deref_array = nir_build_deref_array(b, deref_var, idx);
+   nir_deref_instr *deref_struct = nir_build_deref_struct(b, deref_array, 0);
    assert(intr->num_components <= 2);
    if (is_load) {
       nir_ssa_def *result[2];
       for (unsigned i = 0; i < intr->num_components; i++) {
-         nir_deref_instr *deref_arr = nir_build_deref_array(b, deref_var, offset);
+         nir_deref_instr *deref_arr = nir_build_deref_array(b, deref_struct, offset);
          result[i] = nir_load_deref(b, deref_arr);
          if (intr->intrinsic == nir_intrinsic_load_ssbo)
             nir_intrinsic_set_access(nir_instr_as_intrinsic(result[i]->parent_instr), nir_intrinsic_access(intr));
@@ -1036,7 +1079,7 @@ remove_bo_access_instr(nir_builder *b, nir_instr *instr, void *data)
       nir_ssa_def *load = nir_vec(b, result, intr->num_components);
       nir_ssa_def_rewrite_uses(&intr->dest.ssa, load);
    } else {
-      nir_deref_instr *deref_arr = nir_build_deref_array(b, deref_var, offset);
+      nir_deref_instr *deref_arr = nir_build_deref_array(b, deref_struct, offset);
       nir_build_store_deref(b, &deref_arr->dest.ssa, intr->src[0].ssa, BITFIELD_MASK(intr->num_components), nir_intrinsic_access(intr));
    }
    nir_instr_remove(instr);
@@ -1044,17 +1087,23 @@ remove_bo_access_instr(nir_builder *b, nir_instr *instr, void *data)
 }
 
 static bool
-remove_bo_access(nir_shader *shader)
+remove_bo_access(nir_shader *shader, struct zink_shader *zs)
 {
    struct bo_vars bo;
    memset(&bo, 0, sizeof(bo));
+   if (zs->ubos_used)
+      bo.first_ubo = ffs(zs->ubos_used & ~BITFIELD_BIT(0)) - 2;
+   assert(bo.first_ssbo < PIPE_MAX_CONSTANT_BUFFERS);
+   if (zs->ssbos_used)
+      bo.first_ssbo = ffs(zs->ssbos_used) - 1;
+   assert(bo.first_ssbo < PIPE_MAX_SHADER_BUFFERS);
    nir_foreach_variable_with_modes(var, shader, nir_var_mem_ssbo | nir_var_mem_ubo) {
       if (var->data.mode == nir_var_mem_ssbo) {
-         assert(!bo.ssbo[var->data.driver_location][32 >> 4]);
-         bo.ssbo[var->data.driver_location][32 >> 4] = var;
+         assert(!bo.ssbo[32 >> 4]);
+         bo.ssbo[32 >> 4] = var;
       } else {
-         assert(!bo.ubo[var->data.driver_location][32 >> 4]);
-         bo.ubo[var->data.driver_location][32 >> 4] = var;
+         assert(!bo.ubo[!!var->data.driver_location][32 >> 4]);
+         bo.ubo[!!var->data.driver_location][32 >> 4] = var;
       }
    }
    return nir_shader_instructions_pass(shader, remove_bo_access_instr, nir_metadata_dominance, &bo);
@@ -1385,7 +1434,7 @@ zink_shader_compile(struct zink_screen *screen, struct zink_shader *zs, nir_shad
    if (screen->driconf.inline_uniforms) {
       NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared);
       NIR_PASS_V(nir, rewrite_bo_access, screen);
-      NIR_PASS_V(nir, remove_bo_access);
+      NIR_PASS_V(nir, remove_bo_access, zs);
    }
    if (inlined_uniforms) {
       optimize_nir(nir);
@@ -1441,32 +1490,108 @@ bool nir_lower_dynamic_bo_access(nir_shader *shader);
  * so instead we delete all those broken variables and just make new ones
  */
 static bool
-unbreak_bos(nir_shader *shader)
+unbreak_bos(nir_shader *shader, struct zink_shader *zs, bool needs_size)
 {
-   uint32_t ssbo_used = 0;
-   uint32_t ubo_used = 0;
    uint64_t max_ssbo_size = 0;
    uint64_t max_ubo_size = 0;
-   bool ssbo_sizes[PIPE_MAX_SHADER_BUFFERS] = {false};
 
-   if (!shader->info.num_ssbos && !shader->info.num_ubos && !shader->num_uniforms)
+   if (!shader->info.num_ssbos && !shader->info.num_ubos)
       return false;
+
+   nir_foreach_variable_with_modes(var, shader, nir_var_mem_ssbo | nir_var_mem_ubo) {
+      const struct glsl_type *type = glsl_without_array(var->type);
+      if (type_is_counter(type))
+         continue;
+      unsigned size = glsl_count_attribute_slots(glsl_type_is_array(var->type) ? var->type : type, false);
+      if (var->data.mode == nir_var_mem_ubo)
+         max_ubo_size = MAX2(max_ubo_size, size);
+      else
+         max_ssbo_size = MAX2(max_ssbo_size, size);
+      var->data.mode = nir_var_shader_temp;
+   }
+   nir_fixup_deref_modes(shader);
+   NIR_PASS_V(shader, nir_remove_dead_variables, nir_var_shader_temp, NULL);
+   optimize_nir(shader);
+
+   struct glsl_struct_field *fields = rzalloc_array(shader, struct glsl_struct_field, 2);
+   fields[0].name = ralloc_strdup(shader, "base");
+   fields[1].name = ralloc_strdup(shader, "unsized");
+   if (shader->info.num_ubos) {
+      const struct glsl_type *ubo_type = glsl_array_type(glsl_uint_type(), max_ubo_size * 4, 4);
+      fields[0].type = ubo_type;
+      if (shader->num_uniforms && zs->ubos_used & BITFIELD_BIT(0)) {
+         nir_variable *var = nir_variable_create(shader, nir_var_mem_ubo,
+                                                 glsl_array_type(glsl_struct_type(fields, 1, "struct", false), 1, 0),
+                                                 "uniform_0");
+         var->interface_type = var->type;
+         var->data.mode = nir_var_mem_ubo;
+         var->data.driver_location = 0;
+      }
+
+      unsigned num_ubos = shader->info.num_ubos - !!shader->info.first_ubo_is_default_ubo;
+      uint32_t ubos_used = zs->ubos_used & ~BITFIELD_BIT(0);
+      if (num_ubos && ubos_used) {
+         /* shrink array as much as possible */
+         unsigned first_ubo = ffs(ubos_used) - 2;
+         assert(first_ubo < PIPE_MAX_CONSTANT_BUFFERS);
+         num_ubos -= first_ubo;
+         assert(num_ubos);
+         nir_variable *var = nir_variable_create(shader, nir_var_mem_ubo,
+                                   glsl_array_type(glsl_struct_type(fields, 1, "struct", false), num_ubos, 0),
+                                   "ubos");
+         var->interface_type = var->type;
+         var->data.mode = nir_var_mem_ubo;
+         var->data.driver_location = first_ubo + !!shader->info.first_ubo_is_default_ubo;
+      }
+   }
+   if (shader->info.num_ssbos && zs->ssbos_used) {
+      /* shrink array as much as possible */
+      unsigned first_ssbo = ffs(zs->ssbos_used) - 1;
+      assert(first_ssbo < PIPE_MAX_SHADER_BUFFERS);
+      unsigned num_ssbos = shader->info.num_ssbos - first_ssbo;
+      assert(num_ssbos);
+      const struct glsl_type *ssbo_type = glsl_array_type(glsl_uint_type(), max_ssbo_size * 4, 4);
+      const struct glsl_type *unsized = glsl_array_type(glsl_uint_type(), 0, 4);
+      fields[0].type = ssbo_type;
+      fields[1].type = max_ssbo_size ? unsized : NULL;
+      unsigned field_count = max_ssbo_size && needs_size ? 2 : 1;
+      nir_variable *var = nir_variable_create(shader, nir_var_mem_ssbo,
+                                              glsl_array_type(glsl_struct_type(fields, field_count, "struct", false), num_ssbos, 0),
+                                              "ssbos");
+      var->interface_type = var->type;
+      var->data.mode = nir_var_mem_ssbo;
+      var->data.driver_location = first_ssbo;
+   }
+   return true;
+}
+
+static uint32_t
+get_src_mask(unsigned total, nir_src src)
+{
+   if (nir_src_is_const(src))
+      return BITFIELD_BIT(nir_src_as_uint(src));
+   return BITFIELD_MASK(total);
+}
+
+static bool
+analyze_io(struct zink_shader *zs, nir_shader *shader)
+{
+   bool ret = false;
    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
    nir_foreach_block(block, impl) {
       nir_foreach_instr(instr, block) {
          if (instr->type != nir_instr_type_intrinsic)
             continue;
-
+ 
          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
          switch (intrin->intrinsic) {
          case nir_intrinsic_store_ssbo:
-            ssbo_used |= BITFIELD_BIT(nir_src_as_uint(intrin->src[1]));
+            zs->ssbos_used |= get_src_mask(shader->info.num_ssbos, intrin->src[1]);
             break;
-
+ 
          case nir_intrinsic_get_ssbo_size: {
-            uint32_t slot = nir_src_as_uint(intrin->src[0]);
-            ssbo_used |= BITFIELD_BIT(slot);
-            ssbo_sizes[slot] = true;
+            zs->ssbos_used |= get_src_mask(shader->info.num_ssbos, intrin->src[0]);
+            ret = true;
             break;
          }
          case nir_intrinsic_ssbo_atomic_add:
@@ -1483,69 +1608,18 @@ unbreak_bos(nir_shader *shader)
          case nir_intrinsic_ssbo_atomic_fmax:
          case nir_intrinsic_ssbo_atomic_fcomp_swap:
          case nir_intrinsic_load_ssbo:
-            ssbo_used |= BITFIELD_BIT(nir_src_as_uint(intrin->src[0]));
+            zs->ssbos_used |= get_src_mask(shader->info.num_ssbos, intrin->src[0]);
             break;
          case nir_intrinsic_load_ubo:
          case nir_intrinsic_load_ubo_vec4:
-            ubo_used |= BITFIELD_BIT(nir_src_as_uint(intrin->src[0]));
+            zs->ubos_used |= get_src_mask(shader->info.num_ubos, intrin->src[0]);
             break;
          default:
             break;
          }
       }
    }
-
-   nir_foreach_variable_with_modes(var, shader, nir_var_mem_ssbo | nir_var_mem_ubo) {
-      const struct glsl_type *type = glsl_without_array(var->type);
-      if (type_is_counter(type))
-         continue;
-      unsigned size = glsl_count_attribute_slots(glsl_type_is_array(var->type) ? var->type : type, false);
-      if (var->data.mode == nir_var_mem_ubo)
-         max_ubo_size = MAX2(max_ubo_size, size);
-      else
-         max_ssbo_size = MAX2(max_ssbo_size, size);
-      var->data.mode = nir_var_shader_temp;
-   }
-   nir_fixup_deref_modes(shader);
-   NIR_PASS_V(shader, nir_remove_dead_variables, nir_var_shader_temp, NULL);
-   optimize_nir(shader);
-
-   if (!ssbo_used && !ubo_used)
-      return false;
-
-   struct glsl_struct_field *fields = rzalloc_array(shader, struct glsl_struct_field, 2);
-   fields[0].name = ralloc_strdup(shader, "base");
-   fields[1].name = ralloc_strdup(shader, "unsized");
-   if (ubo_used) {
-      const struct glsl_type *ubo_type = glsl_array_type(glsl_uint_type(), max_ubo_size * 4, 4);
-      fields[0].type = ubo_type;
-      u_foreach_bit(slot, ubo_used) {
-         char buf[64];
-         snprintf(buf, sizeof(buf), "ubo_slot_%u", slot);
-         nir_variable *var = nir_variable_create(shader, nir_var_mem_ubo, glsl_struct_type(fields, 1, "struct", false), buf);
-         var->interface_type = var->type;
-         var->data.driver_location = slot;
-      }
-   }
-   if (ssbo_used) {
-      const struct glsl_type *ssbo_type = glsl_array_type(glsl_uint_type(), max_ssbo_size * 4, 4);
-      const struct glsl_type *unsized = glsl_array_type(glsl_uint_type(), 0, 4);
-      fields[0].type = ssbo_type;
-      u_foreach_bit(slot, ssbo_used) {
-         char buf[64];
-         snprintf(buf, sizeof(buf), "ssbo_slot_%u", slot);
-         bool use_runtime = ssbo_sizes[slot] && max_ssbo_size;
-         if (use_runtime)
-            fields[1].type = unsized;
-         else
-            fields[1].type = NULL;
-         nir_variable *var = nir_variable_create(shader, nir_var_mem_ssbo,
-                                                 glsl_struct_type(fields, 1 + use_runtime, "struct", false), buf);
-         var->interface_type = var->type;
-         var->data.driver_location = slot;
-      }
-   }
-   return true;
+   return ret;
 }
 
 /* this is a "default" bindless texture used if the shader has no texture variables */
@@ -1729,8 +1803,7 @@ zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
       switch (type) {
       case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
       case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC:
-         assert(index < PIPE_MAX_CONSTANT_BUFFERS);
-         return (stage * PIPE_MAX_CONSTANT_BUFFERS) + index;
+         return stage * 2 + !!index;
 
       case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
       case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
@@ -1738,8 +1811,7 @@ zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
          return (stage * PIPE_MAX_SAMPLERS) + index;
 
       case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:
-         assert(index < PIPE_MAX_SHADER_BUFFERS);
-         return (stage * PIPE_MAX_SHADER_BUFFERS) + index;
+         return stage;
 
       case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
       case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:
@@ -2063,12 +2135,13 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
    NIR_PASS_V(nir, nir_lower_fragcolor,
          nir->info.fs.color_is_dual_source ? 1 : 8);
    NIR_PASS_V(nir, lower_64bit_vertex_attribs);
-   NIR_PASS_V(nir, unbreak_bos);
+   bool needs_size = analyze_io(ret, nir);
+   NIR_PASS_V(nir, unbreak_bos, ret, needs_size);
    /* run in compile if there could be inlined uniforms */
    if (!screen->driconf.inline_uniforms) {
       NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared);
       NIR_PASS_V(nir, rewrite_bo_access, screen);
-      NIR_PASS_V(nir, remove_bo_access);
+      NIR_PASS_V(nir, remove_bo_access, ret);
    }
 
    if (zink_debug & ZINK_DEBUG_NIR) {
@@ -2115,8 +2188,8 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
             ret->bindings[ztype][ret->num_bindings[ztype]].index = var->data.driver_location;
             ret->bindings[ztype][ret->num_bindings[ztype]].binding = binding;
             ret->bindings[ztype][ret->num_bindings[ztype]].type = vktype;
-            ret->bindings[ztype][ret->num_bindings[ztype]].size = 1;
-            ret->ubos_used |= (1 << ret->bindings[ztype][ret->num_bindings[ztype]].index);
+            ret->bindings[ztype][ret->num_bindings[ztype]].size = glsl_get_length(var->type);
+            assert(ret->bindings[ztype][ret->num_bindings[ztype]].size);
             ret->num_bindings[ztype]++;
          } else if (var->data.mode == nir_var_mem_ssbo) {
             ztype = ZINK_DESCRIPTOR_TYPE_SSBO;
@@ -2125,10 +2198,10 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
                                              VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
                                              var->data.driver_location);
             ret->bindings[ztype][ret->num_bindings[ztype]].index = var->data.driver_location;
-            ret->ssbos_used |= (1 << ret->bindings[ztype][ret->num_bindings[ztype]].index);
             ret->bindings[ztype][ret->num_bindings[ztype]].binding = var->data.binding;
             ret->bindings[ztype][ret->num_bindings[ztype]].type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
-            ret->bindings[ztype][ret->num_bindings[ztype]].size = 1;
+            ret->bindings[ztype][ret->num_bindings[ztype]].size = glsl_get_length(var->type);
+            assert(ret->bindings[ztype][ret->num_bindings[ztype]].size);
             ret->num_bindings[ztype]++;
          } else {
             assert(var->data.mode == nir_var_uniform ||
@@ -2209,8 +2282,6 @@ zink_shader_finalize(struct pipe_screen *pscreen, void *nirptr)
    if (nir->info.stage == MESA_SHADER_GEOMETRY)
       NIR_PASS_V(nir, nir_lower_gs_intrinsics, nir_lower_gs_intrinsics_per_stream);
    optimize_nir(nir);
-   if (nir->info.num_ubos || nir->info.num_ssbos)
-      NIR_PASS_V(nir, nir_lower_dynamic_bo_access);
    nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
    if (screen->driconf.inline_uniforms)
       nir_find_inlinable_uniforms(nir);
diff --git a/src/gallium/drivers/zink/zink_descriptors_lazy.c b/src/gallium/drivers/zink/zink_descriptors_lazy.c
index be3f3dfef9f..cca84b95803 100644
--- a/src/gallium/drivers/zink/zink_descriptors_lazy.c
+++ b/src/gallium/drivers/zink/zink_descriptors_lazy.c
@@ -78,12 +78,13 @@ bdd_lazy(struct zink_batch_state *bs)
 
 static void
 init_template_entry(struct zink_shader *shader, enum zink_descriptor_type type,
-                    unsigned idx, unsigned offset, VkDescriptorUpdateTemplateEntry *entry, unsigned *entry_idx, bool flatten_dynamic)
+                    unsigned idx, VkDescriptorUpdateTemplateEntry *entry, unsigned *entry_idx, bool flatten_dynamic)
 {
     int index = shader->bindings[type][idx].index;
     enum pipe_shader_type stage = pipe_shader_type_from_mesa(shader->nir->info.stage);
     entry->dstArrayElement = 0;
     entry->dstBinding = shader->bindings[type][idx].binding;
+    entry->descriptorCount = shader->bindings[type][idx].size;
     if (shader->bindings[type][idx].type == VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC && flatten_dynamic)
        /* filter out DYNAMIC type here */
        entry->descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
@@ -92,33 +93,27 @@ init_template_entry(struct zink_shader *shader, enum zink_descriptor_type type,
     switch (shader->bindings[type][idx].type) {
     case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
     case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC:
-       entry->descriptorCount = 1;
-       entry->offset = offsetof(struct zink_context, di.ubos[stage][index + offset]);
+       entry->offset = offsetof(struct zink_context, di.ubos[stage][index]);
        entry->stride = sizeof(VkDescriptorBufferInfo);
        break;
     case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
-       entry->descriptorCount = shader->bindings[type][idx].size;
-       entry->offset = offsetof(struct zink_context, di.textures[stage][index + offset]);
+       entry->offset = offsetof(struct zink_context, di.textures[stage][index]);
        entry->stride = sizeof(VkDescriptorImageInfo);
        break;
     case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
-       entry->descriptorCount = shader->bindings[type][idx].size;
-       entry->offset = offsetof(struct zink_context, di.tbos[stage][index + offset]);
+       entry->offset = offsetof(struct zink_context, di.tbos[stage][index]);
        entry->stride = sizeof(VkBufferView);
        break;
     case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:
-       entry->descriptorCount = 1;
-       entry->offset = offsetof(struct zink_context, di.ssbos[stage][index + offset]);
+       entry->offset = offsetof(struct zink_context, di.ssbos[stage][index]);
        entry->stride = sizeof(VkDescriptorBufferInfo);
        break;
     case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
-       entry->descriptorCount = shader->bindings[type][idx].size;
-       entry->offset = offsetof(struct zink_context, di.images[stage][index + offset]);
+       entry->offset = offsetof(struct zink_context, di.images[stage][index]);
        entry->stride = sizeof(VkDescriptorImageInfo);
        break;
     case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:
-       entry->descriptorCount = shader->bindings[type][idx].size;
-       entry->offset = offsetof(struct zink_context, di.texel_images[stage][index + offset]);
+       entry->offset = offsetof(struct zink_context, di.texel_images[stage][index]);
        entry->stride = sizeof(VkBufferView);
        break;
     default:
@@ -207,21 +202,7 @@ zink_descriptor_program_init_lazy(struct zink_context *ctx, struct zink_program
             enum zink_descriptor_size_index idx = zink_vktype_to_size_idx(shader->bindings[j][k].type);
             sizes[idx].descriptorCount += shader->bindings[j][k].size;
             sizes[idx].type = shader->bindings[j][k].type;
-            switch (shader->bindings[j][k].type) {
-            case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
-            case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
-            case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
-            case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:
-               init_template_entry(shader, j, k, 0, &entries[j][entry_idx[j]], &entry_idx[j], screen->descriptor_mode == ZINK_DESCRIPTOR_MODE_LAZY);
-               break;
-            case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
-            case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:
-               for (unsigned l = 0; l < shader->bindings[j][k].size; l++)
-                  init_template_entry(shader, j, k, l, &entries[j][entry_idx[j]], &entry_idx[j], screen->descriptor_mode == ZINK_DESCRIPTOR_MODE_LAZY);
-               break;
-            default:
-               break;
-            }
+            init_template_entry(shader, j, k, &entries[j][entry_idx[j]], &entry_idx[j], screen->descriptor_mode == ZINK_DESCRIPTOR_MODE_LAZY);
             num_bindings[j]++;
             has_bindings |= BITFIELD_BIT(j);
          }



More information about the mesa-commit mailing list