[Mesa-dev] [PATCH 06/14] nir: handle different bit sizes when constant folding
Iago Toral
itoral at igalia.com
Wed Mar 16 08:02:51 UTC 2016
On Tue, 2016-03-15 at 08:02 -0700, Jason Ekstrand wrote:
>
> On Mar 15, 2016 7:48 AM, "Connor Abbott" <cwabbott0 at gmail.com> wrote:
> >
> > On Tue, Mar 15, 2016 at 10:43 AM, Connor Abbott
> <cwabbott0 at gmail.com> wrote:
> > > On Tue, Mar 15, 2016 at 5:53 AM, Iago Toral <itoral at igalia.com>
> wrote:
> > >> On Mon, 2016-03-14 at 16:48 -0700, Jason Ekstrand wrote:
> > >>>
> > >>>
> > >>> On Mon, Mar 7, 2016 at 12:46 AM, Samuel Iglesias Gonsálvez
> > >>> <siglesias at igalia.com> wrote:
> > >>> From: Connor Abbott <connor.w.abbott at intel.com>
> > >>>
> > >>> v2: Use the bit-size information from the opcode
> information
> > >>> if defined (Iago)
> > >>>
> > >>> Signed-off-by: Iago Toral Quiroga <itoral at igalia.com>
> > >>>
> > >>> FIXME: This should be squashed into the previous commit
> so we
> > >>> don't break
> > >>> the build. The break happens because the python script
> that
> > >>> generates the
> > >>> constant folding pass does not know how to handle the
> sized
> > >>> types introduced
> > >>> by the previous commit until this patch, so it ends up
> > >>> generating code with
> > >>> invalid types. Keep it separated for review purposes.
> > >>> ---
> > >>> src/compiler/nir/nir_constant_expressions.h | 2 +-
> > >>> src/compiler/nir/nir_constant_expressions.py | 246
> > >>> +++++++++++++++++----------
> > >>> src/compiler/nir/nir_opt_constant_folding.c | 24 ++-
> > >>> 3 files changed, 182 insertions(+), 90 deletions(-)
> > >>>
> > >>> diff --git a/src/compiler/nir/nir_constant_expressions.h
> > >>> b/src/compiler/nir/nir_constant_expressions.h
> > >>> index 97997f2..201f278 100644
> > >>> --- a/src/compiler/nir/nir_constant_expressions.h
> > >>> +++ b/src/compiler/nir/nir_constant_expressions.h
> > >>> @@ -28,4 +28,4 @@
> > >>> #include "nir.h"
> > >>>
> > >>> nir_const_value nir_eval_const_opcode(nir_op op,
> unsigned
> > >>> num_components,
> > >>> - nir_const_value
> *src);
> > >>> + unsigned
> bit_size,
> > >>> nir_const_value *src);
> > >>> diff --git
> a/src/compiler/nir/nir_constant_expressions.py
> > >>> b/src/compiler/nir/nir_constant_expressions.py
> > >>> index 32784f6..972d281 100644
> > >>> --- a/src/compiler/nir/nir_constant_expressions.py
> > >>> +++ b/src/compiler/nir/nir_constant_expressions.py
> > >>> @@ -1,4 +1,43 @@
> > >>> #! /usr/bin/python2
> > >>> +
> > >>> +def type_has_size(type_):
> > >>> + return type_[-1:].isdigit()
> > >>> +
> > >>> +def type_sizes(type_):
> > >>> + if type_.endswith("8"):
> > >>> + return [8]
> > >>> + elif type_.endswith("16"):
> > >>> + return [16]
> > >>> + elif type_.endswith("32"):
> > >>> + return [32]
> > >>> + elif type_.endswith("64"):
> > >>> + return [64]
> > >>> + else:
> > >>> + return [32, 64]
> > >>> +
> > >>> +def type_add_size(type_, size):
> > >>> + if type_has_size(type_):
> > >>> + return type_
> > >>> + return type_ + str(size)
> > >>> +
> > >>> +def get_const_field(type_):
> > >>> + if type_ == "int32":
> > >>> + return "i"
> > >>> + if type_ == "uint32":
> > >>> + return "u"
> > >>> + if type_ == "int64":
> > >>> + return "l"
> > >>> + if type_ == "uint64":
> > >>> + return "ul"
> > >>> + if type_ == "bool32":
> > >>> + return "b"
> > >>> + if type_ == "float32":
> > >>> + return "f"
> > >>> + if type_ == "float64":
> > >>> + return "d"
> > >>> + raise Exception(str(type_))
> > >>> + assert(0)
> > >>> +
> > >>> template = """\
> > >>> /*
> > >>> * Copyright (C) 2014 Intel Corporation
> > >>> @@ -205,110 +244,140 @@ unpack_half_1x16(uint16_t u)
> > >>> }
> > >>>
> > >>> /* Some typed vector structures to make things like
> src0.y
> > >>> work */
> > >>> -% for type in ["float", "int", "uint", "bool"]:
> > >>> -struct ${type}_vec {
> > >>> - ${type} x;
> > >>> - ${type} y;
> > >>> - ${type} z;
> > >>> - ${type} w;
> > >>> +typedef float float32_t;
> > >>> +typedef double float64_t;
> > >>> +typedef bool bool32_t;
> > >>> +% for type in ["float", "int", "uint"]:
> > >>> +% for width in [32, 64]:
> > >>> +struct ${type}${width}_vec {
> > >>> + ${type}${width}_t x;
> > >>> + ${type}${width}_t y;
> > >>> + ${type}${width}_t z;
> > >>> + ${type}${width}_t w;
> > >>> };
> > >>> % endfor
> > >>> +% endfor
> > >>> +
> > >>> +struct bool32_vec {
> > >>> + bool x;
> > >>> + bool y;
> > >>> + bool z;
> > >>> + bool w;
> > >>> +};
> > >>>
> > >>> % for name, op in sorted(opcodes.iteritems()):
> > >>> static nir_const_value
> > >>> -evaluate_${name}(unsigned num_components,
> nir_const_value
> > >>> *_src)
> > >>> +evaluate_${name}(unsigned num_components, unsigned
> bit_size,
> > >>> + nir_const_value *_src)
> > >>> {
> > >>> nir_const_value _dst_val = { { {0, 0, 0, 0} } };
> > >>>
> > >>> - ## For each non-per-component input, create a
> variable
> > >>> srcN that
> > >>> - ## contains x, y, z, and w elements which are filled
> in
> > >>> with the
> > >>> - ## appropriately-typed values.
> > >>> - % for j in range(op.num_inputs):
> > >>> - % if op.input_sizes[j] == 0:
> > >>> - <% continue %>
> > >>> - % elif "src" + str(j) not in op.const_expr:
> > >>> - ## Avoid unused variable warnings
> > >>> - <% continue %>
> > >>> - %endif
> > >>> -
> > >>> - struct ${op.input_types[j]}_vec src${j} = {
> > >>> - % for k in range(op.input_sizes[j]):
> > >>> - % if op.input_types[j] == "bool":
> > >>> - _src[${j}].u[${k}] != 0,
> > >>> - % else:
> > >>> - _src[${j}].${op.input_types[j][:1]}[${k}],
> > >>> - % endif
> > >>> - % endfor
> > >>> - };
> > >>> - % endfor
> > >>> + switch (bit_size) {
> > >>> + % for bit_size in [32, 64]:
> > >>> + case ${bit_size}: {
> > >>> + <%
> > >>> + output_type = type_add_size(op.output_type,
> bit_size)
> > >>> + input_types = [type_add_size(type_, bit_size) for
> type_
> > >>> in op.input_types]
> > >>> + %>
> > >>> +
> > >>> + ## For each non-per-component input, create a
> variable
> > >>> srcN that
> > >>> + ## contains x, y, z, and w elements which are
> filled in
> > >>> with the
> > >>> + ## appropriately-typed values.
> > >>> + % for j in range(op.num_inputs):
> > >>> + % if op.input_sizes[j] == 0:
> > >>> + <% continue %>
> > >>> + % elif "src" + str(j) not in op.const_expr:
> > >>> + ## Avoid unused variable warnings
> > >>> + <% continue %>
> > >>> + %endif
> > >>>
> > >>> - % if op.output_size == 0:
> > >>> - ## For per-component instructions, we need to
> iterate
> > >>> over the
> > >>> - ## components and apply the constant expression
> one
> > >>> component
> > >>> - ## at a time.
> > >>> - for (unsigned _i = 0; _i < num_components; _i++)
> {
> > >>> - ## For each per-component input, create a
> variable
> > >>> srcN that
> > >>> - ## contains the value of the current (_i'th)
> > >>> component.
> > >>> - % for j in range(op.num_inputs):
> > >>> - % if op.input_sizes[j] != 0:
> > >>> - <% continue %>
> > >>> - % elif "src" + str(j) not in op.const_expr:
> > >>> - ## Avoid unused variable warnings
> > >>> - <% continue %>
> > >>> - % elif op.input_types[j] == "bool":
> > >>> - bool src${j} = _src[${j}].u[_i] != 0;
> > >>> + struct ${input_types[j]}_vec src${j} = {
> > >>> + % for k in range(op.input_sizes[j]):
> > >>> + % if input_types[j] == "bool32":
> > >>> + _src[${j}].u[${k}] != 0,
> > >>> % else:
> > >>> - ${op.input_types[j]} src${j} =
> > >>> _src[${j}].${op.input_types[j][:1]}[_i];
> > >>> +
> > >>> _src[${j}].${get_const_field(input_types[j])}[${k}],
> > >>> % endif
> > >>> % endfor
> > >>> + };
> > >>> + % endfor
> > >>> +
> > >>> + % if op.output_size == 0:
> > >>> + ## For per-component instructions, we need to
> > >>> iterate over the
> > >>> + ## components and apply the constant
> expression one
> > >>> component
> > >>> + ## at a time.
> > >>> + for (unsigned _i = 0; _i < num_components; _i
> ++) {
> > >>> + ## For each per-component input, create a
> > >>> variable srcN that
> > >>> + ## contains the value of the current
> (_i'th)
> > >>> component.
> > >>> + % for j in range(op.num_inputs):
> > >>> + % if op.input_sizes[j] != 0:
> > >>> + <% continue %>
> > >>> + % elif "src" + str(j) not in
> op.const_expr:
> > >>> + ## Avoid unused variable warnings
> > >>> + <% continue %>
> > >>> + % elif input_types[j] == "bool32":
> > >>> + bool src${j} = _src[${j}].u[_i] != 0;
> > >>> + % else:
> > >>> + ${input_types[j]}_t src${j} =
> > >>> +
> > >>> _src[${j}].${get_const_field(input_types[j])}[_i];
> > >>> + % endif
> > >>> + % endfor
> > >>> +
> > >>> + ## Create an appropriately-typed variable
> dst and
> > >>> assign the
> > >>> + ## result of the const_expr to it. If
> const_expr
> > >>> already contains
> > >>> + ## writes to dst, just include const_expr
> > >>> directly.
> > >>> + % if "dst" in op.const_expr:
> > >>> + ${output_type}_t dst;
> > >>> + ${op.const_expr}
> > >>> + % else:
> > >>> + ${output_type}_t dst = ${op.const_expr};
> > >>> + % endif
> > >>> +
> > >>> + ## Store the current component of the
> actual
> > >>> destination to the
> > >>> + ## value of dst.
> > >>> + % if output_type == "bool32":
> > >>> + ## Sanitize the C value to a proper NIR
> bool
> > >>> + _dst_val.u[_i] = dst ? NIR_TRUE :
> NIR_FALSE;
> > >>> + % else:
> > >>> +
> _dst_val.${get_const_field(output_type)}[_i] =
> > >>> dst;
> > >>> + % endif
> > >>> + }
> > >>> + % else:
> > >>> + ## In the non-per-component case, create a
> struct
> > >>> dst with
> > >>> + ## appropriately-typed elements x, y, z, and w
> and
> > >>> assign the result
> > >>> + ## of the const_expr to all components of dst,
> or
> > >>> include the
> > >>> + ## const_expr directly if it writes to dst
> already.
> > >>> + struct ${output_type}_vec dst;
> > >>>
> > >>> - ## Create an appropriately-typed variable dst
> and
> > >>> assign the
> > >>> - ## result of the const_expr to it. If
> const_expr
> > >>> already contains
> > >>> - ## writes to dst, just include const_expr
> directly.
> > >>> % if "dst" in op.const_expr:
> > >>> - ${op.output_type} dst;
> > >>> ${op.const_expr}
> > >>> % else:
> > >>> - ${op.output_type} dst = ${op.const_expr};
> > >>> + ## Splat the value to all components. This
> way
> > >>> expressions which
> > >>> + ## write the same value to all components
> don't
> > >>> need to explicitly
> > >>> + ## write to dest. One such example is
> fnoise
> > >>> which has a
> > >>> + ## const_expr of 0.0f.
> > >>> + dst.x = dst.y = dst.z = dst.w =
> ${op.const_expr};
> > >>> % endif
> > >>>
> > >>> - ## Store the current component of the actual
> > >>> destination to the
> > >>> - ## value of dst.
> > >>> - % if op.output_type == "bool":
> > >>> - ## Sanitize the C value to a proper NIR
> bool
> > >>> - _dst_val.u[_i] = dst ? NIR_TRUE :
> NIR_FALSE;
> > >>> - % else:
> > >>> - _dst_val.${op.output_type[:1]}[_i] = dst;
> > >>> - % endif
> > >>> - }
> > >>> - % else:
> > >>> - ## In the non-per-component case, create a struct
> dst
> > >>> with
> > >>> - ## appropriately-typed elements x, y, z, and w
> and
> > >>> assign the result
> > >>> - ## of the const_expr to all components of dst, or
> > >>> include the
> > >>> - ## const_expr directly if it writes to dst
> already.
> > >>> - struct ${op.output_type}_vec dst;
> > >>> -
> > >>> - % if "dst" in op.const_expr:
> > >>> - ${op.const_expr}
> > >>> - % else:
> > >>> - ## Splat the value to all components. This
> way
> > >>> expressions which
> > >>> - ## write the same value to all components
> don't need
> > >>> to explicitly
> > >>> - ## write to dest. One such example is fnoise
> which
> > >>> has a
> > >>> - ## const_expr of 0.0f.
> > >>> - dst.x = dst.y = dst.z = dst.w =
> ${op.const_expr};
> > >>> + ## For each component in the destination, copy
> the
> > >>> value of dst to
> > >>> + ## the actual destination.
> > >>> + % for k in range(op.output_size):
> > >>> + % if output_type == "bool32":
> > >>> + ## Sanitize the C value to a proper NIR
> bool
> > >>> + _dst_val.u[${k}] = dst.${"xyzw"[k]} ?
> > >>> NIR_TRUE : NIR_FALSE;
> > >>> + % else:
> > >>> +
> _dst_val.${get_const_field(output_type)}[${k}]
> > >>> = dst.${"xyzw"[k]};
> > >>> + % endif
> > >>> + % endfor
> > >>> % endif
> > >>>
> > >>> - ## For each component in the destination, copy
> the
> > >>> value of dst to
> > >>> - ## the actual destination.
> > >>> - % for k in range(op.output_size):
> > >>> - % if op.output_type == "bool":
> > >>> - ## Sanitize the C value to a proper NIR
> bool
> > >>> - _dst_val.u[${k}] = dst.${"xyzw"[k]} ?
> NIR_TRUE :
> > >>> NIR_FALSE;
> > >>> - % else:
> > >>> - _dst_val.${op.output_type[:1]}[${k}] =
> > >>> dst.${"xyzw"[k]};
> > >>> - % endif
> > >>> - % endfor
> > >>> - % endif
> > >>> + break;
> > >>> + }
> > >>> + % endfor
> > >>> +
> > >>> + default:
> > >>> + unreachable("unknown bit width");
> > >>> + }
> > >>>
> > >>> return _dst_val;
> > >>> }
> > >>> @@ -316,12 +385,12 @@ evaluate_${name}(unsigned
> > >>> num_components, nir_const_value *_src)
> > >>>
> > >>> nir_const_value
> > >>> nir_eval_const_opcode(nir_op op, unsigned
> num_components,
> > >>> - nir_const_value *src)
> > >>> + unsigned bit_width,
> nir_const_value
> > >>> *src)
> > >>> {
> > >>> switch (op) {
> > >>> % for name in sorted(opcodes.iterkeys()):
> > >>> case nir_op_${name}: {
> > >>> - return evaluate_${name}(num_components, src);
> > >>> + return evaluate_${name}(num_components,
> bit_width,
> > >>> src);
> > >>> break;
> > >>> }
> > >>> % endfor
> > >>> @@ -333,4 +402,7 @@ nir_eval_const_opcode(nir_op op,
> unsigned
> > >>> num_components,
> > >>> from nir_opcodes import opcodes
> > >>> from mako.template import Template
> > >>>
> > >>> -print Template(template).render(opcodes=opcodes)
> > >>> +print Template(template).render(opcodes=opcodes,
> > >>> type_sizes=type_sizes,
> > >>> +
> type_has_size=type_has_size,
> > >>> +
> type_add_size=type_add_size,
> > >>> +
> > >>> get_const_field=get_const_field)
> > >>> diff --git a/src/compiler/nir/nir_opt_constant_folding.c
> > >>> b/src/compiler/nir/nir_opt_constant_folding.c
> > >>> index 04876a4..29905a0 100644
> > >>> --- a/src/compiler/nir/nir_opt_constant_folding.c
> > >>> +++ b/src/compiler/nir/nir_opt_constant_folding.c
> > >>> @@ -46,10 +46,23 @@
> constant_fold_alu_instr(nir_alu_instr
> > >>> *instr, void *mem_ctx)
> > >>> if (!instr->dest.dest.is_ssa)
> > >>> return false;
> > >>>
> > >>> + unsigned bit_size = 0;
> > >>> + if (!(nir_op_infos[instr->op].output_type &
> > >>> NIR_ALU_TYPE_SIZE_MASK))
> > >>> + bit_size = instr->dest.dest.ssa.bit_size;
> > >>> + else
> > >>> + bit_size = nir_op_infos[instr->op].output_type &
> > >>> NIR_ALU_TYPE_SIZE_MASK;
> > >>>
> > >>>
> > >>> This isn't right. We need to look at all the unsized types and
> try to
> > >>> pull it from one of those. We shouldn't fall back to grabbing
> from
> > >>> the sized type.
> > >>
> > >> Ok, so you don't like that in the case that the alu operation has
> a
> > >> sized destination we grab the bit-size from the opcode
> definition? I am
> > >> not sure I see the problem with that... isn't the opcode
> mandating a
> > >> specific bit-size in that case? How can the bit-size we want be
> > >> different from that?
> > >>
> > >>>
> > >>> +
> > >>> for (unsigned i = 0; i <
> > >>> nir_op_infos[instr->op].num_inputs; i++) {
> > >>> if (!instr->src[i].src.is_ssa)
> > >>> return false;
> > >>>
> > >>> + if (bit_size == 0) {
> > >>> + if (!(nir_op_infos[instr->op].input_sizes[i] &
> > >>> NIR_ALU_TYPE_SIZE_MASK))
> > >>> + bit_size = instr->src[i].src.ssa->bit_size;
> > >>> + else
> > >>> + bit_size =
> nir_op_infos[instr->op].input_sizes[i]
> > >>> & NIR_ALU_TYPE_SIZE_MASK;
> > >>>
> > >>>
> > >>> Same here. If they don't have any unsized sources or
> destinations to
> > >>> grab from, we should let bit_size be zero.
> > >>
> > >> But if we have an opcode with all sized 64-bit types then...
> > >>
> > >>> Also, if we have multiple sources with the same unsized type, we
> > >>> should assert that the sizes match.
> > >>>
> > >>>
> > >>> + }
> > >>> +
> > >>> nir_instr *src_instr =
> > >>> instr->src[i].src.ssa->parent_instr;
> > >>>
> > >>> if (src_instr->type != nir_instr_type_load_const)
> > >>> @@ -58,24 +71,31 @@
> constant_fold_alu_instr(nir_alu_instr
> > >>> *instr, void *mem_ctx)
> > >>>
> > >>> for (unsigned j = 0; j <
> > >>> nir_ssa_alu_instr_src_components(instr, i);
> > >>> j++) {
> > >>> - src[i].u[j] =
> > >>> load_const->value.u[instr->src[i].swizzle[j]];
> > >>> + if (load_const->def.bit_size == 64)
> > >>> + src[i].ul[j] =
> > >>> load_const->value.ul[instr->src[i].swizzle[j]];
> > >>> + else
> > >>> + src[i].u[j] =
> > >>> load_const->value.u[instr->src[i].swizzle[j]];
> > >>> }
> > >>>
> > >>> /* We shouldn't have any source modifiers in the
> > >>> optimization loop. */
> > >>> assert(!instr->src[i].abs && !
> instr->src[i].negate);
> > >>> }
> > >>>
> > >>> + if (bit_size == 0)
> > >>> + bit_size = 32;
> > >>
> > >> ... this default to 32 here would not be correct any more. If at
> this
> > >> point the bit-size is 0 (meaning that all inputs and output are
> sized)
> > >> then we should take the bit-size from the opcode's output type,
> which is
> > >> known to be sized, right?
> > >
> > > Not quite. xd
> >
> > Err, whoops...
> >
> > Not quite. If any of the operands or the destination is already
> sized
> > by the opcode, then the constant propagation code we generated has
> all
> > of the information it needs. The only case where we need to figure
> > anything out is if there are unsized types, in which case we know
> the
> > bitsizes match -- we just need to find the first unsized type and
> pass
> > the bitsize of that. If all the operands and the destination are
> > sized, then the bit_size will be ignored and it's fine if it's 0.
Oh right, the auto-generated code for the opcodes where the all operands
are sized is the same for both 32-bit and 64-bit paths, so as long as
the bit-size is either one of these we are fine.
> More to the point, you could have an instruction with, say, two
> unsized inputs and a sized destination (comparison operations are an
> example of this). In this case, the two sources have to match but they
> don't have to match the destination; the destination size is fixed.
> Dies that make sense?
Yes, absolutely.
Thank you both for the clarifications, I think this deserves a small
comment in the code so I'll add that too.
> > >
> > >>
> > >>>
> > >>> Then it'll get set here.
> > >>>
> > >>>
> > >>> +
> > >>> /* We shouldn't have any saturate modifiers in the
> > >>> optimization loop. */
> > >>> assert(!instr->dest.saturate);
> > >>>
> > >>> nir_const_value dest =
> > >>> nir_eval_const_opcode(instr->op,
> > >>> instr->dest.dest.ssa.num_components,
> > >>> - src);
> > >>> + bit_size, src);
> > >>>
> > >>> nir_load_const_instr *new_instr =
> > >>> nir_load_const_instr_create(mem_ctx,
> > >>>
> > >>> instr->dest.dest.ssa.num_components);
> > >>>
> > >>> + new_instr->def.bit_size =
> instr->dest.dest.ssa.bit_size;
> > >>> new_instr->value = dest;
> > >>>
> > >>> nir_instr_insert_before(&instr->instr,
> &new_instr->instr);
> > >>> --
> > >>> 2.7.0
> > >>>
> > >>> _______________________________________________
> > >>> mesa-dev mailing list
> > >>> mesa-dev at lists.freedesktop.org
> > >>> https://lists.freedesktop.org/mailman/listinfo/mesa-dev
> > >>>
> > >>>
> > >>> _______________________________________________
> > >>> mesa-dev mailing list
> > >>> mesa-dev at lists.freedesktop.org
> > >>> https://lists.freedesktop.org/mailman/listinfo/mesa-dev
> > >>
> > >>
> > >> _______________________________________________
> > >> mesa-dev mailing list
> > >> mesa-dev at lists.freedesktop.org
> > >> https://lists.freedesktop.org/mailman/listinfo/mesa-dev
>
>
More information about the mesa-dev
mailing list