[Mesa-dev] [PATCH v2 09/53] compiler/spirv: implement 16-bit hyperbolic trigonometric functions

Iago Toral Quiroga itoral at igalia.com
Wed Dec 19 11:50:37 UTC 2018


v2:
 - use nir_fadd_imm and nir_fmul_imm helpers (Jason)
---
 src/compiler/spirv/vtn_glsl450.c | 44 +++++++++++++++++++-------------
 1 file changed, 26 insertions(+), 18 deletions(-)

diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_glsl450.c
index ec91d9308c5..c8400d6c80f 100644
--- a/src/compiler/spirv/vtn_glsl450.c
+++ b/src/compiler/spirv/vtn_glsl450.c
@@ -654,17 +654,17 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
    case GLSLstd450Sinh:
       /* 0.5 * (e^x - e^(-x)) */
       val->ssa->def =
-         nir_fmul(nb, nir_imm_float(nb, 0.5f),
-                      nir_fsub(nb, build_exp(nb, src[0]),
-                                   build_exp(nb, nir_fneg(nb, src[0]))));
+         nir_fmul_imm(nb, nir_fsub(nb, build_exp(nb, src[0]),
+                                       build_exp(nb, nir_fneg(nb, src[0]))),
+                          0.5f);
       return;
 
    case GLSLstd450Cosh:
       /* 0.5 * (e^x + e^(-x)) */
       val->ssa->def =
-         nir_fmul(nb, nir_imm_float(nb, 0.5f),
-                      nir_fadd(nb, build_exp(nb, src[0]),
-                                   build_exp(nb, nir_fneg(nb, src[0]))));
+         nir_fmul_imm(nb, nir_fadd(nb, build_exp(nb, src[0]),
+                                       build_exp(nb, nir_fneg(nb, src[0]))),
+                          0.5f);
       return;
 
    case GLSLstd450Tanh: {
@@ -675,30 +675,38 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
        * We clamp x to (-inf, +10] to avoid precision problems.  When x > 10,
        * e^2x is so much larger than 1.0 that 1.0 gets flushed to zero in the
        * computation e^2x +/- 1 so it can be ignored.
+       *
+       * For 16-bit precision we clamp x to (-inf, +4.2] since the maximum
+       * representable number is only 65,504 and e^(2*6) exceeds that. Also,
+       * if x > 4.2, tanh(x) will return 1.0 in fp16.
        */
-      nir_ssa_def *x = nir_fmin(nb, src[0], nir_imm_float(nb, 10));
-      nir_ssa_def *exp2x = build_exp(nb, nir_fmul(nb, x, nir_imm_float(nb, 2)));
-      val->ssa->def = nir_fdiv(nb, nir_fsub(nb, exp2x, nir_imm_float(nb, 1)),
-                                   nir_fadd(nb, exp2x, nir_imm_float(nb, 1)));
+      const uint32_t bit_size = src[0]->bit_size;
+      const double clamped_x = bit_size > 16 ? 10.0 : 4.2;
+      nir_ssa_def *x = nir_fmin(nb, src[0],
+                                    nir_imm_floatN_t(nb, clamped_x, bit_size));
+      nir_ssa_def *exp2x = build_exp(nb, nir_fmul_imm(nb, x, 2.0));
+      val->ssa->def = nir_fdiv(nb, nir_fadd_imm(nb, exp2x, -1.0),
+                                   nir_fadd_imm(nb, exp2x, 1.0));
       return;
    }
 
    case GLSLstd450Asinh:
       val->ssa->def = nir_fmul(nb, nir_fsign(nb, src[0]),
          build_log(nb, nir_fadd(nb, nir_fabs(nb, src[0]),
-                       nir_fsqrt(nb, nir_fadd(nb, nir_fmul(nb, src[0], src[0]),
-                                                  nir_imm_float(nb, 1.0f))))));
+                       nir_fsqrt(nb, nir_fadd_imm(nb, nir_fmul(nb, src[0], src[0]),
+                                                      1.0f)))));
       return;
    case GLSLstd450Acosh:
       val->ssa->def = build_log(nb, nir_fadd(nb, src[0],
-         nir_fsqrt(nb, nir_fsub(nb, nir_fmul(nb, src[0], src[0]),
-                                    nir_imm_float(nb, 1.0f)))));
+         nir_fsqrt(nb, nir_fadd_imm(nb, nir_fmul(nb, src[0], src[0]),
+                                        -1.0f))));
       return;
    case GLSLstd450Atanh: {
-      nir_ssa_def *one = nir_imm_float(nb, 1.0);
-      val->ssa->def = nir_fmul(nb, nir_imm_float(nb, 0.5f),
-         build_log(nb, nir_fdiv(nb, nir_fadd(nb, one, src[0]),
-                                    nir_fsub(nb, one, src[0]))));
+      nir_ssa_def *one = nir_imm_floatN_t(nb, 1.0, src[0]->bit_size);
+      val->ssa->def =
+         nir_fmul_imm(nb, build_log(nb, nir_fdiv(nb, nir_fadd_imm(nb, src[0], 1.0f),
+                                        nir_fsub(nb, one, src[0]))),
+                          0.5f);
       return;
    }
 
-- 
2.17.1



More information about the mesa-dev mailing list