[Mesa-dev] [PATCH v2] nir: Get rid of nir_constant_data

Jason Ekstrand jason at jlekstrand.net
Fri Dec 2 00:07:24 UTC 2016


This has bothered me for about as long as NIR has been around.  Why do we
have two different unions for constants?  No good reason other than one of
them is a direct port from GLSL IR.
---
 src/compiler/glsl/glsl_to_nir.cpp  | 35 +++++++++++++-------
 src/compiler/nir/nir.c             | 32 +++++++-----------
 src/compiler/nir/nir.h             | 30 ++++++-----------
 src/compiler/nir/nir_clone.c       |  2 +-
 src/compiler/nir/nir_print.c       | 29 ++++++++++-------
 src/compiler/spirv/spirv_to_nir.c  | 67 +++++++++++++++++---------------------
 src/compiler/spirv/vtn_variables.c |  8 ++---
 7 files changed, 98 insertions(+), 105 deletions(-)

diff --git a/src/compiler/glsl/glsl_to_nir.cpp b/src/compiler/glsl/glsl_to_nir.cpp
index 628f8de..0b74b7e 100644
--- a/src/compiler/glsl/glsl_to_nir.cpp
+++ b/src/compiler/glsl/glsl_to_nir.cpp
@@ -198,34 +198,47 @@ constant_copy(ir_constant *ir, void *mem_ctx)
 
    nir_constant *ret = ralloc(mem_ctx, nir_constant);
 
-   unsigned total_elems = ir->type->components();
+   const unsigned rows = ir->type->vector_elements;
+   const unsigned cols = ir->type->matrix_columns;
    unsigned i;
 
    ret->num_elements = 0;
    switch (ir->type->base_type) {
    case GLSL_TYPE_UINT:
-      for (i = 0; i < total_elems; i++)
-         ret->value.u[i] = ir->value.u[i];
+      for (unsigned c = 0; c < cols; c++) {
+         for (unsigned r = 0; r < rows; r++)
+            ret->values[c].u32[r] = ir->value.u[c * rows + r];
+      }
       break;
 
    case GLSL_TYPE_INT:
-      for (i = 0; i < total_elems; i++)
-         ret->value.i[i] = ir->value.i[i];
+      for (unsigned c = 0; c < cols; c++) {
+         for (unsigned r = 0; r < rows; r++)
+            ret->values[c].i32[r] = ir->value.i[c * rows + r];
+      }
       break;
 
    case GLSL_TYPE_FLOAT:
-      for (i = 0; i < total_elems; i++)
-         ret->value.f[i] = ir->value.f[i];
+      for (unsigned c = 0; c < cols; c++) {
+         for (unsigned r = 0; r < rows; r++)
+            ret->values[c].f32[r] = ir->value.f[c * rows + r];
+      }
       break;
 
    case GLSL_TYPE_DOUBLE:
-      for (i = 0; i < total_elems; i++)
-         ret->value.d[i] = ir->value.d[i];
+      for (unsigned c = 0; c < cols; c++) {
+         for (unsigned r = 0; r < rows; r++)
+            ret->values[c].f64[r] = ir->value.d[c * rows + r];
+      }
       break;
 
    case GLSL_TYPE_BOOL:
-      for (i = 0; i < total_elems; i++)
-         ret->value.b[i] = ir->value.b[i];
+      for (unsigned c = 0; c < cols; c++) {
+         for (unsigned r = 0; r < rows; r++) {
+            ret->values[c].u32[r] = ir->value.b[c * rows + r] ?
+                                    NIR_TRUE : NIR_FALSE;
+         }
+      }
       break;
 
    case GLSL_TYPE_STRUCT:
diff --git a/src/compiler/nir/nir.c b/src/compiler/nir/nir.c
index cfb032c..2d882f7 100644
--- a/src/compiler/nir/nir.c
+++ b/src/compiler/nir/nir.c
@@ -806,7 +806,7 @@ nir_deref_get_const_initializer_load(nir_shader *shader, nir_deref_var *deref)
    assert(constant);
 
    const nir_deref *tail = &deref->deref;
-   unsigned matrix_offset = 0;
+   unsigned matrix_col = 0;
    while (tail->child) {
       switch (tail->child->deref_type) {
       case nir_deref_type_array: {
@@ -814,7 +814,7 @@ nir_deref_get_const_initializer_load(nir_shader *shader, nir_deref_var *deref)
          assert(arr->deref_array_type == nir_deref_array_type_direct);
          if (glsl_type_is_matrix(tail->type)) {
             assert(arr->deref.child == NULL);
-            matrix_offset = arr->base_offset;
+            matrix_col = arr->base_offset;
          } else {
             constant = constant->elements[arr->base_offset];
          }
@@ -838,24 +838,16 @@ nir_deref_get_const_initializer_load(nir_shader *shader, nir_deref_var *deref)
       nir_load_const_instr_create(shader, glsl_get_vector_elements(tail->type),
                                   bit_size);
 
-   matrix_offset *= load->def.num_components;
-   for (unsigned i = 0; i < load->def.num_components; i++) {
-      switch (glsl_get_base_type(tail->type)) {
-      case GLSL_TYPE_FLOAT:
-      case GLSL_TYPE_INT:
-      case GLSL_TYPE_UINT:
-         load->value.u32[i] = constant->value.u[matrix_offset + i];
-         break;
-      case GLSL_TYPE_DOUBLE:
-         load->value.f64[i] = constant->value.d[matrix_offset + i];
-         break;
-      case GLSL_TYPE_BOOL:
-         load->value.u32[i] = constant->value.b[matrix_offset + i] ?
-                             NIR_TRUE : NIR_FALSE;
-         break;
-      default:
-         unreachable("Invalid immediate type");
-      }
+   switch (glsl_get_base_type(tail->type)) {
+   case GLSL_TYPE_FLOAT:
+   case GLSL_TYPE_INT:
+   case GLSL_TYPE_UINT:
+   case GLSL_TYPE_DOUBLE:
+   case GLSL_TYPE_BOOL:
+      load->value = constant->values[matrix_col];
+      break;
+   default:
+      unreachable("Invalid immediate type");
    }
 
    return load;
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 3e6d168..9e8ed2c 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -97,16 +97,15 @@ typedef enum {
    nir_var_all             = ~0,
 } nir_variable_mode;
 
-/**
- * Data stored in an nir_constant
- */
-union nir_constant_data {
-   unsigned u[16];
-   int i[16];
-   float f[16];
-   bool b[16];
-   double d[16];
-};
+
+typedef union {
+   float f32[4];
+   double f64[4];
+   int32_t i32[4];
+   uint32_t u32[4];
+   int64_t i64[4];
+   uint64_t u64[4];
+} nir_const_value;
 
 typedef struct nir_constant {
    /**
@@ -116,7 +115,7 @@ typedef struct nir_constant {
     * by the type associated with the \c nir_variable.  Constants may be
     * scalars, vectors, or matrices.
     */
-   union nir_constant_data value;
+   nir_const_value values[4];
 
    /* we could get this from the var->type but makes clone *much* easier to
     * not have to care about the type.
@@ -1345,15 +1344,6 @@ nir_tex_instr_src_index(nir_tex_instr *instr, nir_tex_src_type type)
 
 void nir_tex_instr_remove_src(nir_tex_instr *tex, unsigned src_idx);
 
-typedef union {
-   float f32[4];
-   double f64[4];
-   int32_t i32[4];
-   uint32_t u32[4];
-   int64_t i64[4];
-   uint64_t u64[4];
-} nir_const_value;
-
 typedef struct {
    nir_instr instr;
 
diff --git a/src/compiler/nir/nir_clone.c b/src/compiler/nir/nir_clone.c
index 4f7bdd9..be89426 100644
--- a/src/compiler/nir/nir_clone.c
+++ b/src/compiler/nir/nir_clone.c
@@ -114,7 +114,7 @@ nir_constant_clone(const nir_constant *c, nir_variable *nvar)
 {
    nir_constant *nc = ralloc(nvar, nir_constant);
 
-   nc->value = c->value;
+   memcpy(nc->values, c->values, sizeof(nc->values));
    nc->num_elements = c->num_elements;
    nc->elements = ralloc_array(nvar, nir_constant *, c->num_elements);
    for (unsigned i = 0; i < c->num_elements; i++) {
diff --git a/src/compiler/nir/nir_print.c b/src/compiler/nir/nir_print.c
index a5b2909..eb5f57f 100644
--- a/src/compiler/nir/nir_print.c
+++ b/src/compiler/nir/nir_print.c
@@ -295,30 +295,37 @@ static void
 print_constant(nir_constant *c, const struct glsl_type *type, print_state *state)
 {
    FILE *fp = state->fp;
-   unsigned total_elems = glsl_get_components(type);
-   unsigned i;
+   const unsigned rows = glsl_get_vector_elements(type);
+   const unsigned cols = glsl_get_matrix_columns(type);
+   unsigned i, j;
 
    switch (glsl_get_base_type(type)) {
    case GLSL_TYPE_UINT:
    case GLSL_TYPE_INT:
    case GLSL_TYPE_BOOL:
-      for (i = 0; i < total_elems; i++) {
-         if (i > 0) fprintf(fp, ", ");
-         fprintf(fp, "0x%08x", c->value.u[i]);
+      for (i = 0; i < cols; i++) {
+         for (j = 0; j < rows; j++) {
+            if (i + j > 0) fprintf(fp, ", ");
+            fprintf(fp, "0x%08x", c->values[i].u32[j]);
+         }
       }
       break;
 
    case GLSL_TYPE_FLOAT:
-      for (i = 0; i < total_elems; i++) {
-         if (i > 0) fprintf(fp, ", ");
-         fprintf(fp, "%f", c->value.f[i]);
+      for (i = 0; i < cols; i++) {
+         for (j = 0; j < rows; j++) {
+            if (i + j > 0) fprintf(fp, ", ");
+            fprintf(fp, "%f", c->values[i].f32[j]);
+         }
       }
       break;
 
    case GLSL_TYPE_DOUBLE:
-      for (i = 0; i < total_elems; i++) {
-         if (i > 0) fprintf(fp, ", ");
-         fprintf(fp, "%f", c->value.d[i]);
+      for (i = 0; i < cols; i++) {
+         for (j = 0; j < rows; j++) {
+            if (i + j > 0) fprintf(fp, ", ");
+            fprintf(fp, "%f", c->values[i].f64[j]);
+         }
       }
       break;
 
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index 34968a4..f60c6e6 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -104,8 +104,7 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
          nir_load_const_instr *load =
             nir_load_const_instr_create(b->shader, num_components, 32);
 
-         for (unsigned i = 0; i < num_components; i++)
-            load->value.u32[i] = constant->value.u[i];
+         load->value = constant->values[0];
 
          nir_instr_insert_before_cf_list(&b->impl->body, &load->instr);
          val->def = &load->def;
@@ -121,8 +120,7 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
             nir_load_const_instr *load =
                nir_load_const_instr_create(b->shader, rows, 32);
 
-            for (unsigned j = 0; j < rows; j++)
-               load->value.u32[j] = constant->value.u[rows * i + j];
+            load->value = constant->values[i];
 
             nir_instr_insert_before_cf_list(&b->impl->body, &load->instr);
             col_val->def = &load->def;
@@ -752,7 +750,7 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
          length = 0;
       } else {
          length =
-            vtn_value(b, w[3], vtn_value_type_constant)->constant->value.u[0];
+            vtn_value(b, w[3], vtn_value_type_constant)->constant->values[0].u32[0];
       }
 
       val->type->type = glsl_array_type(array_element->type, length);
@@ -972,9 +970,9 @@ handle_workgroup_size_decoration_cb(struct vtn_builder *b,
 
    assert(val->const_type == glsl_vector_type(GLSL_TYPE_UINT, 3));
 
-   b->shader->info->cs.local_size[0] = val->constant->value.u[0];
-   b->shader->info->cs.local_size[1] = val->constant->value.u[1];
-   b->shader->info->cs.local_size[2] = val->constant->value.u[2];
+   b->shader->info->cs.local_size[0] = val->constant->values[0].u32[0];
+   b->shader->info->cs.local_size[1] = val->constant->values[0].u32[1];
+   b->shader->info->cs.local_size[2] = val->constant->values[0].u32[2];
 }
 
 static void
@@ -987,11 +985,11 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
    switch (opcode) {
    case SpvOpConstantTrue:
       assert(val->const_type == glsl_bool_type());
-      val->constant->value.u[0] = NIR_TRUE;
+      val->constant->values[0].u32[0] = NIR_TRUE;
       break;
    case SpvOpConstantFalse:
       assert(val->const_type == glsl_bool_type());
-      val->constant->value.u[0] = NIR_FALSE;
+      val->constant->values[0].u32[0] = NIR_FALSE;
       break;
 
    case SpvOpSpecConstantTrue:
@@ -999,17 +997,17 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
       assert(val->const_type == glsl_bool_type());
       uint32_t int_val =
          get_specialization(b, val, (opcode == SpvOpSpecConstantTrue));
-      val->constant->value.u[0] = int_val ? NIR_TRUE : NIR_FALSE;
+      val->constant->values[0].u32[0] = int_val ? NIR_TRUE : NIR_FALSE;
       break;
    }
 
    case SpvOpConstant:
       assert(glsl_type_is_scalar(val->const_type));
-      val->constant->value.u[0] = w[3];
+      val->constant->values[0].u32[0] = w[3];
       break;
    case SpvOpSpecConstant:
       assert(glsl_type_is_scalar(val->const_type));
-      val->constant->value.u[0] = get_specialization(b, val, w[3]);
+      val->constant->values[0].u32[0] = get_specialization(b, val, w[3]);
       break;
    case SpvOpSpecConstantComposite:
    case SpvOpConstantComposite: {
@@ -1024,16 +1022,14 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
       case GLSL_TYPE_FLOAT:
       case GLSL_TYPE_BOOL:
          if (glsl_type_is_matrix(val->const_type)) {
-            unsigned rows = glsl_get_vector_elements(val->const_type);
             assert(glsl_get_matrix_columns(val->const_type) == elem_count);
             for (unsigned i = 0; i < elem_count; i++)
-               for (unsigned j = 0; j < rows; j++)
-                  val->constant->value.u[rows * i + j] = elems[i]->value.u[j];
+               val->constant->values[i] = elems[i]->values[0];
          } else {
             assert(glsl_type_is_vector(val->const_type));
             assert(glsl_get_vector_elements(val->const_type) == elem_count);
             for (unsigned i = 0; i < elem_count; i++)
-               val->constant->value.u[i] = elems[i]->value.u[0];
+               val->constant->values[0].u32[i] = elems[i]->values[0].u32[0];
          }
          ralloc_free(elems);
          break;
@@ -1062,16 +1058,16 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
 
          uint32_t u[8];
          for (unsigned i = 0; i < len0; i++)
-            u[i] = v0->constant->value.u[i];
+            u[i] = v0->constant->values[0].u32[i];
          for (unsigned i = 0; i < len1; i++)
-            u[len0 + i] = v1->constant->value.u[i];
+            u[len0 + i] = v1->constant->values[0].u32[i];
 
          for (unsigned i = 0; i < count - 6; i++) {
             uint32_t comp = w[i + 6];
             if (comp == (uint32_t)-1) {
-               val->constant->value.u[i] = 0xdeadbeef;
+               val->constant->values[0].u32[i] = 0xdeadbeef;
             } else {
-               val->constant->value.u[i] = u[comp];
+               val->constant->values[0].u32[i] = u[comp];
             }
          }
          break;
@@ -1095,6 +1091,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
          }
 
          int elem = -1;
+         int col = 0;
          const struct glsl_type *type = comp->const_type;
          for (unsigned i = deref_start; i < count; i++) {
             switch (glsl_get_base_type(type)) {
@@ -1103,15 +1100,14 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
             case GLSL_TYPE_FLOAT:
             case GLSL_TYPE_BOOL:
                /* If we hit this granularity, we're picking off an element */
-               if (elem < 0)
-                  elem = 0;
-
                if (glsl_type_is_matrix(type)) {
-                  elem += w[i] * glsl_get_vector_elements(type);
+                  assert(col == 0 && elem == -1);
+                  col = w[i];
+                  elem = 0;
                   type = glsl_get_column_type(type);
                } else {
-                  assert(glsl_type_is_vector(type));
-                  elem += w[i];
+                  assert(elem <= 0 && glsl_type_is_vector(type));
+                  elem = w[i];
                   type = glsl_scalar_type(glsl_get_base_type(type));
                }
                continue;
@@ -1137,7 +1133,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
             } else {
                unsigned num_components = glsl_get_vector_elements(type);
                for (unsigned i = 0; i < num_components; i++)
-                  val->constant->value.u[i] = (*c)->value.u[elem + i];
+                  val->constant->values[0].u32[i] = (*c)->values[col].u32[elem + i];
             }
          } else {
             struct vtn_value *insert =
@@ -1148,7 +1144,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
             } else {
                unsigned num_components = glsl_get_vector_elements(type);
                for (unsigned i = 0; i < num_components; i++)
-                  (*c)->value.u[elem + i] = insert->constant->value.u[i];
+                  (*c)->values[col].u32[elem + i] = insert->constant->values[0].u32[i];
             }
          }
          break;
@@ -1170,16 +1166,11 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
 
             unsigned j = swap ? 1 - i : i;
             assert(bit_size == 32);
-            for (unsigned k = 0; k < num_components; k++)
-               src[j].u32[k] = c->value.u[k];
+            src[j] = c->values[0];
          }
 
-         nir_const_value res = nir_eval_const_opcode(op, num_components,
-                                                     bit_size, src);
-
-         for (unsigned k = 0; k < num_components; k++)
-            val->constant->value.u[k] = res.u32[k];
-
+         val->constant->values[0] =
+            nir_eval_const_opcode(op, num_components, bit_size, src);
          break;
       } /* default */
       }
@@ -1475,7 +1466,7 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
    case SpvOpImageGather:
       /* This has a component as its next source */
       gather_component =
-         vtn_value(b, w[idx++], vtn_value_type_constant)->constant->value.u[0];
+         vtn_value(b, w[idx++], vtn_value_type_constant)->constant->values[0].u32[0];
       break;
 
    default:
diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c
index 14366dc..917aa9d 100644
--- a/src/compiler/spirv/vtn_variables.c
+++ b/src/compiler/spirv/vtn_variables.c
@@ -938,9 +938,9 @@ apply_var_decoration(struct vtn_builder *b, nir_variable *nir_var,
          nir_var->data.read_only = true;
 
          nir_constant *c = rzalloc(nir_var, nir_constant);
-         c->value.u[0] = b->shader->info->cs.local_size[0];
-         c->value.u[1] = b->shader->info->cs.local_size[1];
-         c->value.u[2] = b->shader->info->cs.local_size[2];
+         c->values[0].u32[0] = b->shader->info->cs.local_size[0];
+         c->values[0].u32[1] = b->shader->info->cs.local_size[1];
+         c->values[0].u32[2] = b->shader->info->cs.local_size[2];
          nir_var->constant_initializer = c;
          break;
       }
@@ -1388,7 +1388,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
          struct vtn_value *link_val = vtn_untyped_value(b, w[i]);
          if (link_val->value_type == vtn_value_type_constant) {
             chain->link[idx].mode = vtn_access_mode_literal;
-            chain->link[idx].id = link_val->constant->value.u[0];
+            chain->link[idx].id = link_val->constant->values[0].u32[0];
          } else {
             chain->link[idx].mode = vtn_access_mode_id;
             chain->link[idx].id = w[i];
-- 
2.5.0.400.gff86faf



More information about the mesa-dev mailing list