Mesa (master): nir: Update saturated float->int/uint conversion algorithm

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Tue Jan 5 20:06:40 UTC 2021


Module: Mesa
Branch: master
Commit: 4d83306a9aabb5f9ea7e6a54d0e25c0f82805965
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=4d83306a9aabb5f9ea7e6a54d0e25c0f82805965

Author: Jesse Natalie <jenatali at microsoft.com>
Date:   Mon Dec 28 15:45:58 2020 -0800

nir: Update saturated float->int/uint conversion algorithm

The mantissa for a float doesn't contain enough data to accurately represent
the min/max values for some destination types. Instead of clamping before
converting, clamp after converting when coming from floats. This improves
conformance of CL conversions, specifically for float -> long/ulong with
int64 emulation enabled.

Refactors the limit determination from the clamp, so we can determine
limits for the dest type (int/uint) in both the source (float) and dest
type. The limit as a float is used for comparison, while the limit as a
dest type is used for bcsel.

Important note is that the comparison is inverted to fge instead of flt,
so the bcsel chooses the direct int/uint over the converted float in the
case where the comparison comes up equal, but the conversion can't produce
the exact min/max value.

Reviewed-by: Jason Ekstrand <jason at jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/8256>

---

 src/compiler/nir/nir_conversion_builder.h | 133 +++++++++++++++++++-----------
 1 file changed, 87 insertions(+), 46 deletions(-)

diff --git a/src/compiler/nir/nir_conversion_builder.h b/src/compiler/nir/nir_conversion_builder.h
index 78e41bfb690..c124e2650f0 100644
--- a/src/compiler/nir/nir_conversion_builder.h
+++ b/src/compiler/nir/nir_conversion_builder.h
@@ -222,28 +222,26 @@ nir_alu_type_range_contains_type_range(nir_alu_type a, nir_alu_type b)
 }
 
 /**
- * Clamp the source value into the widest representatble range of the
- * destination type with cmp + bcsel.
+ * Retrieves limits used for clamping a value of the src type into
+ * the widest representable range of the dst type via cmp + bcsel
  */
-static inline nir_ssa_def *
-nir_clamp_to_type_range(nir_builder *b,
-                        nir_ssa_def *src, nir_alu_type src_type,
-                        nir_alu_type dest_type)
+static inline void
+nir_get_clamp_limits(nir_builder *b,
+                     nir_alu_type src_type,
+                     nir_alu_type dest_type,
+                     nir_ssa_def **low, nir_ssa_def **high)
 {
-   assert(nir_alu_type_get_type_size(src_type) == 0 ||
-          nir_alu_type_get_type_size(src_type) == src->bit_size);
-   src_type |= src->bit_size;
-   if (nir_alu_type_range_contains_type_range(dest_type, src_type))
-      return src;
-
    /* Split types from bit sizes */
    nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
    nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
+   unsigned src_bit_size = nir_alu_type_get_type_size(src_type);
    unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
-   assert(dest_bit_size != 0);
+   assert(dest_bit_size != 0 && src_bit_size != 0);
+
+   *low = NULL;
+   *high = NULL;
 
    /* limits of the destination type, expressed in the source type */
-   nir_ssa_def *low = NULL, *high = NULL;
    switch (dest_base_type) {
    case nir_type_int: {
       int64_t ilow, ihigh;
@@ -256,14 +254,14 @@ nir_clamp_to_type_range(nir_builder *b,
       }
 
       if (src_base_type == nir_type_int) {
-         low = nir_imm_intN_t(b, ilow, src->bit_size);
-         high = nir_imm_intN_t(b, ihigh, src->bit_size);
+         *low = nir_imm_intN_t(b, ilow, src_bit_size);
+         *high = nir_imm_intN_t(b, ihigh, src_bit_size);
       } else if (src_base_type == nir_type_uint) {
-         assert(src->bit_size >= dest_bit_size);
-         high = nir_imm_intN_t(b, ihigh, src->bit_size);
+         assert(src_bit_size >= dest_bit_size);
+         *high = nir_imm_intN_t(b, ihigh, src_bit_size);
       } else {
-         low = nir_imm_floatN_t(b, ilow, src->bit_size);
-         high = nir_imm_floatN_t(b, ihigh, src->bit_size);
+         *low = nir_imm_floatN_t(b, ilow, src_bit_size);
+         *high = nir_imm_floatN_t(b, ihigh, src_bit_size);
       }
       break;
    }
@@ -271,12 +269,12 @@ nir_clamp_to_type_range(nir_builder *b,
       uint64_t uhigh = dest_bit_size == 64 ?
          ~0ull : (1ull << dest_bit_size) - 1;
       if (src_base_type != nir_type_float) {
-         low = nir_imm_intN_t(b, 0, src->bit_size);
-         if (src_base_type == nir_type_uint || src->bit_size > dest_bit_size)
-            high = nir_imm_intN_t(b, uhigh, src->bit_size);
+         *low = nir_imm_intN_t(b, 0, src_bit_size);
+         if (src_base_type == nir_type_uint || src_bit_size > dest_bit_size)
+            *high = nir_imm_intN_t(b, uhigh, src_bit_size);
       } else {
-         low = nir_imm_floatN_t(b, 0.0f, src->bit_size);
-         high = nir_imm_floatN_t(b, uhigh, src->bit_size);
+         *low = nir_imm_floatN_t(b, 0.0f, src_bit_size);
+         *high = nir_imm_floatN_t(b, uhigh, src_bit_size);
       }
       break;
    }
@@ -302,29 +300,29 @@ nir_clamp_to_type_range(nir_builder *b,
       switch (src_base_type) {
       case nir_type_int: {
          int64_t src_ilow, src_ihigh;
-         if (src->bit_size == 64) {
+         if (src_bit_size == 64) {
             src_ilow = INT64_MIN;
             src_ihigh = INT64_MAX;
          } else {
-            src_ilow = -(1ll << (src->bit_size - 1));
-            src_ihigh = (1ll << (src->bit_size - 1)) - 1;
+            src_ilow = -(1ll << (src_bit_size - 1));
+            src_ihigh = (1ll << (src_bit_size - 1)) - 1;
          }
          if (src_ilow < flow)
-            low = nir_imm_intN_t(b, flow, src->bit_size);
+            *low = nir_imm_intN_t(b, flow, src_bit_size);
          if (src_ihigh > fhigh)
-            high = nir_imm_intN_t(b, fhigh, src->bit_size);
+            *high = nir_imm_intN_t(b, fhigh, src_bit_size);
          break;
       }
       case nir_type_uint: {
-         uint64_t src_uhigh = src->bit_size == 64 ?
-            ~0ull : (1ull << src->bit_size) - 1;
+         uint64_t src_uhigh = src_bit_size == 64 ?
+            ~0ull : (1ull << src_bit_size) - 1;
          if (src_uhigh > fhigh)
-            high = nir_imm_intN_t(b, fhigh, src->bit_size);
+            *high = nir_imm_intN_t(b, fhigh, src_bit_size);
          break;
       }
       case nir_type_float:
-         low = nir_imm_floatN_t(b, flow, src->bit_size);
-         high = nir_imm_floatN_t(b, fhigh, src->bit_size);
+         *low = nir_imm_floatN_t(b, flow, src_bit_size);
+         *high = nir_imm_floatN_t(b, fhigh, src_bit_size);
          break;
       default:
          unreachable("Clamping from unknown type");
@@ -335,9 +333,34 @@ nir_clamp_to_type_range(nir_builder *b,
       unreachable("clamping to unknown type");
       break;
    }
+}
+
+/**
+ * Clamp the value into the widest representatble range of the
+ * destination type with cmp + bcsel.
+ * 
+ * val/val_type: The variables used for bcsel
+ * src/src_type: The variables used for comparison
+ * dest_type: The type which determines the range used for comparison
+ */
+static inline nir_ssa_def *
+nir_clamp_to_type_range(nir_builder *b,
+                        nir_ssa_def *val, nir_alu_type val_type,
+                        nir_ssa_def *src, nir_alu_type src_type,
+                        nir_alu_type dest_type)
+{
+   assert(nir_alu_type_get_type_size(src_type) == 0 ||
+          nir_alu_type_get_type_size(src_type) == src->bit_size);
+   src_type |= src->bit_size;
+   if (nir_alu_type_range_contains_type_range(dest_type, src_type))
+      return val;
+
+   /* limits of the destination type, expressed in the source type */
+   nir_ssa_def *low = NULL, *high = NULL;
+   nir_get_clamp_limits(b, src_type, dest_type, &low, &high);
 
    nir_ssa_def *low_cond = NULL, *high_cond = NULL;
-   switch (src_base_type) {
+   switch (nir_alu_type_get_base_type(src_type)) {
    case nir_type_int:
       low_cond = low ? nir_ilt(b, src, low) : NULL;
       high_cond = high ? nir_ilt(b, high, src) : NULL;
@@ -347,18 +370,23 @@ nir_clamp_to_type_range(nir_builder *b,
       high_cond = high ? nir_ult(b, high, src) : NULL;
       break;
    case nir_type_float:
-      low_cond = low ? nir_flt(b, src, low) : NULL;
-      high_cond = high ? nir_flt(b, high, src) : NULL;
+      low_cond = low ? nir_fge(b, low, src) : NULL;
+      high_cond = high ? nir_fge(b, src, high) : NULL;
       break;
    default:
       unreachable("clamping from unknown type");
    }
 
-   nir_ssa_def *res = src;
-   if (low_cond)
-      res = nir_bcsel(b, low_cond, low, res);
-   if (high_cond)
-      res = nir_bcsel(b, high_cond, high, res);
+   nir_ssa_def *val_low = low, *val_high = high;
+   if (val_type != src_type) {
+      nir_get_clamp_limits(b, val_type, dest_type, &val_low, &val_high);
+   }
+
+   nir_ssa_def *res = val;
+   if (low_cond && val_low)
+      res = nir_bcsel(b, low_cond, val_low, res);
+   if (high_cond && val_high)
+      res = nir_bcsel(b, high_cond, val_high, res);
 
    return res;
 }
@@ -425,6 +453,14 @@ nir_convert_with_rounding(nir_builder *b,
       !nir_alu_type_range_contains_type_range(dest_type, src_type);
    round = nir_simplify_conversion_rounding(src_type, dest_type, round);
 
+   /* For float -> int/uint conversions, we might not be able to represent
+    * the destination range in the source float accurately. For these cases,
+    * do the comparison in float range, but the bcsel in the destination range.
+    */
+   bool clamp_after_conversion = clamp &&
+      src_base_type == nir_type_float &&
+      dest_base_type != nir_type_float;
+
    /*
     * If we don't care about rounding and clamping, we can just use NIR's
     * built-in ops. There is also a special case for SPIR-V in shaders, where
@@ -452,8 +488,8 @@ nir_convert_with_rounding(nir_builder *b,
    nir_ssa_def *dest = src;
 
    /* clamp the result into range */
-   if (clamp)
-      dest = nir_clamp_to_type_range(b, dest, src_type, dest_type);
+   if (clamp && !clamp_after_conversion)
+      dest = nir_clamp_to_type_range(b, src, src_type, src, src_type, dest_type);
 
    /* round with selected rounding mode */
    if (!trivial_convert && round != nir_rounding_mode_undef) {
@@ -472,7 +508,12 @@ nir_convert_with_rounding(nir_builder *b,
 
    /* now we can convert the value */
    nir_op op = nir_type_conversion_op(src_type, dest_type, round);
-   return nir_build_alu(b, op, dest, NULL, NULL, NULL);
+   dest = nir_build_alu(b, op, dest, NULL, NULL, NULL);
+
+   if (clamp_after_conversion)
+      dest = nir_clamp_to_type_range(b, dest, dest_type, src, src_type, dest_type);
+
+   return dest;
 }
 
 #ifdef __cplusplus



More information about the mesa-commit mailing list