[Mesa-dev] [RFC] nir/algebraic: support for power-of-two optimizations
Ilia Mirkin
imirkin at alum.mit.edu
Sun May 8 14:41:21 UTC 2016
On Sat, May 7, 2016 at 1:06 PM, Rob Clark <robdclark at gmail.com> wrote:
> From: Rob Clark <robclark at freedesktop.org>
>
> It was kinda sad that we couldn't optimize imul/idiv by power-of-two.
> So I bashed my head against python for a while and this is what I came
> up with. In the search expression, you can use "#a^2" to only match
> constants which are a power of two. The rest is taken care of w/ normal
> replacement expression. (Might be nice if we had an ilog2 to avoid the
> float/int conversion stuff.)
>
> Still a couple rough edges and things which should be split out.
> ---
> src/compiler/nir/nir_algebraic.py | 9 ++++--
> src/compiler/nir/nir_opt_algebraic.py | 5 ++++
> src/compiler/nir/nir_search.c | 27 +++++++++++++++++
> src/compiler/nir/nir_search.h | 9 +++++-
> src/gallium/drivers/freedreno/ir3/ir3_nir.c | 45 +++++++++++++++++++----------
> 5 files changed, 77 insertions(+), 18 deletions(-)
>
> diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py
> index 285f853..c2b47fd 100644
> --- a/src/compiler/nir/nir_algebraic.py
> +++ b/src/compiler/nir/nir_algebraic.py
> @@ -83,6 +83,7 @@ static const ${val.c_type} ${val.name} = {
> % elif isinstance(val, Variable):
> ${val.index}, /* ${val.var_name} */
> ${'true' if val.is_constant else 'false'},
> + ${'true' if val.is_power_of_two else 'false'},
> ${val.type() or 'nir_type_invalid' },
> % elif isinstance(val, Expression):
> ${'true' if val.inexact else 'false'},
> @@ -113,7 +114,7 @@ static const ${val.c_type} ${val.name} = {
> Variable=Variable,
> Expression=Expression)
>
> -_constant_re = re.compile(r"(?P<value>[^@]+)(?:@(?P<bits>\d+))?")
> +_constant_re = re.compile(r"(?P<value>[^@\^]+)(?P<PoT>\^2)?(?:@(?P<bits>\d+))?")
>
> class Constant(Value):
> def __init__(self, val, name):
> @@ -123,6 +124,7 @@ class Constant(Value):
> m = _constant_re.match(val)
> self.value = ast.literal_eval(m.group('value'))
> self.bit_size = int(m.group('bits')) if m.group('bits') else 0
> + self.power_of_two = True if m.group('PoT') else False
self.power_of_two = bool(m.group('PoT'))
> else:
> self.value = val
> self.bit_size = 0
> @@ -149,7 +151,7 @@ class Constant(Value):
> elif isinstance(self.value, float):
> return "nir_type_float"
>
> -_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
> +_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)(?P<PoT>\^2)?"
> r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?")
>
> class Variable(Value):
> @@ -161,6 +163,9 @@ class Variable(Value):
>
> self.var_name = m.group('name')
> self.is_constant = m.group('const') is not None
> + self.is_power_of_two = m.group('PoT') is not None
> + if self.is_power_of_two:
> + assert self.is_constant
> self.required_type = m.group('type')
> self.bit_size = int(m.group('bits')) if m.group('bits') else 0
>
> diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py
> index 0a95725..e1381b2 100644
> --- a/src/compiler/nir/nir_opt_algebraic.py
> +++ b/src/compiler/nir/nir_opt_algebraic.py
> @@ -62,6 +62,11 @@ d = 'd'
> # constructed value should have that bit-size.
>
> optimizations = [
> +
> + # add 64b variants?
> + (('imul', a, '#b^2 at 32'), ('ishl', a, ('f2i', ('flog2', ('i2f', b))))),
> + (('idiv', a, '#b^2 at 32'), ('ishr', a, ('f2i', ('flog2', ('i2f', b))))),
I think you can just use ('ishl', a, ('find_lsb', b)) [Double-check on
the definition of that, but if not that, then find_umsb.]
Also, this only holds for unsigned division. (-1 / 2 == 0, but -1 >> 1
== -1) So I think you want udiv here, not idiv.
Lastly, you might want to throw something in for umod. ( == & (b - 1) )
-ilia
> +
> (('fneg', ('fneg', a)), a),
> (('ineg', ('ineg', a)), a),
> (('fabs', ('fabs', a)), ('fabs', a)),
> diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
> index 2c2fd92..92af521 100644
> --- a/src/compiler/nir/nir_search.c
> +++ b/src/compiler/nir/nir_search.c
> @@ -70,6 +70,13 @@ alu_instr_is_bool(nir_alu_instr *instr)
> }
> }
>
> +/* helper for this somewhere? */
> +static bool
> +is_power_of_two(unsigned int x)
> +{
> + return ((x != 0) && !(x & (x - 1)));
> +}
> +
> static bool
> match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
> unsigned num_components, const uint8_t *swizzle,
> @@ -127,6 +134,26 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
> instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
> return false;
>
> + if (var->is_power_of_two) {
> + assert(var->is_constant);
> + nir_const_value *val = nir_src_as_const_value(instr->src[src].src);
> + for (unsigned i = 0; i < num_components; i++) {
> + switch (nir_op_infos[instr->op].input_types[src]) {
> + // TODO handle other types??
> + case nir_type_int:
> + if (!is_power_of_two(val->i32[new_swizzle[i]]))
> + return false;
> + break;
> + case nir_type_uint:
> + if (!is_power_of_two(val->u32[new_swizzle[i]]))
> + return false;
> + break;
> + default:
> + return false;
> + }
> + }
> + }
> +
> if (var->type != nir_type_invalid) {
> if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
> return false;
> diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h
> index c49eba7..32ed538 100644
> --- a/src/compiler/nir/nir_search.h
> +++ b/src/compiler/nir/nir_search.h
> @@ -52,11 +52,18 @@ typedef struct {
>
> /** Indicates that the given variable must be a constant
> *
> - * This is only alloed in search expressions and indicates that the
> + * This is only allowed in search expressions and indicates that the
> * given variable is only allowed to match constant values.
> */
> bool is_constant;
>
> + /** Indicates that the given constant is a power of two
> + *
> + * This is only allowed in search expressions, and only for constant
> + * variables.
> + */
> + bool is_power_of_two;
> +
> /** Indicates that the given variable must have a certain type
> *
> * This is only allowed in search expressions and indicates that the
> diff --git a/src/gallium/drivers/freedreno/ir3/ir3_nir.c b/src/gallium/drivers/freedreno/ir3/ir3_nir.c
> index 7e3ccc0..44c694a 100644
> --- a/src/gallium/drivers/freedreno/ir3/ir3_nir.c
> +++ b/src/gallium/drivers/freedreno/ir3/ir3_nir.c
> @@ -77,6 +77,27 @@ ir3_key_lowers_nir(const struct ir3_shader_key *key)
>
> #define OPT_V(nir, pass, ...) NIR_PASS_V(nir, pass, ##__VA_ARGS__)
>
> +static void
> +ir3_optimize_loop(nir_shader *s)
> +{
> + bool progress;
> + do {
> + progress = false;
> +
> + OPT_V(s, nir_lower_vars_to_ssa);
> + OPT_V(s, nir_lower_alu_to_scalar);
> + OPT_V(s, nir_lower_phis_to_scalar);
> +
> + progress |= OPT(s, nir_copy_prop);
> + progress |= OPT(s, nir_opt_dce);
> + progress |= OPT(s, nir_opt_cse);
> + progress |= OPT(s, ir3_nir_lower_if_else);
> + progress |= OPT(s, nir_opt_algebraic);
> + progress |= OPT(s, nir_opt_constant_folding);
> +
> + } while (progress);
> +}
> +
> struct nir_shader *
> ir3_optimize_nir(struct ir3_shader *shader, nir_shader *s,
> const struct ir3_shader_key *key)
> @@ -84,7 +105,6 @@ ir3_optimize_nir(struct ir3_shader *shader, nir_shader *s,
> struct nir_lower_tex_options tex_options = {
> .lower_rect = 0,
> };
> - bool progress;
>
> if (key) {
> switch (shader->type) {
> @@ -140,24 +160,19 @@ ir3_optimize_nir(struct ir3_shader *shader, nir_shader *s,
> }
>
> OPT_V(s, nir_lower_tex, &tex_options);
> - OPT_V(s, nir_lower_idiv);
> OPT_V(s, nir_lower_load_const_to_scalar);
>
> - do {
> - progress = false;
> -
> - OPT_V(s, nir_lower_vars_to_ssa);
> - OPT_V(s, nir_lower_alu_to_scalar);
> - OPT_V(s, nir_lower_phis_to_scalar);
> + ir3_optimize_loop(s);
>
> - progress |= OPT(s, nir_copy_prop);
> - progress |= OPT(s, nir_opt_dce);
> - progress |= OPT(s, nir_opt_cse);
> - progress |= OPT(s, ir3_nir_lower_if_else);
> - progress |= OPT(s, nir_opt_algebraic);
> - progress |= OPT(s, nir_opt_constant_folding);
> + /* do idiv lowering after first opt loop to give a chance for
> + * divide by immed power-of-two to be caught first:
> + *
> + * XXX TODO nir_lower_idiv should return progress so we could
> + * skip second loop..
> + */
> + OPT_V(s, nir_lower_idiv);
>
> - } while (progress);
> + ir3_optimize_loop(s);
>
> OPT_V(s, nir_remove_dead_variables, nir_var_local);
>
> --
> 2.5.5
>
> _______________________________________________
> mesa-dev mailing list
> mesa-dev at lists.freedesktop.org
> https://lists.freedesktop.org/mailman/listinfo/mesa-dev
More information about the mesa-dev
mailing list