[Mesa-dev] [PATCH 39/44] nir/lower_subgroups: Lower ballot intrinsics to the specified bit size
Jason Ekstrand
jason at jlekstrand.net
Tue Sep 5 15:13:31 UTC 2017
Ballot intrinsics return a bitfield of subgroups. In GLSL and some
SPIR-V extensions, they return a uint64_t. In SPV_KHR_shader_ballot,
they return a uvec4. Also, some back-ends would rather pass around
32-bit values because it's easier than messing with 64-bit all the time.
To solve this mess, we make nir_lower_subgroups take a new parameter
called ballot_bit_size and it lowers whichever thing it gets in from the
source language (uint64_t or uvec4) to a scalar with the specified
number of bits. This replaces a chunk of the old lowering code.
---
src/compiler/nir/nir.h | 3 +-
src/compiler/nir/nir_lower_subgroups.c | 91 ++++++++++++++++++++++++++++++++--
src/compiler/nir/nir_opt_intrinsics.c | 18 -------
src/intel/compiler/brw_compiler.c | 1 -
src/intel/compiler/brw_nir.c | 1 +
5 files changed, 88 insertions(+), 26 deletions(-)
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 52aea05..f5b46c7 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -1840,8 +1840,6 @@ typedef struct nir_shader_compiler_options {
*/
bool use_interpolated_input_intrinsics;
- unsigned max_subgroup_size;
-
unsigned max_unroll_iterations;
} nir_shader_compiler_options;
@@ -2452,6 +2450,7 @@ bool nir_lower_samplers_as_deref(nir_shader *shader,
const struct gl_shader_program *shader_program);
typedef struct nir_lower_subgroups_options {
+ uint8_t ballot_bit_size;
bool lower_to_scalar:1;
bool lower_vote_trivial:1;
bool lower_subgroup_masks:1;
diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c
index 02738c4..1cc6717 100644
--- a/src/compiler/nir/nir_lower_subgroups.c
+++ b/src/compiler/nir/nir_lower_subgroups.c
@@ -28,6 +28,42 @@
* \file nir_opt_intrinsics.c
*/
+/* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
+static nir_ssa_def *
+uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
+ unsigned num_components, unsigned bit_size)
+{
+ assert(value->num_components == 1);
+ assert(value->bit_size == 32 || value->bit_size == 64);
+
+ nir_ssa_def *zero = nir_imm_int(b, 0);
+ if (num_components > 1) {
+ /* SPIR-V uses a uvec4 for ballot values */
+ assert(num_components == 4);
+ assert(bit_size == 32);
+
+ if (value->bit_size == 32) {
+ return nir_vec4(b, value, zero, zero, zero);
+ } else {
+ assert(value->bit_size == 64);
+ return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value),
+ nir_unpack_64_2x32_split_y(b, value),
+ zero, zero);
+ }
+ } else {
+ /* GLSL uses a uint64_t for ballot values */
+ assert(num_components == 1);
+ assert(bit_size == 64);
+
+ if (value->bit_size == 32) {
+ return nir_pack_64_2x32_split(b, value, zero);
+ } else {
+ assert(value->bit_size == 64);
+ return value;
+ }
+ }
+}
+
static nir_ssa_def *
lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
{
@@ -86,24 +122,69 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
if (!options->lower_subgroup_masks)
return NULL;
+ uint64_t mask;
+ switch (intrin->intrinsic) {
+ case nir_intrinsic_load_subgroup_eq_mask:
+ mask = 1ull;
+ break;
+ case nir_intrinsic_load_subgroup_ge_mask:
+ case nir_intrinsic_load_subgroup_lt_mask:
+ mask = ~0ull;
+ break;
+ case nir_intrinsic_load_subgroup_gt_mask:
+ case nir_intrinsic_load_subgroup_le_mask:
+ mask = ~1ull;
+ break;
+ default:
+ unreachable("you seriously can't tell this is unreachable?");
+ }
+
nir_ssa_def *count = nir_load_subgroup_invocation(b);
+ nir_ssa_def *shifted;
+ if (options->ballot_bit_size == 32) {
+ shifted = nir_ishl(b, nir_imm_int(b, mask), count);
+ } else {
+ assert(options->ballot_bit_size == 64);
+ shifted = nir_ishl(b, nir_imm_int64(b, mask), count);
+ }
+
+ nir_ssa_def *ballot =
+ uint_to_ballot_type(b, shifted,
+ intrin->dest.ssa.num_components,
+ intrin->dest.ssa.bit_size);
switch (intrin->intrinsic) {
case nir_intrinsic_load_subgroup_eq_mask:
- return nir_ishl(b, nir_imm_int64(b, 1ull), count);
case nir_intrinsic_load_subgroup_ge_mask:
- return nir_ishl(b, nir_imm_int64(b, ~0ull), count);
case nir_intrinsic_load_subgroup_gt_mask:
- return nir_ishl(b, nir_imm_int64(b, ~1ull), count);
+ return ballot;
case nir_intrinsic_load_subgroup_le_mask:
- return nir_inot(b, nir_ishl(b, nir_imm_int64(b, ~1ull), count));
case nir_intrinsic_load_subgroup_lt_mask:
- return nir_inot(b, nir_ishl(b, nir_imm_int64(b, ~0ull), count));
+ return nir_inot(b, ballot);
default:
unreachable("you seriously can't tell this is unreachable?");
}
break;
}
+
+ case nir_intrinsic_ballot: {
+ if (intrin->dest.ssa.num_components == 1 &&
+ intrin->dest.ssa.bit_size == options->ballot_bit_size)
+ return NULL;
+
+ nir_intrinsic_instr *ballot =
+ nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
+ ballot->num_components = 1;
+ nir_ssa_dest_init(&ballot->instr, &ballot->dest,
+ 1, options->ballot_bit_size, NULL);
+ nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
+ nir_builder_instr_insert(b, &ballot->instr);
+
+ return uint_to_ballot_type(b, &ballot->dest.ssa,
+ intrin->dest.ssa.num_components,
+ intrin->dest.ssa.bit_size);
+ }
+
default:
break;
}
diff --git a/src/compiler/nir/nir_opt_intrinsics.c b/src/compiler/nir/nir_opt_intrinsics.c
index 98c8b1a..eb394af 100644
--- a/src/compiler/nir/nir_opt_intrinsics.c
+++ b/src/compiler/nir/nir_opt_intrinsics.c
@@ -54,24 +54,6 @@ opt_intrinsics_impl(nir_function_impl *impl)
if (nir_src_as_const_value(intrin->src[0]))
replacement = nir_imm_int(&b, NIR_TRUE);
break;
- case nir_intrinsic_ballot: {
- assert(b.shader->options->max_subgroup_size != 0);
- if (b.shader->options->max_subgroup_size > 32 ||
- intrin->dest.ssa.bit_size <= 32)
- continue;
-
- nir_intrinsic_instr *ballot =
- nir_intrinsic_instr_create(b.shader, nir_intrinsic_ballot);
- nir_ssa_dest_init(&ballot->instr, &ballot->dest, 1, 32, NULL);
- nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
-
- nir_builder_instr_insert(&b, &ballot->instr);
-
- replacement = nir_pack_64_2x32_split(&b,
- &ballot->dest.ssa,
- nir_imm_int(&b, 0));
- break;
- }
default:
break;
}
diff --git a/src/intel/compiler/brw_compiler.c b/src/intel/compiler/brw_compiler.c
index a6129e9..f31f29d 100644
--- a/src/intel/compiler/brw_compiler.c
+++ b/src/intel/compiler/brw_compiler.c
@@ -57,7 +57,6 @@ static const struct nir_shader_compiler_options scalar_nir_options = {
.lower_unpack_snorm_4x8 = true,
.lower_unpack_unorm_2x16 = true,
.lower_unpack_unorm_4x8 = true,
- .max_subgroup_size = 32,
.max_unroll_iterations = 32,
};
diff --git a/src/intel/compiler/brw_nir.c b/src/intel/compiler/brw_nir.c
index 1091324..bc80df3 100644
--- a/src/intel/compiler/brw_nir.c
+++ b/src/intel/compiler/brw_nir.c
@@ -624,6 +624,7 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir)
OPT(nir_normalize_cubemap_coords);
const nir_lower_subgroups_options subgroups_options = {
+ .ballot_bit_size = 32,
.lower_to_scalar = true,
.lower_subgroup_masks = true,
.lower_vote_trivial = !is_scalar,
--
2.5.0.400.gff86faf
More information about the mesa-dev
mailing list