[Mesa-dev] [PATCH 2/2] nir: Add support for 8 and 16-bit types

Iago Toral itoral at igalia.com
Wed Mar 29 06:07:51 UTC 2017


On Tue, 2017-03-28 at 08:28 -0700, Jason Ekstrand wrote:
> Sorry I haven't gotten back on this.  It got lost somehow.
> 
> On Fri, Mar 10, 2017 at 1:56 AM, Iago Toral <itoral at igalia.com>
> wrote:
> > On Thu, 2017-03-09 at 14:05 -0800, Jason Ekstrand wrote:
> > > ---
> > >  src/compiler/nir/nir.h                       |  4 ++++
> > >  src/compiler/nir/nir_constant_expressions.py | 16
> > +++++++++++++++-
> > >  src/compiler/nir/nir_opcodes.py              |  6 +++++-
> > >  3 files changed, 24 insertions(+), 2 deletions(-)
> > >
> > > diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
> > > index 57b8be3..eaa103d 100644
> > > --- a/src/compiler/nir/nir.h
> > > +++ b/src/compiler/nir/nir.h
> > > @@ -105,6 +105,10 @@ typedef enum {
> > >  typedef union {
> > >     float f32[4];
> > >     double f64[4];
> > > +   int8_t i8[4];
> > > +   uint8_t u8[4];
> > > +   int16_t i16[4];
> > > +   uint16_t u16[4];
> > >     int32_t i32[4];
> > >     uint32_t u32[4];
> > >     int64_t i64[4];
> > > 
> > > diff --git a/src/compiler/nir/nir_constant_expressions.py
> > > b/src/compiler/nir/nir_constant_expressions.py
> > > index aecca8b..cbda4b1 100644
> > > --- a/src/compiler/nir/nir_constant_expressions.py
> > > +++ b/src/compiler/nir/nir_constant_expressions.py
> > > @@ -14,8 +14,10 @@ def type_size(type_):
> > >  def type_sizes(type_):
> > >      if type_has_size(type_):
> > >          return [type_size(type_)]
> > > +    elif type_ == 'float':
> > > +        return [16, 32, 64]
> > >      else:
> > > -        return [32, 64]
> > > +        return [8, 16, 32, 64]
> > >  
> > >  def type_add_size(type_, size):
> > >      if type_has_size(type_):
> > > @@ -34,6 +36,8 @@ def op_bit_sizes(op):
> > >  def get_const_field(type_):
> > >      if type_ == "bool32":
> > >          return "u32"
> > > +    elif type_ == "float16":
> > > +        return "u16"
> > 
> > I was wondering why not have a f16 field with type float16_t
> > in nir_const_value instead,  but reading the rest of the patch it
> > seems
> > we really want to work with 32-bit floats and then encode results
> > back to 16-bit values? Is it because there are some operations that
> > we can't implement directly with half-floats?
> > 
> Yes.  CPUs cannot, in general, work with 16-bit floats so there is no
> actual float16_t type in C.  We typedef float to float16_t below just
> because it makes the codegen easier.  In reality, all 16-bit
> operations have to be done in 32-bit with a manual conversion to/from
> 16-bit float on both sides.

Thanks for the explanation, it makes sense:

Reviewed-by: Iago Toral Quiroga <itoral at igalia.com>

 
> > Iago
> > 
> > >      else:
> > >          m = type_split_re.match(type_)
> > >          if not m:
> > > @@ -246,6 +250,7 @@ unpack_half_1x16(uint16_t u)
> > >  }
> > >  
> > >  /* Some typed vector structures to make things like src0.y work
> > */
> > > +typedef float float16_t;
> > >  typedef float float32_t;
> > >  typedef double float64_t;
> > >  typedef bool bool32_t;
> > > @@ -297,6 +302,8 @@ evaluate_${name}(MAYBE_UNUSED unsigned
> > > num_components, unsigned bit_size,
> > >           % for k in range(op.input_sizes[j]):
> > >              % if input_types[j] == "bool32":
> > >                 _src[${j}].u32[${k}] != 0,
> > > +            % elif input_types[j] == "float16":
> > > +               _mesa_half_to_float(_src[${j}].u16[${k}]),
> > >              % else:
> > >
> >                 _src[${j}].${get_const_field(input_types[j])}[${k}]
> > ,
> > >              % endif
> > > @@ -322,6 +329,9 @@ evaluate_${name}(MAYBE_UNUSED unsigned
> > > num_components, unsigned bit_size,
> > >                    <% continue %>
> > >                 % elif input_types[j] == "bool32":
> > >                    const bool src${j} = _src[${j}].u32[_i] != 0;
> > > +               % elif input_types[j] == "float16":
> > > +                  const float src${j} =
> > > +                     _mesa_half_to_float(_src[${j}].u16[_i]);
> > >                 % else:
> > >                    const ${input_types[j]}_t src${j} =
> > >
> >                       _src[${j}].${get_const_field(input_types[j])}
> > [_
> > > i];
> > > @@ -344,6 +354,8 @@ evaluate_${name}(MAYBE_UNUSED unsigned
> > > num_components, unsigned bit_size,
> > >              % if output_type == "bool32":
> > >                 ## Sanitize the C value to a proper NIR bool
> > >                 _dst_val.u32[_i] = dst ? NIR_TRUE : NIR_FALSE;
> > > +            % elif output_type == "float16":
> > > +               _dst_val.u16[_i] = _mesa_float_to_half(dst);
> > >              % else:
> > >                 _dst_val.${get_const_field(output_type)}[_i] =
> > dst;
> > >              % endif
> > > @@ -371,6 +383,8 @@ evaluate_${name}(MAYBE_UNUSED unsigned
> > > num_components, unsigned bit_size,
> > >              % if output_type == "bool32":
> > >                 ## Sanitize the C value to a proper NIR bool
> > >                 _dst_val.u32[${k}] = dst.${"xyzw"[k]} ? NIR_TRUE
> > :
> > > NIR_FALSE;
> > > +            % elif output_type == "float16":
> > > +               _dst_val.u16[${k}] =
> > > _mesa_float_to_half(dst.${"xyzw"[k]});
> > >              % else:
> > >                 _dst_val.${get_const_field(output_type)}[${k}] =
> > > dst.${"xyzw"[k]};
> > >              % endif
> > > diff --git a/src/compiler/nir/nir_opcodes.py
> > > b/src/compiler/nir/nir_opcodes.py
> > > index 53e9aff..37c655b 100644
> > > --- a/src/compiler/nir/nir_opcodes.py
> > > +++ b/src/compiler/nir/nir_opcodes.py
> > > @@ -175,7 +175,11 @@ for src_t in [tint, tuint, tfloat]:
> > >        dst_types = [tint, tuint, tfloat]
> > >  
> > >     for dst_t in dst_types:
> > > -      for bit_size in [32, 64]:
> > > +      if dst_t == tfloat:
> > > +         bit_sizes = [16, 32, 64]
> > > +      else:
> > > +         bit_sizes = [8, 16, 32, 64]
> > > +      for bit_size in bit_sizes:
> > >           unop_convert("{}2{}{}".format(src_t[0], dst_t[0],
> > > bit_size),
> > >                        dst_t + str(bit_size), src_t, "src0")
> > >  
> > 


More information about the mesa-dev mailing list