[Beignet] [PATCH 07/12] Backend: Add sub_group built-in functions for intel extension

Xiuli Pan xiuli.pan at intel.com
Thu May 26 03:14:22 UTC 2016


From: Pan Xiuli <xiuli.pan at intel.com>

Add sub_group_reduce/exclusive/inclusive_max/min/add builtin functions.
They share the in thread algorithm of work group functions.

Signed-off-by: Pan Xiuli <xiuli.pan at intel.com>
---
 backend/src/backend/gen8_context.cpp               |  23 ++++
 backend/src/backend/gen8_context.hpp               |   1 +
 backend/src/backend/gen_context.cpp                |  23 ++++
 backend/src/backend/gen_context.hpp                |   1 +
 .../src/backend/gen_insn_gen7_schedule_info.hxx    |   1 +
 backend/src/backend/gen_insn_selection.cpp         | 116 +++++++++++++++++
 backend/src/backend/gen_insn_selection.hxx         |   1 +
 backend/src/ir/instruction.cpp                     | 144 +++++++++++++++++++++
 backend/src/ir/instruction.hpp                     |  11 ++
 backend/src/ir/instruction.hxx                     |   1 +
 backend/src/libocl/tmpl/ocl_simd.tmpl.cl           |  98 ++++++++++++++
 backend/src/libocl/tmpl/ocl_simd.tmpl.h            |  95 ++++++++++++++
 backend/src/llvm/llvm_gen_backend.cpp              |  74 +++++++++++
 backend/src/llvm/llvm_gen_ocl_function.hxx         |  15 +++
 14 files changed, 604 insertions(+)

diff --git a/backend/src/backend/gen8_context.cpp b/backend/src/backend/gen8_context.cpp
index 477b22b..7ddb95a 100644
--- a/backend/src/backend/gen8_context.cpp
+++ b/backend/src/backend/gen8_context.cpp
@@ -1845,4 +1845,27 @@ namespace gbe
     }
   }
 
+  void Gen8Context::emitSubGroupOpInstruction(const SelectionInstruction &insn){
+    const GenRegister dst = ra->genReg(insn.dst(0));
+    const GenRegister tmp = GenRegister::retype(ra->genReg(insn.dst(1)), dst.type);
+    const GenRegister theVal = GenRegister::retype(ra->genReg(insn.src(0)), dst.type);
+    GenRegister threadData = ra->genReg(insn.src(1));
+
+    uint32_t wg_op = insn.extra.workgroupOp;
+    uint32_t simd = p->curr.execWidth;
+
+    /* masked elements should be properly set to init value */
+    p->push(); {
+      p->curr.noMask = 1;
+      wgOpInitValue(p, tmp, wg_op);
+      p->curr.noMask = 0;
+      p->MOV(tmp, theVal);
+      p->curr.noMask = 1;
+      p->MOV(theVal, tmp);
+    } p->pop();
+
+    /* do some calculation within each thread */
+    wgOpPerformThread(dst, theVal, threadData, tmp, simd, wg_op, p);
+  }
+
 }
diff --git a/backend/src/backend/gen8_context.hpp b/backend/src/backend/gen8_context.hpp
index 771e20b..ec1358c 100644
--- a/backend/src/backend/gen8_context.hpp
+++ b/backend/src/backend/gen8_context.hpp
@@ -77,6 +77,7 @@ namespace gbe
     virtual void emitF64DIVInstruction(const SelectionInstruction &insn);
 
     virtual void emitWorkGroupOpInstruction(const SelectionInstruction &insn);
+    virtual void emitSubGroupOpInstruction(const SelectionInstruction &insn);
 
     static GenRegister unpacked_ud(GenRegister reg, uint32_t offset = 0);
 
diff --git a/backend/src/backend/gen_context.cpp b/backend/src/backend/gen_context.cpp
index 4e24816..4d0a3f3 100644
--- a/backend/src/backend/gen_context.cpp
+++ b/backend/src/backend/gen_context.cpp
@@ -3374,6 +3374,29 @@ namespace gbe
     }
   }
 
+  void GenContext::emitSubGroupOpInstruction(const SelectionInstruction &insn){
+    const GenRegister dst = ra->genReg(insn.dst(0));
+    const GenRegister tmp = GenRegister::retype(ra->genReg(insn.dst(1)), dst.type);
+    const GenRegister theVal = GenRegister::retype(ra->genReg(insn.src(0)), dst.type);
+    GenRegister threadData = ra->genReg(insn.src(1));
+
+    uint32_t wg_op = insn.extra.workgroupOp;
+    uint32_t simd = p->curr.execWidth;
+
+    /* masked elements should be properly set to init value */
+    p->push(); {
+      p->curr.noMask = 1;
+      wgOpInitValue(p, tmp, wg_op);
+      p->curr.noMask = 0;
+      p->MOV(tmp, theVal);
+      p->curr.noMask = 1;
+      p->MOV(theVal, tmp);
+    } p->pop();
+
+    /* do some calculation within each thread */
+    wgOpPerformThread(dst, theVal, threadData, tmp, simd, wg_op, p);
+  }
+
   void GenContext::emitPrintfLongInstruction(GenRegister& addr, GenRegister& data,
                                              GenRegister& src, uint32_t bti) {
     p->MOV(GenRegister::retype(data, GEN_TYPE_UD), src.bottom_half());
diff --git a/backend/src/backend/gen_context.hpp b/backend/src/backend/gen_context.hpp
index ebc55e6..4c43ccb 100644
--- a/backend/src/backend/gen_context.hpp
+++ b/backend/src/backend/gen_context.hpp
@@ -181,6 +181,7 @@ namespace gbe
     void emitCalcTimestampInstruction(const SelectionInstruction &insn);
     void emitStoreProfilingInstruction(const SelectionInstruction &insn);
     virtual void emitWorkGroupOpInstruction(const SelectionInstruction &insn);
+    virtual void emitSubGroupOpInstruction(const SelectionInstruction &insn);
     void emitPrintfInstruction(const SelectionInstruction &insn);
     void scratchWrite(const GenRegister header, uint32_t offset, uint32_t reg_num, uint32_t reg_type, uint32_t channel_mode);
     void scratchRead(const GenRegister dst, const GenRegister header, uint32_t offset, uint32_t reg_num, uint32_t reg_type, uint32_t channel_mode);
diff --git a/backend/src/backend/gen_insn_gen7_schedule_info.hxx b/backend/src/backend/gen_insn_gen7_schedule_info.hxx
index 112df32..cb5c4f1 100644
--- a/backend/src/backend/gen_insn_gen7_schedule_info.hxx
+++ b/backend/src/backend/gen_insn_gen7_schedule_info.hxx
@@ -48,4 +48,5 @@ DECL_GEN7_SCHEDULE(F64DIV,          20,        40,      20)
 DECL_GEN7_SCHEDULE(CalcTimestamp,   80,        1,        1)
 DECL_GEN7_SCHEDULE(StoreProfiling,  80,        1,        1)
 DECL_GEN7_SCHEDULE(WorkGroupOp,     80,        1,        1)
+DECL_GEN7_SCHEDULE(SubGroupOp,      80,        1,        1)
 DECL_GEN7_SCHEDULE(Printf,          80,        1,        1)
diff --git a/backend/src/backend/gen_insn_selection.cpp b/backend/src/backend/gen_insn_selection.cpp
index 09f459a..855c39d 100644
--- a/backend/src/backend/gen_insn_selection.cpp
+++ b/backend/src/backend/gen_insn_selection.cpp
@@ -694,6 +694,9 @@ namespace gbe
                       GenRegister tmpData2, GenRegister slmOff,
                       vector<GenRegister> msg, uint32_t msgSizeReq,
                       GenRegister localBarrier);
+    /*! Sub Group Operations */
+    void SUBGROUP_OP(uint32_t wg_op, Reg dst, GenRegister src,
+                      GenRegister tmpData1, GenRegister tmpData2);
     /* common functions for both binary instruction and sel_cmp and compare instruction.
        It will handle the IMM or normal register assignment, and will try to avoid LOADI
        as much as possible. */
@@ -1995,6 +1998,23 @@ namespace gbe
     insn->src(5) = localBarrier;
   }
 
+  void Selection::Opaque::SUBGROUP_OP(uint32_t wg_op,
+                                       Reg dst,
+                                       GenRegister src,
+                                       GenRegister tmpData1,
+                                       GenRegister tmpData2)
+  {
+    SelectionInstruction *insn = this->appendInsn(SEL_OP_SUBGROUP_OP, 2, 2);
+
+    insn->extra.workgroupOp = wg_op;
+
+    insn->dst(0) = dst;
+    insn->dst(1) = tmpData1;
+
+    insn->src(0) = src;
+    insn->src(1) = tmpData2;
+  }
+
   // Boiler plate to initialize the selection library at c++ pre-main
   static SelectionLibrary *selLib = NULL;
   static void destroySelectionLibrary(void) { GBE_DELETE(selLib); }
@@ -6399,6 +6419,101 @@ extern bool OCL_DEBUGINFO; // first defined by calling BVAR in program.cpp
     DECL_CTOR(WorkGroupInstruction, 1, 1);
   };
 
+  /*! SubGroup instruction pattern */
+  class SubGroupInstructionPattern : public SelectionPattern
+  {
+  public:
+    SubGroupInstructionPattern(void) : SelectionPattern(1,1) {
+      for (uint32_t op = 0; op < ir::OP_INVALID; ++op)
+        if (ir::isOpcodeFrom<ir::SubGroupInstruction>(ir::Opcode(op)) == true)
+          this->opcodes.push_back(ir::Opcode(op));
+    }
+
+    /* SUBGROUP OP: ALL, ANY, REDUCE, SCAN INCLUSIVE, SCAN EXCLUSIVE
+     * Shared algorithm with workgroup inthread */
+    INLINE bool emitSGReduce(Selection::Opaque &sel, const ir::SubGroupInstruction &insn) const
+    {
+      using namespace ir;
+
+      GBE_ASSERT(insn.getSrcNum() == 1);
+
+      const WorkGroupOps workGroupOp = insn.getWorkGroupOpcode();
+      const Type type = insn.getType();
+      GenRegister dst = sel.selReg(insn.getDst(0), type);
+      GenRegister src = sel.selReg(insn.getSrc(0), type);
+      GenRegister tmpData1 = GenRegister::retype(sel.selReg(sel.reg(FAMILY_QWORD)), type);
+      GenRegister tmpData2 = GenRegister::retype(sel.selReg(sel.reg(FAMILY_QWORD)), type);
+
+      /* Perform workgroup op */
+      sel.SUBGROUP_OP(workGroupOp, dst, src, tmpData1, tmpData2);
+
+      return true;
+    }
+
+    /* SUBROUP OP: BROADCAST
+     * Shared algorithm with simd shuffle */
+    INLINE bool emitSGBroadcast(Selection::Opaque &sel, const ir::SubGroupInstruction &insn, SelectionDAG &dag) const
+    {
+      using namespace ir;
+
+      GBE_ASSERT(insn.getSrcNum() == 2);
+
+      const Type type = insn.getType();
+      const GenRegister src0 = sel.selReg(insn.getSrc(0), type);
+      const GenRegister dst = sel.selReg(insn.getDst(0), type);
+      GenRegister src1;
+
+      SelectionDAG *dag0 = dag.child[0];
+      SelectionDAG *dag1 = dag.child[1];
+      if (dag1 != NULL && dag1->insn.getOpcode() == OP_LOADI && canGetRegisterFromImmediate(dag1->insn)) {
+        const auto &childInsn = cast<LoadImmInstruction>(dag1->insn);
+        src1 = getRegisterFromImmediate(childInsn.getImmediate(), TYPE_U32);
+        if (dag0) dag0->isRoot = 1;
+      } else {
+        markAllChildren(dag);
+        src1 = sel.selReg(insn.getSrc(1), TYPE_U32);
+      }
+
+      sel.push(); {
+      if (src1.file == GEN_IMMEDIATE_VALUE) {
+          uint32_t offset = src1.value.ud % sel.curr.execWidth;
+          GenRegister reg = GenRegister::subphysicaloffset(src0, offset);
+          reg.vstride = GEN_VERTICAL_STRIDE_0;
+          reg.hstride = GEN_HORIZONTAL_STRIDE_0;
+          reg.width = GEN_WIDTH_1;
+          sel.MOV(dst, reg);
+      } else {
+        GenRegister shiftL = sel.selReg(sel.reg(FAMILY_DWORD), TYPE_U32);
+        sel.SHL(shiftL, src1, GenRegister::immud(0x2));
+        sel.SIMD_SHUFFLE(dst, src0, shiftL);
+      }
+      } sel.pop();
+
+      return true;
+    }
+
+    INLINE bool emit(Selection::Opaque &sel, SelectionDAG &dag) const
+    {
+      using namespace ir;
+      const ir::SubGroupInstruction &insn = cast<SubGroupInstruction>(dag.insn);
+      const WorkGroupOps workGroupOp = insn.getWorkGroupOpcode();
+
+      if (workGroupOp == WORKGROUP_OP_BROADCAST){
+        return emitSGBroadcast(sel, insn, dag);
+      }
+      else if (workGroupOp >= WORKGROUP_OP_ANY && workGroupOp <= WORKGROUP_OP_EXCLUSIVE_MAX){
+        if(emitSGReduce(sel, insn))
+          markAllChildren(dag);
+        else
+          return false;
+      }
+      else
+        GBE_ASSERT(0);
+
+      return true;
+    }
+  };
+
   /*! Sort patterns */
   INLINE bool cmp(const SelectionPattern *p0, const SelectionPattern *p1) {
     if (p0->insnNum != p1->insnNum)
@@ -6436,6 +6551,7 @@ extern bool OCL_DEBUGINFO; // first defined by calling BVAR in program.cpp
     this->insert<CalcTimestampInstructionPattern>();
     this->insert<StoreProfilingInstructionPattern>();
     this->insert<WorkGroupInstructionPattern>();
+    this->insert<SubGroupInstructionPattern>();
     this->insert<NullaryInstructionPattern>();
     this->insert<WaitInstructionPattern>();
     this->insert<PrintfInstructionPattern>();
diff --git a/backend/src/backend/gen_insn_selection.hxx b/backend/src/backend/gen_insn_selection.hxx
index 4352490..0e11f9f 100644
--- a/backend/src/backend/gen_insn_selection.hxx
+++ b/backend/src/backend/gen_insn_selection.hxx
@@ -94,4 +94,5 @@ DECL_SELECTION_IR(F64DIV, F64DIVInstruction)
 DECL_SELECTION_IR(CALC_TIMESTAMP, CalcTimestampInstruction)
 DECL_SELECTION_IR(STORE_PROFILING, StoreProfilingInstruction)
 DECL_SELECTION_IR(WORKGROUP_OP, WorkGroupOpInstruction)
+DECL_SELECTION_IR(SUBGROUP_OP, SubGroupOpInstruction)
 DECL_SELECTION_IR(PRINTF, PrintfInstruction)
diff --git a/backend/src/ir/instruction.cpp b/backend/src/ir/instruction.cpp
index d9051ab..47606b2 100644
--- a/backend/src/ir/instruction.cpp
+++ b/backend/src/ir/instruction.cpp
@@ -994,6 +994,33 @@ namespace ir {
         Register dst[1];
     };
 
+    class ALIGNED_INSTRUCTION SubGroupInstruction :
+      public BasePolicy,
+      public TupleSrcPolicy<SubGroupInstruction>,
+      public NDstPolicy<SubGroupInstruction, 1>
+    {
+      public:
+        INLINE SubGroupInstruction(WorkGroupOps opcode, Register dst,
+            Tuple srcTuple, uint8_t srcNum, Type type) {
+          this->opcode = OP_SUBGROUP;
+          this->workGroupOp = opcode;
+          this->type = type;
+          this->dst[0] = dst;
+          this->src = srcTuple;
+          this->srcNum = srcNum;
+        }
+        INLINE Type getType(void) const { return this->type; }
+        INLINE bool wellFormed(const Function &fn, std::string &whyNot) const;
+        INLINE void out(std::ostream &out, const Function &fn) const;
+        INLINE WorkGroupOps getWorkGroupOpcode(void) const { return this->workGroupOp; }
+
+        WorkGroupOps workGroupOp:5;
+        uint32_t srcNum:3;          //!< Source Number
+        Type type;                  //!< Type of the instruction
+        Tuple src;
+        Register dst[1];
+    };
+
     class ALIGNED_INSTRUCTION PrintfInstruction :
       public BasePolicy,
       public TupleSrcPolicy<PrintfInstruction>,
@@ -1505,6 +1532,52 @@ namespace ir {
       return true;
     }
 
+    INLINE bool SubGroupInstruction::wellFormed(const Function &fn, std::string &whyNot) const {
+      const RegisterFamily family = getFamily(this->type);
+
+      if (UNLIKELY(checkSpecialRegForWrite(dst[0], fn, whyNot) == false))
+        return false;
+      if (UNLIKELY(checkRegisterData(family, dst[0], fn, whyNot) == false))
+        return false;
+
+      switch (this->workGroupOp) {
+        case WORKGROUP_OP_ANY:
+        case WORKGROUP_OP_ALL:
+        case WORKGROUP_OP_REDUCE_ADD:
+        case WORKGROUP_OP_REDUCE_MIN:
+        case WORKGROUP_OP_REDUCE_MAX:
+        case WORKGROUP_OP_INCLUSIVE_ADD:
+        case WORKGROUP_OP_INCLUSIVE_MIN:
+        case WORKGROUP_OP_INCLUSIVE_MAX:
+        case WORKGROUP_OP_EXCLUSIVE_ADD:
+        case WORKGROUP_OP_EXCLUSIVE_MIN:
+        case WORKGROUP_OP_EXCLUSIVE_MAX:
+          if (this->srcNum != 1) {
+            whyNot = "Wrong number of source.";
+            return false;
+          }
+          break;
+        case WORKGROUP_OP_BROADCAST:
+          if (this->srcNum != 2) {
+            whyNot = "Wrong number of source.";
+            return false;
+          } else {
+            const RegisterFamily fam = fn.getPointerFamily();
+            for (uint32_t srcID = 1; srcID < this->srcNum; ++srcID) {
+              const Register regID = fn.getRegister(src, srcID);
+              if (UNLIKELY(checkRegisterData(fam, regID, fn, whyNot) == false))
+                return false;
+            }
+          }
+          break;
+        default:
+          whyNot = "No such sub group function.";
+          return false;
+      }
+
+      return true;
+    }
+
     INLINE bool PrintfInstruction::wellFormed(const Function &fn, std::string &whyNot) const {
       return true;
     }
@@ -1739,6 +1812,67 @@ namespace ir {
       out << "TheadID Map at SLM: " << this->slmAddr;
     }
 
+    INLINE void SubGroupInstruction::out(std::ostream &out, const Function &fn) const {
+      this->outOpcode(out);
+
+      switch (this->workGroupOp) {
+        case WORKGROUP_OP_ANY:
+          out << "_" << "ANY";
+          break;
+        case WORKGROUP_OP_ALL:
+          out << "_" << "ALL";
+          break;
+        case WORKGROUP_OP_REDUCE_ADD:
+          out << "_" << "REDUCE_ADD";
+          break;
+        case WORKGROUP_OP_REDUCE_MIN:
+          out << "_" << "REDUCE_MIN";
+          break;
+        case WORKGROUP_OP_REDUCE_MAX:
+          out << "_" << "REDUCE_MAX";
+          break;
+        case WORKGROUP_OP_INCLUSIVE_ADD:
+          out << "_" << "INCLUSIVE_ADD";
+          break;
+        case WORKGROUP_OP_INCLUSIVE_MIN:
+          out << "_" << "INCLUSIVE_MIN";
+          break;
+        case WORKGROUP_OP_INCLUSIVE_MAX:
+          out << "_" << "INCLUSIVE_MAX";
+          break;
+        case WORKGROUP_OP_EXCLUSIVE_ADD:
+          out << "_" << "EXCLUSIVE_ADD";
+          break;
+        case WORKGROUP_OP_EXCLUSIVE_MIN:
+          out << "_" << "EXCLUSIVE_MIN";
+          break;
+        case WORKGROUP_OP_EXCLUSIVE_MAX:
+          out << "_" << "EXCLUSIVE_MAX";
+          break;
+        case WORKGROUP_OP_BROADCAST:
+          out << "_" << "BROADCAST";
+          break;
+        default:
+          GBE_ASSERT(0);
+      }
+
+      out << " %" << this->getDst(fn, 0);
+      out << " %" << this->getSrc(fn, 0);
+
+      if (this->workGroupOp == WORKGROUP_OP_BROADCAST) {
+        do {
+          int localN = srcNum - 1;
+          GBE_ASSERT(localN);
+          out << " Local ID:";
+          out << " %" << this->getSrc(fn, 1);
+          localN--;
+          if (!localN)
+            break;
+        } while(0);
+      }
+
+    }
+
     INLINE void PrintfInstruction::out(std::ostream &out, const Function &fn) const {
       this->outOpcode(out);
     }
@@ -1903,6 +2037,10 @@ START_INTROSPECTION(WorkGroupInstruction)
 #include "ir/instruction.hxx"
 END_INTROSPECTION(WorkGroupInstruction)
 
+START_INTROSPECTION(SubGroupInstruction)
+#include "ir/instruction.hxx"
+END_INTROSPECTION(SubGroupInstruction)
+
 START_INTROSPECTION(PrintfInstruction)
 #include "ir/instruction.hxx"
 END_INTROSPECTION(PrintfInstruction)
@@ -2117,6 +2255,8 @@ DECL_MEM_FN(StoreProfilingInstruction, uint32_t, getBTI(void), getBTI())
 DECL_MEM_FN(WorkGroupInstruction, Type, getType(void), getType())
 DECL_MEM_FN(WorkGroupInstruction, WorkGroupOps, getWorkGroupOpcode(void), getWorkGroupOpcode())
 DECL_MEM_FN(WorkGroupInstruction, uint32_t, getSlmAddr(void), getSlmAddr())
+DECL_MEM_FN(SubGroupInstruction, Type, getType(void), getType())
+DECL_MEM_FN(SubGroupInstruction, WorkGroupOps, getWorkGroupOpcode(void), getWorkGroupOpcode())
 DECL_MEM_FN(PrintfInstruction, uint32_t, getNum(void), getNum())
 DECL_MEM_FN(PrintfInstruction, uint32_t, getBti(void), getBti())
 DECL_MEM_FN(PrintfInstruction, Type, getType(const Function& fn, uint32_t ID), getType(fn, ID))
@@ -2418,6 +2558,10 @@ DECL_MEM_FN(MemInstruction, void,     setBtiReg(Register reg), setBtiReg(reg))
     return internal::WorkGroupInstruction(opcode, slmAddr, dst, srcTuple, srcNum, type).convert();
   }
 
+  Instruction SUBGROUP(WorkGroupOps opcode, Register dst, Tuple srcTuple, uint8_t srcNum, Type type) {
+    return internal::SubGroupInstruction(opcode, dst, srcTuple, srcNum, type).convert();
+  }
+
   Instruction PRINTF(Register dst, Tuple srcTuple, Tuple typeTuple, uint8_t srcNum, uint8_t bti, uint16_t num) {
     return internal::PrintfInstruction(dst, srcTuple, typeTuple, srcNum, bti, num).convert();
   }
diff --git a/backend/src/ir/instruction.hpp b/backend/src/ir/instruction.hpp
index bbdef91..a605f45 100644
--- a/backend/src/ir/instruction.hpp
+++ b/backend/src/ir/instruction.hpp
@@ -611,6 +611,15 @@ namespace ir {
     uint32_t getSlmAddr(void) const;
   };
 
+  /*! Related to Sub Group. */
+  class SubGroupInstruction : public Instruction {
+  public:
+    /*! Return true if the given instruction is an instance of this class */
+    static bool isClassOf(const Instruction &insn);
+    Type getType(void) const;
+    WorkGroupOps getWorkGroupOpcode(void) const;
+  };
+
   /*! Printf instruction. */
   class PrintfInstruction : public Instruction {
   public:
@@ -850,6 +859,8 @@ namespace ir {
 
   /*! work group */
   Instruction WORKGROUP(WorkGroupOps opcode, uint32_t slmAddr, Register dst, Tuple srcTuple, uint8_t srcNum, Type type);
+  /*! sub group */
+  Instruction SUBGROUP(WorkGroupOps opcode, Register dst, Tuple srcTuple, uint8_t srcNum, Type type);
   /*! printf */
   Instruction PRINTF(Register dst, Tuple srcTuple, Tuple typeTuple, uint8_t srcNum, uint8_t bti, uint16_t num);
 } /* namespace ir */
diff --git a/backend/src/ir/instruction.hxx b/backend/src/ir/instruction.hxx
index 651ed64..57e13eb 100644
--- a/backend/src/ir/instruction.hxx
+++ b/backend/src/ir/instruction.hxx
@@ -112,4 +112,5 @@ DECL_INSN(CALC_TIMESTAMP, CalcTimestampInstruction)
 DECL_INSN(STORE_PROFILING, StoreProfilingInstruction)
 DECL_INSN(WAIT, WaitInstruction)
 DECL_INSN(WORKGROUP, WorkGroupInstruction)
+DECL_INSN(SUBGROUP, SubGroupInstruction)
 DECL_INSN(PRINTF, PrintfInstruction)
diff --git a/backend/src/libocl/tmpl/ocl_simd.tmpl.cl b/backend/src/libocl/tmpl/ocl_simd.tmpl.cl
index c2e22c1..a25dcef 100644
--- a/backend/src/libocl/tmpl/ocl_simd.tmpl.cl
+++ b/backend/src/libocl/tmpl/ocl_simd.tmpl.cl
@@ -35,3 +35,101 @@ uint get_sub_group_size(void)
   else
     return get_max_sub_group_size();
 }
+
+/* broadcast */
+#define BROADCAST_IMPL(GEN_TYPE) \
+    OVERLOADABLE GEN_TYPE __gen_ocl_sub_group_broadcast(GEN_TYPE a, size_t local_id); \
+    OVERLOADABLE GEN_TYPE sub_group_broadcast(GEN_TYPE a, size_t local_id) { \
+      return __gen_ocl_sub_group_broadcast(a, local_id); \
+    } \
+    OVERLOADABLE GEN_TYPE __gen_ocl_sub_group_broadcast(GEN_TYPE a, size_t local_id_x, size_t local_id_y); \
+    OVERLOADABLE GEN_TYPE sub_group_broadcast(GEN_TYPE a, size_t local_id_x, size_t local_id_y) { \
+      return __gen_ocl_sub_group_broadcast(a, local_id_x, local_id_y);  \
+    } \
+    OVERLOADABLE GEN_TYPE __gen_ocl_sub_group_broadcast(GEN_TYPE a, size_t local_id_x, size_t local_id_y, size_t local_id_z); \
+    OVERLOADABLE GEN_TYPE sub_group_broadcast(GEN_TYPE a, size_t local_id_x, size_t local_id_y, size_t local_id_z) { \
+      return __gen_ocl_sub_group_broadcast(a, local_id_x, local_id_y, local_id_z); \
+    }
+
+BROADCAST_IMPL(int)
+BROADCAST_IMPL(uint)
+BROADCAST_IMPL(long)
+BROADCAST_IMPL(ulong)
+BROADCAST_IMPL(float)
+BROADCAST_IMPL(double)
+#undef BROADCAST_IMPL
+
+
+#define RANGE_OP(RANGE, OP, GEN_TYPE, SIGN) \
+    OVERLOADABLE GEN_TYPE __gen_ocl_sub_group_##RANGE##_##OP(bool sign, GEN_TYPE x); \
+    OVERLOADABLE GEN_TYPE sub_group_##RANGE##_##OP(GEN_TYPE x) { \
+      return __gen_ocl_sub_group_##RANGE##_##OP(SIGN, x);  \
+    }
+
+/* reduce add */
+RANGE_OP(reduce, add, int, true)
+RANGE_OP(reduce, add, uint, false)
+RANGE_OP(reduce, add, long, true)
+RANGE_OP(reduce, add, ulong, false)
+RANGE_OP(reduce, add, float, true)
+RANGE_OP(reduce, add, double, true)
+/* reduce min */
+RANGE_OP(reduce, min, int, true)
+RANGE_OP(reduce, min, uint, false)
+RANGE_OP(reduce, min, long, true)
+RANGE_OP(reduce, min, ulong, false)
+RANGE_OP(reduce, min, float, true)
+RANGE_OP(reduce, min, double, true)
+/* reduce max */
+RANGE_OP(reduce, max, int, true)
+RANGE_OP(reduce, max, uint, false)
+RANGE_OP(reduce, max, long, true)
+RANGE_OP(reduce, max, ulong, false)
+RANGE_OP(reduce, max, float, true)
+RANGE_OP(reduce, max, double, true)
+
+/* scan_inclusive add */
+RANGE_OP(scan_inclusive, add, int, true)
+RANGE_OP(scan_inclusive, add, uint, false)
+RANGE_OP(scan_inclusive, add, long, true)
+RANGE_OP(scan_inclusive, add, ulong, false)
+RANGE_OP(scan_inclusive, add, float, true)
+RANGE_OP(scan_inclusive, add, double, true)
+/* scan_inclusive min */
+RANGE_OP(scan_inclusive, min, int, true)
+RANGE_OP(scan_inclusive, min, uint, false)
+RANGE_OP(scan_inclusive, min, long, true)
+RANGE_OP(scan_inclusive, min, ulong, false)
+RANGE_OP(scan_inclusive, min, float, true)
+RANGE_OP(scan_inclusive, min, double, true)
+/* scan_inclusive max */
+RANGE_OP(scan_inclusive, max, int, true)
+RANGE_OP(scan_inclusive, max, uint, false)
+RANGE_OP(scan_inclusive, max, long, true)
+RANGE_OP(scan_inclusive, max, ulong, false)
+RANGE_OP(scan_inclusive, max, float, true)
+RANGE_OP(scan_inclusive, max, double, true)
+
+/* scan_exclusive add */
+RANGE_OP(scan_exclusive, add, int, true)
+RANGE_OP(scan_exclusive, add, uint, false)
+RANGE_OP(scan_exclusive, add, long, true)
+RANGE_OP(scan_exclusive, add, ulong, false)
+RANGE_OP(scan_exclusive, add, float, true)
+RANGE_OP(scan_exclusive, add, double, true)
+/* scan_exclusive min */
+RANGE_OP(scan_exclusive, min, int, true)
+RANGE_OP(scan_exclusive, min, uint, false)
+RANGE_OP(scan_exclusive, min, long, true)
+RANGE_OP(scan_exclusive, min, ulong, false)
+RANGE_OP(scan_exclusive, min, float, true)
+RANGE_OP(scan_exclusive, min, double, true)
+/* scan_exclusive max */
+RANGE_OP(scan_exclusive, max, int, true)
+RANGE_OP(scan_exclusive, max, uint, false)
+RANGE_OP(scan_exclusive, max, long, true)
+RANGE_OP(scan_exclusive, max, ulong, false)
+RANGE_OP(scan_exclusive, max, float, true)
+RANGE_OP(scan_exclusive, max, double, true)
+
+#undef RANGE_OP
diff --git a/backend/src/libocl/tmpl/ocl_simd.tmpl.h b/backend/src/libocl/tmpl/ocl_simd.tmpl.h
index 96337cd..355ee30 100644
--- a/backend/src/libocl/tmpl/ocl_simd.tmpl.h
+++ b/backend/src/libocl/tmpl/ocl_simd.tmpl.h
@@ -34,6 +34,101 @@ uint get_num_sub_groups(void);
 uint get_sub_group_id(void);
 uint get_sub_group_local_id(void);
 
+/* broadcast */
+OVERLOADABLE int sub_group_broadcast(int a, size_t local_id);
+OVERLOADABLE uint sub_group_broadcast(uint a, size_t local_id);
+OVERLOADABLE long sub_group_broadcast(long a, size_t local_id);
+OVERLOADABLE ulong sub_group_broadcast(ulong a, size_t local_id);
+OVERLOADABLE float sub_group_broadcast(float a, size_t local_id);
+OVERLOADABLE double sub_group_broadcast(double a, size_t local_id);
+
+OVERLOADABLE int sub_group_broadcast(int a, size_t local_id_x, size_t local_id_y);
+OVERLOADABLE uint sub_group_broadcast(uint a, size_t local_id_x, size_t local_id_y);
+OVERLOADABLE long sub_group_broadcast(long a, size_t local_id_x, size_t local_id_y);
+OVERLOADABLE ulong sub_group_broadcast(ulong a, size_t local_id_x, size_t local_id_y);
+OVERLOADABLE float sub_group_broadcast(float a, size_t local_id_x, size_t local_id_y);
+OVERLOADABLE double sub_group_broadcast(double a, size_t local_id_x, size_t local_id_y);
+
+OVERLOADABLE int sub_group_broadcast(int a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
+OVERLOADABLE uint sub_group_broadcast(uint a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
+OVERLOADABLE long sub_group_broadcast(long a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
+OVERLOADABLE ulong sub_group_broadcast(ulong a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
+OVERLOADABLE float sub_group_broadcast(float a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
+OVERLOADABLE double sub_group_broadcast(double a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
+
+/* reduce add */
+OVERLOADABLE int sub_group_reduce_add(int x);
+OVERLOADABLE uint sub_group_reduce_add(uint x);
+OVERLOADABLE long sub_group_reduce_add(long x);
+OVERLOADABLE ulong sub_group_reduce_add(ulong x);
+OVERLOADABLE float sub_group_reduce_add(float x);
+OVERLOADABLE double sub_group_reduce_add(double x);
+
+/* reduce min */
+OVERLOADABLE int sub_group_reduce_min(int x);
+OVERLOADABLE uint sub_group_reduce_min(uint x);
+OVERLOADABLE long sub_group_reduce_min(long x);
+OVERLOADABLE ulong sub_group_reduce_min(ulong x);
+OVERLOADABLE float sub_group_reduce_min(float x);
+OVERLOADABLE double sub_group_reduce_min(double x);
+
+/* reduce max */
+OVERLOADABLE int sub_group_reduce_max(int x);
+OVERLOADABLE uint sub_group_reduce_max(uint x);
+OVERLOADABLE long sub_group_reduce_max(long x);
+OVERLOADABLE ulong sub_group_reduce_max(ulong x);
+OVERLOADABLE float sub_group_reduce_max(float x);
+OVERLOADABLE double sub_group_reduce_max(double x);
+
+/* scan_inclusive add */
+OVERLOADABLE int sub_group_scan_inclusive_add(int x);
+OVERLOADABLE uint sub_group_scan_inclusive_add(uint x);
+OVERLOADABLE long sub_group_scan_inclusive_add(long x);
+OVERLOADABLE ulong sub_group_scan_inclusive_add(ulong x);
+OVERLOADABLE float sub_group_scan_inclusive_add(float x);
+OVERLOADABLE double sub_group_scan_inclusive_add(double x);
+
+/* scan_inclusive min */
+OVERLOADABLE int sub_group_scan_inclusive_min(int x);
+OVERLOADABLE uint sub_group_scan_inclusive_min(uint x);
+OVERLOADABLE long sub_group_scan_inclusive_min(long x);
+OVERLOADABLE ulong sub_group_scan_inclusive_min(ulong x);
+OVERLOADABLE float sub_group_scan_inclusive_min(float x);
+OVERLOADABLE double sub_group_scan_inclusive_min(double x);
+
+/* scan_inclusive max */
+OVERLOADABLE int sub_group_scan_inclusive_max(int x);
+OVERLOADABLE uint sub_group_scan_inclusive_max(uint x);
+OVERLOADABLE long sub_group_scan_inclusive_max(long x);
+OVERLOADABLE ulong sub_group_scan_inclusive_max(ulong x);
+OVERLOADABLE float sub_group_scan_inclusive_max(float x);
+OVERLOADABLE double sub_group_scan_inclusive_max(double x);
+
+/* scan_exclusive add */
+OVERLOADABLE int sub_group_scan_exclusive_add(int x);
+OVERLOADABLE uint sub_group_scan_exclusive_add(uint x);
+OVERLOADABLE long sub_group_scan_exclusive_add(long x);
+OVERLOADABLE ulong sub_group_scan_exclusive_add(ulong x);
+OVERLOADABLE float sub_group_scan_exclusive_add(float x);
+OVERLOADABLE double sub_group_scan_exclusive_add(double x);
+
+/* scan_exclusive min */
+OVERLOADABLE int sub_group_scan_exclusive_min(int x);
+OVERLOADABLE uint sub_group_scan_exclusive_min(uint x);
+OVERLOADABLE long sub_group_scan_exclusive_min(long x);
+OVERLOADABLE ulong sub_group_scan_exclusive_min(ulong x);
+OVERLOADABLE float sub_group_scan_exclusive_min(float x);
+OVERLOADABLE double sub_group_scan_exclusive_min(double x);
+
+/* scan_exclusive max */
+OVERLOADABLE int sub_group_scan_exclusive_max(int x);
+OVERLOADABLE uint sub_group_scan_exclusive_max(uint x);
+OVERLOADABLE long sub_group_scan_exclusive_max(long x);
+OVERLOADABLE ulong sub_group_scan_exclusive_max(ulong x);
+OVERLOADABLE float sub_group_scan_exclusive_max(float x);
+OVERLOADABLE double sub_group_scan_exclusive_max(double x);
+
+/* shuffle */
 OVERLOADABLE float intel_sub_group_shuffle(float x, uint c);
 OVERLOADABLE int intel_sub_group_shuffle(int x, uint c);
 OVERLOADABLE uint intel_sub_group_shuffle(uint x, uint c);
diff --git a/backend/src/llvm/llvm_gen_backend.cpp b/backend/src/llvm/llvm_gen_backend.cpp
index f5228d2..a091d7c 100644
--- a/backend/src/llvm/llvm_gen_backend.cpp
+++ b/backend/src/llvm/llvm_gen_backend.cpp
@@ -695,6 +695,8 @@ namespace gbe
     void emitAtomicInst(CallInst &I, CallSite &CS, ir::AtomicOps opcode);
     // Emit workgroup instructions
     void emitWorkGroupInst(CallInst &I, CallSite &CS, ir::WorkGroupOps opcode);
+    // Emit subgroup instructions
+    void emitSubGroupInst(CallInst &I, CallSite &CS, ir::WorkGroupOps opcode);
 
     uint8_t appendSampler(CallSite::arg_iterator AI);
     uint8_t getImageID(CallInst &I);
@@ -3729,6 +3731,16 @@ namespace gbe
       case GEN_OCL_WORK_GROUP_SCAN_INCLUSIVE_ADD:
       case GEN_OCL_WORK_GROUP_SCAN_INCLUSIVE_MAX:
       case GEN_OCL_WORK_GROUP_SCAN_INCLUSIVE_MIN:
+      case GEN_OCL_SUB_GROUP_BROADCAST:
+      case GEN_OCL_SUB_GROUP_REDUCE_ADD:
+      case GEN_OCL_SUB_GROUP_REDUCE_MAX:
+      case GEN_OCL_SUB_GROUP_REDUCE_MIN:
+      case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_ADD:
+      case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_MAX:
+      case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_MIN:
+      case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_ADD:
+      case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_MAX:
+      case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_MIN:
       case GEN_OCL_LRP:
         this->newRegister(&I);
         break;
@@ -3898,6 +3910,48 @@ namespace gbe
     GBE_ASSERT(AI == AE);
   }
 
+  void GenWriter::emitSubGroupInst(CallInst &I, CallSite &CS, ir::WorkGroupOps opcode) {
+    CallSite::arg_iterator AI = CS.arg_begin();
+    CallSite::arg_iterator AE = CS.arg_end();
+    GBE_ASSERT(AI != AE);
+
+    if (opcode == ir::WORKGROUP_OP_ALL || opcode == ir::WORKGROUP_OP_ANY) {
+      GBE_ASSERT(getType(ctx, (*AI)->getType()) == ir::TYPE_S32);
+      ir::Register src[3];
+      src[0] = this->getRegister(*(AI++));
+      const ir::Tuple srcTuple = ctx.arrayTuple(&src[0], 1);
+      ctx.SUBGROUP(opcode, getRegister(&I), srcTuple, 1, ir::TYPE_S32);
+    } else if (opcode == ir::WORKGROUP_OP_BROADCAST) {
+      int argNum = CS.arg_size();
+      std::vector<ir::Register> src(argNum);
+      for (int i = 0; i < argNum; i++) {
+        src[i] = this->getRegister(*(AI++));
+      }
+      const ir::Tuple srcTuple = ctx.arrayTuple(&src[0], argNum);
+      ctx.SUBGROUP(ir::WORKGROUP_OP_BROADCAST, getRegister(&I), srcTuple, argNum,
+          getType(ctx, (*CS.arg_begin())->getType()));
+    } else {
+      ConstantInt *sign = dyn_cast<ConstantInt>(AI);
+      GBE_ASSERT(sign);
+      bool isSign = sign->getZExtValue();
+      AI++;
+      ir::Type ty;
+      if (isSign) {
+        ty = getType(ctx, (*AI)->getType());
+
+      } else {
+        ty = getUnsignedType(ctx, (*AI)->getType());
+      }
+
+      ir::Register src[3];
+      src[0] = this->getRegister(*(AI++));
+      const ir::Tuple srcTuple = ctx.arrayTuple(&src[0], 1);
+      ctx.SUBGROUP(opcode, getRegister(&I), srcTuple, 1, ty);
+    }
+
+    GBE_ASSERT(AI == AE);
+  }
+
   /* append a new sampler. should be called before any reference to
    * a sampler_t value. */
   uint8_t GenWriter::appendSampler(CallSite::arg_iterator AI) {
@@ -4690,6 +4744,26 @@ namespace gbe
             this->emitWorkGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_MAX); break;
           case GEN_OCL_WORK_GROUP_SCAN_INCLUSIVE_MIN:
             this->emitWorkGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_MIN); break;
+          case GEN_OCL_SUB_GROUP_BROADCAST:
+            this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_BROADCAST); break;
+          case GEN_OCL_SUB_GROUP_REDUCE_ADD:
+            this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_REDUCE_ADD); break;
+          case GEN_OCL_SUB_GROUP_REDUCE_MAX:
+            this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_REDUCE_MAX); break;
+          case GEN_OCL_SUB_GROUP_REDUCE_MIN:
+            this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_REDUCE_MIN); break;
+          case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_ADD:
+            this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_EXCLUSIVE_ADD); break;
+          case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_MAX:
+            this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_EXCLUSIVE_MAX); break;
+          case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_MIN:
+            this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_EXCLUSIVE_MIN); break;
+          case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_ADD:
+            this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_ADD); break;
+          case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_MAX:
+            this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_MAX); break;
+          case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_MIN:
+            this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_MIN); break;
           case GEN_OCL_LRP:
           {
             const ir::Register dst  = this->getRegister(&I);
diff --git a/backend/src/llvm/llvm_gen_ocl_function.hxx b/backend/src/llvm/llvm_gen_ocl_function.hxx
index cff4d61..213ead0 100644
--- a/backend/src/llvm/llvm_gen_ocl_function.hxx
+++ b/backend/src/llvm/llvm_gen_ocl_function.hxx
@@ -202,5 +202,20 @@ DECL_LLVM_GEN_FUNCTION(WORK_GROUP_SCAN_INCLUSIVE_MIN, __gen_ocl_work_group_scan_
 DECL_LLVM_GEN_FUNCTION(WORK_GROUP_ALL, __gen_ocl_work_group_all)
 DECL_LLVM_GEN_FUNCTION(WORK_GROUP_ANY, __gen_ocl_work_group_any)
 
+// work group function
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_BROADCAST, __gen_ocl_sub_group_broadcast)
+
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_REDUCE_ADD, __gen_ocl_sub_group_reduce_add)
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_REDUCE_MAX, __gen_ocl_sub_group_reduce_max)
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_REDUCE_MIN, __gen_ocl_sub_group_reduce_min)
+
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_EXCLUSIVE_ADD, __gen_ocl_sub_group_scan_exclusive_add)
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_EXCLUSIVE_MAX, __gen_ocl_sub_group_scan_exclusive_max)
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_EXCLUSIVE_MIN, __gen_ocl_sub_group_scan_exclusive_min)
+
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_INCLUSIVE_ADD, __gen_ocl_sub_group_scan_inclusive_add)
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_INCLUSIVE_MAX, __gen_ocl_sub_group_scan_inclusive_max)
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_INCLUSIVE_MIN, __gen_ocl_sub_group_scan_inclusive_min)
+
 // common function
 DECL_LLVM_GEN_FUNCTION(LRP, __gen_ocl_lrp)
-- 
2.7.4



More information about the Beignet mailing list