[Mesa-dev] [PATCH 36/56] spirv: Add subgroup ballot support
Jason Ekstrand
jason at jlekstrand.net
Wed Mar 7 14:35:24 UTC 2018
Reviewed-by: Iago Toral Quiroga <itoral at igalia.com>
---
src/compiler/shader_info.h | 1 +
src/compiler/spirv/spirv_to_nir.c | 5 ++
src/compiler/spirv/vtn_subgroup.c | 143 ++++++++++++++++++++++++++++++++++---
src/compiler/spirv/vtn_variables.c | 20 ++++++
4 files changed, 161 insertions(+), 8 deletions(-)
diff --git a/src/compiler/shader_info.h b/src/compiler/shader_info.h
index 6183432..3a6e545 100644
--- a/src/compiler/shader_info.h
+++ b/src/compiler/shader_info.h
@@ -44,6 +44,7 @@ struct spirv_supported_capabilities {
bool multiview;
bool variable_pointers;
bool storage_16bit;
+ bool subgroup_ballot;
bool subgroup_basic;
};
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index 0a06c39..451d44f 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -3296,6 +3296,11 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
spv_check_supported(subgroup_basic, cap);
break;
+ case SpvCapabilitySubgroupBallotKHR:
+ case SpvCapabilityGroupNonUniformBallot:
+ spv_check_supported(subgroup_ballot, cap);
+ break;
+
case SpvCapabilityVariablePointersStorageBuffer:
case SpvCapabilityVariablePointers:
spv_check_supported(variable_pointers, cap);
diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c
index 033c43e..a86f0cb 100644
--- a/src/compiler/spirv/vtn_subgroup.c
+++ b/src/compiler/spirv/vtn_subgroup.c
@@ -23,6 +23,44 @@
#include "vtn_private.h"
+static void
+vtn_build_subgroup_instr(struct vtn_builder *b,
+ nir_intrinsic_op nir_op,
+ struct vtn_ssa_value *dst,
+ struct vtn_ssa_value *src0,
+ nir_ssa_def *index)
+{
+ /* Some of the subgroup operations take an index. SPIR-V allows this to be
+ * any integer type. To make things simpler for drivers, we only support
+ * 32-bit indices.
+ */
+ if (index && index->bit_size != 32)
+ index = nir_u2u32(&b->nb, index);
+
+ vtn_assert(dst->type == src0->type);
+ if (!glsl_type_is_vector_or_scalar(dst->type)) {
+ for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
+ vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
+ src0->elems[i], index);
+ }
+ return;
+ }
+
+ nir_intrinsic_instr *intrin =
+ nir_intrinsic_instr_create(b->nb.shader, nir_op);
+ nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+ dst->type, NULL);
+ intrin->num_components = intrin->dest.ssa.num_components;
+
+ intrin->src[0] = nir_src_for_ssa(src0->def);
+ if (index)
+ intrin->src[1] = nir_src_for_ssa(index);
+
+ nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+ dst->def = &intrin->dest.ssa;
+}
+
void
vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
@@ -43,17 +81,106 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
break;
}
- case SpvOpGroupNonUniformAll:
- case SpvOpGroupNonUniformAny:
- case SpvOpGroupNonUniformAllEqual:
- case SpvOpGroupNonUniformBroadcast:
- case SpvOpGroupNonUniformBroadcastFirst:
- case SpvOpGroupNonUniformBallot:
- case SpvOpGroupNonUniformInverseBallot:
+ case SpvOpGroupNonUniformBallot: {
+ vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
+ "OpGroupNonUniformBallot must return a uvec4");
+ nir_intrinsic_instr *ballot =
+ nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
+ ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
+ nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
+ ballot->num_components = 4;
+ nir_builder_instr_insert(&b->nb, &ballot->instr);
+ val->ssa->def = &ballot->dest.ssa;
+ break;
+ }
+
+ case SpvOpGroupNonUniformInverseBallot: {
+ /* This one is just a BallotBitfieldExtract with subgroup invocation.
+ * We could add a NIR intrinsic but it's easier to just lower it on the
+ * spot.
+ */
+ nir_intrinsic_instr *intrin =
+ nir_intrinsic_instr_create(b->nb.shader,
+ nir_intrinsic_ballot_bitfield_extract);
+
+ intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
+ intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
+
+ nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
+ nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+ val->ssa->def = &intrin->dest.ssa;
+ break;
+ }
+
case SpvOpGroupNonUniformBallotBitExtract:
case SpvOpGroupNonUniformBallotBitCount:
case SpvOpGroupNonUniformBallotFindLSB:
- case SpvOpGroupNonUniformBallotFindMSB:
+ case SpvOpGroupNonUniformBallotFindMSB: {
+ nir_ssa_def *src0, *src1 = NULL;
+ nir_intrinsic_op op;
+ switch (opcode) {
+ case SpvOpGroupNonUniformBallotBitExtract:
+ op = nir_intrinsic_ballot_bitfield_extract;
+ src0 = vtn_ssa_value(b, w[4])->def;
+ src1 = vtn_ssa_value(b, w[5])->def;
+ break;
+ case SpvOpGroupNonUniformBallotBitCount:
+ switch ((SpvGroupOperation)w[4]) {
+ case SpvGroupOperationReduce:
+ op = nir_intrinsic_ballot_bit_count_reduce;
+ break;
+ case SpvGroupOperationInclusiveScan:
+ op = nir_intrinsic_ballot_bit_count_inclusive;
+ break;
+ case SpvGroupOperationExclusiveScan:
+ op = nir_intrinsic_ballot_bit_count_exclusive;
+ break;
+ default:
+ unreachable("Invalid group operation");
+ }
+ src0 = vtn_ssa_value(b, w[5])->def;
+ break;
+ case SpvOpGroupNonUniformBallotFindLSB:
+ op = nir_intrinsic_ballot_find_lsb;
+ src0 = vtn_ssa_value(b, w[4])->def;
+ break;
+ case SpvOpGroupNonUniformBallotFindMSB:
+ op = nir_intrinsic_ballot_find_msb;
+ src0 = vtn_ssa_value(b, w[4])->def;
+ break;
+ default:
+ unreachable("Unhandled opcode");
+ }
+
+ nir_intrinsic_instr *intrin =
+ nir_intrinsic_instr_create(b->nb.shader, op);
+
+ intrin->src[0] = nir_src_for_ssa(src0);
+ if (src1)
+ intrin->src[1] = nir_src_for_ssa(src1);
+
+ nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
+ nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+ val->ssa->def = &intrin->dest.ssa;
+ break;
+ }
+
+ case SpvOpGroupNonUniformBroadcastFirst:
+ vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
+ val->ssa, vtn_ssa_value(b, w[4]), NULL);
+ break;
+
+ case SpvOpGroupNonUniformBroadcast:
+ vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
+ val->ssa, vtn_ssa_value(b, w[4]),
+ vtn_ssa_value(b, w[5])->def);
+ break;
+
+ case SpvOpGroupNonUniformAll:
+ case SpvOpGroupNonUniformAny:
+ case SpvOpGroupNonUniformAllEqual:
case SpvOpGroupNonUniformShuffle:
case SpvOpGroupNonUniformShuffleXor:
case SpvOpGroupNonUniformShuffleUp:
diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c
index 2b8c5f6..53952e8 100644
--- a/src/compiler/spirv/vtn_variables.c
+++ b/src/compiler/spirv/vtn_variables.c
@@ -1309,6 +1309,26 @@ vtn_get_builtin_location(struct vtn_builder *b,
*location = SYSTEM_VALUE_VIEW_INDEX;
set_mode_system_value(b, mode);
break;
+ case SpvBuiltInSubgroupEqMask:
+ *location = SYSTEM_VALUE_SUBGROUP_EQ_MASK,
+ set_mode_system_value(b, mode);
+ break;
+ case SpvBuiltInSubgroupGeMask:
+ *location = SYSTEM_VALUE_SUBGROUP_GE_MASK,
+ set_mode_system_value(b, mode);
+ break;
+ case SpvBuiltInSubgroupGtMask:
+ *location = SYSTEM_VALUE_SUBGROUP_GT_MASK,
+ set_mode_system_value(b, mode);
+ break;
+ case SpvBuiltInSubgroupLeMask:
+ *location = SYSTEM_VALUE_SUBGROUP_LE_MASK,
+ set_mode_system_value(b, mode);
+ break;
+ case SpvBuiltInSubgroupLtMask:
+ *location = SYSTEM_VALUE_SUBGROUP_LT_MASK,
+ set_mode_system_value(b, mode);
+ break;
default:
vtn_fail("unsupported builtin");
}
--
2.5.0.400.gff86faf
More information about the mesa-dev
mailing list