[Mesa-dev] [PATCH] nir+vtn: vec8+vec16 support
Rob Clark
robdclark at gmail.com
Sat Apr 7 18:16:03 UTC 2018
This introduces new vec8 and vec16 instructions (which are the only
instructions taking more than 4 sources), in order to construct 8 and 16
component vectors.
In order to avoid fixing up the non-autogenerated nir_build_alu() sites
and making them pass 16 src args for the benefit of the two instructions
that take more than 4 srcs (ie vec8 and vec16), nir_build_alu() is has
nir_build_alu_tail() split out and re-used by nir_build_alu2() (which is
used for the > 4 src args case).
Signed-off-by: Rob Clark <robdclark at gmail.com>
---
src/compiler/nir/nir.h | 30 +++++-----
src/compiler/nir/nir_builder.h | 69 ++++++++++++++++-------
src/compiler/nir/nir_builder_opcodes_h.py | 5 ++
src/compiler/nir/nir_constant_expressions.py | 33 +++++++++--
src/compiler/nir/nir_lower_alu_to_scalar.c | 8 ++-
src/compiler/nir/nir_lower_io_to_scalar.c | 4 +-
src/compiler/nir/nir_lower_load_const_to_scalar.c | 2 +-
src/compiler/nir/nir_opcodes.py | 39 ++++++++++++-
src/compiler/nir/nir_print.c | 19 +++++--
src/compiler/nir/nir_validate.c | 4 +-
src/compiler/spirv/spirv_to_nir.c | 6 +-
src/compiler/spirv/vtn_alu.c | 2 +-
12 files changed, 165 insertions(+), 56 deletions(-)
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 25595d1f0bf..a82c9658580 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -118,16 +118,16 @@ typedef enum {
} nir_rounding_mode;
typedef union {
- float f32[4];
- double f64[4];
- int8_t i8[4];
- uint8_t u8[4];
- int16_t i16[4];
- uint16_t u16[4];
- int32_t i32[4];
- uint32_t u32[4];
- int64_t i64[4];
- uint64_t u64[4];
+ float f32[16];
+ double f64[16];
+ int8_t i8[16];
+ uint8_t u8[16];
+ int16_t i16[16];
+ uint16_t u16[16];
+ int32_t i32[16];
+ uint32_t u32[16];
+ int64_t i64[16];
+ uint64_t u64[16];
} nir_const_value;
typedef struct nir_constant {
@@ -138,7 +138,7 @@ typedef struct nir_constant {
* by the type associated with the \c nir_variable. Constants may be
* scalars, vectors, or matrices.
*/
- nir_const_value values[4];
+ nir_const_value values[16];
/* we could get this from the var->type but makes clone *much* easier to
* not have to care about the type.
@@ -693,7 +693,7 @@ typedef struct {
* a statement like "foo.xzw = bar.zyx" would have a writemask of 1101b and
* a swizzle of {2, x, 1, 0} where x means "don't care."
*/
- uint8_t swizzle[4];
+ uint8_t swizzle[16];
} nir_alu_src;
typedef struct {
@@ -708,7 +708,7 @@ typedef struct {
bool saturate;
- unsigned write_mask : 4; /* ignored if dest.is_ssa is true */
+ unsigned write_mask : 16; /* ignored if dest.is_ssa is true */
} nir_alu_dest;
typedef enum {
@@ -837,14 +837,14 @@ typedef struct {
/**
* The number of components in each input
*/
- unsigned input_sizes[4];
+ unsigned input_sizes[16];
/**
* The type of vector that each input takes. Note that negate and
* absolute value are only allowed on inputs with int or float type and
* behave differently on the two.
*/
- nir_alu_type input_types[4];
+ nir_alu_type input_types[16];
nir_op_algebraic_property algebraic_properties;
} nir_op_info;
diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h
index 8eddd8c4c01..51cb4227643 100644
--- a/src/compiler/nir/nir_builder.h
+++ b/src/compiler/nir/nir_builder.h
@@ -290,24 +290,12 @@ nir_imm_ivec4(nir_builder *build, int x, int y, int z, int w)
}
static inline nir_ssa_def *
-nir_build_alu(nir_builder *build, nir_op op, nir_ssa_def *src0,
- nir_ssa_def *src1, nir_ssa_def *src2, nir_ssa_def *src3)
+nir_build_alu_tail(nir_builder *build, nir_alu_instr *instr)
{
- const nir_op_info *op_info = &nir_op_infos[op];
- nir_alu_instr *instr = nir_alu_instr_create(build->shader, op);
- if (!instr)
- return NULL;
+ const nir_op_info *op_info = &nir_op_infos[instr->op];
instr->exact = build->exact;
- instr->src[0].src = nir_src_for_ssa(src0);
- if (src1)
- instr->src[1].src = nir_src_for_ssa(src1);
- if (src2)
- instr->src[2].src = nir_src_for_ssa(src2);
- if (src3)
- instr->src[3].src = nir_src_for_ssa(src3);
-
/* Guess the number of components the destination temporary should have
* based on our input sizes, if it's not fixed for the op.
*/
@@ -362,12 +350,54 @@ nir_build_alu(nir_builder *build, nir_op op, nir_ssa_def *src0,
return &instr->dest.dest.ssa;
}
+static inline nir_ssa_def *
+nir_build_alu(nir_builder *build, nir_op op, nir_ssa_def *src0,
+ nir_ssa_def *src1, nir_ssa_def *src2, nir_ssa_def *src3)
+{
+ nir_alu_instr *instr = nir_alu_instr_create(build->shader, op);
+ if (!instr)
+ return NULL;
+
+ instr->src[0].src = nir_src_for_ssa(src0);
+ if (src1)
+ instr->src[1].src = nir_src_for_ssa(src1);
+ if (src2)
+ instr->src[2].src = nir_src_for_ssa(src2);
+ if (src3)
+ instr->src[3].src = nir_src_for_ssa(src3);
+
+ return nir_build_alu_tail(build, instr);
+}
+
+/* for the couple special cases with more than 4 src args: */
+static inline nir_ssa_def *
+nir_build_alu2(nir_builder *build, nir_op op, nir_ssa_def **srcs)
+{
+ const nir_op_info *op_info = &nir_op_infos[op];
+ nir_alu_instr *instr = nir_alu_instr_create(build->shader, op);
+ if (!instr)
+ return NULL;
+
+ for (unsigned i = 0; i < op_info->num_inputs; i++)
+ instr->src[i].src = nir_src_for_ssa(srcs[i]);
+
+ return nir_build_alu_tail(build, instr);
+}
+
#include "nir_builder_opcodes.h"
static inline nir_ssa_def *
nir_vec(nir_builder *build, nir_ssa_def **comp, unsigned num_components)
{
switch (num_components) {
+ case 16:
+ return nir_vec16(build, comp[0], comp[1], comp[2], comp[3],
+ comp[4], comp[5], comp[6], comp[7],
+ comp[8], comp[9], comp[10], comp[11],
+ comp[12], comp[13], comp[14], comp[15]);
+ case 8:
+ return nir_vec8(build, comp[0], comp[1], comp[2], comp[3],
+ comp[4], comp[5], comp[6], comp[7]);
case 4:
return nir_vec4(build, comp[0], comp[1], comp[2], comp[3]);
case 3:
@@ -417,7 +447,7 @@ nir_imov_alu(nir_builder *build, nir_alu_src src, unsigned num_components)
* Construct an fmov or imov that reswizzles the source's components.
*/
static inline nir_ssa_def *
-nir_swizzle(nir_builder *build, nir_ssa_def *src, const unsigned swiz[4],
+nir_swizzle(nir_builder *build, nir_ssa_def *src, const unsigned swiz[16],
unsigned num_components, bool use_fmov)
{
nir_alu_src alu_src = { NIR_SRC_INIT };
@@ -468,16 +498,16 @@ nir_bany(nir_builder *b, nir_ssa_def *src)
static inline nir_ssa_def *
nir_channel(nir_builder *b, nir_ssa_def *def, unsigned c)
{
- unsigned swizzle[4] = {c, c, c, c};
+ unsigned swizzle[16] = { c };
return nir_swizzle(b, def, swizzle, 1, false);
}
static inline nir_ssa_def *
nir_channels(nir_builder *b, nir_ssa_def *def, unsigned mask)
{
- unsigned num_channels = 0, swizzle[4] = { 0, 0, 0, 0 };
+ unsigned num_channels = 0, swizzle[16] = {0};
- for (unsigned i = 0; i < 4; i++) {
+ for (unsigned i = 0; i < 16; i++) {
if ((mask & (1 << i)) == 0)
continue;
swizzle[num_channels++] = i;
@@ -513,7 +543,8 @@ nir_ssa_for_src(nir_builder *build, nir_src src, int num_components)
static inline nir_ssa_def *
nir_ssa_for_alu_src(nir_builder *build, nir_alu_instr *instr, unsigned srcn)
{
- static uint8_t trivial_swizzle[4] = { 0, 1, 2, 3 };
+ static uint8_t trivial_swizzle[16] = { 0, 1, 2, 3, 4, 5, 6, 7, 8,
+ 9, 10, 11, 12, 13, 14, 15 };
nir_alu_src *src = &instr->src[srcn];
unsigned num_components = nir_ssa_alu_instr_src_components(instr, srcn);
diff --git a/src/compiler/nir/nir_builder_opcodes_h.py b/src/compiler/nir/nir_builder_opcodes_h.py
index 4a41e6079ed..c9ac032d390 100644
--- a/src/compiler/nir/nir_builder_opcodes_h.py
+++ b/src/compiler/nir/nir_builder_opcodes_h.py
@@ -37,7 +37,12 @@ def src_list(num_srcs):
static inline nir_ssa_def *
nir_${name}(nir_builder *build, ${src_decl_list(opcode.num_inputs)})
{
+% if opcode.num_inputs > 4:
+ nir_ssa_def *srcs[] = {${src_list(opcode.num_inputs)}};
+ return nir_build_alu2(build, nir_op_${name}, srcs);
+% else:
return nir_build_alu(build, nir_op_${name}, ${src_list(opcode.num_inputs)});
+% endif
}
% endfor
diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py
index ee92be51dbe..03e6b9e8e5b 100644
--- a/src/compiler/nir/nir_constant_expressions.py
+++ b/src/compiler/nir/nir_constant_expressions.py
@@ -258,6 +258,7 @@ typedef float float16_t;
typedef float float32_t;
typedef double float64_t;
typedef bool bool32_t;
+
% for type in ["float", "int", "uint"]:
% for width in type_sizes(type):
struct ${type}${width}_vec {
@@ -265,6 +266,18 @@ struct ${type}${width}_vec {
${type}${width}_t y;
${type}${width}_t z;
${type}${width}_t w;
+ ${type}${width}_t e;
+ ${type}${width}_t f;
+ ${type}${width}_t g;
+ ${type}${width}_t h;
+ ${type}${width}_t i;
+ ${type}${width}_t j;
+ ${type}${width}_t k;
+ ${type}${width}_t l;
+ ${type}${width}_t m;
+ ${type}${width}_t n;
+ ${type}${width}_t o;
+ ${type}${width}_t p;
};
% endfor
% endfor
@@ -274,6 +287,18 @@ struct bool32_vec {
bool y;
bool z;
bool w;
+ bool e;
+ bool f;
+ bool g;
+ bool h;
+ bool i;
+ bool j;
+ bool k;
+ bool l;
+ bool m;
+ bool n;
+ bool o;
+ bool p;
};
<%def name="evaluate_op(op, bit_size)">
@@ -303,7 +328,7 @@ struct bool32_vec {
_src[${j}].${get_const_field(input_types[j])}[${k}],
% endif
% endfor
- % for k in range(op.input_sizes[j], 4):
+ % for k in range(op.input_sizes[j], 16):
0,
% endfor
};
@@ -377,11 +402,11 @@ struct bool32_vec {
% for k in range(op.output_size):
% if output_type == "bool32":
## Sanitize the C value to a proper NIR bool
- _dst_val.u32[${k}] = dst.${"xyzw"[k]} ? NIR_TRUE : NIR_FALSE;
+ _dst_val.u32[${k}] = dst.${"xyzwefghijklmnop"[k]} ? NIR_TRUE : NIR_FALSE;
% elif output_type == "float16":
- _dst_val.u16[${k}] = _mesa_float_to_half(dst.${"xyzw"[k]});
+ _dst_val.u16[${k}] = _mesa_float_to_half(dst.${"xyzwefghijklmnop"[k]});
% else:
- _dst_val.${get_const_field(output_type)}[${k}] = dst.${"xyzw"[k]};
+ _dst_val.${get_const_field(output_type)}[${k}] = dst.${"xyzwefghijklmnop"[k]};
% endif
% endfor
% endif
diff --git a/src/compiler/nir/nir_lower_alu_to_scalar.c b/src/compiler/nir/nir_lower_alu_to_scalar.c
index a0377dcb0be..99b4e0cdd06 100644
--- a/src/compiler/nir/nir_lower_alu_to_scalar.c
+++ b/src/compiler/nir/nir_lower_alu_to_scalar.c
@@ -93,6 +93,8 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
return true;
switch (instr->op) {
+ case nir_op_vec16:
+ case nir_op_vec8:
case nir_op_vec4:
case nir_op_vec3:
case nir_op_vec2:
@@ -209,9 +211,9 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
return false;
unsigned num_components = instr->dest.dest.ssa.num_components;
- nir_ssa_def *comps[] = { NULL, NULL, NULL, NULL };
+ nir_ssa_def *comps[16] = {NULL};
- for (chan = 0; chan < 4; chan++) {
+ for (chan = 0; chan < 16; chan++) {
if (!(instr->dest.write_mask & (1 << chan)))
continue;
@@ -225,7 +227,7 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
0 : chan);
nir_alu_src_copy(&lower->src[i], &instr->src[i], lower);
- for (int j = 0; j < 4; j++)
+ for (int j = 0; j < 16; j++)
lower->src[i].swizzle[j] = instr->src[i].swizzle[src_chan];
}
diff --git a/src/compiler/nir/nir_lower_io_to_scalar.c b/src/compiler/nir/nir_lower_io_to_scalar.c
index 179eb42a4d0..f8eb8f15145 100644
--- a/src/compiler/nir/nir_lower_io_to_scalar.c
+++ b/src/compiler/nir/nir_lower_io_to_scalar.c
@@ -38,7 +38,7 @@ lower_load_input_to_scalar(nir_builder *b, nir_intrinsic_instr *intr)
assert(intr->dest.is_ssa);
- nir_ssa_def *loads[4];
+ nir_ssa_def *loads[16];
for (unsigned i = 0; i < intr->num_components; i++) {
nir_intrinsic_instr *chan_intr =
@@ -177,7 +177,7 @@ lower_load_to_scalar_early(nir_builder *b, nir_intrinsic_instr *intr,
assert(intr->dest.is_ssa);
- nir_ssa_def *loads[4];
+ nir_ssa_def *loads[16];
nir_variable **chan_vars;
if (var->data.mode == nir_var_shader_in) {
diff --git a/src/compiler/nir/nir_lower_load_const_to_scalar.c b/src/compiler/nir/nir_lower_load_const_to_scalar.c
index 39447d42c23..8ec05006f41 100644
--- a/src/compiler/nir/nir_lower_load_const_to_scalar.c
+++ b/src/compiler/nir/nir_lower_load_const_to_scalar.c
@@ -46,7 +46,7 @@ lower_load_const_instr_scalar(nir_load_const_instr *lower)
b.cursor = nir_before_instr(&lower->instr);
/* Emit the individual loads. */
- nir_ssa_def *loads[4];
+ nir_ssa_def *loads[16];
for (unsigned i = 0; i < lower->def.num_components; i++) {
nir_load_const_instr *load_comp =
nir_load_const_instr_create(b.shader, 1, lower->def.bit_size);
diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py
index a762fdd2201..847374bb341 100644
--- a/src/compiler/nir/nir_opcodes.py
+++ b/src/compiler/nir/nir_opcodes.py
@@ -72,7 +72,7 @@ class Opcode(object):
assert isinstance(algebraic_properties, str)
assert isinstance(const_expr, str)
assert len(input_sizes) == len(input_types)
- assert 0 <= output_size <= 4
+ assert (0 <= output_size <= 4) or (output_size == 8) or (output_size == 16)
for size in input_sizes:
assert 0 <= size <= 4
if output_size != 0:
@@ -785,4 +785,41 @@ dst.z = src2.x;
dst.w = src3.x;
""")
+opcode("vec8", 8, tuint,
+ [1, 1, 1, 1, 1, 1, 1, 1],
+ [tuint, tuint, tuint, tuint, tuint, tuint, tuint, tuint],
+ "", """
+dst.x = src0.x;
+dst.y = src1.x;
+dst.z = src2.x;
+dst.w = src3.x;
+dst.e = src4.x;
+dst.f = src5.x;
+dst.g = src6.x;
+dst.h = src7.x;
+""")
+
+opcode("vec16", 16, tuint,
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
+ [tuint, tuint, tuint, tuint, tuint, tuint, tuint, tuint,
+ tuint, tuint, tuint, tuint, tuint, tuint, tuint, tuint],
+ "", """
+dst.x = src0.x;
+dst.y = src1.x;
+dst.z = src2.x;
+dst.w = src3.x;
+dst.e = src4.x;
+dst.f = src5.x;
+dst.g = src6.x;
+dst.h = src7.x;
+dst.i = src8.x;
+dst.j = src9.x;
+dst.k = src10.x;
+dst.l = src11.x;
+dst.m = src12.x;
+dst.n = src13.x;
+dst.o = src14.x;
+dst.p = src15.x;
+""")
+
diff --git a/src/compiler/nir/nir_print.c b/src/compiler/nir/nir_print.c
index 5cec6a49963..094d76ace76 100644
--- a/src/compiler/nir/nir_print.c
+++ b/src/compiler/nir/nir_print.c
@@ -171,6 +171,12 @@ print_dest(nir_dest *dest, print_state *state)
print_reg_dest(&dest->reg, state);
}
+static const char *
+wrmask_string(unsigned num_components)
+{
+ return (num_components > 4) ? "abcdefghijklmnop" : "xyzw";
+}
+
static void
print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state)
{
@@ -208,7 +214,7 @@ print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state)
if (!nir_alu_instr_channel_used(instr, src, i))
continue;
- fprintf(fp, "%c", "xyzw"[instr->src[src].swizzle[i]]);
+ fprintf(fp, "%c", wrmask_string(live_channels)[instr->src[src].swizzle[i]]);
}
}
@@ -226,10 +232,11 @@ print_alu_dest(nir_alu_dest *dest, print_state *state)
if (!dest->dest.is_ssa &&
dest->write_mask != (1 << dest->dest.reg.reg->num_components) - 1) {
+ unsigned live_channels = dest->dest.reg.reg->num_components;
fprintf(fp, ".");
- for (unsigned i = 0; i < 4; i++)
+ for (unsigned i = 0; i < live_channels; i++)
if ((dest->write_mask >> i) & 1)
- fprintf(fp, "%c", "xyzw"[i]);
+ fprintf(fp, "%c", wrmask_string(live_channels)[i]);
}
}
@@ -459,7 +466,7 @@ print_var_decl(nir_variable *var, print_state *state)
case nir_var_shader_in:
case nir_var_shader_out:
if (num_components < 4 && num_components != 0) {
- const char *xyzw = "xyzw";
+ const char *xyzw = wrmask_string(num_components);
for (int i = 0; i < num_components; i++)
components_local[i + 1] = xyzw[i + var->data.location_frac];
@@ -595,9 +602,9 @@ print_intrinsic_instr(nir_intrinsic_instr *instr, print_state *state)
/* special case wrmask to show it as a writemask.. */
unsigned wrmask = nir_intrinsic_write_mask(instr);
fprintf(fp, " wrmask=");
- for (unsigned i = 0; i < 4; i++)
+ for (unsigned i = 0; i < instr->num_components; i++)
if ((wrmask >> i) & 1)
- fprintf(fp, "%c", "xyzw"[i]);
+ fprintf(fp, "%c", wrmask_string(instr->num_components)[i]);
} else if (idx == NIR_INTRINSIC_REDUCTION_OP) {
nir_op reduction_op = nir_intrinsic_reduction_op(instr);
fprintf(fp, " reduction_op=%s", nir_op_infos[reduction_op].name);
diff --git a/src/compiler/nir/nir_validate.c b/src/compiler/nir/nir_validate.c
index da5d8fb0f74..7ca7f04f558 100644
--- a/src/compiler/nir/nir_validate.c
+++ b/src/compiler/nir/nir_validate.c
@@ -237,8 +237,8 @@ validate_alu_src(nir_alu_instr *instr, unsigned index, validate_state *state)
else
num_components = src->src.reg.reg->num_components;
}
- for (unsigned i = 0; i < 4; i++) {
- validate_assert(state, src->swizzle[i] < 4);
+ for (unsigned i = 0; i < num_components; i++) {
+ validate_assert(state, src->swizzle[i] < num_components);
if (nir_alu_instr_channel_used(instr, index, i))
validate_assert(state, src->swizzle[i] < num_components);
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index 819098446b1..390b3584c4f 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -2712,6 +2712,8 @@ create_vec(struct vtn_builder *b, unsigned num_components, unsigned bit_size)
case 2: op = nir_op_vec2; break;
case 3: op = nir_op_vec3; break;
case 4: op = nir_op_vec4; break;
+ case 8: op = nir_op_vec8; break;
+ case 16: op = nir_op_vec16; break;
default: vtn_fail("bad vector size");
}
@@ -2756,7 +2758,7 @@ vtn_ssa_transpose(struct vtn_builder *b, struct vtn_ssa_value *src)
nir_ssa_def *
vtn_vector_extract(struct vtn_builder *b, nir_ssa_def *src, unsigned index)
{
- unsigned swiz[4] = { index };
+ unsigned swiz[16] = { index };
return nir_swizzle(&b->nb, src, swiz, 1, true);
}
@@ -2973,7 +2975,7 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode,
unsigned elems = count - 3;
assume(elems >= 1);
if (glsl_type_is_vector_or_scalar(type)) {
- nir_ssa_def *srcs[4];
+ nir_ssa_def *srcs[16];
for (unsigned i = 0; i < elems; i++)
srcs[i] = vtn_ssa_value(b, w[3 + i])->def;
val->ssa->def =
diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c
index fc378495b81..268e1be3c14 100644
--- a/src/compiler/spirv/vtn_alu.c
+++ b/src/compiler/spirv/vtn_alu.c
@@ -246,7 +246,7 @@ vtn_handle_bitcast(struct vtn_builder *b, struct vtn_ssa_value *dest,
unsigned dest_components = glsl_get_vector_elements(dest->type);
vtn_assert(src_bit_size * src_components == dest_bit_size * dest_components);
- nir_ssa_def *dest_chan[4];
+ nir_ssa_def *dest_chan[16];
if (src_bit_size > dest_bit_size) {
vtn_assert(src_bit_size % dest_bit_size == 0);
unsigned divisor = src_bit_size / dest_bit_size;
--
2.14.3
More information about the mesa-dev
mailing list