Mesa (main): zink: implement cs uniform inlining
GitLab Mirror
gitlab-mirror at kemper.freedesktop.org
Wed Nov 10 01:29:41 UTC 2021
Module: Mesa
Branch: main
Commit: a8d90c8ed55e77344bcf277934a5ff2fa52d3e15
URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=a8d90c8ed55e77344bcf277934a5ff2fa52d3e15
Author: Mike Blumenkrantz <michael.blumenkrantz at gmail.com>
Date: Tue Nov 9 10:07:49 2021 -0500
zink: implement cs uniform inlining
this implements shader variants for compute
Reviewed-by: Dave Airlie <airlied at redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13727>
---
src/gallium/drivers/zink/zink_context.c | 7 +-
src/gallium/drivers/zink/zink_context.h | 1 -
src/gallium/drivers/zink/zink_draw.cpp | 6 ++
src/gallium/drivers/zink/zink_pipeline.h | 7 ++
src/gallium/drivers/zink/zink_program.c | 111 ++++++++++++++++++++++++++++---
src/gallium/drivers/zink/zink_program.h | 10 ++-
6 files changed, 127 insertions(+), 15 deletions(-)
diff --git a/src/gallium/drivers/zink/zink_context.c b/src/gallium/drivers/zink/zink_context.c
index 75b376369c9..defca8d80e0 100644
--- a/src/gallium/drivers/zink/zink_context.c
+++ b/src/gallium/drivers/zink/zink_context.c
@@ -1047,18 +1047,17 @@ zink_set_inlinable_constants(struct pipe_context *pctx,
struct zink_shader_key *key = NULL;
if (shader == PIPE_SHADER_COMPUTE) {
- inlinable_uniforms = ctx->compute_inlinable_uniforms;
+ key = &ctx->compute_pipeline_state.key;
} else {
key = &ctx->gfx_pipeline_state.shader_keys.key[shader];
- inlinable_uniforms = key->base.inlined_uniform_values;
}
+ inlinable_uniforms = key->base.inlined_uniform_values;
if (!(ctx->inlinable_uniforms_valid_mask & bit) ||
memcmp(inlinable_uniforms, values, num_values * 4)) {
memcpy(inlinable_uniforms, values, num_values * 4);
ctx->dirty_shader_stages |= bit;
ctx->inlinable_uniforms_valid_mask |= bit;
- if (key)
- key->inline_uniforms = true;
+ key->inline_uniforms = true;
}
}
diff --git a/src/gallium/drivers/zink/zink_context.h b/src/gallium/drivers/zink/zink_context.h
index 8301e90241f..65aaf02c00f 100644
--- a/src/gallium/drivers/zink/zink_context.h
+++ b/src/gallium/drivers/zink/zink_context.h
@@ -199,7 +199,6 @@ struct zink_context {
unsigned shader_has_inlinable_uniforms_mask;
unsigned inlinable_uniforms_valid_mask;
- uint32_t compute_inlinable_uniforms[MAX_INLINABLE_UNIFORMS];
struct pipe_constant_buffer ubos[PIPE_SHADER_TYPES][PIPE_MAX_CONSTANT_BUFFERS];
struct pipe_shader_buffer ssbos[PIPE_SHADER_TYPES][PIPE_MAX_SHADER_BUFFERS];
diff --git a/src/gallium/drivers/zink/zink_draw.cpp b/src/gallium/drivers/zink/zink_draw.cpp
index a80e37b67c0..1548f64229c 100644
--- a/src/gallium/drivers/zink/zink_draw.cpp
+++ b/src/gallium/drivers/zink/zink_draw.cpp
@@ -882,6 +882,12 @@ zink_launch_grid(struct pipe_context *pctx, const struct pipe_grid_info *info)
zink_batch_reference_program(&ctx->batch, &ctx->curr_compute->base);
}
+ if (ctx->dirty_shader_stages & BITFIELD_BIT(PIPE_SHADER_COMPUTE)) {
+ /* update inlinable constants */
+ zink_update_compute_program(ctx);
+ ctx->dirty_shader_stages &= ~BITFIELD_BIT(PIPE_SHADER_COMPUTE);
+ }
+
if (prev_pipeline != pipeline || BATCH_CHANGED)
VKCTX(CmdBindPipeline)(batch->state->cmdbuf, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
if (BATCH_CHANGED) {
diff --git a/src/gallium/drivers/zink/zink_pipeline.h b/src/gallium/drivers/zink/zink_pipeline.h
index 4acc6c44285..04c6a05d25e 100644
--- a/src/gallium/drivers/zink/zink_pipeline.h
+++ b/src/gallium/drivers/zink/zink_pipeline.h
@@ -92,10 +92,17 @@ struct zink_compute_pipeline_state {
/* Pre-hashed value for table lookup, invalid when zero.
* Members after this point are not included in pipeline state hash key */
uint32_t hash;
+ uint32_t final_hash;
bool dirty;
bool use_local_size;
uint32_t local_size[3];
+ uint32_t module_hash;
+ VkShaderModule module;
+ bool module_changed;
+
+ struct zink_shader_key key;
+
VkPipeline pipeline;
};
diff --git a/src/gallium/drivers/zink/zink_program.c b/src/gallium/drivers/zink/zink_program.c
index fa5461dc990..92e7baca53d 100644
--- a/src/gallium/drivers/zink/zink_program.c
+++ b/src/gallium/drivers/zink/zink_program.c
@@ -155,7 +155,7 @@ destroy_shader_cache(struct zink_screen *screen, struct list_head *sc)
}
static void
-update_shader_modules(struct zink_context *ctx,
+update_gfx_shader_modules(struct zink_context *ctx,
struct zink_screen *screen,
struct zink_gfx_program *prog, uint32_t mask,
struct zink_gfx_pipeline_state *state)
@@ -245,7 +245,87 @@ equals_gfx_pipeline_state(const void *a, const void *b)
void
zink_update_gfx_program(struct zink_context *ctx, struct zink_gfx_program *prog)
{
- update_shader_modules(ctx, zink_screen(ctx->base.screen), prog, ctx->dirty_shader_stages & prog->stages_present, &ctx->gfx_pipeline_state);
+ update_gfx_shader_modules(ctx, zink_screen(ctx->base.screen), prog, ctx->dirty_shader_stages & prog->stages_present, &ctx->gfx_pipeline_state);
+}
+
+static bool
+uniforms_match(const struct zink_shader_module *zm, uint32_t *uniforms, unsigned num_uniforms)
+{
+ assert(zm->num_uniforms == num_uniforms);
+ return !memcmp(zm->key, uniforms, zm->num_uniforms * sizeof(uint32_t));
+}
+
+static uint32_t
+cs_module_hash(const struct zink_shader_module *zm)
+{
+ return _mesa_hash_data(zm->key, zm->num_uniforms * sizeof(uint32_t));
+}
+
+static void
+update_cs_shader_module(struct zink_context *ctx, struct zink_compute_program *comp)
+{
+ struct zink_shader *zs = comp->shader;
+ VkShaderModule mod;
+ struct zink_shader_module *zm = NULL;
+ unsigned base_size = 0;
+ struct zink_shader_key *key = &ctx->compute_pipeline_state.key;
+
+ if (ctx && zs->nir->info.num_inlinable_uniforms &&
+ ctx->inlinable_uniforms_valid_mask & BITFIELD64_BIT(PIPE_SHADER_COMPUTE)) {
+ if (comp->inlined_variant_count < ZINK_MAX_INLINED_VARIANTS)
+ base_size = zs->nir->info.num_inlinable_uniforms;
+ else
+ key->inline_uniforms = false;
+ }
+
+ if (base_size) {
+ struct zink_shader_module *iter, *next;
+ LIST_FOR_EACH_ENTRY_SAFE(iter, next, &comp->shader_cache, list) {
+ if (!uniforms_match(iter, key->base.inlined_uniform_values, base_size))
+ continue;
+ list_delinit(&iter->list);
+ zm = iter;
+ break;
+ }
+ } else {
+ zm = comp->module;
+ }
+
+ if (!zm) {
+ zm = malloc(sizeof(struct zink_shader_module) + base_size * sizeof(uint32_t));
+ if (!zm) {
+ return;
+ }
+ mod = zink_shader_compile(zink_screen(ctx->base.screen), zs, comp->shader->nir, key);
+ if (!mod) {
+ FREE(zm);
+ return;
+ }
+ zm->shader = mod;
+ list_inithead(&zm->list);
+ zm->num_uniforms = base_size;
+ zm->key_size = 0;
+ assert(base_size);
+ memcpy(zm->key, key->base.inlined_uniform_values, base_size * sizeof(uint32_t));
+ zm->hash = cs_module_hash(zm);
+ zm->default_variant = false;
+ comp->inlined_variant_count++;
+ }
+ if (zm->num_uniforms)
+ list_add(&zm->list, &comp->shader_cache);
+ if (comp->curr == zm)
+ return;
+ ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
+ comp->curr = zm;
+ ctx->compute_pipeline_state.module_hash = zm->hash;
+ ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
+ ctx->compute_pipeline_state.module_changed = true;
+}
+
+void
+zink_update_compute_program(struct zink_context *ctx)
+{
+ update_cs_shader_module(ctx, ctx->curr_compute);
}
VkPipelineLayout
@@ -418,7 +498,10 @@ zink_program_update_compute_pipeline_state(struct zink_context *ctx, struct zink
static bool
equals_compute_pipeline_state(const void *a, const void *b)
{
- return memcmp(a, b, offsetof(struct zink_compute_pipeline_state, hash)) == 0;
+ const struct zink_compute_pipeline_state *sa = a;
+ const struct zink_compute_pipeline_state *sb = b;
+ return !memcmp(a, b, offsetof(struct zink_compute_pipeline_state, hash)) &&
+ sa->module == sb->module;
}
struct zink_compute_program *
@@ -432,12 +515,13 @@ zink_create_compute_program(struct zink_context *ctx, struct zink_shader *shader
pipe_reference_init(&comp->base.reference, 1);
comp->base.is_compute = true;
- comp->module = CALLOC_STRUCT(zink_shader_module);
+ comp->curr = comp->module = CALLOC_STRUCT(zink_shader_module);
assert(comp->module);
comp->module->shader = zink_shader_compile(screen, shader, shader->nir, NULL);
assert(comp->module->shader);
+ list_inithead(&comp->shader_cache);
- comp->pipelines = _mesa_hash_table_create(NULL, hash_compute_pipeline_state,
+ comp->pipelines = _mesa_hash_table_create(NULL, NULL,
equals_compute_pipeline_state);
_mesa_set_add(shader->programs, comp);
@@ -736,13 +820,16 @@ zink_get_compute_pipeline(struct zink_screen *screen,
{
struct hash_entry *entry = NULL;
- if (!state->dirty)
+ if (!state->dirty && !state->module_changed)
return state->pipeline;
if (state->dirty) {
+ if (state->pipeline) //avoid on first hash
+ state->final_hash ^= state->hash;
state->hash = hash_compute_pipeline_state(state);
state->dirty = false;
+ state->final_hash ^= state->hash;
}
- entry = _mesa_hash_table_search_pre_hashed(comp->pipelines, state->hash, state);
+ entry = _mesa_hash_table_search_pre_hashed(comp->pipelines, state->final_hash, state);
if (!entry) {
util_queue_fence_wait(&comp->base.cache_fence);
@@ -758,7 +845,7 @@ zink_get_compute_pipeline(struct zink_screen *screen,
memcpy(&pc_entry->state, state, sizeof(*state));
pc_entry->pipeline = pipeline;
- entry = _mesa_hash_table_insert_pre_hashed(comp->pipelines, state->hash, pc_entry, pc_entry);
+ entry = _mesa_hash_table_insert_pre_hashed(comp->pipelines, state->final_hash, pc_entry, pc_entry);
assert(entry);
}
@@ -777,6 +864,11 @@ bind_stage(struct zink_context *ctx, enum pipe_shader_type stage,
ctx->shader_has_inlinable_uniforms_mask &= ~(1 << stage);
if (stage == PIPE_SHADER_COMPUTE) {
+ if (ctx->compute_stage) {
+ ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
+ ctx->compute_pipeline_state.module = VK_NULL_HANDLE;
+ ctx->compute_pipeline_state.module_hash = 0;
+ }
if (shader && shader != ctx->compute_stage) {
struct hash_entry *entry = _mesa_hash_table_search(&ctx->compute_program_cache, shader);
if (entry) {
@@ -789,6 +881,9 @@ bind_stage(struct zink_context *ctx, enum pipe_shader_type stage,
ctx->curr_compute = comp;
zink_batch_reference_program(&ctx->batch, &ctx->curr_compute->base);
}
+ ctx->compute_pipeline_state.module_hash = ctx->curr_compute->curr->hash;
+ ctx->compute_pipeline_state.module = ctx->curr_compute->curr->shader;
+ ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
} else if (!shader)
ctx->curr_compute = NULL;
ctx->compute_stage = shader;
diff --git a/src/gallium/drivers/zink/zink_program.h b/src/gallium/drivers/zink/zink_program.h
index 111207f9541..78ac5f048ea 100644
--- a/src/gallium/drivers/zink/zink_program.h
+++ b/src/gallium/drivers/zink/zink_program.h
@@ -116,7 +116,12 @@ struct zink_gfx_program {
struct zink_compute_program {
struct zink_program base;
- struct zink_shader_module *module;
+ struct zink_shader_module *curr;
+
+ struct zink_shader_module *module; //base
+ struct list_head shader_cache; //inline uniforms
+ unsigned inlined_variant_count;
+
struct zink_shader *shader;
struct hash_table *pipelines;
};
@@ -272,7 +277,8 @@ zink_pipeline_layout_create(struct zink_screen *screen, struct zink_program *pg,
void
zink_program_update_compute_pipeline_state(struct zink_context *ctx, struct zink_compute_program *comp, const uint block[3]);
-
+void
+zink_update_compute_program(struct zink_context *ctx);
VkPipeline
zink_get_compute_pipeline(struct zink_screen *screen,
struct zink_compute_program *comp,
More information about the mesa-commit
mailing list