[Mesa-dev] [PATCH 43/48] nir/lower_subgroups: Lower ballot intrinsics to the specified bit size
Jason Ekstrand
jason at jlekstrand.net
Tue Oct 31 23:54:32 UTC 2017
Ballot intrinsics return a bitfield of subgroups. In GLSL and some
SPIR-V extensions, they return a uint64_t. In SPV_KHR_shader_ballot,
they return a uvec4. Also, some back-ends would rather pass around
32-bit values because it's easier than messing with 64-bit all the time.
To solve this mess, we make nir_lower_subgroups take a new parameter
called ballot_bit_size and it lowers whichever thing it gets in from the
source language (uint64_t or uvec4) to a scalar with the specified
number of bits. This replaces a chunk of the old lowering code.
Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin at intel.com>
Reviewed-by: Iago Toral Quiroga <itoral at igalia.com>
---
src/compiler/nir/nir.h | 3 +-
src/compiler/nir/nir_lower_subgroups.c | 92 ++++++++++++++++++++++++++++++----
src/compiler/nir/nir_opt_intrinsics.c | 18 -------
src/intel/compiler/brw_compiler.c | 1 -
src/intel/compiler/brw_nir.c | 1 +
5 files changed, 84 insertions(+), 31 deletions(-)
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 1a25d7b..563b57f 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -1854,8 +1854,6 @@ typedef struct nir_shader_compiler_options {
*/
bool use_interpolated_input_intrinsics;
- unsigned max_subgroup_size;
-
unsigned max_unroll_iterations;
} nir_shader_compiler_options;
@@ -2469,6 +2467,7 @@ bool nir_lower_samplers_as_deref(nir_shader *shader,
const struct gl_shader_program *shader_program);
typedef struct nir_lower_subgroups_options {
+ uint8_t ballot_bit_size;
bool lower_to_scalar:1;
bool lower_vote_trivial:1;
bool lower_subgroup_masks:1;
diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c
index 0d11dc9..76e8316 100644
--- a/src/compiler/nir/nir_lower_subgroups.c
+++ b/src/compiler/nir/nir_lower_subgroups.c
@@ -28,6 +28,42 @@
* \file nir_opt_intrinsics.c
*/
+/* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
+static nir_ssa_def *
+uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
+ unsigned num_components, unsigned bit_size)
+{
+ assert(value->num_components == 1);
+ assert(value->bit_size == 32 || value->bit_size == 64);
+
+ nir_ssa_def *zero = nir_imm_int(b, 0);
+ if (num_components > 1) {
+ /* SPIR-V uses a uvec4 for ballot values */
+ assert(num_components == 4);
+ assert(bit_size == 32);
+
+ if (value->bit_size == 32) {
+ return nir_vec4(b, value, zero, zero, zero);
+ } else {
+ assert(value->bit_size == 64);
+ return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value),
+ nir_unpack_64_2x32_split_y(b, value),
+ zero, zero);
+ }
+ } else {
+ /* GLSL uses a uint64_t for ballot values */
+ assert(num_components == 1);
+ assert(bit_size == 64);
+
+ if (value->bit_size == 32) {
+ return nir_pack_64_2x32_split(b, value, zero);
+ } else {
+ assert(value->bit_size == 64);
+ return value;
+ }
+ }
+}
+
static nir_ssa_def *
lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
{
@@ -62,7 +98,8 @@ lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
static nir_ssa_def *
high_subgroup_mask(nir_builder *b,
nir_ssa_def *count,
- uint64_t base_mask)
+ uint64_t base_mask,
+ unsigned bit_size)
{
/* group_mask could probably be calculated more efficiently but we want to
* be sure not to shift by 64 if the subgroup size is 64 because the GLSL
@@ -71,10 +108,11 @@ high_subgroup_mask(nir_builder *b,
* subgroup size is likely to be known at compile time.
*/
nir_ssa_def *subgroup_size = nir_load_subgroup_size(b);
- nir_ssa_def *all_bits = nir_imm_int64(b, ~0ull);
+ nir_ssa_def *all_bits = nir_imm_intN_t(b, ~0ull, bit_size);
nir_ssa_def *shift = nir_isub(b, nir_imm_int(b, 64), subgroup_size);
nir_ssa_def *group_mask = nir_ushr(b, all_bits, shift);
- nir_ssa_def *higher_bits = nir_ishl(b, nir_imm_int64(b, base_mask), count);
+ nir_ssa_def *higher_bits =
+ nir_ishl(b, nir_imm_intN_t(b, base_mask, bit_size), count);
return nir_iand(b, higher_bits, group_mask);
}
@@ -109,24 +147,58 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
if (!options->lower_subgroup_masks)
return NULL;
- nir_ssa_def *count = nir_load_subgroup_invocation(b);
+ /* If either the result or the requested bit size is 64-bits then we
+ * know that we have 64-bit types and using them will probably be more
+ * efficient than messing around with 32-bit shifts and packing.
+ */
+ const unsigned bit_size = MAX2(options->ballot_bit_size,
+ intrin->dest.ssa.bit_size);
+ nir_ssa_def *count = nir_load_subgroup_invocation(b);
+ nir_ssa_def *val;
switch (intrin->intrinsic) {
case nir_intrinsic_load_subgroup_eq_mask:
- return nir_ishl(b, nir_imm_int64(b, 1ull), count);
+ val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count);
+ break;
case nir_intrinsic_load_subgroup_ge_mask:
- return high_subgroup_mask(b, count, ~0ull);
+ val = high_subgroup_mask(b, count, ~0ull, bit_size);
+ break;
case nir_intrinsic_load_subgroup_gt_mask:
- return high_subgroup_mask(b, count, ~1ull);
+ val = high_subgroup_mask(b, count, ~1ull, bit_size);
+ break;
case nir_intrinsic_load_subgroup_le_mask:
- return nir_inot(b, nir_ishl(b, nir_imm_int64(b, ~1ull), count));
+ val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
+ break;
case nir_intrinsic_load_subgroup_lt_mask:
- return nir_inot(b, nir_ishl(b, nir_imm_int64(b, ~0ull), count));
+ val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count));
+ break;
default:
unreachable("you seriously can't tell this is unreachable?");
}
- break;
+
+ return uint_to_ballot_type(b, val,
+ intrin->dest.ssa.num_components,
+ intrin->dest.ssa.bit_size);
+ }
+
+ case nir_intrinsic_ballot: {
+ if (intrin->dest.ssa.num_components == 1 &&
+ intrin->dest.ssa.bit_size == options->ballot_bit_size)
+ return NULL;
+
+ nir_intrinsic_instr *ballot =
+ nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
+ ballot->num_components = 1;
+ nir_ssa_dest_init(&ballot->instr, &ballot->dest,
+ 1, options->ballot_bit_size, NULL);
+ nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
+ nir_builder_instr_insert(b, &ballot->instr);
+
+ return uint_to_ballot_type(b, &ballot->dest.ssa,
+ intrin->dest.ssa.num_components,
+ intrin->dest.ssa.bit_size);
}
+
default:
break;
}
diff --git a/src/compiler/nir/nir_opt_intrinsics.c b/src/compiler/nir/nir_opt_intrinsics.c
index 98c8b1a..eb394af 100644
--- a/src/compiler/nir/nir_opt_intrinsics.c
+++ b/src/compiler/nir/nir_opt_intrinsics.c
@@ -54,24 +54,6 @@ opt_intrinsics_impl(nir_function_impl *impl)
if (nir_src_as_const_value(intrin->src[0]))
replacement = nir_imm_int(&b, NIR_TRUE);
break;
- case nir_intrinsic_ballot: {
- assert(b.shader->options->max_subgroup_size != 0);
- if (b.shader->options->max_subgroup_size > 32 ||
- intrin->dest.ssa.bit_size <= 32)
- continue;
-
- nir_intrinsic_instr *ballot =
- nir_intrinsic_instr_create(b.shader, nir_intrinsic_ballot);
- nir_ssa_dest_init(&ballot->instr, &ballot->dest, 1, 32, NULL);
- nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
-
- nir_builder_instr_insert(&b, &ballot->instr);
-
- replacement = nir_pack_64_2x32_split(&b,
- &ballot->dest.ssa,
- nir_imm_int(&b, 0));
- break;
- }
default:
break;
}
diff --git a/src/intel/compiler/brw_compiler.c b/src/intel/compiler/brw_compiler.c
index a6129e9..f31f29d 100644
--- a/src/intel/compiler/brw_compiler.c
+++ b/src/intel/compiler/brw_compiler.c
@@ -57,7 +57,6 @@ static const struct nir_shader_compiler_options scalar_nir_options = {
.lower_unpack_snorm_4x8 = true,
.lower_unpack_unorm_2x16 = true,
.lower_unpack_unorm_4x8 = true,
- .max_subgroup_size = 32,
.max_unroll_iterations = 32,
};
diff --git a/src/intel/compiler/brw_nir.c b/src/intel/compiler/brw_nir.c
index f599f74..0d59d36 100644
--- a/src/intel/compiler/brw_nir.c
+++ b/src/intel/compiler/brw_nir.c
@@ -637,6 +637,7 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir)
OPT(nir_lower_system_values);
const nir_lower_subgroups_options subgroups_options = {
+ .ballot_bit_size = 32,
.lower_to_scalar = true,
.lower_subgroup_masks = true,
.lower_vote_trivial = !is_scalar,
--
2.5.0.400.gff86faf
More information about the mesa-dev
mailing list