Mesa (main): zink: implement cs uniform inlining

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Nov 10 01:29:41 UTC 2021


Module: Mesa
Branch: main
Commit: a8d90c8ed55e77344bcf277934a5ff2fa52d3e15
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=a8d90c8ed55e77344bcf277934a5ff2fa52d3e15

Author: Mike Blumenkrantz <michael.blumenkrantz at gmail.com>
Date:   Tue Nov  9 10:07:49 2021 -0500

zink: implement cs uniform inlining

this implements shader variants for compute

Reviewed-by: Dave Airlie <airlied at redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13727>

---

 src/gallium/drivers/zink/zink_context.c  |   7 +-
 src/gallium/drivers/zink/zink_context.h  |   1 -
 src/gallium/drivers/zink/zink_draw.cpp   |   6 ++
 src/gallium/drivers/zink/zink_pipeline.h |   7 ++
 src/gallium/drivers/zink/zink_program.c  | 111 ++++++++++++++++++++++++++++---
 src/gallium/drivers/zink/zink_program.h  |  10 ++-
 6 files changed, 127 insertions(+), 15 deletions(-)

diff --git a/src/gallium/drivers/zink/zink_context.c b/src/gallium/drivers/zink/zink_context.c
index 75b376369c9..defca8d80e0 100644
--- a/src/gallium/drivers/zink/zink_context.c
+++ b/src/gallium/drivers/zink/zink_context.c
@@ -1047,18 +1047,17 @@ zink_set_inlinable_constants(struct pipe_context *pctx,
    struct zink_shader_key *key = NULL;
 
    if (shader == PIPE_SHADER_COMPUTE) {
-      inlinable_uniforms = ctx->compute_inlinable_uniforms;
+      key = &ctx->compute_pipeline_state.key;
    } else {
       key = &ctx->gfx_pipeline_state.shader_keys.key[shader];
-      inlinable_uniforms = key->base.inlined_uniform_values;
    }
+   inlinable_uniforms = key->base.inlined_uniform_values;
    if (!(ctx->inlinable_uniforms_valid_mask & bit) ||
        memcmp(inlinable_uniforms, values, num_values * 4)) {
       memcpy(inlinable_uniforms, values, num_values * 4);
       ctx->dirty_shader_stages |= bit;
       ctx->inlinable_uniforms_valid_mask |= bit;
-      if (key)
-         key->inline_uniforms = true;
+      key->inline_uniforms = true;
    }
 }
 
diff --git a/src/gallium/drivers/zink/zink_context.h b/src/gallium/drivers/zink/zink_context.h
index 8301e90241f..65aaf02c00f 100644
--- a/src/gallium/drivers/zink/zink_context.h
+++ b/src/gallium/drivers/zink/zink_context.h
@@ -199,7 +199,6 @@ struct zink_context {
 
    unsigned shader_has_inlinable_uniforms_mask;
    unsigned inlinable_uniforms_valid_mask;
-   uint32_t compute_inlinable_uniforms[MAX_INLINABLE_UNIFORMS];
 
    struct pipe_constant_buffer ubos[PIPE_SHADER_TYPES][PIPE_MAX_CONSTANT_BUFFERS];
    struct pipe_shader_buffer ssbos[PIPE_SHADER_TYPES][PIPE_MAX_SHADER_BUFFERS];
diff --git a/src/gallium/drivers/zink/zink_draw.cpp b/src/gallium/drivers/zink/zink_draw.cpp
index a80e37b67c0..1548f64229c 100644
--- a/src/gallium/drivers/zink/zink_draw.cpp
+++ b/src/gallium/drivers/zink/zink_draw.cpp
@@ -882,6 +882,12 @@ zink_launch_grid(struct pipe_context *pctx, const struct pipe_grid_info *info)
       zink_batch_reference_program(&ctx->batch, &ctx->curr_compute->base);
    }
 
+   if (ctx->dirty_shader_stages & BITFIELD_BIT(PIPE_SHADER_COMPUTE)) {
+      /* update inlinable constants */
+      zink_update_compute_program(ctx);
+      ctx->dirty_shader_stages &= ~BITFIELD_BIT(PIPE_SHADER_COMPUTE);
+   }
+
    if (prev_pipeline != pipeline || BATCH_CHANGED)
       VKCTX(CmdBindPipeline)(batch->state->cmdbuf, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
    if (BATCH_CHANGED) {
diff --git a/src/gallium/drivers/zink/zink_pipeline.h b/src/gallium/drivers/zink/zink_pipeline.h
index 4acc6c44285..04c6a05d25e 100644
--- a/src/gallium/drivers/zink/zink_pipeline.h
+++ b/src/gallium/drivers/zink/zink_pipeline.h
@@ -92,10 +92,17 @@ struct zink_compute_pipeline_state {
    /* Pre-hashed value for table lookup, invalid when zero.
     * Members after this point are not included in pipeline state hash key */
    uint32_t hash;
+   uint32_t final_hash;
    bool dirty;
    bool use_local_size;
    uint32_t local_size[3];
 
+   uint32_t module_hash;
+   VkShaderModule module;
+   bool module_changed;
+
+   struct zink_shader_key key;
+
    VkPipeline pipeline;
 };
 
diff --git a/src/gallium/drivers/zink/zink_program.c b/src/gallium/drivers/zink/zink_program.c
index fa5461dc990..92e7baca53d 100644
--- a/src/gallium/drivers/zink/zink_program.c
+++ b/src/gallium/drivers/zink/zink_program.c
@@ -155,7 +155,7 @@ destroy_shader_cache(struct zink_screen *screen, struct list_head *sc)
 }
 
 static void
-update_shader_modules(struct zink_context *ctx,
+update_gfx_shader_modules(struct zink_context *ctx,
                       struct zink_screen *screen,
                       struct zink_gfx_program *prog, uint32_t mask,
                       struct zink_gfx_pipeline_state *state)
@@ -245,7 +245,87 @@ equals_gfx_pipeline_state(const void *a, const void *b)
 void
 zink_update_gfx_program(struct zink_context *ctx, struct zink_gfx_program *prog)
 {
-   update_shader_modules(ctx, zink_screen(ctx->base.screen), prog, ctx->dirty_shader_stages & prog->stages_present, &ctx->gfx_pipeline_state);
+   update_gfx_shader_modules(ctx, zink_screen(ctx->base.screen), prog, ctx->dirty_shader_stages & prog->stages_present, &ctx->gfx_pipeline_state);
+}
+
+static bool
+uniforms_match(const struct zink_shader_module *zm, uint32_t *uniforms, unsigned num_uniforms)
+{
+   assert(zm->num_uniforms == num_uniforms);
+   return !memcmp(zm->key, uniforms, zm->num_uniforms * sizeof(uint32_t));
+}
+
+static uint32_t
+cs_module_hash(const struct zink_shader_module *zm)
+{
+   return _mesa_hash_data(zm->key, zm->num_uniforms * sizeof(uint32_t));
+}
+
+static void
+update_cs_shader_module(struct zink_context *ctx, struct zink_compute_program *comp)
+{
+   struct zink_shader *zs = comp->shader;
+   VkShaderModule mod;
+   struct zink_shader_module *zm = NULL;
+   unsigned base_size = 0;
+   struct zink_shader_key *key = &ctx->compute_pipeline_state.key;
+
+   if (ctx && zs->nir->info.num_inlinable_uniforms &&
+       ctx->inlinable_uniforms_valid_mask & BITFIELD64_BIT(PIPE_SHADER_COMPUTE)) {
+      if (comp->inlined_variant_count < ZINK_MAX_INLINED_VARIANTS)
+         base_size = zs->nir->info.num_inlinable_uniforms;
+      else
+         key->inline_uniforms = false;
+   }
+
+   if (base_size) {
+      struct zink_shader_module *iter, *next;
+      LIST_FOR_EACH_ENTRY_SAFE(iter, next, &comp->shader_cache, list) {
+         if (!uniforms_match(iter, key->base.inlined_uniform_values, base_size))
+            continue;
+         list_delinit(&iter->list);
+         zm = iter;
+         break;
+      }
+   } else {
+      zm = comp->module;
+   }
+
+   if (!zm) {
+      zm = malloc(sizeof(struct zink_shader_module) + base_size * sizeof(uint32_t));
+      if (!zm) {
+         return;
+      }
+      mod = zink_shader_compile(zink_screen(ctx->base.screen), zs, comp->shader->nir, key);
+      if (!mod) {
+         FREE(zm);
+         return;
+      }
+      zm->shader = mod;
+      list_inithead(&zm->list);
+      zm->num_uniforms = base_size;
+      zm->key_size = 0;
+      assert(base_size);
+      memcpy(zm->key, key->base.inlined_uniform_values, base_size * sizeof(uint32_t));
+      zm->hash = cs_module_hash(zm);
+      zm->default_variant = false;
+      comp->inlined_variant_count++;
+   }
+   if (zm->num_uniforms)
+      list_add(&zm->list, &comp->shader_cache);
+   if (comp->curr == zm)
+      return;
+   ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
+   comp->curr = zm;
+   ctx->compute_pipeline_state.module_hash = zm->hash;
+   ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
+   ctx->compute_pipeline_state.module_changed = true;
+}
+
+void
+zink_update_compute_program(struct zink_context *ctx)
+{
+   update_cs_shader_module(ctx, ctx->curr_compute);
 }
 
 VkPipelineLayout
@@ -418,7 +498,10 @@ zink_program_update_compute_pipeline_state(struct zink_context *ctx, struct zink
 static bool
 equals_compute_pipeline_state(const void *a, const void *b)
 {
-   return memcmp(a, b, offsetof(struct zink_compute_pipeline_state, hash)) == 0;
+   const struct zink_compute_pipeline_state *sa = a;
+   const struct zink_compute_pipeline_state *sb = b;
+   return !memcmp(a, b, offsetof(struct zink_compute_pipeline_state, hash)) &&
+          sa->module == sb->module;
 }
 
 struct zink_compute_program *
@@ -432,12 +515,13 @@ zink_create_compute_program(struct zink_context *ctx, struct zink_shader *shader
    pipe_reference_init(&comp->base.reference, 1);
    comp->base.is_compute = true;
 
-   comp->module = CALLOC_STRUCT(zink_shader_module);
+   comp->curr = comp->module = CALLOC_STRUCT(zink_shader_module);
    assert(comp->module);
    comp->module->shader = zink_shader_compile(screen, shader, shader->nir, NULL);
    assert(comp->module->shader);
+   list_inithead(&comp->shader_cache);
 
-   comp->pipelines = _mesa_hash_table_create(NULL, hash_compute_pipeline_state,
+   comp->pipelines = _mesa_hash_table_create(NULL, NULL,
                                              equals_compute_pipeline_state);
 
    _mesa_set_add(shader->programs, comp);
@@ -736,13 +820,16 @@ zink_get_compute_pipeline(struct zink_screen *screen,
 {
    struct hash_entry *entry = NULL;
 
-   if (!state->dirty)
+   if (!state->dirty && !state->module_changed)
       return state->pipeline;
    if (state->dirty) {
+      if (state->pipeline) //avoid on first hash
+         state->final_hash ^= state->hash;
       state->hash = hash_compute_pipeline_state(state);
       state->dirty = false;
+      state->final_hash ^= state->hash;
    }
-   entry = _mesa_hash_table_search_pre_hashed(comp->pipelines, state->hash, state);
+   entry = _mesa_hash_table_search_pre_hashed(comp->pipelines, state->final_hash, state);
 
    if (!entry) {
       util_queue_fence_wait(&comp->base.cache_fence);
@@ -758,7 +845,7 @@ zink_get_compute_pipeline(struct zink_screen *screen,
       memcpy(&pc_entry->state, state, sizeof(*state));
       pc_entry->pipeline = pipeline;
 
-      entry = _mesa_hash_table_insert_pre_hashed(comp->pipelines, state->hash, pc_entry, pc_entry);
+      entry = _mesa_hash_table_insert_pre_hashed(comp->pipelines, state->final_hash, pc_entry, pc_entry);
       assert(entry);
    }
 
@@ -777,6 +864,11 @@ bind_stage(struct zink_context *ctx, enum pipe_shader_type stage,
       ctx->shader_has_inlinable_uniforms_mask &= ~(1 << stage);
 
    if (stage == PIPE_SHADER_COMPUTE) {
+      if (ctx->compute_stage) {
+         ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
+         ctx->compute_pipeline_state.module = VK_NULL_HANDLE;
+         ctx->compute_pipeline_state.module_hash = 0;
+      }
       if (shader && shader != ctx->compute_stage) {
          struct hash_entry *entry = _mesa_hash_table_search(&ctx->compute_program_cache, shader);
          if (entry) {
@@ -789,6 +881,9 @@ bind_stage(struct zink_context *ctx, enum pipe_shader_type stage,
             ctx->curr_compute = comp;
             zink_batch_reference_program(&ctx->batch, &ctx->curr_compute->base);
          }
+         ctx->compute_pipeline_state.module_hash = ctx->curr_compute->curr->hash;
+         ctx->compute_pipeline_state.module = ctx->curr_compute->curr->shader;
+         ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
       } else if (!shader)
          ctx->curr_compute = NULL;
       ctx->compute_stage = shader;
diff --git a/src/gallium/drivers/zink/zink_program.h b/src/gallium/drivers/zink/zink_program.h
index 111207f9541..78ac5f048ea 100644
--- a/src/gallium/drivers/zink/zink_program.h
+++ b/src/gallium/drivers/zink/zink_program.h
@@ -116,7 +116,12 @@ struct zink_gfx_program {
 struct zink_compute_program {
    struct zink_program base;
 
-   struct zink_shader_module *module;
+   struct zink_shader_module *curr;
+
+   struct zink_shader_module *module; //base
+   struct list_head shader_cache; //inline uniforms
+   unsigned inlined_variant_count;
+
    struct zink_shader *shader;
    struct hash_table *pipelines;
 };
@@ -272,7 +277,8 @@ zink_pipeline_layout_create(struct zink_screen *screen, struct zink_program *pg,
 
 void
 zink_program_update_compute_pipeline_state(struct zink_context *ctx, struct zink_compute_program *comp, const uint block[3]);
-
+void
+zink_update_compute_program(struct zink_context *ctx);
 VkPipeline
 zink_get_compute_pipeline(struct zink_screen *screen,
                       struct zink_compute_program *comp,



More information about the mesa-commit mailing list