[Beignet] [PATCH V2 6/9] Backend: Add half type support for sub group functions

Xiuli Pan xiuli.pan at intel.com
Mon Aug 8 03:31:24 UTC 2016


From: Pan Xiuli <xiuli.pan at intel.com>

Sub group functions support half type if FP16 is supported, thus for
gen8+ devices.

Signed-off-by: Pan Xiuli <xiuli.pan at intel.com>
---
 backend/src/backend/gen8_context.cpp       |  6 ++++++
 backend/src/backend/gen_insn_selection.cpp |  6 ++++--
 backend/src/libocl/tmpl/ocl_simd.tmpl.cl   | 10 ++++++++++
 backend/src/libocl/tmpl/ocl_simd.tmpl.h    | 13 +++++++++++++
 4 files changed, 33 insertions(+), 2 deletions(-)

diff --git a/backend/src/backend/gen8_context.cpp b/backend/src/backend/gen8_context.cpp
index 7ddb95a..5809835 100644
--- a/backend/src/backend/gen8_context.cpp
+++ b/backend/src/backend/gen8_context.cpp
@@ -1343,6 +1343,8 @@ namespace gbe
         p->MOV(dataReg, GenRegister::immd(0x0));
       else if (dataReg.type == GEN_TYPE_UD)
         p->MOV(dataReg, GenRegister::immud(0x0));
+      else if (dataReg.type == GEN_TYPE_HF)
+        p->MOV(dataReg, GenRegister::immh(0x0));
       else if (dataReg.type == GEN_TYPE_F)
         p->MOV(dataReg, GenRegister::immf(0x0));
       else if (dataReg.type == GEN_TYPE_L)
@@ -1361,6 +1363,8 @@ namespace gbe
         p->MOV(dataReg, GenRegister::immd(0x7FFFFFFF));
       else if (dataReg.type == GEN_TYPE_UD)
         p->MOV(dataReg, GenRegister::immud(0xFFFFFFFF));
+      else if (dataReg.type == GEN_TYPE_HF)
+        p->MOV(GenRegister::retype(dataReg, GEN_TYPE_UW), GenRegister::immuw(0x7C00));
       else if (dataReg.type == GEN_TYPE_F)
         p->MOV(GenRegister::retype(dataReg, GEN_TYPE_UD), GenRegister::immud(0x7F800000));
       else if (dataReg.type == GEN_TYPE_L)
@@ -1379,6 +1383,8 @@ namespace gbe
         p->MOV(dataReg, GenRegister::immd(0x80000000));
       else if (dataReg.type == GEN_TYPE_UD)
         p->MOV(dataReg, GenRegister::immud(0x0));
+      else if (dataReg.type == GEN_TYPE_HF)
+        p->MOV(GenRegister::retype(dataReg, GEN_TYPE_UW), GenRegister::immuw(0xFC00));
       else if (dataReg.type == GEN_TYPE_F)
         p->MOV(GenRegister::retype(dataReg, GEN_TYPE_UD), GenRegister::immud(0xFF800000));
       else if (dataReg.type == GEN_TYPE_L)
diff --git a/backend/src/backend/gen_insn_selection.cpp b/backend/src/backend/gen_insn_selection.cpp
index e342161..f61e4ed 100644
--- a/backend/src/backend/gen_insn_selection.cpp
+++ b/backend/src/backend/gen_insn_selection.cpp
@@ -5993,7 +5993,8 @@ extern bool OCL_DEBUGINFO; // first defined by calling BVAR in program.cpp
         }
         else {
           GenRegister shiftL = sel.selReg(sel.reg(FAMILY_DWORD), TYPE_U32);
-          sel.SHL(shiftL, src1, GenRegister::immud(0x2));
+          uint32_t SHLimm = typeSize(getGenType(type)) == 2 ? 1 : (typeSize(getGenType(type)) == 4 ? 2 : 8);
+          sel.SHL(shiftL, src1, GenRegister::immud(SHLimm));
           sel.SIMD_SHUFFLE(dst, src0, shiftL);
         }
       }
@@ -6658,7 +6659,8 @@ extern bool OCL_DEBUGINFO; // first defined by calling BVAR in program.cpp
           sel.MOV(dst, reg);
       } else {
         GenRegister shiftL = sel.selReg(sel.reg(FAMILY_DWORD), TYPE_U32);
-        sel.SHL(shiftL, src1, GenRegister::immud(0x2));
+        uint32_t SHLimm = typeSize(getGenType(type)) == 2 ? 1 : (typeSize(getGenType(type)) == 4 ? 2 : 8);
+        sel.SHL(shiftL, src1, GenRegister::immud(SHLimm));
         sel.SIMD_SHUFFLE(dst, src0, shiftL);
       }
       } sel.pop();
diff --git a/backend/src/libocl/tmpl/ocl_simd.tmpl.cl b/backend/src/libocl/tmpl/ocl_simd.tmpl.cl
index a42ddc9..9c09b21 100644
--- a/backend/src/libocl/tmpl/ocl_simd.tmpl.cl
+++ b/backend/src/libocl/tmpl/ocl_simd.tmpl.cl
@@ -55,6 +55,7 @@ BROADCAST_IMPL(int)
 BROADCAST_IMPL(uint)
 BROADCAST_IMPL(long)
 BROADCAST_IMPL(ulong)
+BROADCAST_IMPL(half)
 BROADCAST_IMPL(float)
 BROADCAST_IMPL(double)
 #undef BROADCAST_IMPL
@@ -71,6 +72,7 @@ RANGE_OP(reduce, add, int, true)
 RANGE_OP(reduce, add, uint, false)
 RANGE_OP(reduce, add, long, true)
 RANGE_OP(reduce, add, ulong, false)
+RANGE_OP(reduce, add, half, true)
 RANGE_OP(reduce, add, float, true)
 RANGE_OP(reduce, add, double, true)
 /* reduce min */
@@ -78,6 +80,7 @@ RANGE_OP(reduce, min, int, true)
 RANGE_OP(reduce, min, uint, false)
 RANGE_OP(reduce, min, long, true)
 RANGE_OP(reduce, min, ulong, false)
+RANGE_OP(reduce, min, half, true)
 RANGE_OP(reduce, min, float, true)
 RANGE_OP(reduce, min, double, true)
 /* reduce max */
@@ -85,6 +88,7 @@ RANGE_OP(reduce, max, int, true)
 RANGE_OP(reduce, max, uint, false)
 RANGE_OP(reduce, max, long, true)
 RANGE_OP(reduce, max, ulong, false)
+RANGE_OP(reduce, max, half, true)
 RANGE_OP(reduce, max, float, true)
 RANGE_OP(reduce, max, double, true)
 
@@ -93,6 +97,7 @@ RANGE_OP(scan_inclusive, add, int, true)
 RANGE_OP(scan_inclusive, add, uint, false)
 RANGE_OP(scan_inclusive, add, long, true)
 RANGE_OP(scan_inclusive, add, ulong, false)
+RANGE_OP(scan_inclusive, add, half, true)
 RANGE_OP(scan_inclusive, add, float, true)
 RANGE_OP(scan_inclusive, add, double, true)
 /* scan_inclusive min */
@@ -100,6 +105,7 @@ RANGE_OP(scan_inclusive, min, int, true)
 RANGE_OP(scan_inclusive, min, uint, false)
 RANGE_OP(scan_inclusive, min, long, true)
 RANGE_OP(scan_inclusive, min, ulong, false)
+RANGE_OP(scan_inclusive, min, half, true)
 RANGE_OP(scan_inclusive, min, float, true)
 RANGE_OP(scan_inclusive, min, double, true)
 /* scan_inclusive max */
@@ -107,6 +113,7 @@ RANGE_OP(scan_inclusive, max, int, true)
 RANGE_OP(scan_inclusive, max, uint, false)
 RANGE_OP(scan_inclusive, max, long, true)
 RANGE_OP(scan_inclusive, max, ulong, false)
+RANGE_OP(scan_inclusive, max, half, true)
 RANGE_OP(scan_inclusive, max, float, true)
 RANGE_OP(scan_inclusive, max, double, true)
 
@@ -115,6 +122,7 @@ RANGE_OP(scan_exclusive, add, int, true)
 RANGE_OP(scan_exclusive, add, uint, false)
 RANGE_OP(scan_exclusive, add, long, true)
 RANGE_OP(scan_exclusive, add, ulong, false)
+RANGE_OP(scan_exclusive, add, half, true)
 RANGE_OP(scan_exclusive, add, float, true)
 RANGE_OP(scan_exclusive, add, double, true)
 /* scan_exclusive min */
@@ -122,6 +130,7 @@ RANGE_OP(scan_exclusive, min, int, true)
 RANGE_OP(scan_exclusive, min, uint, false)
 RANGE_OP(scan_exclusive, min, long, true)
 RANGE_OP(scan_exclusive, min, ulong, false)
+RANGE_OP(scan_exclusive, min, half, true)
 RANGE_OP(scan_exclusive, min, float, true)
 RANGE_OP(scan_exclusive, min, double, true)
 /* scan_exclusive max */
@@ -129,6 +138,7 @@ RANGE_OP(scan_exclusive, max, int, true)
 RANGE_OP(scan_exclusive, max, uint, false)
 RANGE_OP(scan_exclusive, max, long, true)
 RANGE_OP(scan_exclusive, max, ulong, false)
+RANGE_OP(scan_exclusive, max, half, true)
 RANGE_OP(scan_exclusive, max, float, true)
 RANGE_OP(scan_exclusive, max, double, true)
 
diff --git a/backend/src/libocl/tmpl/ocl_simd.tmpl.h b/backend/src/libocl/tmpl/ocl_simd.tmpl.h
index 15da0e7..ae3b379 100644
--- a/backend/src/libocl/tmpl/ocl_simd.tmpl.h
+++ b/backend/src/libocl/tmpl/ocl_simd.tmpl.h
@@ -39,6 +39,7 @@ OVERLOADABLE int sub_group_broadcast(int a, size_t local_id);
 OVERLOADABLE uint sub_group_broadcast(uint a, size_t local_id);
 OVERLOADABLE long sub_group_broadcast(long a, size_t local_id);
 OVERLOADABLE ulong sub_group_broadcast(ulong a, size_t local_id);
+OVERLOADABLE half sub_group_broadcast(half a, size_t local_id);
 OVERLOADABLE float sub_group_broadcast(float a, size_t local_id);
 OVERLOADABLE double sub_group_broadcast(double a, size_t local_id);
 
@@ -46,6 +47,7 @@ OVERLOADABLE int sub_group_broadcast(int a, size_t local_id_x, size_t local_id_y
 OVERLOADABLE uint sub_group_broadcast(uint a, size_t local_id_x, size_t local_id_y);
 OVERLOADABLE long sub_group_broadcast(long a, size_t local_id_x, size_t local_id_y);
 OVERLOADABLE ulong sub_group_broadcast(ulong a, size_t local_id_x, size_t local_id_y);
+OVERLOADABLE half sub_group_broadcast(half a, size_t local_id_x, size_t local_id_y);
 OVERLOADABLE float sub_group_broadcast(float a, size_t local_id_x, size_t local_id_y);
 OVERLOADABLE double sub_group_broadcast(double a, size_t local_id_x, size_t local_id_y);
 
@@ -53,6 +55,7 @@ OVERLOADABLE int sub_group_broadcast(int a, size_t local_id_x, size_t local_id_y
 OVERLOADABLE uint sub_group_broadcast(uint a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
 OVERLOADABLE long sub_group_broadcast(long a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
 OVERLOADABLE ulong sub_group_broadcast(ulong a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
+OVERLOADABLE half sub_group_broadcast(half a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
 OVERLOADABLE float sub_group_broadcast(float a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
 OVERLOADABLE double sub_group_broadcast(double a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
 
@@ -61,6 +64,7 @@ OVERLOADABLE int sub_group_reduce_add(int x);
 OVERLOADABLE uint sub_group_reduce_add(uint x);
 OVERLOADABLE long sub_group_reduce_add(long x);
 OVERLOADABLE ulong sub_group_reduce_add(ulong x);
+OVERLOADABLE half sub_group_reduce_add(half x);
 OVERLOADABLE float sub_group_reduce_add(float x);
 OVERLOADABLE double sub_group_reduce_add(double x);
 
@@ -69,6 +73,7 @@ OVERLOADABLE int sub_group_reduce_min(int x);
 OVERLOADABLE uint sub_group_reduce_min(uint x);
 OVERLOADABLE long sub_group_reduce_min(long x);
 OVERLOADABLE ulong sub_group_reduce_min(ulong x);
+OVERLOADABLE half sub_group_reduce_min(half x);
 OVERLOADABLE float sub_group_reduce_min(float x);
 OVERLOADABLE double sub_group_reduce_min(double x);
 
@@ -77,6 +82,7 @@ OVERLOADABLE int sub_group_reduce_max(int x);
 OVERLOADABLE uint sub_group_reduce_max(uint x);
 OVERLOADABLE long sub_group_reduce_max(long x);
 OVERLOADABLE ulong sub_group_reduce_max(ulong x);
+OVERLOADABLE half sub_group_reduce_max(half x);
 OVERLOADABLE float sub_group_reduce_max(float x);
 OVERLOADABLE double sub_group_reduce_max(double x);
 
@@ -85,6 +91,7 @@ OVERLOADABLE int sub_group_scan_inclusive_add(int x);
 OVERLOADABLE uint sub_group_scan_inclusive_add(uint x);
 OVERLOADABLE long sub_group_scan_inclusive_add(long x);
 OVERLOADABLE ulong sub_group_scan_inclusive_add(ulong x);
+OVERLOADABLE half sub_group_scan_inclusive_add(half x);
 OVERLOADABLE float sub_group_scan_inclusive_add(float x);
 OVERLOADABLE double sub_group_scan_inclusive_add(double x);
 
@@ -93,6 +100,7 @@ OVERLOADABLE int sub_group_scan_inclusive_min(int x);
 OVERLOADABLE uint sub_group_scan_inclusive_min(uint x);
 OVERLOADABLE long sub_group_scan_inclusive_min(long x);
 OVERLOADABLE ulong sub_group_scan_inclusive_min(ulong x);
+OVERLOADABLE half sub_group_scan_inclusive_min(half x);
 OVERLOADABLE float sub_group_scan_inclusive_min(float x);
 OVERLOADABLE double sub_group_scan_inclusive_min(double x);
 
@@ -101,6 +109,7 @@ OVERLOADABLE int sub_group_scan_inclusive_max(int x);
 OVERLOADABLE uint sub_group_scan_inclusive_max(uint x);
 OVERLOADABLE long sub_group_scan_inclusive_max(long x);
 OVERLOADABLE ulong sub_group_scan_inclusive_max(ulong x);
+OVERLOADABLE half sub_group_scan_inclusive_max(half x);
 OVERLOADABLE float sub_group_scan_inclusive_max(float x);
 OVERLOADABLE double sub_group_scan_inclusive_max(double x);
 
@@ -109,6 +118,7 @@ OVERLOADABLE int sub_group_scan_exclusive_add(int x);
 OVERLOADABLE uint sub_group_scan_exclusive_add(uint x);
 OVERLOADABLE long sub_group_scan_exclusive_add(long x);
 OVERLOADABLE ulong sub_group_scan_exclusive_add(ulong x);
+OVERLOADABLE half sub_group_scan_exclusive_add(half x);
 OVERLOADABLE float sub_group_scan_exclusive_add(float x);
 OVERLOADABLE double sub_group_scan_exclusive_add(double x);
 
@@ -117,6 +127,7 @@ OVERLOADABLE int sub_group_scan_exclusive_min(int x);
 OVERLOADABLE uint sub_group_scan_exclusive_min(uint x);
 OVERLOADABLE long sub_group_scan_exclusive_min(long x);
 OVERLOADABLE ulong sub_group_scan_exclusive_min(ulong x);
+OVERLOADABLE half sub_group_scan_exclusive_min(half x);
 OVERLOADABLE float sub_group_scan_exclusive_min(float x);
 OVERLOADABLE double sub_group_scan_exclusive_min(double x);
 
@@ -125,10 +136,12 @@ OVERLOADABLE int sub_group_scan_exclusive_max(int x);
 OVERLOADABLE uint sub_group_scan_exclusive_max(uint x);
 OVERLOADABLE long sub_group_scan_exclusive_max(long x);
 OVERLOADABLE ulong sub_group_scan_exclusive_max(ulong x);
+OVERLOADABLE half sub_group_scan_exclusive_max(half x);
 OVERLOADABLE float sub_group_scan_exclusive_max(float x);
 OVERLOADABLE double sub_group_scan_exclusive_max(double x);
 
 /* shuffle */
+OVERLOADABLE half intel_sub_group_shuffle(half x, uint c);
 OVERLOADABLE float intel_sub_group_shuffle(float x, uint c);
 OVERLOADABLE int intel_sub_group_shuffle(int x, uint c);
 OVERLOADABLE uint intel_sub_group_shuffle(uint x, uint c);
-- 
2.7.4



More information about the Beignet mailing list