Mesa (master): zink: support emitting 16-bit float types

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Fri Apr 30 12:11:51 UTC 2021


Module: Mesa
Branch: master
Commit: 1971efe5ba7c2e3fdeacfe6b019073c237baf55b
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=1971efe5ba7c2e3fdeacfe6b019073c237baf55b

Author: Erik Faye-Lund <erik.faye-lund at collabora.com>
Date:   Wed Apr  7 16:50:22 2021 +0200

zink: support emitting 16-bit float types

This prepares us for being able to support using 16-bit float types
in shaders, which might help performance in some cases.

Reviewed-By: Mike Blumenkrantz <michael.blumenkrantz at gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/10101>

---

 .../drivers/zink/nir_to_spirv/nir_to_spirv.c       | 33 +++++++++++++++++-----
 .../drivers/zink/nir_to_spirv/spirv_builder.c      | 13 ++++++---
 2 files changed, 35 insertions(+), 11 deletions(-)

diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
index 75d75bab541..b55e026a950 100644
--- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
+++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
@@ -220,7 +220,7 @@ get_atomic_op(nir_intrinsic_op op)
 static SpvId
 emit_float_const(struct ntv_context *ctx, int bit_size, double value)
 {
-   assert(bit_size == 32 || bit_size == 64);
+   assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
    return spirv_builder_const_float(&ctx->builder, bit_size, value);
 }
 
@@ -241,7 +241,7 @@ emit_int_const(struct ntv_context *ctx, int bit_size, int64_t value)
 static SpvId
 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
 {
-   assert(bit_size == 32 || bit_size == 64);
+   assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
 
    SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
    if (num_components > 1)
@@ -312,6 +312,9 @@ get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
    case GLSL_TYPE_BOOL:
       return spirv_builder_type_bool(&ctx->builder);
 
+   case GLSL_TYPE_FLOAT16:
+      return spirv_builder_type_float(&ctx->builder, 16);
+
    case GLSL_TYPE_FLOAT:
       return spirv_builder_type_float(&ctx->builder, 32);
 
@@ -1389,7 +1392,7 @@ static SpvId
 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
                   unsigned num_components, double value)
 {
-   assert(bit_size == 32 || bit_size == 64);
+   assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
 
    SpvId result = emit_float_const(ctx, bit_size, value);
    if (num_components == 1)
@@ -1578,12 +1581,15 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
    UNOP(nir_op_f2u16, SpvOpConvertFToU)
    UNOP(nir_op_f2i32, SpvOpConvertFToS)
    UNOP(nir_op_f2u32, SpvOpConvertFToU)
+   UNOP(nir_op_i2f16, SpvOpConvertSToF)
    UNOP(nir_op_i2f32, SpvOpConvertSToF)
+   UNOP(nir_op_u2f16, SpvOpConvertUToF)
    UNOP(nir_op_u2f32, SpvOpConvertUToF)
    UNOP(nir_op_i2i16, SpvOpSConvert)
    UNOP(nir_op_i2i32, SpvOpSConvert)
    UNOP(nir_op_u2u16, SpvOpUConvert)
    UNOP(nir_op_u2u32, SpvOpUConvert)
+   UNOP(nir_op_f2f16, SpvOpFConvert)
    UNOP(nir_op_f2f32, SpvOpFConvert)
    UNOP(nir_op_f2i64, SpvOpConvertFToS)
    UNOP(nir_op_f2u64, SpvOpConvertFToU)
@@ -1612,6 +1618,7 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
                            get_ivec_constant(ctx, bit_size, num_components, 0));
       break;
 
+   case nir_op_b2f16:
    case nir_op_b2f32:
    case nir_op_b2f64:
       assert(nir_op_infos[alu->op].num_inputs == 1);
@@ -1809,10 +1816,19 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
                                                      load_const->value[i].b);
 
       } else {
-         for (int i = 0; i < num_components; i++)
-            components[i] = emit_uint_const(ctx, bit_size,
-                                            bit_size == 64 ? load_const->value[i].u64 : load_const->value[i].u32);
-
+         for (int i = 0; i < num_components; i++) {
+            if (bit_size == 16)
+               components[i] = emit_uint_const(ctx, bit_size,
+                                               load_const->value[i].u16);
+            else if (bit_size == 32)
+               components[i] = emit_uint_const(ctx, bit_size,
+                                               load_const->value[i].u32);
+            else if (bit_size == 64)
+               components[i] = emit_uint_const(ctx, bit_size,
+                                               load_const->value[i].u64);
+            else
+               unreachable("unhandled constant bit size!");
+         }
       }
       constant = spirv_builder_const_composite(&ctx->builder, type,
                                                components, num_components);
@@ -3587,6 +3603,9 @@ nir_to_spirv(struct nir_shader *s, const struct zink_so_info *so_info, bool spir
       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityInt16);
    if (s->info.bit_sizes_int & 64)
       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityInt64);
+
+   if (s->info.bit_sizes_float & 16)
+      spirv_builder_emit_cap(&ctx.builder, SpvCapabilityFloat16);
    if (s->info.bit_sizes_float & 64)
       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityFloat64);
 
diff --git a/src/gallium/drivers/zink/nir_to_spirv/spirv_builder.c b/src/gallium/drivers/zink/nir_to_spirv/spirv_builder.c
index a6fe423c0f9..40898d560a7 100644
--- a/src/gallium/drivers/zink/nir_to_spirv/spirv_builder.c
+++ b/src/gallium/drivers/zink/nir_to_spirv/spirv_builder.c
@@ -28,6 +28,7 @@
 #include "util/ralloc.h"
 #include "util/u_bitcast.h"
 #include "util/u_memory.h"
+#include "util/half_float.h"
 #include "util/hash_table.h"
 #define XXH_INLINE_ALL
 #include "util/xxhash.h"
@@ -1408,7 +1409,7 @@ spirv_builder_const_bool(struct spirv_builder *b, bool val)
 SpvId
 spirv_builder_const_int(struct spirv_builder *b, int width, int64_t val)
 {
-   assert(width >= 32);
+   assert(width >= 16);
    SpvId type = spirv_builder_type_int(b, width);
    if (width <= 32)
       return emit_constant_32(b, type, val);
@@ -1437,12 +1438,16 @@ spirv_builder_spec_const_uint(struct spirv_builder *b, int width)
 SpvId
 spirv_builder_const_float(struct spirv_builder *b, int width, double val)
 {
-   assert(width >= 32);
+   assert(width >= 16);
    SpvId type = spirv_builder_type_float(b, width);
-   if (width <= 32)
+   if (width == 16)
+      return emit_constant_32(b, type, _mesa_float_to_half(val));
+   else if (width == 32)
       return emit_constant_32(b, type, u_bitcast_f2u(val));
-   else
+   else if (width == 64)
       return emit_constant_64(b, type, u_bitcast_d2u(val));
+
+   unreachable("unhandled float-width");
 }
 
 SpvId



More information about the mesa-commit mailing list