Mesa (main): d3d12: Implement num workgroups as a state var

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Tue Jan 11 01:48:38 UTC 2022


Module: Mesa
Branch: main
Commit: 9cc6b17c8ad508bcc404b1f75285f7a9d5f7f792
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=9cc6b17c8ad508bcc404b1f75285f7a9d5f7f792

Author: Jesse Natalie <jenatali at microsoft.com>
Date:   Fri Dec 31 14:50:07 2021 -0800

d3d12: Implement num workgroups as a state var

Reviewed-by: Sil Vilerino <sivileri at microsoft.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14367>

---

 src/gallium/drivers/d3d12/d3d12_compiler.cpp |  2 ++
 src/gallium/drivers/d3d12/d3d12_compiler.h   |  7 +++++-
 src/gallium/drivers/d3d12/d3d12_draw.cpp     | 35 +++++++++++++++++++++++++-
 src/gallium/drivers/d3d12/d3d12_nir_passes.c | 37 ++++++++++++++++++++++++++++
 src/gallium/drivers/d3d12/d3d12_nir_passes.h |  3 +++
 5 files changed, 82 insertions(+), 2 deletions(-)

diff --git a/src/gallium/drivers/d3d12/d3d12_compiler.cpp b/src/gallium/drivers/d3d12/d3d12_compiler.cpp
index c3dea8eb4f3..c154c5e8bf8 100644
--- a/src/gallium/drivers/d3d12/d3d12_compiler.cpp
+++ b/src/gallium/drivers/d3d12/d3d12_compiler.cpp
@@ -1183,6 +1183,8 @@ d3d12_create_compute_shader(struct d3d12_context *ctx,
 
    nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
 
+   NIR_PASS_V(nir, d3d12_lower_compute_state_vars);
+
    return d3d12_create_shader_impl(ctx, sel, nir, nullptr, nullptr);
 }
 
diff --git a/src/gallium/drivers/d3d12/d3d12_compiler.h b/src/gallium/drivers/d3d12/d3d12_compiler.h
index 587bd9a039a..703b058bb48 100644
--- a/src/gallium/drivers/d3d12/d3d12_compiler.h
+++ b/src/gallium/drivers/d3d12/d3d12_compiler.h
@@ -45,7 +45,12 @@ enum d3d12_state_var {
    D3D12_STATE_VAR_PT_SPRITE,
    D3D12_STATE_VAR_FIRST_VERTEX,
    D3D12_STATE_VAR_DEPTH_TRANSFORM,
-   D3D12_MAX_STATE_VARS
+   D3D12_MAX_GRAPHICS_STATE_VARS,
+
+   D3D12_STATE_VAR_NUM_WORKGROUPS = 0,
+   D3D12_MAX_COMPUTE_STATE_VARS,
+
+   D3D12_MAX_STATE_VARS = MAX2(D3D12_MAX_GRAPHICS_STATE_VARS, D3D12_MAX_COMPUTE_STATE_VARS)
 };
 
 #define D3D12_MAX_POINT_SIZE 255.0f
diff --git a/src/gallium/drivers/d3d12/d3d12_draw.cpp b/src/gallium/drivers/d3d12/d3d12_draw.cpp
index dd7bbbb6d16..78b3be97880 100644
--- a/src/gallium/drivers/d3d12/d3d12_draw.cpp
+++ b/src/gallium/drivers/d3d12/d3d12_draw.cpp
@@ -378,6 +378,32 @@ fill_graphics_state_vars(struct d3d12_context *ctx,
    return size;
 }
 
+static unsigned
+fill_compute_state_vars(struct d3d12_context *ctx,
+                        const struct pipe_grid_info *info,
+                        struct d3d12_shader *shader,
+                        uint32_t *values)
+{
+   unsigned size = 0;
+
+   for (unsigned j = 0; j < shader->num_state_vars; ++j) {
+      uint32_t *ptr = values + size;
+
+      switch (shader->state_vars[j].var) {
+      case D3D12_STATE_VAR_NUM_WORKGROUPS:
+         ptr[0] = info->grid[0];
+         ptr[1] = info->grid[1];
+         ptr[2] = info->grid[2];
+         size += 4;
+         break;
+      default:
+         unreachable("unknown state variable");
+      }
+   }
+
+   return size;
+}
+
 static bool
 check_descriptors_left(struct d3d12_context *ctx, bool compute)
 {
@@ -489,7 +515,7 @@ update_graphics_root_parameters(struct d3d12_context *ctx,
       update_shader_stage_root_parameters(ctx, shader_sel, num_params, num_root_descriptors, root_desc_tables, root_desc_indices);
       /* TODO Don't always update state vars */
       if (shader_sel->current->num_state_vars > 0) {
-         uint32_t constants[D3D12_MAX_STATE_VARS * 4];
+         uint32_t constants[D3D12_MAX_GRAPHICS_STATE_VARS * 4];
          unsigned size = fill_graphics_state_vars(ctx, dinfo, draw, shader_sel->current, constants);
          ctx->cmdlist->SetGraphicsRoot32BitConstants(num_params, size, constants, 0);
          num_params++;
@@ -510,6 +536,13 @@ update_compute_root_parameters(struct d3d12_context *ctx,
    struct d3d12_shader_selector *shader_sel = ctx->compute_state;
    if (shader_sel) {
       update_shader_stage_root_parameters(ctx, shader_sel, num_params, num_root_descriptors, root_desc_tables, root_desc_indices);
+      /* TODO Don't always update state vars */
+      if (shader_sel->current->num_state_vars > 0) {
+         uint32_t constants[D3D12_MAX_COMPUTE_STATE_VARS * 4];
+         unsigned size = fill_compute_state_vars(ctx, info, shader_sel->current, constants);
+         ctx->cmdlist->SetComputeRoot32BitConstants(num_params, size, constants, 0);
+         num_params++;
+      }
    }
    return num_root_descriptors;
 }
diff --git a/src/gallium/drivers/d3d12/d3d12_nir_passes.c b/src/gallium/drivers/d3d12/d3d12_nir_passes.c
index 7ed43cbc7b8..c2cc4d70473 100644
--- a/src/gallium/drivers/d3d12/d3d12_nir_passes.c
+++ b/src/gallium/drivers/d3d12/d3d12_nir_passes.c
@@ -220,6 +220,43 @@ d3d12_lower_depth_range(nir_shader *nir)
    }
 }
 
+struct compute_state_vars {
+   nir_variable *num_workgroups;
+};
+
+static bool
+lower_compute_state_vars(nir_builder *b, nir_instr *instr, void *_state)
+{
+   if (instr->type != nir_instr_type_intrinsic)
+      return false;
+
+   b->cursor = nir_after_instr(instr);
+   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
+   struct compute_state_vars *vars = _state;
+   nir_ssa_def *result = NULL;
+   switch (intr->intrinsic) {
+   case nir_intrinsic_load_num_workgroups:
+      result = get_state_var(b, D3D12_STATE_VAR_NUM_WORKGROUPS, "d3d12_NumWorkgroups",
+         glsl_vec_type(3), &vars->num_workgroups);
+      break;
+   default:
+      return false;
+   }
+
+   nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
+   nir_instr_remove(instr);
+   return true;
+}
+
+bool
+d3d12_lower_compute_state_vars(nir_shader *nir)
+{
+   assert(nir->info.stage == MESA_SHADER_COMPUTE);
+   struct compute_state_vars vars = { 0 };
+   return nir_shader_instructions_pass(nir, lower_compute_state_vars,
+      nir_metadata_block_index | nir_metadata_dominance, &vars);
+}
+
 static bool
 is_color_output(nir_variable *var)
 {
diff --git a/src/gallium/drivers/d3d12/d3d12_nir_passes.h b/src/gallium/drivers/d3d12/d3d12_nir_passes.h
index 54a9d2452fb..03f7454572f 100644
--- a/src/gallium/drivers/d3d12/d3d12_nir_passes.h
+++ b/src/gallium/drivers/d3d12/d3d12_nir_passes.h
@@ -55,6 +55,9 @@ d3d12_lower_depth_range(nir_shader *nir);
 bool
 d3d12_lower_load_first_vertex(nir_shader *nir);
 
+bool
+d3d12_lower_compute_state_vars(nir_shader *nir);
+
 void
 d3d12_lower_uint_cast(nir_shader *nir, bool is_signed);
 



More information about the mesa-commit mailing list