Mesa (main): ac/nir/ngg: Clean up mesh shader output LDS layout.
GitLab Mirror
gitlab-mirror at kemper.freedesktop.org
Wed Jun 8 09:12:36 UTC 2022
Module: Mesa
Branch: main
Commit: 304a0e948b49161abfd9badbf0783f099258d6f1
URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=304a0e948b49161abfd9badbf0783f099258d6f1
Author: Timur Kristóf <timur.kristof at gmail.com>
Date: Fri May 20 15:27:03 2022 +0200
ac/nir/ngg: Clean up mesh shader output LDS layout.
Signed-off-by: Timur Kristóf <timur.kristof at gmail.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset at gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02 at gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16737>
---
src/amd/common/ac_nir_lower_ngg.c | 356 ++++++++++++++++++++++----------------
1 file changed, 210 insertions(+), 146 deletions(-)
diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c
index 8ab79e4580e..d2339967c6a 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -95,17 +95,48 @@ typedef struct
gs_output_component_info output_component_info[VARYING_SLOT_MAX][4];
} lower_ngg_gs_state;
+/* LDS layout of Mesh Shader workgroup info. */
+enum {
+ /* DW0: number of primitives */
+ lds_ms_num_prims = 0,
+ /* DW1: reserved for future use */
+ lds_ms_dw1_reserved = 4,
+ /* DW2: workgroup index within the current dispatch */
+ lds_ms_wg_index = 8,
+ /* DW3: number of API workgroups in flight */
+ lds_ms_num_api_waves = 12,
+};
+
+/* Potential location for Mesh Shader outputs. */
+typedef enum {
+ ms_out_mode_lds,
+} ms_out_mode;
+
+typedef struct
+{
+ uint64_t mask; /* Mask of output locations */
+ uint32_t addr; /* Base address */
+} ms_out_part;
+
typedef struct
{
+ /* Mesh shader LDS layout. For details, see ms_calculate_output_layout. */
+ struct {
+ uint32_t workgroup_info_addr;
+ ms_out_part vtx_attr;
+ ms_out_part prm_attr;
+ uint32_t indices_addr;
+ uint32_t total_size;
+ } lds;
+} ms_out_mem_layout;
+
+typedef struct
+{
+ ms_out_mem_layout layout;
uint64_t per_vertex_outputs;
uint64_t per_primitive_outputs;
- unsigned num_per_vertex_outputs;
- unsigned num_per_primitive_outputs;
unsigned vertices_per_prim;
- unsigned vertex_attr_lds_addr;
- unsigned prim_attr_lds_addr;
- unsigned prim_vtx_indices_addr;
- unsigned numprims_lds_addr;
+
unsigned wave_size;
unsigned api_workgroup_size;
unsigned hw_workgroup_size;
@@ -1977,7 +2008,7 @@ ms_store_prim_indices(nir_builder *b,
if (!offset_src)
offset_src = nir_imm_int(b, 0);
- nir_store_shared(b, nir_u2u8(b, val), offset_src, .base = s->prim_vtx_indices_addr + offset_const);
+ nir_store_shared(b, nir_u2u8(b, val), offset_src, .base = s->layout.lds.indices_addr + offset_const);
}
static nir_ssa_def *
@@ -1989,7 +2020,7 @@ ms_load_prim_indices(nir_builder *b,
if (!offset_src)
offset_src = nir_imm_int(b, 0);
- return nir_load_shared(b, 1, 8, offset_src, .base = s->prim_vtx_indices_addr + offset_const);
+ return nir_load_shared(b, 1, 8, offset_src, .base = s->layout.lds.indices_addr + offset_const);
}
static void
@@ -1998,7 +2029,7 @@ ms_store_num_prims(nir_builder *b,
lower_ngg_ms_state *s)
{
nir_ssa_def *addr = nir_imm_int(b, 0);
- nir_store_shared(b, nir_u2u32(b, store_val), addr, .base = s->numprims_lds_addr);
+ nir_store_shared(b, nir_u2u32(b, store_val), addr, .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
}
static nir_ssa_def *
@@ -2006,10 +2037,9 @@ ms_load_num_prims(nir_builder *b,
lower_ngg_ms_state *s)
{
nir_ssa_def *addr = nir_imm_int(b, 0);
- return nir_load_shared(b, 1, 32, addr, .base = s->numprims_lds_addr);
+ return nir_load_shared(b, 1, 32, addr, .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
}
-
static nir_ssa_def *
lower_ms_store_output(nir_builder *b,
nir_intrinsic_instr *intrin,
@@ -2135,25 +2165,9 @@ update_ms_output_info(nir_intrinsic_instr *intrin,
}
}
-static void
-ms_store_arrayed_output_intrin(nir_builder *b,
- nir_intrinsic_instr *intrin,
- unsigned num_arrayed_outputs,
- unsigned base_shared_addr)
+static nir_ssa_def *
+regroup_store_val(nir_builder *b, nir_ssa_def *store_val)
{
- nir_ssa_def *store_val = intrin->src[0].ssa;
- nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
- nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa;
-
- unsigned driver_location = nir_intrinsic_base(intrin);
- unsigned component_offset = nir_intrinsic_component(intrin);
- unsigned write_mask = nir_intrinsic_write_mask(intrin);
- unsigned bit_size = store_val->bit_size;
-
- nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_arrayed_outputs);
- nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16u);
- nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
-
/* Vulkan spec 15.1.4-15.1.5:
*
* The shader interface consists of output slots with 4x 32-bit components.
@@ -2166,109 +2180,138 @@ ms_store_arrayed_output_intrin(nir_builder *b,
* 32-bit and then stored contiguously.
*/
- if (bit_size < 32) {
+ if (store_val->bit_size < 32) {
assert(store_val->num_components <= 4);
nir_ssa_def *comps[4] = {0};
for (unsigned c = 0; c < store_val->num_components; ++c)
comps[c] = nir_u2u32(b, nir_channel(b, store_val, c));
- store_val = nir_vec(b, comps, store_val->num_components);
- bit_size = 32;
+ return nir_vec(b, comps, store_val->num_components);
}
- unsigned const_off = base_shared_addr + component_offset * 4;
+ return store_val;
+}
+
+static nir_ssa_def *
+regroup_load_val(nir_builder *b, nir_ssa_def *load, unsigned dest_bit_size)
+{
+ if (dest_bit_size == load->bit_size)
+ return load;
+
+ /* Small bitsize components are not stored contiguously, take care of that here. */
+ unsigned num_components = load->num_components;
+ assert(num_components <= 4);
+ nir_ssa_def *components[4] = {0};
+ for (unsigned i = 0; i < num_components; ++i)
+ components[i] = nir_u2u(b, nir_channel(b, load, i), dest_bit_size);
+
+ return nir_vec(b, components, num_components);
+}
+
+static const ms_out_part *
+ms_get_out_layout_part(unsigned location,
+ shader_info *info,
+ ms_out_mode *out_mode,
+ lower_ngg_ms_state *s)
+{
+ uint64_t mask = BITFIELD64_BIT(location);
+
+ if (info->per_primitive_outputs & mask) {
+ if (mask & s->layout.lds.prm_attr.mask) {
+ *out_mode = ms_out_mode_lds;
+ return &s->layout.lds.prm_attr;
+ }
+ } else {
+ if (mask & s->layout.lds.vtx_attr.mask) {
+ *out_mode = ms_out_mode_lds;
+ return &s->layout.lds.vtx_attr;
+ }
+ }
+
+ unreachable("Couldn't figure out mesh shader output mode.");
+}
+
+static void
+ms_store_arrayed_output_intrin(nir_builder *b,
+ nir_intrinsic_instr *intrin,
+ lower_ngg_ms_state *s)
+{
+ ms_out_mode out_mode;
+ unsigned location = nir_intrinsic_io_semantics(intrin).location;
+ const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
+
+ unsigned driver_location = nir_intrinsic_base(intrin);
+ unsigned component_offset = nir_intrinsic_component(intrin);
+ unsigned write_mask = nir_intrinsic_write_mask(intrin);
+ unsigned num_outputs = util_bitcount64(out->mask);
+ unsigned const_off = out->addr + component_offset * 4;
+
+ nir_ssa_def *store_val = regroup_store_val(b, intrin->src[0].ssa);
+ nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
+ nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_outputs);
+ nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa;
+ nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16u);
+ nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
- nir_store_shared(b, store_val, addr, .base = const_off,
- .write_mask = write_mask, .align_mul = 16,
- .align_offset = const_off % 16);
+ if (out_mode == ms_out_mode_lds) {
+ nir_store_shared(b, store_val, addr, .base = const_off,
+ .write_mask = write_mask, .align_mul = 16,
+ .align_offset = const_off % 16);
+ } else {
+ unreachable("Invalid MS output mode for store");
+ }
}
static nir_ssa_def *
ms_load_arrayed_output(nir_builder *b,
nir_ssa_def *arr_index,
nir_ssa_def *base_offset,
+ unsigned location,
unsigned driver_location,
unsigned component_offset,
unsigned num_components,
unsigned load_bit_size,
- unsigned num_arrayed_outputs,
- unsigned base_shared_addr)
+ lower_ngg_ms_state *s)
{
+ ms_out_mode out_mode;
+ const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
+
unsigned component_addr_off = component_offset * 4;
+ unsigned num_outputs = util_bitcount64(out->mask);
+ unsigned const_off = out->addr + component_offset * 4;
- nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_arrayed_outputs);
+ nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_outputs);
nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16);
nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
- return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16,
- .align_offset = component_addr_off % 16,
- .base = base_shared_addr + component_addr_off);
+ if (out_mode == ms_out_mode_lds) {
+ return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16,
+ .align_offset = component_addr_off % 16,
+ .base = const_off);
+ } else {
+ unreachable("Invalid MS output mode for load");
+ }
}
static nir_ssa_def *
ms_load_arrayed_output_intrin(nir_builder *b,
nir_intrinsic_instr *intrin,
- unsigned num_arrayed_outputs,
- unsigned base_shared_addr)
+ lower_ngg_ms_state *s)
{
nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa;
+ unsigned location = nir_intrinsic_io_semantics(intrin).location;
unsigned driver_location = nir_intrinsic_base(intrin);
unsigned component_offset = nir_intrinsic_component(intrin);
unsigned bit_size = intrin->dest.ssa.bit_size;
unsigned num_components = intrin->dest.ssa.num_components;
unsigned load_bit_size = MAX2(bit_size, 32);
- nir_ssa_def *loaded =
- ms_load_arrayed_output(b, arr_index, base_offset, driver_location, component_offset, num_components,
- load_bit_size, num_arrayed_outputs, base_shared_addr);
-
- if (bit_size == load_bit_size)
- return loaded;
-
- /* Small bitsize components are not stored contiguously, take care of that here. */
- assert(num_components <= 4);
- nir_ssa_def *components[4] = {0};
- for (unsigned i = 0; i < num_components; ++i)
- components[i] = nir_u2u(b, nir_channel(b, loaded, i), bit_size);
-
- return nir_vec(b, components, num_components);
-}
-
-static nir_ssa_def *
-lower_ms_store_per_vertex_output(nir_builder *b,
- nir_intrinsic_instr *intrin,
- lower_ngg_ms_state *s)
-{
- update_ms_output_info(intrin, s);
- ms_store_arrayed_output_intrin(b, intrin, s->num_per_vertex_outputs, s->vertex_attr_lds_addr);
- return NIR_LOWER_INSTR_PROGRESS_REPLACE;
-}
-
-static nir_ssa_def *
-lower_ms_load_per_vertex_output(nir_builder *b,
- nir_intrinsic_instr *intrin,
- lower_ngg_ms_state *s)
-{
- return ms_load_arrayed_output_intrin(b, intrin, s->num_per_vertex_outputs, s->vertex_attr_lds_addr);
-}
-
-static nir_ssa_def *
-lower_ms_store_per_primitive_output(nir_builder *b,
- nir_intrinsic_instr *intrin,
- lower_ngg_ms_state *s)
-{
- update_ms_output_info(intrin, s);
- ms_store_arrayed_output_intrin(b, intrin, s->num_per_primitive_outputs, s->prim_attr_lds_addr);
- return NIR_LOWER_INSTR_PROGRESS_REPLACE;
-}
+ nir_ssa_def *load =
+ ms_load_arrayed_output(b, arr_index, base_offset, location, driver_location,
+ component_offset, num_components, load_bit_size, s);
-static nir_ssa_def *
-lower_ms_load_per_primitive_output(nir_builder *b,
- nir_intrinsic_instr *intrin,
- lower_ngg_ms_state *s)
-{
- return ms_load_arrayed_output_intrin(b, intrin, s->num_per_primitive_outputs, s->prim_attr_lds_addr);
+ return regroup_load_val(b, load, bit_size);
}
static nir_ssa_def *
@@ -2309,24 +2352,26 @@ lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
- if (intrin->intrinsic == nir_intrinsic_store_output)
+ switch (intrin->intrinsic) {
+ case nir_intrinsic_store_output:
return lower_ms_store_output(b, intrin, s);
- else if (intrin->intrinsic == nir_intrinsic_load_output)
+ case nir_intrinsic_load_output:
return lower_ms_load_output(b, intrin, s);
- else if (intrin->intrinsic == nir_intrinsic_store_per_vertex_output)
- return lower_ms_store_per_vertex_output(b, intrin, s);
- else if (intrin->intrinsic == nir_intrinsic_load_per_vertex_output)
- return lower_ms_load_per_vertex_output(b, intrin, s);
- else if (intrin->intrinsic == nir_intrinsic_store_per_primitive_output)
- return lower_ms_store_per_primitive_output(b, intrin, s);
- else if (intrin->intrinsic == nir_intrinsic_load_per_primitive_output)
- return lower_ms_load_per_primitive_output(b, intrin, s);
- else if (intrin->intrinsic == nir_intrinsic_load_workgroup_id)
+ case nir_intrinsic_store_per_vertex_output:
+ case nir_intrinsic_store_per_primitive_output:
+ update_ms_output_info(intrin, s);
+ ms_store_arrayed_output_intrin(b, intrin, s);
+ return NIR_LOWER_INSTR_PROGRESS_REPLACE;
+ case nir_intrinsic_load_per_vertex_output:
+ case nir_intrinsic_load_per_primitive_output:
+ return ms_load_arrayed_output_intrin(b, intrin, s);
+ case nir_intrinsic_load_workgroup_id:
return lower_ms_load_workgroup_id(b, intrin, s);
- else if (intrin->intrinsic == nir_intrinsic_scoped_barrier)
+ case nir_intrinsic_scoped_barrier:
return update_ms_scoped_barrier(b, intrin, s);
- else
+ default:
unreachable("Not a lowerable mesh shader intrinsic.");
+ }
}
static bool
@@ -2356,14 +2401,12 @@ lower_ms_intrinsics(nir_shader *shader, lower_ngg_ms_state *s)
static void
ms_emit_arrayed_outputs(nir_builder *b,
nir_ssa_def *invocation_index,
- uint64_t arrayed_outputs_mask,
- unsigned num_arrayed_outputs,
- unsigned lds_base_addr,
+ uint64_t mask,
lower_ngg_ms_state *s)
{
nir_ssa_def *zero = nir_imm_int(b, 0);
- u_foreach_bit64(slot, arrayed_outputs_mask) {
+ u_foreach_bit64(slot, mask) {
/* Should not occour here, handled separately. */
assert(slot != VARYING_SLOT_PRIMITIVE_COUNT && slot != VARYING_SLOT_PRIMITIVE_INDICES);
@@ -2376,8 +2419,8 @@ ms_emit_arrayed_outputs(nir_builder *b,
u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components);
nir_ssa_def *load =
- ms_load_arrayed_output(b, invocation_index, zero, driver_location, start_comp,
- num_components, 32, num_arrayed_outputs, lds_base_addr);
+ ms_load_arrayed_output(b, invocation_index, zero, slot, driver_location, start_comp,
+ num_components, 32, s);
nir_store_output(b, load, nir_imm_int(b, 0), .base = slot, .component = start_comp,
.io_semantics = io_sem);
@@ -2416,7 +2459,7 @@ emit_ms_prelude(nir_builder *b, lower_ngg_ms_state *s)
return;
}
- unsigned workgroup_index_lds_addr = s->numprims_lds_addr + 8;
+ unsigned workgroup_index_lds_addr = s->layout.lds.workgroup_info_addr + lds_ms_wg_index;
nir_ssa_def *zero = nir_imm_int(b, 0);
nir_ssa_def *dont_care = nir_ssa_undef(b, 1, 32);
@@ -2524,8 +2567,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
nir_if *if_has_output_vertex = nir_push_if(b, has_output_vertex);
{
/* All per-vertex attributes. */
- ms_emit_arrayed_outputs(b, invocation_index, s->per_vertex_outputs,
- s->num_per_vertex_outputs, s->vertex_attr_lds_addr, s);
+ ms_emit_arrayed_outputs(b, invocation_index, s->per_vertex_outputs, s);
nir_export_vertex_amd(b);
}
nir_pop_if(b, if_has_output_vertex);
@@ -2535,8 +2577,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
nir_if *if_has_output_primitive = nir_push_if(b, has_output_primitive);
{
/* Generic per-primitive attributes. */
- ms_emit_arrayed_outputs(b, invocation_index, s->per_primitive_outputs,
- s->num_per_primitive_outputs, s->prim_attr_lds_addr, s);
+ ms_emit_arrayed_outputs(b, invocation_index, s->per_primitive_outputs, s);
/* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
if (s->insert_layer_output) {
@@ -2549,7 +2590,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
/* Primitive connectivity data: describes which vertices the primitive uses. */
nir_ssa_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
- nir_ssa_def *indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->prim_vtx_indices_addr);
+ nir_ssa_def *indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
nir_ssa_def *indices[3];
nir_ssa_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u);
@@ -2591,7 +2632,7 @@ handle_smaller_ms_api_workgroup(nir_builder *b,
bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size;
bool need_additional_barriers = scan_barriers && !can_shrink_barriers;
- unsigned api_waves_in_flight_addr = s->numprims_lds_addr + 12;
+ unsigned api_waves_in_flight_addr = s->layout.lds.workgroup_info_addr + lds_ms_num_api_waves;
unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size);
/* Scan the shader for workgroup barriers. */
@@ -2707,6 +2748,50 @@ handle_smaller_ms_api_workgroup(nir_builder *b,
}
}
+static void
+ms_calculate_arrayed_output_layout(ms_out_mem_layout *l,
+ unsigned max_vertices,
+ unsigned max_primitives)
+{
+ uint32_t lds_vtx_attr_size = util_bitcount64(l->lds.vtx_attr.mask) * max_vertices * 16;
+ uint32_t lds_prm_attr_size = util_bitcount64(l->lds.prm_attr.mask) * max_primitives * 16;
+ l->lds.prm_attr.addr = ALIGN(l->lds.vtx_attr.addr + lds_vtx_attr_size, 16);
+ l->lds.total_size = l->lds.prm_attr.addr + lds_prm_attr_size;
+}
+
+static ms_out_mem_layout
+ms_calculate_output_layout(unsigned api_shared_size,
+ uint64_t per_vertex_output_mask,
+ uint64_t per_primitive_output_mask,
+ unsigned max_vertices,
+ unsigned max_primitives,
+ unsigned vertices_per_prim)
+{
+ uint64_t lds_per_vertex_output_mask = per_vertex_output_mask;
+ uint64_t lds_per_primitive_output_mask = per_primitive_output_mask;
+
+ /* Shared memory used by the API shader. */
+ ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } };
+
+ /* Workgroup information, see ms_workgroup_* for the layout. */
+ l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16);
+ l.lds.total_size = l.lds.workgroup_info_addr + 16;
+
+ /* Per-vertex and per-primitive output attributes. */
+ l.lds.vtx_attr.addr = ALIGN(l.lds.total_size, 16);
+ l.lds.vtx_attr.mask = lds_per_vertex_output_mask;
+ l.lds.prm_attr.mask = lds_per_primitive_output_mask;
+ ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
+
+ /* Indices: flat array of 8-bit vertex indices for each primitive. */
+ l.lds.indices_addr = ALIGN(l.lds.total_size, 16);
+ l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim;
+
+ /* NGG is only allowed to address up to 32K of LDS. */
+ assert(l.lds.total_size <= 32 * 1024);
+ return l;
+}
+
void
ac_nir_lower_ngg_ms(nir_shader *shader,
unsigned wave_size,
@@ -2722,30 +2807,14 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
uint64_t per_primitive_outputs =
shader->info.per_primitive_outputs & shader->info.outputs_written & ~special_outputs;
- unsigned num_per_vertex_outputs = util_bitcount64(per_vertex_outputs);
- unsigned num_per_primitive_outputs = util_bitcount64(per_primitive_outputs);
unsigned max_vertices = shader->info.mesh.max_vertices_out;
unsigned max_primitives = shader->info.mesh.max_primitives_out;
- /* LDS area for total number of output primitives and other info.
- * DW0: number of primitives
- * DW1: reserved for later use
- * DW2: workgroup index within the current dispatch
- * DW3: number of API workgroups in flight
- */
- unsigned numprims_lds_addr = ALIGN(shader->info.shared_size, 16);
- unsigned numprims_lds_size = 16;
- /* LDS area for vertex attributes */
- unsigned vertex_attr_lds_addr = ALIGN(numprims_lds_addr + numprims_lds_size, 16);
- unsigned vertex_attr_lds_size = max_vertices * num_per_vertex_outputs * 16;
- /* LDS area for primitive attributes */
- unsigned prim_attr_lds_addr = ALIGN(vertex_attr_lds_addr + vertex_attr_lds_size, 16);
- unsigned prim_attr_lds_size = max_primitives * num_per_primitive_outputs * 16;
- /* LDS area for the vertex indices (stored as a flat array) */
- unsigned prim_vtx_indices_addr = ALIGN(prim_attr_lds_addr + prim_attr_lds_size, 16);
- unsigned prim_vtx_indices_size = max_primitives * vertices_per_prim;
-
- shader->info.shared_size = prim_vtx_indices_addr + prim_vtx_indices_size;
+ ms_out_mem_layout layout =
+ ms_calculate_output_layout(shader->info.shared_size, per_vertex_outputs, per_primitive_outputs,
+ max_vertices, max_primitives, vertices_per_prim);
+
+ shader->info.shared_size = layout.lds.total_size;
/* The workgroup size that is specified by the API shader may be different
* from the size of the workgroup that actually runs on the HW, due to the
@@ -2762,16 +2831,11 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
ALIGN(MAX3(api_workgroup_size, max_primitives, max_vertices), wave_size);
lower_ngg_ms_state state = {
+ .layout = layout,
.wave_size = wave_size,
.per_vertex_outputs = per_vertex_outputs,
.per_primitive_outputs = per_primitive_outputs,
- .num_per_vertex_outputs = num_per_vertex_outputs,
- .num_per_primitive_outputs = num_per_primitive_outputs,
.vertices_per_prim = vertices_per_prim,
- .vertex_attr_lds_addr = vertex_attr_lds_addr,
- .prim_attr_lds_addr = prim_attr_lds_addr,
- .prim_vtx_indices_addr = prim_vtx_indices_addr,
- .numprims_lds_addr = numprims_lds_addr,
.api_workgroup_size = api_workgroup_size,
.hw_workgroup_size = hw_workgroup_size,
.insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER),
More information about the mesa-commit
mailing list