Mesa (main): spirv_to_nir: Cast RelaxedPrecision ALU op dests to mediump.

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Thu May 19 20:15:08 UTC 2022


Module: Mesa
Branch: main
Commit: 260559050afc11b50e05209203bc7e77bd42bcfe
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=260559050afc11b50e05209203bc7e77bd42bcfe

Author: Emma Anholt <emma at anholt.net>
Date:   Tue Apr 26 16:29:04 2022 -0700

spirv_to_nir: Cast RelaxedPrecision ALU op dests to mediump.

This is controlled by spirv_to_nir_options.relaxed_precision_alu, because
some drivers don't want it.

This gets us mostly 16-bit math on turnip in vk-5-normal.

Reviewed-by: Matt Turner <mattst88 at gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16465>

---

 src/compiler/spirv/nir_spirv.h   |  10 +++
 src/compiler/spirv/vtn_alu.c     | 136 ++++++++++++++++++++++++++++++++++++++-
 src/compiler/spirv/vtn_glsl450.c |  43 +++++++++++++
 src/compiler/spirv/vtn_private.h |   7 ++
 4 files changed, 193 insertions(+), 3 deletions(-)

diff --git a/src/compiler/spirv/nir_spirv.h b/src/compiler/spirv/nir_spirv.h
index 0410c5f7f51..8899eae7623 100644
--- a/src/compiler/spirv/nir_spirv.h
+++ b/src/compiler/spirv/nir_spirv.h
@@ -75,6 +75,16 @@ struct spirv_to_nir_options {
     */
    uint16_t float_controls_execution_mode;
 
+   /* True if RelaxedPrecision-decorated ALU result values should be performed
+    * with 16-bit math.
+    */
+   bool mediump_16bit_alu;
+
+   /* When mediump_16bit_alu is set, determines whether nir_op_fddx/fddy can be
+    * performed in 16-bit math.
+    */
+   bool mediump_16bit_derivatives;
+
    struct spirv_supported_capabilities caps;
 
    /* Address format for various kinds of pointers. */
diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c
index 6fd7a05afc2..b35f38bffc5 100644
--- a/src/compiler/spirv/vtn_alu.c
+++ b/src/compiler/spirv/vtn_alu.c
@@ -153,6 +153,48 @@ mat_times_scalar(struct vtn_builder *b,
    return dest;
 }
 
+nir_ssa_def *
+vtn_mediump_downconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_ssa_def *def)
+{
+   if (def->bit_size == 16)
+      return def;
+
+   switch (base_type) {
+   case GLSL_TYPE_FLOAT:
+      return nir_f2fmp(&b->nb, def);
+   case GLSL_TYPE_INT:
+   case GLSL_TYPE_UINT:
+      return nir_i2imp(&b->nb, def);
+   default:
+      unreachable("bad relaxed precision input type");
+   }
+}
+
+struct vtn_ssa_value *
+vtn_mediump_downconvert_value(struct vtn_builder *b, struct vtn_ssa_value *src)
+{
+   if (!src)
+      return src;
+
+   struct vtn_ssa_value *srcmp = vtn_create_ssa_value(b, src->type);
+
+   if (src->transposed) {
+      srcmp->transposed = vtn_mediump_downconvert_value(b, src->transposed);
+   } else {
+      enum glsl_base_type base_type = glsl_get_base_type(src->type);
+
+      if (glsl_type_is_vector_or_scalar(src->type)) {
+         srcmp->def = vtn_mediump_downconvert(b, base_type, src->def);
+      } else {
+         assert(glsl_get_base_type(src->type) == GLSL_TYPE_FLOAT);
+         for (int i = 0; i < glsl_get_matrix_columns(src->type); i++)
+            srcmp->elems[i]->def = vtn_mediump_downconvert(b, base_type, src->elems[i]->def);
+      }
+   }
+
+   return srcmp;
+}
+
 static struct vtn_ssa_value *
 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
                       struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
@@ -465,6 +507,84 @@ handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
    }
 }
 
+static void
+vtn_value_is_relaxed_precision_cb(struct vtn_builder *b,
+                          struct vtn_value *val, int member,
+                          const struct vtn_decoration *dec, void *void_ctx)
+{
+   bool *relaxed_precision = void_ctx;
+   switch (dec->decoration) {
+   case SpvDecorationRelaxedPrecision:
+      *relaxed_precision = true;
+      break;
+
+   default:
+      break;
+   }
+}
+
+bool
+vtn_value_is_relaxed_precision(struct vtn_builder *b, struct vtn_value *val)
+{
+   bool result = false;
+   vtn_foreach_decoration(b, val,
+                          vtn_value_is_relaxed_precision_cb, &result);
+   return result;
+}
+
+static bool
+vtn_alu_op_mediump_16bit(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val)
+{
+   if (!b->options->mediump_16bit_alu || !vtn_value_is_relaxed_precision(b, dest_val))
+      return false;
+
+   switch (opcode) {
+   case SpvOpDPdx:
+   case SpvOpDPdy:
+   case SpvOpDPdxFine:
+   case SpvOpDPdyFine:
+   case SpvOpDPdxCoarse:
+   case SpvOpDPdyCoarse:
+   case SpvOpFwidth:
+   case SpvOpFwidthFine:
+   case SpvOpFwidthCoarse:
+      return b->options->mediump_16bit_derivatives;
+   default:
+      return true;
+   }
+}
+
+static nir_ssa_def *
+vtn_mediump_upconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_ssa_def *def)
+{
+   if (def->bit_size != 16)
+      return def;
+
+   switch (base_type) {
+   case GLSL_TYPE_FLOAT:
+      return nir_f2f32(&b->nb, def);
+   case GLSL_TYPE_INT:
+      return nir_i2i32(&b->nb, def);
+   case GLSL_TYPE_UINT:
+      return nir_u2u32(&b->nb, def);
+   default:
+      unreachable("bad relaxed precision output type");
+   }
+}
+
+void
+vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value)
+{
+   enum glsl_base_type base_type = glsl_get_base_type(value->type);
+
+   if (glsl_type_is_vector_or_scalar(value->type)) {
+      value->def = vtn_mediump_upconvert(b, base_type, value->def);
+   } else {
+      for (int i = 0; i < glsl_get_matrix_columns(value->type); i++)
+         value->elems[i]->def = vtn_mediump_upconvert(b, base_type, value->elems[i]->def);
+   }
+}
+
 void
 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
                const uint32_t *w, unsigned count)
@@ -473,17 +593,25 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
    const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
 
    vtn_handle_no_contraction(b, dest_val);
+   bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
 
    /* Collect the various SSA sources */
    const unsigned num_inputs = count - 3;
    struct vtn_ssa_value *vtn_src[4] = { NULL, };
-   for (unsigned i = 0; i < num_inputs; i++)
+   for (unsigned i = 0; i < num_inputs; i++) {
       vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
+      if (mediump_16bit)
+         vtn_src[i] = vtn_mediump_downconvert_value(b, vtn_src[i]);
+   }
 
    if (glsl_type_is_matrix(vtn_src[0]->type) ||
        (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
-      vtn_push_ssa_value(b, w[2],
-         vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]));
+      struct vtn_ssa_value *dest = vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]);
+
+      if (mediump_16bit)
+         vtn_mediump_upconvert_value(b, dest);
+
+      vtn_push_ssa_value(b, w[2], dest);
       b->nb.exact = b->exact;
       return;
    }
@@ -861,6 +989,8 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       break;
    }
 
+   if (mediump_16bit)
+      vtn_mediump_upconvert_value(b, dest);
    vtn_push_ssa_value(b, w[2], dest);
 
    b->nb.exact = b->exact;
diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_glsl450.c
index a8446f50a34..5ae8c582739 100644
--- a/src/compiler/spirv/vtn_glsl450.c
+++ b/src/compiler/spirv/vtn_glsl450.c
@@ -277,6 +277,41 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
 {
    struct nir_builder *nb = &b->nb;
    const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
+   struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
+
+   bool mediump_16bit;
+   switch (entrypoint) {
+   case GLSLstd450PackSnorm4x8:
+   case GLSLstd450PackUnorm4x8:
+   case GLSLstd450PackSnorm2x16:
+   case GLSLstd450PackUnorm2x16:
+   case GLSLstd450PackHalf2x16:
+   case GLSLstd450PackDouble2x32:
+   case GLSLstd450UnpackSnorm4x8:
+   case GLSLstd450UnpackUnorm4x8:
+   case GLSLstd450UnpackSnorm2x16:
+   case GLSLstd450UnpackUnorm2x16:
+   case GLSLstd450UnpackHalf2x16:
+   case GLSLstd450UnpackDouble2x32:
+      /* Asking for relaxed precision snorm 4x8 pack results (for example)
+       * doesn't even make sense.  The NIR opcodes have a fixed output size, so
+       * no trying to reduce precision.
+       */
+      mediump_16bit = false;
+      break;
+
+   case GLSLstd450Frexp:
+   case GLSLstd450FrexpStruct:
+   case GLSLstd450Modf:
+   case GLSLstd450ModfStruct:
+      /* Not sure how to detect the ->elems[i] destinations on these in vtn_upconvert_value(). */
+      mediump_16bit = false;
+      break;
+
+   default:
+      mediump_16bit = b->options->mediump_16bit_alu && vtn_value_is_relaxed_precision(b, dest_val);
+      break;
+   }
 
    /* Collect the various SSA sources */
    unsigned num_inputs = count - 5;
@@ -287,9 +322,14 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
          continue;
 
       src[i] = vtn_get_nir_ssa(b, w[i + 5]);
+      if (mediump_16bit) {
+         struct vtn_ssa_value *vtn_src = vtn_ssa_value(b, w[i + 5]);
+         src[i] = vtn_mediump_downconvert(b, glsl_get_base_type(vtn_src->type), src[i]);
+      }
    }
 
    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
+
    vtn_handle_no_contraction(b, vtn_untyped_value(b, w[2]));
    switch (entrypoint) {
    case GLSLstd450Radians:
@@ -589,6 +629,9 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
    }
    b->nb.exact = false;
 
+   if (mediump_16bit)
+      vtn_mediump_upconvert_value(b, dest);
+
    vtn_push_ssa_value(b, w[2], dest);
 }
 
diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h
index f7a82ce6527..6a9d1394760 100644
--- a/src/compiler/spirv/vtn_private.h
+++ b/src/compiler/spirv/vtn_private.h
@@ -1048,6 +1048,13 @@ SpvMemorySemanticsMask vtn_mode_to_memory_semantics(enum vtn_variable_mode mode)
 void vtn_emit_memory_barrier(struct vtn_builder *b, SpvScope scope,
                              SpvMemorySemanticsMask semantics);
 
+bool vtn_value_is_relaxed_precision(struct vtn_builder *b, struct vtn_value *val);
+nir_ssa_def *
+vtn_mediump_downconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_ssa_def *def);
+struct vtn_ssa_value *
+vtn_mediump_downconvert_value(struct vtn_builder *b, struct vtn_ssa_value *src);
+void vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value);
+
 static inline int
 cmp_uint32_t(const void *pa, const void *pb)
 {



More information about the mesa-commit mailing list