[Mesa-dev] [PATCH v2 19/29] nir: Add support for 1-bit data types

Jason Ekstrand jason at jlekstrand.net
Thu Dec 6 19:45:10 UTC 2018


This commit adds support for 1-bit Booleans and integers.  Booleans
obviously take a value of true or false.  Because we have to define the
semantics of 1-bit signed and unsigned integers, we define uint1_t to
take values of 0 and 1 and int1_t to take values of 0 and -1.  1-bit
arithmetic is then well-defined in the usual way, just with fewer bits.
The definition of int1_t and uint1_t doesn't usually matter but we do
need something for purposes of constant folding.
---
 src/compiler/nir/nir.c                        | 15 +++++------
 src/compiler/nir/nir.h                        | 21 +++++++++++-----
 src/compiler/nir/nir_builder.h                | 12 ++++++++-
 src/compiler/nir/nir_constant_expressions.py  | 25 ++++++++++++++++---
 src/compiler/nir/nir_instr_set.c              | 23 ++++++++++++++---
 .../nir/nir_lower_load_const_to_scalar.c      |  3 +++
 src/compiler/nir/nir_opt_constant_folding.c   |  3 +++
 src/compiler/nir/nir_print.c                  |  3 +++
 src/compiler/nir/nir_search.c                 |  3 ++-
 src/compiler/nir/nir_validate.c               |  2 +-
 src/compiler/spirv/spirv_to_nir.c             |  9 +++++++
 11 files changed, 95 insertions(+), 24 deletions(-)

diff --git a/src/compiler/nir/nir.c b/src/compiler/nir/nir.c
index 249b9357c3f..3c80e03a091 100644
--- a/src/compiler/nir/nir.c
+++ b/src/compiler/nir/nir.c
@@ -638,6 +638,7 @@ const_value_int(int64_t i, unsigned bit_size)
 {
    nir_const_value v;
    switch (bit_size) {
+   case 1:  v.b[0]   = i & 1;  break;
    case 8:  v.i8[0]  = i;  break;
    case 16: v.i16[0] = i;  break;
    case 32: v.i32[0] = i;  break;
@@ -1206,6 +1207,8 @@ nir_src_comp_as_int(nir_src src, unsigned comp)
 
    assert(comp < load->def.num_components);
    switch (load->def.bit_size) {
+   /* int1_t uses 0/-1 convention */
+   case 1:  return -(int)load->value.b[comp];
    case 8:  return load->value.i8[comp];
    case 16: return load->value.i16[comp];
    case 32: return load->value.i32[comp];
@@ -1223,6 +1226,7 @@ nir_src_comp_as_uint(nir_src src, unsigned comp)
 
    assert(comp < load->def.num_components);
    switch (load->def.bit_size) {
+   case 1:  return load->value.b[comp];
    case 8:  return load->value.u8[comp];
    case 16: return load->value.u16[comp];
    case 32: return load->value.u32[comp];
@@ -1235,15 +1239,12 @@ nir_src_comp_as_uint(nir_src src, unsigned comp)
 bool
 nir_src_comp_as_bool(nir_src src, unsigned comp)
 {
-   assert(nir_src_is_const(src));
-   nir_load_const_instr *load = nir_instr_as_load_const(src.ssa->parent_instr);
+   int64_t i = nir_src_comp_as_int(src, comp);
 
-   assert(comp < load->def.num_components);
-   assert(load->def.bit_size == 32);
-   assert(load->value.u32[comp] == NIR_TRUE ||
-          load->value.u32[comp] == NIR_FALSE);
+   /* Booleans of any size use 0/-1 convention */
+   assert(i == 0 || i == -1);
 
-   return load->value.u32[comp];
+   return i;
 }
 
 double
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index a888fbd1516..7cb2b8e97e4 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -118,6 +118,7 @@ typedef enum {
 } nir_rounding_mode;
 
 typedef union {
+   bool b[NIR_MAX_VEC_COMPONENTS];
    float f32[NIR_MAX_VEC_COMPONENTS];
    double f64[NIR_MAX_VEC_COMPONENTS];
    int8_t i8[NIR_MAX_VEC_COMPONENTS];
@@ -779,17 +780,25 @@ typedef struct {
    unsigned write_mask : NIR_MAX_VEC_COMPONENTS; /* ignored if dest.is_ssa is true */
 } nir_alu_dest;
 
+/** NIR sized and unsized types
+ *
+ * The values in this enum are carefully chosen so that the sized type is
+ * just the unsized type OR the number of bits.
+ */
 typedef enum {
    nir_type_invalid = 0, /* Not a valid type */
-   nir_type_float,
-   nir_type_int,
-   nir_type_uint,
-   nir_type_bool,
+   nir_type_int =       2,
+   nir_type_uint =      4,
+   nir_type_bool =      6,
+   nir_type_float =     128,
+   nir_type_bool1 =     1  | nir_type_bool,
    nir_type_bool32 =    32 | nir_type_bool,
+   nir_type_int1 =      1  | nir_type_int,
    nir_type_int8 =      8  | nir_type_int,
    nir_type_int16 =     16 | nir_type_int,
    nir_type_int32 =     32 | nir_type_int,
    nir_type_int64 =     64 | nir_type_int,
+   nir_type_uint1 =     1  | nir_type_uint,
    nir_type_uint8 =     8  | nir_type_uint,
    nir_type_uint16 =    16 | nir_type_uint,
    nir_type_uint32 =    32 | nir_type_uint,
@@ -799,8 +808,8 @@ typedef enum {
    nir_type_float64 =   64 | nir_type_float,
 } nir_alu_type;
 
-#define NIR_ALU_TYPE_SIZE_MASK 0xfffffff8
-#define NIR_ALU_TYPE_BASE_TYPE_MASK 0x00000007
+#define NIR_ALU_TYPE_SIZE_MASK 0x79
+#define NIR_ALU_TYPE_BASE_TYPE_MASK 0x86
 
 static inline unsigned
 nir_alu_type_get_type_size(nir_alu_type type)
diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h
index e0cdcd4ba23..d8abb7fd027 100644
--- a/src/compiler/nir/nir_builder.h
+++ b/src/compiler/nir/nir_builder.h
@@ -332,7 +332,10 @@ nir_imm_intN_t(nir_builder *build, uint64_t x, unsigned bit_size)
 
    memset(&v, 0, sizeof(v));
    assert(bit_size <= 64);
-   v.i64[0] = x & (~0ull >> (64 - bit_size));
+   if (bit_size == 1)
+      v.b[0] = x & 1;
+   else
+      v.i64[0] = x & (~0ull >> (64 - bit_size));
 
    return nir_build_imm(build, 1, bit_size, v);
 }
@@ -351,6 +354,13 @@ nir_imm_ivec4(nir_builder *build, int x, int y, int z, int w)
    return nir_build_imm(build, 4, 32, v);
 }
 
+static inline nir_ssa_def *
+nir_imm_boolN_t(nir_builder *build, bool x, unsigned bit_size)
+{
+   /* We use a 0/-1 convention for all booleans regardless of size */
+   return nir_imm_intN_t(build, -(int)x, bit_size);
+}
+
 static inline nir_ssa_def *
 nir_build_alu(nir_builder *build, nir_op op, nir_ssa_def *src0,
               nir_ssa_def *src1, nir_ssa_def *src2, nir_ssa_def *src3)
diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py
index 2b8fb640425..0d944fb46a4 100644
--- a/src/compiler/nir/nir_constant_expressions.py
+++ b/src/compiler/nir/nir_constant_expressions.py
@@ -24,7 +24,9 @@ def op_bit_sizes(op):
     return sorted(list(sizes)) if sizes is not None else None
 
 def get_const_field(type_):
-    if type_base_type(type_) == 'bool':
+    if type_size(type_) == 1:
+        return 'b'
+    elif type_base_type(type_) == 'bool':
         return 'i' + str(type_size(type_))
     elif type_ == "float16":
         return "u16"
@@ -236,9 +238,12 @@ unpack_half_1x16(uint16_t u)
 }
 
 /* Some typed vector structures to make things like src0.y work */
+typedef int8_t int1_t;
+typedef uint8_t uint1_t;
 typedef float float16_t;
 typedef float float32_t;
 typedef double float64_t;
+typedef bool bool1_t;
 typedef bool bool8_t;
 typedef bool bool16_t;
 typedef bool bool32_t;
@@ -273,7 +278,10 @@ struct ${type}${width}_vec {
 
       const struct ${input_types[j]}_vec src${j} = {
       % for k in range(op.input_sizes[j]):
-         % if input_types[j] == "float16":
+         % if input_types[j] == "int1":
+             /* 1-bit integers use a 0/-1 convention */
+             -(int1_t)_src[${j}].b[${k}],
+         % elif input_types[j] == "float16":
             _mesa_half_to_float(_src[${j}].u16[${k}]),
          % else:
             _src[${j}].${get_const_field(input_types[j])}[${k}],
@@ -298,6 +306,9 @@ struct ${type}${width}_vec {
             % elif "src" + str(j) not in op.const_expr:
                ## Avoid unused variable warnings
                <% continue %>
+            % elif input_types[j] == "int1":
+               /* 1-bit integers use a 0/-1 convention */
+               const int1_t src${j} = -(int1_t)_src[${j}].b[_i];
             % elif input_types[j] == "float16":
                const float src${j} =
                   _mesa_half_to_float(_src[${j}].u16[_i]);
@@ -320,7 +331,10 @@ struct ${type}${width}_vec {
 
          ## Store the current component of the actual destination to the
          ## value of dst.
-         % if output_type.startswith("bool"):
+         % if output_type == "int1" or output_type == "uint1":
+            /* 1-bit integers get truncated */
+            _dst_val.b[_i] = dst & 1;
+         % elif output_type.startswith("bool"):
             ## Sanitize the C value to a proper NIR 0/-1 bool
             _dst_val.${get_const_field(output_type)}[_i] = -(int)dst;
          % elif output_type == "float16":
@@ -349,7 +363,10 @@ struct ${type}${width}_vec {
       ## For each component in the destination, copy the value of dst to
       ## the actual destination.
       % for k in range(op.output_size):
-         % if output_type == "bool32":
+         % if output_type == "int1" or output_type == "uint1":
+            /* 1-bit integers get truncated */
+            _dst_val.b[${k}] = dst.${"xyzw"[k]} & 1;
+         % elif output_type.startswith("bool"):
             ## Sanitize the C value to a proper NIR 0/-1 bool
             _dst_val.${get_const_field(output_type)}[${k}] = -(int)dst.${"xyzw"[k]};
          % elif output_type == "float16":
diff --git a/src/compiler/nir/nir_instr_set.c b/src/compiler/nir/nir_instr_set.c
index 19771fcd9dd..3b535b9009a 100644
--- a/src/compiler/nir/nir_instr_set.c
+++ b/src/compiler/nir/nir_instr_set.c
@@ -117,8 +117,15 @@ hash_load_const(uint32_t hash, const nir_load_const_instr *instr)
 {
    hash = HASH(hash, instr->def.num_components);
 
-   unsigned size = instr->def.num_components * (instr->def.bit_size / 8);
-   hash = _mesa_fnv32_1a_accumulate_block(hash, instr->value.f32, size);
+   if (instr->def.bit_size == 1) {
+      for (unsigned i = 0; i < instr->def.num_components; i++) {
+         uint8_t b = instr->value.b[i];
+         hash = HASH(hash, b);
+      }
+   } else {
+      unsigned size = instr->def.num_components * (instr->def.bit_size / 8);
+      hash = _mesa_fnv32_1a_accumulate_block(hash, instr->value.f32, size);
+   }
 
    return hash;
 }
@@ -399,8 +406,16 @@ nir_instrs_equal(const nir_instr *instr1, const nir_instr *instr2)
       if (load1->def.bit_size != load2->def.bit_size)
          return false;
 
-      return memcmp(load1->value.f32, load2->value.f32,
-                    load1->def.num_components * (load1->def.bit_size / 8u)) == 0;
+      if (load1->def.bit_size == 1) {
+         for (unsigned i = 0; i < load1->def.num_components; i++) {
+            if (load1->value.b[i] != load2->value.b[i])
+               return false;
+         }
+         return true;
+      } else {
+         unsigned size = load1->def.num_components * (load1->def.bit_size / 8);
+         return memcmp(load1->value.f32, load2->value.f32, size) == 0;
+      }
    }
    case nir_instr_type_phi: {
       nir_phi_instr *phi1 = nir_instr_as_phi(instr1);
diff --git a/src/compiler/nir/nir_lower_load_const_to_scalar.c b/src/compiler/nir/nir_lower_load_const_to_scalar.c
index b2e055f7dea..b62d32e483e 100644
--- a/src/compiler/nir/nir_lower_load_const_to_scalar.c
+++ b/src/compiler/nir/nir_lower_load_const_to_scalar.c
@@ -63,6 +63,9 @@ lower_load_const_instr_scalar(nir_load_const_instr *lower)
       case 8:
          load_comp->value.u8[0] = lower->value.u8[i];
          break;
+      case 1:
+         load_comp->value.b[0] = lower->value.b[i];
+         break;
       default:
          assert(!"invalid bit size");
       }
diff --git a/src/compiler/nir/nir_opt_constant_folding.c b/src/compiler/nir/nir_opt_constant_folding.c
index 1fca530af24..c91e7e16855 100644
--- a/src/compiler/nir/nir_opt_constant_folding.c
+++ b/src/compiler/nir/nir_opt_constant_folding.c
@@ -87,6 +87,9 @@ constant_fold_alu_instr(nir_alu_instr *instr, void *mem_ctx)
          case 8:
             src[i].u8[j] = load_const->value.u8[instr->src[i].swizzle[j]];
             break;
+         case 1:
+            src[i].b[j] = load_const->value.b[instr->src[i].swizzle[j]];
+            break;
          default:
             unreachable("Invalid bit size");
          }
diff --git a/src/compiler/nir/nir_print.c b/src/compiler/nir/nir_print.c
index 7124ff09e82..bafbaf20b2d 100644
--- a/src/compiler/nir/nir_print.c
+++ b/src/compiler/nir/nir_print.c
@@ -946,6 +946,9 @@ print_load_const_instr(nir_load_const_instr *instr, print_state *state)
       case 8:
          fprintf(fp, "0x%02x", instr->value.u8[i]);
          break;
+      case 1:
+         fprintf(fp, "%s", instr->value.b[i] ? "true" : "false");
+         break;
       }
    }
 
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index c7954b39415..50f5464cef8 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -476,8 +476,9 @@ construct_value(nir_builder *build,
          break;
 
       case nir_type_bool:
-         cval = nir_imm_bool(build, c->data.u);
+         cval = nir_imm_boolN_t(build, c->data.u, bit_size);
          break;
+
       default:
          unreachable("Invalid alu source type");
       }
diff --git a/src/compiler/nir/nir_validate.c b/src/compiler/nir/nir_validate.c
index 62893cad87e..c896b9a8037 100644
--- a/src/compiler/nir/nir_validate.c
+++ b/src/compiler/nir/nir_validate.c
@@ -818,7 +818,7 @@ validate_if(nir_if *if_stmt, validate_state *state)
    nir_cf_node *next_node = nir_cf_node_next(&if_stmt->cf_node);
    validate_assert(state, next_node->type == nir_cf_node_block);
 
-   validate_src(&if_stmt->condition, state, 32, 1);
+   validate_src(&if_stmt->condition, state, 0, 1);
 
    validate_assert(state, !exec_list_is_empty(&if_stmt->then_list));
    validate_assert(state, !exec_list_is_empty(&if_stmt->else_list));
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index 22efaa276d9..b539409656f 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -1561,6 +1561,9 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
             case 8:
                val->constant->values[0].u8[i] = elems[i]->values[0].u8[0];
                break;
+            case 1:
+               val->constant->values[0].b[i] = elems[i]->values[0].b[0];
+               break;
             default:
                vtn_fail("Invalid SpvOpConstantComposite bit size");
             }
@@ -1734,6 +1737,9 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
                   case 8:
                      val->constant->values[0].u8[i] = (*c)->values[col].u8[elem + i];
                      break;
+                  case 1:
+                     val->constant->values[0].b[i] = (*c)->values[col].b[elem + i];
+                     break;
                   default:
                      vtn_fail("Invalid SpvOpCompositeExtract bit size");
                   }
@@ -1761,6 +1767,9 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
                   case 8:
                      (*c)->values[col].u8[elem + i] = insert->constant->values[0].u8[i];
                      break;
+                  case 1:
+                     (*c)->values[col].b[elem + i] = insert->constant->values[0].b[i];
+                     break;
                   default:
                      vtn_fail("Invalid SpvOpCompositeInsert bit size");
                   }
-- 
2.19.2



More information about the mesa-dev mailing list