Mesa (main): d3d12: Implement num workgroups as a state var
GitLab Mirror
gitlab-mirror at kemper.freedesktop.org
Tue Jan 11 01:48:38 UTC 2022
Module: Mesa
Branch: main
Commit: 9cc6b17c8ad508bcc404b1f75285f7a9d5f7f792
URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=9cc6b17c8ad508bcc404b1f75285f7a9d5f7f792
Author: Jesse Natalie <jenatali at microsoft.com>
Date: Fri Dec 31 14:50:07 2021 -0800
d3d12: Implement num workgroups as a state var
Reviewed-by: Sil Vilerino <sivileri at microsoft.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14367>
---
src/gallium/drivers/d3d12/d3d12_compiler.cpp | 2 ++
src/gallium/drivers/d3d12/d3d12_compiler.h | 7 +++++-
src/gallium/drivers/d3d12/d3d12_draw.cpp | 35 +++++++++++++++++++++++++-
src/gallium/drivers/d3d12/d3d12_nir_passes.c | 37 ++++++++++++++++++++++++++++
src/gallium/drivers/d3d12/d3d12_nir_passes.h | 3 +++
5 files changed, 82 insertions(+), 2 deletions(-)
diff --git a/src/gallium/drivers/d3d12/d3d12_compiler.cpp b/src/gallium/drivers/d3d12/d3d12_compiler.cpp
index c3dea8eb4f3..c154c5e8bf8 100644
--- a/src/gallium/drivers/d3d12/d3d12_compiler.cpp
+++ b/src/gallium/drivers/d3d12/d3d12_compiler.cpp
@@ -1183,6 +1183,8 @@ d3d12_create_compute_shader(struct d3d12_context *ctx,
nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
+ NIR_PASS_V(nir, d3d12_lower_compute_state_vars);
+
return d3d12_create_shader_impl(ctx, sel, nir, nullptr, nullptr);
}
diff --git a/src/gallium/drivers/d3d12/d3d12_compiler.h b/src/gallium/drivers/d3d12/d3d12_compiler.h
index 587bd9a039a..703b058bb48 100644
--- a/src/gallium/drivers/d3d12/d3d12_compiler.h
+++ b/src/gallium/drivers/d3d12/d3d12_compiler.h
@@ -45,7 +45,12 @@ enum d3d12_state_var {
D3D12_STATE_VAR_PT_SPRITE,
D3D12_STATE_VAR_FIRST_VERTEX,
D3D12_STATE_VAR_DEPTH_TRANSFORM,
- D3D12_MAX_STATE_VARS
+ D3D12_MAX_GRAPHICS_STATE_VARS,
+
+ D3D12_STATE_VAR_NUM_WORKGROUPS = 0,
+ D3D12_MAX_COMPUTE_STATE_VARS,
+
+ D3D12_MAX_STATE_VARS = MAX2(D3D12_MAX_GRAPHICS_STATE_VARS, D3D12_MAX_COMPUTE_STATE_VARS)
};
#define D3D12_MAX_POINT_SIZE 255.0f
diff --git a/src/gallium/drivers/d3d12/d3d12_draw.cpp b/src/gallium/drivers/d3d12/d3d12_draw.cpp
index dd7bbbb6d16..78b3be97880 100644
--- a/src/gallium/drivers/d3d12/d3d12_draw.cpp
+++ b/src/gallium/drivers/d3d12/d3d12_draw.cpp
@@ -378,6 +378,32 @@ fill_graphics_state_vars(struct d3d12_context *ctx,
return size;
}
+static unsigned
+fill_compute_state_vars(struct d3d12_context *ctx,
+ const struct pipe_grid_info *info,
+ struct d3d12_shader *shader,
+ uint32_t *values)
+{
+ unsigned size = 0;
+
+ for (unsigned j = 0; j < shader->num_state_vars; ++j) {
+ uint32_t *ptr = values + size;
+
+ switch (shader->state_vars[j].var) {
+ case D3D12_STATE_VAR_NUM_WORKGROUPS:
+ ptr[0] = info->grid[0];
+ ptr[1] = info->grid[1];
+ ptr[2] = info->grid[2];
+ size += 4;
+ break;
+ default:
+ unreachable("unknown state variable");
+ }
+ }
+
+ return size;
+}
+
static bool
check_descriptors_left(struct d3d12_context *ctx, bool compute)
{
@@ -489,7 +515,7 @@ update_graphics_root_parameters(struct d3d12_context *ctx,
update_shader_stage_root_parameters(ctx, shader_sel, num_params, num_root_descriptors, root_desc_tables, root_desc_indices);
/* TODO Don't always update state vars */
if (shader_sel->current->num_state_vars > 0) {
- uint32_t constants[D3D12_MAX_STATE_VARS * 4];
+ uint32_t constants[D3D12_MAX_GRAPHICS_STATE_VARS * 4];
unsigned size = fill_graphics_state_vars(ctx, dinfo, draw, shader_sel->current, constants);
ctx->cmdlist->SetGraphicsRoot32BitConstants(num_params, size, constants, 0);
num_params++;
@@ -510,6 +536,13 @@ update_compute_root_parameters(struct d3d12_context *ctx,
struct d3d12_shader_selector *shader_sel = ctx->compute_state;
if (shader_sel) {
update_shader_stage_root_parameters(ctx, shader_sel, num_params, num_root_descriptors, root_desc_tables, root_desc_indices);
+ /* TODO Don't always update state vars */
+ if (shader_sel->current->num_state_vars > 0) {
+ uint32_t constants[D3D12_MAX_COMPUTE_STATE_VARS * 4];
+ unsigned size = fill_compute_state_vars(ctx, info, shader_sel->current, constants);
+ ctx->cmdlist->SetComputeRoot32BitConstants(num_params, size, constants, 0);
+ num_params++;
+ }
}
return num_root_descriptors;
}
diff --git a/src/gallium/drivers/d3d12/d3d12_nir_passes.c b/src/gallium/drivers/d3d12/d3d12_nir_passes.c
index 7ed43cbc7b8..c2cc4d70473 100644
--- a/src/gallium/drivers/d3d12/d3d12_nir_passes.c
+++ b/src/gallium/drivers/d3d12/d3d12_nir_passes.c
@@ -220,6 +220,43 @@ d3d12_lower_depth_range(nir_shader *nir)
}
}
+struct compute_state_vars {
+ nir_variable *num_workgroups;
+};
+
+static bool
+lower_compute_state_vars(nir_builder *b, nir_instr *instr, void *_state)
+{
+ if (instr->type != nir_instr_type_intrinsic)
+ return false;
+
+ b->cursor = nir_after_instr(instr);
+ nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
+ struct compute_state_vars *vars = _state;
+ nir_ssa_def *result = NULL;
+ switch (intr->intrinsic) {
+ case nir_intrinsic_load_num_workgroups:
+ result = get_state_var(b, D3D12_STATE_VAR_NUM_WORKGROUPS, "d3d12_NumWorkgroups",
+ glsl_vec_type(3), &vars->num_workgroups);
+ break;
+ default:
+ return false;
+ }
+
+ nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
+ nir_instr_remove(instr);
+ return true;
+}
+
+bool
+d3d12_lower_compute_state_vars(nir_shader *nir)
+{
+ assert(nir->info.stage == MESA_SHADER_COMPUTE);
+ struct compute_state_vars vars = { 0 };
+ return nir_shader_instructions_pass(nir, lower_compute_state_vars,
+ nir_metadata_block_index | nir_metadata_dominance, &vars);
+}
+
static bool
is_color_output(nir_variable *var)
{
diff --git a/src/gallium/drivers/d3d12/d3d12_nir_passes.h b/src/gallium/drivers/d3d12/d3d12_nir_passes.h
index 54a9d2452fb..03f7454572f 100644
--- a/src/gallium/drivers/d3d12/d3d12_nir_passes.h
+++ b/src/gallium/drivers/d3d12/d3d12_nir_passes.h
@@ -55,6 +55,9 @@ d3d12_lower_depth_range(nir_shader *nir);
bool
d3d12_lower_load_first_vertex(nir_shader *nir);
+bool
+d3d12_lower_compute_state_vars(nir_shader *nir);
+
void
d3d12_lower_uint_cast(nir_shader *nir, bool is_signed);
More information about the mesa-commit
mailing list