[Beignet] [PATCH 3/4] Backend: Add sub_group built-in functions for intel extension
Xiuli Pan
xiuli.pan at intel.com
Tue May 17 00:20:32 UTC 2016
From: Pan Xiuli <xiuli.pan at intel.com>
Add sub_group_reduce/exclusive/inclusive_max/min/add builtin functions.
They share the in thread algorithm of work group functions.
Signed-off-by: Pan Xiuli <xiuli.pan at intel.com>
---
backend/src/backend/gen8_context.cpp | 23 ++++
backend/src/backend/gen8_context.hpp | 1 +
backend/src/backend/gen_context.cpp | 23 ++++
backend/src/backend/gen_context.hpp | 1 +
.../src/backend/gen_insn_gen7_schedule_info.hxx | 1 +
backend/src/backend/gen_insn_selection.cpp | 116 +++++++++++++++++
backend/src/backend/gen_insn_selection.hxx | 1 +
backend/src/ir/instruction.cpp | 144 +++++++++++++++++++++
backend/src/ir/instruction.hpp | 11 ++
backend/src/ir/instruction.hxx | 1 +
backend/src/libocl/tmpl/ocl_simd.tmpl.cl | 98 ++++++++++++++
backend/src/libocl/tmpl/ocl_simd.tmpl.h | 95 ++++++++++++++
backend/src/llvm/llvm_gen_backend.cpp | 74 +++++++++++
backend/src/llvm/llvm_gen_ocl_function.hxx | 15 +++
14 files changed, 604 insertions(+)
diff --git a/backend/src/backend/gen8_context.cpp b/backend/src/backend/gen8_context.cpp
index 477b22b..7ddb95a 100644
--- a/backend/src/backend/gen8_context.cpp
+++ b/backend/src/backend/gen8_context.cpp
@@ -1845,4 +1845,27 @@ namespace gbe
}
}
+ void Gen8Context::emitSubGroupOpInstruction(const SelectionInstruction &insn){
+ const GenRegister dst = ra->genReg(insn.dst(0));
+ const GenRegister tmp = GenRegister::retype(ra->genReg(insn.dst(1)), dst.type);
+ const GenRegister theVal = GenRegister::retype(ra->genReg(insn.src(0)), dst.type);
+ GenRegister threadData = ra->genReg(insn.src(1));
+
+ uint32_t wg_op = insn.extra.workgroupOp;
+ uint32_t simd = p->curr.execWidth;
+
+ /* masked elements should be properly set to init value */
+ p->push(); {
+ p->curr.noMask = 1;
+ wgOpInitValue(p, tmp, wg_op);
+ p->curr.noMask = 0;
+ p->MOV(tmp, theVal);
+ p->curr.noMask = 1;
+ p->MOV(theVal, tmp);
+ } p->pop();
+
+ /* do some calculation within each thread */
+ wgOpPerformThread(dst, theVal, threadData, tmp, simd, wg_op, p);
+ }
+
}
diff --git a/backend/src/backend/gen8_context.hpp b/backend/src/backend/gen8_context.hpp
index 771e20b..ec1358c 100644
--- a/backend/src/backend/gen8_context.hpp
+++ b/backend/src/backend/gen8_context.hpp
@@ -77,6 +77,7 @@ namespace gbe
virtual void emitF64DIVInstruction(const SelectionInstruction &insn);
virtual void emitWorkGroupOpInstruction(const SelectionInstruction &insn);
+ virtual void emitSubGroupOpInstruction(const SelectionInstruction &insn);
static GenRegister unpacked_ud(GenRegister reg, uint32_t offset = 0);
diff --git a/backend/src/backend/gen_context.cpp b/backend/src/backend/gen_context.cpp
index 4e24816..4d0a3f3 100644
--- a/backend/src/backend/gen_context.cpp
+++ b/backend/src/backend/gen_context.cpp
@@ -3374,6 +3374,29 @@ namespace gbe
}
}
+ void GenContext::emitSubGroupOpInstruction(const SelectionInstruction &insn){
+ const GenRegister dst = ra->genReg(insn.dst(0));
+ const GenRegister tmp = GenRegister::retype(ra->genReg(insn.dst(1)), dst.type);
+ const GenRegister theVal = GenRegister::retype(ra->genReg(insn.src(0)), dst.type);
+ GenRegister threadData = ra->genReg(insn.src(1));
+
+ uint32_t wg_op = insn.extra.workgroupOp;
+ uint32_t simd = p->curr.execWidth;
+
+ /* masked elements should be properly set to init value */
+ p->push(); {
+ p->curr.noMask = 1;
+ wgOpInitValue(p, tmp, wg_op);
+ p->curr.noMask = 0;
+ p->MOV(tmp, theVal);
+ p->curr.noMask = 1;
+ p->MOV(theVal, tmp);
+ } p->pop();
+
+ /* do some calculation within each thread */
+ wgOpPerformThread(dst, theVal, threadData, tmp, simd, wg_op, p);
+ }
+
void GenContext::emitPrintfLongInstruction(GenRegister& addr, GenRegister& data,
GenRegister& src, uint32_t bti) {
p->MOV(GenRegister::retype(data, GEN_TYPE_UD), src.bottom_half());
diff --git a/backend/src/backend/gen_context.hpp b/backend/src/backend/gen_context.hpp
index ebc55e6..4c43ccb 100644
--- a/backend/src/backend/gen_context.hpp
+++ b/backend/src/backend/gen_context.hpp
@@ -181,6 +181,7 @@ namespace gbe
void emitCalcTimestampInstruction(const SelectionInstruction &insn);
void emitStoreProfilingInstruction(const SelectionInstruction &insn);
virtual void emitWorkGroupOpInstruction(const SelectionInstruction &insn);
+ virtual void emitSubGroupOpInstruction(const SelectionInstruction &insn);
void emitPrintfInstruction(const SelectionInstruction &insn);
void scratchWrite(const GenRegister header, uint32_t offset, uint32_t reg_num, uint32_t reg_type, uint32_t channel_mode);
void scratchRead(const GenRegister dst, const GenRegister header, uint32_t offset, uint32_t reg_num, uint32_t reg_type, uint32_t channel_mode);
diff --git a/backend/src/backend/gen_insn_gen7_schedule_info.hxx b/backend/src/backend/gen_insn_gen7_schedule_info.hxx
index 112df32..cb5c4f1 100644
--- a/backend/src/backend/gen_insn_gen7_schedule_info.hxx
+++ b/backend/src/backend/gen_insn_gen7_schedule_info.hxx
@@ -48,4 +48,5 @@ DECL_GEN7_SCHEDULE(F64DIV, 20, 40, 20)
DECL_GEN7_SCHEDULE(CalcTimestamp, 80, 1, 1)
DECL_GEN7_SCHEDULE(StoreProfiling, 80, 1, 1)
DECL_GEN7_SCHEDULE(WorkGroupOp, 80, 1, 1)
+DECL_GEN7_SCHEDULE(SubGroupOp, 80, 1, 1)
DECL_GEN7_SCHEDULE(Printf, 80, 1, 1)
diff --git a/backend/src/backend/gen_insn_selection.cpp b/backend/src/backend/gen_insn_selection.cpp
index 83b35cf..596e70b 100644
--- a/backend/src/backend/gen_insn_selection.cpp
+++ b/backend/src/backend/gen_insn_selection.cpp
@@ -694,6 +694,9 @@ namespace gbe
GenRegister tmpData2, GenRegister slmOff,
vector<GenRegister> msg, uint32_t msgSizeReq,
GenRegister localBarrier);
+ /*! Sub Group Operations */
+ void SUBGROUP_OP(uint32_t wg_op, Reg dst, GenRegister src,
+ GenRegister tmpData1, GenRegister tmpData2);
/* common functions for both binary instruction and sel_cmp and compare instruction.
It will handle the IMM or normal register assignment, and will try to avoid LOADI
as much as possible. */
@@ -1995,6 +1998,23 @@ namespace gbe
insn->src(5) = localBarrier;
}
+ void Selection::Opaque::SUBGROUP_OP(uint32_t wg_op,
+ Reg dst,
+ GenRegister src,
+ GenRegister tmpData1,
+ GenRegister tmpData2)
+ {
+ SelectionInstruction *insn = this->appendInsn(SEL_OP_SUBGROUP_OP, 2, 2);
+
+ insn->extra.workgroupOp = wg_op;
+
+ insn->dst(0) = dst;
+ insn->dst(1) = tmpData1;
+
+ insn->src(0) = src;
+ insn->src(1) = tmpData2;
+ }
+
// Boiler plate to initialize the selection library at c++ pre-main
static SelectionLibrary *selLib = NULL;
static void destroySelectionLibrary(void) { GBE_DELETE(selLib); }
@@ -6400,6 +6420,101 @@ extern bool OCL_DEBUGINFO; // first defined by calling BVAR in program.cpp
DECL_CTOR(WorkGroupInstruction, 1, 1);
};
+ /*! SubGroup instruction pattern */
+ class SubGroupInstructionPattern : public SelectionPattern
+ {
+ public:
+ SubGroupInstructionPattern(void) : SelectionPattern(1,1) {
+ for (uint32_t op = 0; op < ir::OP_INVALID; ++op)
+ if (ir::isOpcodeFrom<ir::SubGroupInstruction>(ir::Opcode(op)) == true)
+ this->opcodes.push_back(ir::Opcode(op));
+ }
+
+ /* SUBGROUP OP: ALL, ANY, REDUCE, SCAN INCLUSIVE, SCAN EXCLUSIVE
+ * Shared algorithm with workgroup inthread */
+ INLINE bool emitSGReduce(Selection::Opaque &sel, const ir::SubGroupInstruction &insn) const
+ {
+ using namespace ir;
+
+ GBE_ASSERT(insn.getSrcNum() == 1);
+
+ const WorkGroupOps workGroupOp = insn.getWorkGroupOpcode();
+ const Type type = insn.getType();
+ GenRegister dst = sel.selReg(insn.getDst(0), type);
+ GenRegister src = sel.selReg(insn.getSrc(0), type);
+ GenRegister tmpData1 = GenRegister::retype(sel.selReg(sel.reg(FAMILY_QWORD)), type);
+ GenRegister tmpData2 = GenRegister::retype(sel.selReg(sel.reg(FAMILY_QWORD)), type);
+
+ /* Perform workgroup op */
+ sel.SUBGROUP_OP(workGroupOp, dst, src, tmpData1, tmpData2);
+
+ return true;
+ }
+
+ /* SUBROUP OP: BROADCAST
+ * Shared algorithm with simd shuffle */
+ INLINE bool emitSGBroadcast(Selection::Opaque &sel, const ir::SubGroupInstruction &insn, SelectionDAG &dag) const
+ {
+ using namespace ir;
+
+ GBE_ASSERT(insn.getSrcNum() == 2);
+
+ const Type type = insn.getType();
+ const GenRegister src0 = sel.selReg(insn.getSrc(0), type);
+ const GenRegister dst = sel.selReg(insn.getDst(0), type);
+ GenRegister src1;
+
+ SelectionDAG *dag0 = dag.child[0];
+ SelectionDAG *dag1 = dag.child[1];
+ if (dag1 != NULL && dag1->insn.getOpcode() == OP_LOADI && canGetRegisterFromImmediate(dag1->insn)) {
+ const auto &childInsn = cast<LoadImmInstruction>(dag1->insn);
+ src1 = getRegisterFromImmediate(childInsn.getImmediate(), TYPE_U32);
+ if (dag0) dag0->isRoot = 1;
+ } else {
+ markAllChildren(dag);
+ src1 = sel.selReg(insn.getSrc(1), TYPE_U32);
+ }
+
+ sel.push(); {
+ if (src1.file == GEN_IMMEDIATE_VALUE) {
+ uint32_t offset = src1.value.ud % sel.curr.execWidth;
+ GenRegister reg = GenRegister::subphysicaloffset(src0, offset);
+ reg.vstride = GEN_VERTICAL_STRIDE_0;
+ reg.hstride = GEN_HORIZONTAL_STRIDE_0;
+ reg.width = GEN_WIDTH_1;
+ sel.MOV(dst, reg);
+ } else {
+ GenRegister shiftL = sel.selReg(sel.reg(FAMILY_DWORD), TYPE_U32);
+ sel.SHL(shiftL, src1, GenRegister::immud(0x2));
+ sel.SIMD_SHUFFLE(dst, src0, shiftL);
+ }
+ } sel.pop();
+
+ return true;
+ }
+
+ INLINE bool emit(Selection::Opaque &sel, SelectionDAG &dag) const
+ {
+ using namespace ir;
+ const ir::SubGroupInstruction &insn = cast<SubGroupInstruction>(dag.insn);
+ const WorkGroupOps workGroupOp = insn.getWorkGroupOpcode();
+
+ if (workGroupOp == WORKGROUP_OP_BROADCAST){
+ return emitSGBroadcast(sel, insn, dag);
+ }
+ else if (workGroupOp >= WORKGROUP_OP_ANY && workGroupOp <= WORKGROUP_OP_EXCLUSIVE_MAX){
+ if(emitSGReduce(sel, insn))
+ markAllChildren(dag);
+ else
+ return false;
+ }
+ else
+ GBE_ASSERT(0);
+
+ return true;
+ }
+ };
+
/*! Sort patterns */
INLINE bool cmp(const SelectionPattern *p0, const SelectionPattern *p1) {
if (p0->insnNum != p1->insnNum)
@@ -6437,6 +6552,7 @@ extern bool OCL_DEBUGINFO; // first defined by calling BVAR in program.cpp
this->insert<CalcTimestampInstructionPattern>();
this->insert<StoreProfilingInstructionPattern>();
this->insert<WorkGroupInstructionPattern>();
+ this->insert<SubGroupInstructionPattern>();
this->insert<NullaryInstructionPattern>();
this->insert<WaitInstructionPattern>();
this->insert<PrintfInstructionPattern>();
diff --git a/backend/src/backend/gen_insn_selection.hxx b/backend/src/backend/gen_insn_selection.hxx
index 4352490..0e11f9f 100644
--- a/backend/src/backend/gen_insn_selection.hxx
+++ b/backend/src/backend/gen_insn_selection.hxx
@@ -94,4 +94,5 @@ DECL_SELECTION_IR(F64DIV, F64DIVInstruction)
DECL_SELECTION_IR(CALC_TIMESTAMP, CalcTimestampInstruction)
DECL_SELECTION_IR(STORE_PROFILING, StoreProfilingInstruction)
DECL_SELECTION_IR(WORKGROUP_OP, WorkGroupOpInstruction)
+DECL_SELECTION_IR(SUBGROUP_OP, SubGroupOpInstruction)
DECL_SELECTION_IR(PRINTF, PrintfInstruction)
diff --git a/backend/src/ir/instruction.cpp b/backend/src/ir/instruction.cpp
index d9051ab..47606b2 100644
--- a/backend/src/ir/instruction.cpp
+++ b/backend/src/ir/instruction.cpp
@@ -994,6 +994,33 @@ namespace ir {
Register dst[1];
};
+ class ALIGNED_INSTRUCTION SubGroupInstruction :
+ public BasePolicy,
+ public TupleSrcPolicy<SubGroupInstruction>,
+ public NDstPolicy<SubGroupInstruction, 1>
+ {
+ public:
+ INLINE SubGroupInstruction(WorkGroupOps opcode, Register dst,
+ Tuple srcTuple, uint8_t srcNum, Type type) {
+ this->opcode = OP_SUBGROUP;
+ this->workGroupOp = opcode;
+ this->type = type;
+ this->dst[0] = dst;
+ this->src = srcTuple;
+ this->srcNum = srcNum;
+ }
+ INLINE Type getType(void) const { return this->type; }
+ INLINE bool wellFormed(const Function &fn, std::string &whyNot) const;
+ INLINE void out(std::ostream &out, const Function &fn) const;
+ INLINE WorkGroupOps getWorkGroupOpcode(void) const { return this->workGroupOp; }
+
+ WorkGroupOps workGroupOp:5;
+ uint32_t srcNum:3; //!< Source Number
+ Type type; //!< Type of the instruction
+ Tuple src;
+ Register dst[1];
+ };
+
class ALIGNED_INSTRUCTION PrintfInstruction :
public BasePolicy,
public TupleSrcPolicy<PrintfInstruction>,
@@ -1505,6 +1532,52 @@ namespace ir {
return true;
}
+ INLINE bool SubGroupInstruction::wellFormed(const Function &fn, std::string &whyNot) const {
+ const RegisterFamily family = getFamily(this->type);
+
+ if (UNLIKELY(checkSpecialRegForWrite(dst[0], fn, whyNot) == false))
+ return false;
+ if (UNLIKELY(checkRegisterData(family, dst[0], fn, whyNot) == false))
+ return false;
+
+ switch (this->workGroupOp) {
+ case WORKGROUP_OP_ANY:
+ case WORKGROUP_OP_ALL:
+ case WORKGROUP_OP_REDUCE_ADD:
+ case WORKGROUP_OP_REDUCE_MIN:
+ case WORKGROUP_OP_REDUCE_MAX:
+ case WORKGROUP_OP_INCLUSIVE_ADD:
+ case WORKGROUP_OP_INCLUSIVE_MIN:
+ case WORKGROUP_OP_INCLUSIVE_MAX:
+ case WORKGROUP_OP_EXCLUSIVE_ADD:
+ case WORKGROUP_OP_EXCLUSIVE_MIN:
+ case WORKGROUP_OP_EXCLUSIVE_MAX:
+ if (this->srcNum != 1) {
+ whyNot = "Wrong number of source.";
+ return false;
+ }
+ break;
+ case WORKGROUP_OP_BROADCAST:
+ if (this->srcNum != 2) {
+ whyNot = "Wrong number of source.";
+ return false;
+ } else {
+ const RegisterFamily fam = fn.getPointerFamily();
+ for (uint32_t srcID = 1; srcID < this->srcNum; ++srcID) {
+ const Register regID = fn.getRegister(src, srcID);
+ if (UNLIKELY(checkRegisterData(fam, regID, fn, whyNot) == false))
+ return false;
+ }
+ }
+ break;
+ default:
+ whyNot = "No such sub group function.";
+ return false;
+ }
+
+ return true;
+ }
+
INLINE bool PrintfInstruction::wellFormed(const Function &fn, std::string &whyNot) const {
return true;
}
@@ -1739,6 +1812,67 @@ namespace ir {
out << "TheadID Map at SLM: " << this->slmAddr;
}
+ INLINE void SubGroupInstruction::out(std::ostream &out, const Function &fn) const {
+ this->outOpcode(out);
+
+ switch (this->workGroupOp) {
+ case WORKGROUP_OP_ANY:
+ out << "_" << "ANY";
+ break;
+ case WORKGROUP_OP_ALL:
+ out << "_" << "ALL";
+ break;
+ case WORKGROUP_OP_REDUCE_ADD:
+ out << "_" << "REDUCE_ADD";
+ break;
+ case WORKGROUP_OP_REDUCE_MIN:
+ out << "_" << "REDUCE_MIN";
+ break;
+ case WORKGROUP_OP_REDUCE_MAX:
+ out << "_" << "REDUCE_MAX";
+ break;
+ case WORKGROUP_OP_INCLUSIVE_ADD:
+ out << "_" << "INCLUSIVE_ADD";
+ break;
+ case WORKGROUP_OP_INCLUSIVE_MIN:
+ out << "_" << "INCLUSIVE_MIN";
+ break;
+ case WORKGROUP_OP_INCLUSIVE_MAX:
+ out << "_" << "INCLUSIVE_MAX";
+ break;
+ case WORKGROUP_OP_EXCLUSIVE_ADD:
+ out << "_" << "EXCLUSIVE_ADD";
+ break;
+ case WORKGROUP_OP_EXCLUSIVE_MIN:
+ out << "_" << "EXCLUSIVE_MIN";
+ break;
+ case WORKGROUP_OP_EXCLUSIVE_MAX:
+ out << "_" << "EXCLUSIVE_MAX";
+ break;
+ case WORKGROUP_OP_BROADCAST:
+ out << "_" << "BROADCAST";
+ break;
+ default:
+ GBE_ASSERT(0);
+ }
+
+ out << " %" << this->getDst(fn, 0);
+ out << " %" << this->getSrc(fn, 0);
+
+ if (this->workGroupOp == WORKGROUP_OP_BROADCAST) {
+ do {
+ int localN = srcNum - 1;
+ GBE_ASSERT(localN);
+ out << " Local ID:";
+ out << " %" << this->getSrc(fn, 1);
+ localN--;
+ if (!localN)
+ break;
+ } while(0);
+ }
+
+ }
+
INLINE void PrintfInstruction::out(std::ostream &out, const Function &fn) const {
this->outOpcode(out);
}
@@ -1903,6 +2037,10 @@ START_INTROSPECTION(WorkGroupInstruction)
#include "ir/instruction.hxx"
END_INTROSPECTION(WorkGroupInstruction)
+START_INTROSPECTION(SubGroupInstruction)
+#include "ir/instruction.hxx"
+END_INTROSPECTION(SubGroupInstruction)
+
START_INTROSPECTION(PrintfInstruction)
#include "ir/instruction.hxx"
END_INTROSPECTION(PrintfInstruction)
@@ -2117,6 +2255,8 @@ DECL_MEM_FN(StoreProfilingInstruction, uint32_t, getBTI(void), getBTI())
DECL_MEM_FN(WorkGroupInstruction, Type, getType(void), getType())
DECL_MEM_FN(WorkGroupInstruction, WorkGroupOps, getWorkGroupOpcode(void), getWorkGroupOpcode())
DECL_MEM_FN(WorkGroupInstruction, uint32_t, getSlmAddr(void), getSlmAddr())
+DECL_MEM_FN(SubGroupInstruction, Type, getType(void), getType())
+DECL_MEM_FN(SubGroupInstruction, WorkGroupOps, getWorkGroupOpcode(void), getWorkGroupOpcode())
DECL_MEM_FN(PrintfInstruction, uint32_t, getNum(void), getNum())
DECL_MEM_FN(PrintfInstruction, uint32_t, getBti(void), getBti())
DECL_MEM_FN(PrintfInstruction, Type, getType(const Function& fn, uint32_t ID), getType(fn, ID))
@@ -2418,6 +2558,10 @@ DECL_MEM_FN(MemInstruction, void, setBtiReg(Register reg), setBtiReg(reg))
return internal::WorkGroupInstruction(opcode, slmAddr, dst, srcTuple, srcNum, type).convert();
}
+ Instruction SUBGROUP(WorkGroupOps opcode, Register dst, Tuple srcTuple, uint8_t srcNum, Type type) {
+ return internal::SubGroupInstruction(opcode, dst, srcTuple, srcNum, type).convert();
+ }
+
Instruction PRINTF(Register dst, Tuple srcTuple, Tuple typeTuple, uint8_t srcNum, uint8_t bti, uint16_t num) {
return internal::PrintfInstruction(dst, srcTuple, typeTuple, srcNum, bti, num).convert();
}
diff --git a/backend/src/ir/instruction.hpp b/backend/src/ir/instruction.hpp
index 9cc926d..799a7bf 100644
--- a/backend/src/ir/instruction.hpp
+++ b/backend/src/ir/instruction.hpp
@@ -611,6 +611,15 @@ namespace ir {
uint32_t getSlmAddr(void) const;
};
+ /*! Related to Sub Group. */
+ class SubGroupInstruction : public Instruction {
+ public:
+ /*! Return true if the given instruction is an instance of this class */
+ static bool isClassOf(const Instruction &insn);
+ Type getType(void) const;
+ WorkGroupOps getWorkGroupOpcode(void) const;
+ };
+
/*! Printf instruction. */
class PrintfInstruction : public Instruction {
public:
@@ -850,6 +859,8 @@ namespace ir {
/*! work group */
Instruction WORKGROUP(WorkGroupOps opcode, uint32_t slmAddr, Register dst, Tuple srcTuple, uint8_t srcNum, Type type);
+ /*! sub group */
+ Instruction SUBGROUP(WorkGroupOps opcode, Register dst, Tuple srcTuple, uint8_t srcNum, Type type);
/*! printf */
Instruction PRINTF(Register dst, Tuple srcTuple, Tuple typeTuple, uint8_t srcNum, uint8_t bti, uint16_t num);
} /* namespace ir */
diff --git a/backend/src/ir/instruction.hxx b/backend/src/ir/instruction.hxx
index 651ed64..57e13eb 100644
--- a/backend/src/ir/instruction.hxx
+++ b/backend/src/ir/instruction.hxx
@@ -112,4 +112,5 @@ DECL_INSN(CALC_TIMESTAMP, CalcTimestampInstruction)
DECL_INSN(STORE_PROFILING, StoreProfilingInstruction)
DECL_INSN(WAIT, WaitInstruction)
DECL_INSN(WORKGROUP, WorkGroupInstruction)
+DECL_INSN(SUBGROUP, SubGroupInstruction)
DECL_INSN(PRINTF, PrintfInstruction)
diff --git a/backend/src/libocl/tmpl/ocl_simd.tmpl.cl b/backend/src/libocl/tmpl/ocl_simd.tmpl.cl
index c2e22c1..a25dcef 100644
--- a/backend/src/libocl/tmpl/ocl_simd.tmpl.cl
+++ b/backend/src/libocl/tmpl/ocl_simd.tmpl.cl
@@ -35,3 +35,101 @@ uint get_sub_group_size(void)
else
return get_max_sub_group_size();
}
+
+/* broadcast */
+#define BROADCAST_IMPL(GEN_TYPE) \
+ OVERLOADABLE GEN_TYPE __gen_ocl_sub_group_broadcast(GEN_TYPE a, size_t local_id); \
+ OVERLOADABLE GEN_TYPE sub_group_broadcast(GEN_TYPE a, size_t local_id) { \
+ return __gen_ocl_sub_group_broadcast(a, local_id); \
+ } \
+ OVERLOADABLE GEN_TYPE __gen_ocl_sub_group_broadcast(GEN_TYPE a, size_t local_id_x, size_t local_id_y); \
+ OVERLOADABLE GEN_TYPE sub_group_broadcast(GEN_TYPE a, size_t local_id_x, size_t local_id_y) { \
+ return __gen_ocl_sub_group_broadcast(a, local_id_x, local_id_y); \
+ } \
+ OVERLOADABLE GEN_TYPE __gen_ocl_sub_group_broadcast(GEN_TYPE a, size_t local_id_x, size_t local_id_y, size_t local_id_z); \
+ OVERLOADABLE GEN_TYPE sub_group_broadcast(GEN_TYPE a, size_t local_id_x, size_t local_id_y, size_t local_id_z) { \
+ return __gen_ocl_sub_group_broadcast(a, local_id_x, local_id_y, local_id_z); \
+ }
+
+BROADCAST_IMPL(int)
+BROADCAST_IMPL(uint)
+BROADCAST_IMPL(long)
+BROADCAST_IMPL(ulong)
+BROADCAST_IMPL(float)
+BROADCAST_IMPL(double)
+#undef BROADCAST_IMPL
+
+
+#define RANGE_OP(RANGE, OP, GEN_TYPE, SIGN) \
+ OVERLOADABLE GEN_TYPE __gen_ocl_sub_group_##RANGE##_##OP(bool sign, GEN_TYPE x); \
+ OVERLOADABLE GEN_TYPE sub_group_##RANGE##_##OP(GEN_TYPE x) { \
+ return __gen_ocl_sub_group_##RANGE##_##OP(SIGN, x); \
+ }
+
+/* reduce add */
+RANGE_OP(reduce, add, int, true)
+RANGE_OP(reduce, add, uint, false)
+RANGE_OP(reduce, add, long, true)
+RANGE_OP(reduce, add, ulong, false)
+RANGE_OP(reduce, add, float, true)
+RANGE_OP(reduce, add, double, true)
+/* reduce min */
+RANGE_OP(reduce, min, int, true)
+RANGE_OP(reduce, min, uint, false)
+RANGE_OP(reduce, min, long, true)
+RANGE_OP(reduce, min, ulong, false)
+RANGE_OP(reduce, min, float, true)
+RANGE_OP(reduce, min, double, true)
+/* reduce max */
+RANGE_OP(reduce, max, int, true)
+RANGE_OP(reduce, max, uint, false)
+RANGE_OP(reduce, max, long, true)
+RANGE_OP(reduce, max, ulong, false)
+RANGE_OP(reduce, max, float, true)
+RANGE_OP(reduce, max, double, true)
+
+/* scan_inclusive add */
+RANGE_OP(scan_inclusive, add, int, true)
+RANGE_OP(scan_inclusive, add, uint, false)
+RANGE_OP(scan_inclusive, add, long, true)
+RANGE_OP(scan_inclusive, add, ulong, false)
+RANGE_OP(scan_inclusive, add, float, true)
+RANGE_OP(scan_inclusive, add, double, true)
+/* scan_inclusive min */
+RANGE_OP(scan_inclusive, min, int, true)
+RANGE_OP(scan_inclusive, min, uint, false)
+RANGE_OP(scan_inclusive, min, long, true)
+RANGE_OP(scan_inclusive, min, ulong, false)
+RANGE_OP(scan_inclusive, min, float, true)
+RANGE_OP(scan_inclusive, min, double, true)
+/* scan_inclusive max */
+RANGE_OP(scan_inclusive, max, int, true)
+RANGE_OP(scan_inclusive, max, uint, false)
+RANGE_OP(scan_inclusive, max, long, true)
+RANGE_OP(scan_inclusive, max, ulong, false)
+RANGE_OP(scan_inclusive, max, float, true)
+RANGE_OP(scan_inclusive, max, double, true)
+
+/* scan_exclusive add */
+RANGE_OP(scan_exclusive, add, int, true)
+RANGE_OP(scan_exclusive, add, uint, false)
+RANGE_OP(scan_exclusive, add, long, true)
+RANGE_OP(scan_exclusive, add, ulong, false)
+RANGE_OP(scan_exclusive, add, float, true)
+RANGE_OP(scan_exclusive, add, double, true)
+/* scan_exclusive min */
+RANGE_OP(scan_exclusive, min, int, true)
+RANGE_OP(scan_exclusive, min, uint, false)
+RANGE_OP(scan_exclusive, min, long, true)
+RANGE_OP(scan_exclusive, min, ulong, false)
+RANGE_OP(scan_exclusive, min, float, true)
+RANGE_OP(scan_exclusive, min, double, true)
+/* scan_exclusive max */
+RANGE_OP(scan_exclusive, max, int, true)
+RANGE_OP(scan_exclusive, max, uint, false)
+RANGE_OP(scan_exclusive, max, long, true)
+RANGE_OP(scan_exclusive, max, ulong, false)
+RANGE_OP(scan_exclusive, max, float, true)
+RANGE_OP(scan_exclusive, max, double, true)
+
+#undef RANGE_OP
diff --git a/backend/src/libocl/tmpl/ocl_simd.tmpl.h b/backend/src/libocl/tmpl/ocl_simd.tmpl.h
index 96337cd..355ee30 100644
--- a/backend/src/libocl/tmpl/ocl_simd.tmpl.h
+++ b/backend/src/libocl/tmpl/ocl_simd.tmpl.h
@@ -34,6 +34,101 @@ uint get_num_sub_groups(void);
uint get_sub_group_id(void);
uint get_sub_group_local_id(void);
+/* broadcast */
+OVERLOADABLE int sub_group_broadcast(int a, size_t local_id);
+OVERLOADABLE uint sub_group_broadcast(uint a, size_t local_id);
+OVERLOADABLE long sub_group_broadcast(long a, size_t local_id);
+OVERLOADABLE ulong sub_group_broadcast(ulong a, size_t local_id);
+OVERLOADABLE float sub_group_broadcast(float a, size_t local_id);
+OVERLOADABLE double sub_group_broadcast(double a, size_t local_id);
+
+OVERLOADABLE int sub_group_broadcast(int a, size_t local_id_x, size_t local_id_y);
+OVERLOADABLE uint sub_group_broadcast(uint a, size_t local_id_x, size_t local_id_y);
+OVERLOADABLE long sub_group_broadcast(long a, size_t local_id_x, size_t local_id_y);
+OVERLOADABLE ulong sub_group_broadcast(ulong a, size_t local_id_x, size_t local_id_y);
+OVERLOADABLE float sub_group_broadcast(float a, size_t local_id_x, size_t local_id_y);
+OVERLOADABLE double sub_group_broadcast(double a, size_t local_id_x, size_t local_id_y);
+
+OVERLOADABLE int sub_group_broadcast(int a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
+OVERLOADABLE uint sub_group_broadcast(uint a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
+OVERLOADABLE long sub_group_broadcast(long a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
+OVERLOADABLE ulong sub_group_broadcast(ulong a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
+OVERLOADABLE float sub_group_broadcast(float a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
+OVERLOADABLE double sub_group_broadcast(double a, size_t local_id_x, size_t local_id_y, size_t local_id_z);
+
+/* reduce add */
+OVERLOADABLE int sub_group_reduce_add(int x);
+OVERLOADABLE uint sub_group_reduce_add(uint x);
+OVERLOADABLE long sub_group_reduce_add(long x);
+OVERLOADABLE ulong sub_group_reduce_add(ulong x);
+OVERLOADABLE float sub_group_reduce_add(float x);
+OVERLOADABLE double sub_group_reduce_add(double x);
+
+/* reduce min */
+OVERLOADABLE int sub_group_reduce_min(int x);
+OVERLOADABLE uint sub_group_reduce_min(uint x);
+OVERLOADABLE long sub_group_reduce_min(long x);
+OVERLOADABLE ulong sub_group_reduce_min(ulong x);
+OVERLOADABLE float sub_group_reduce_min(float x);
+OVERLOADABLE double sub_group_reduce_min(double x);
+
+/* reduce max */
+OVERLOADABLE int sub_group_reduce_max(int x);
+OVERLOADABLE uint sub_group_reduce_max(uint x);
+OVERLOADABLE long sub_group_reduce_max(long x);
+OVERLOADABLE ulong sub_group_reduce_max(ulong x);
+OVERLOADABLE float sub_group_reduce_max(float x);
+OVERLOADABLE double sub_group_reduce_max(double x);
+
+/* scan_inclusive add */
+OVERLOADABLE int sub_group_scan_inclusive_add(int x);
+OVERLOADABLE uint sub_group_scan_inclusive_add(uint x);
+OVERLOADABLE long sub_group_scan_inclusive_add(long x);
+OVERLOADABLE ulong sub_group_scan_inclusive_add(ulong x);
+OVERLOADABLE float sub_group_scan_inclusive_add(float x);
+OVERLOADABLE double sub_group_scan_inclusive_add(double x);
+
+/* scan_inclusive min */
+OVERLOADABLE int sub_group_scan_inclusive_min(int x);
+OVERLOADABLE uint sub_group_scan_inclusive_min(uint x);
+OVERLOADABLE long sub_group_scan_inclusive_min(long x);
+OVERLOADABLE ulong sub_group_scan_inclusive_min(ulong x);
+OVERLOADABLE float sub_group_scan_inclusive_min(float x);
+OVERLOADABLE double sub_group_scan_inclusive_min(double x);
+
+/* scan_inclusive max */
+OVERLOADABLE int sub_group_scan_inclusive_max(int x);
+OVERLOADABLE uint sub_group_scan_inclusive_max(uint x);
+OVERLOADABLE long sub_group_scan_inclusive_max(long x);
+OVERLOADABLE ulong sub_group_scan_inclusive_max(ulong x);
+OVERLOADABLE float sub_group_scan_inclusive_max(float x);
+OVERLOADABLE double sub_group_scan_inclusive_max(double x);
+
+/* scan_exclusive add */
+OVERLOADABLE int sub_group_scan_exclusive_add(int x);
+OVERLOADABLE uint sub_group_scan_exclusive_add(uint x);
+OVERLOADABLE long sub_group_scan_exclusive_add(long x);
+OVERLOADABLE ulong sub_group_scan_exclusive_add(ulong x);
+OVERLOADABLE float sub_group_scan_exclusive_add(float x);
+OVERLOADABLE double sub_group_scan_exclusive_add(double x);
+
+/* scan_exclusive min */
+OVERLOADABLE int sub_group_scan_exclusive_min(int x);
+OVERLOADABLE uint sub_group_scan_exclusive_min(uint x);
+OVERLOADABLE long sub_group_scan_exclusive_min(long x);
+OVERLOADABLE ulong sub_group_scan_exclusive_min(ulong x);
+OVERLOADABLE float sub_group_scan_exclusive_min(float x);
+OVERLOADABLE double sub_group_scan_exclusive_min(double x);
+
+/* scan_exclusive max */
+OVERLOADABLE int sub_group_scan_exclusive_max(int x);
+OVERLOADABLE uint sub_group_scan_exclusive_max(uint x);
+OVERLOADABLE long sub_group_scan_exclusive_max(long x);
+OVERLOADABLE ulong sub_group_scan_exclusive_max(ulong x);
+OVERLOADABLE float sub_group_scan_exclusive_max(float x);
+OVERLOADABLE double sub_group_scan_exclusive_max(double x);
+
+/* shuffle */
OVERLOADABLE float intel_sub_group_shuffle(float x, uint c);
OVERLOADABLE int intel_sub_group_shuffle(int x, uint c);
OVERLOADABLE uint intel_sub_group_shuffle(uint x, uint c);
diff --git a/backend/src/llvm/llvm_gen_backend.cpp b/backend/src/llvm/llvm_gen_backend.cpp
index b57cf88..3ddbfcc 100644
--- a/backend/src/llvm/llvm_gen_backend.cpp
+++ b/backend/src/llvm/llvm_gen_backend.cpp
@@ -695,6 +695,8 @@ namespace gbe
void emitAtomicInst(CallInst &I, CallSite &CS, ir::AtomicOps opcode);
// Emit workgroup instructions
void emitWorkGroupInst(CallInst &I, CallSite &CS, ir::WorkGroupOps opcode);
+ // Emit subgroup instructions
+ void emitSubGroupInst(CallInst &I, CallSite &CS, ir::WorkGroupOps opcode);
uint8_t appendSampler(CallSite::arg_iterator AI);
uint8_t getImageID(CallInst &I);
@@ -3715,6 +3717,16 @@ namespace gbe
case GEN_OCL_WORK_GROUP_SCAN_INCLUSIVE_ADD:
case GEN_OCL_WORK_GROUP_SCAN_INCLUSIVE_MAX:
case GEN_OCL_WORK_GROUP_SCAN_INCLUSIVE_MIN:
+ case GEN_OCL_SUB_GROUP_BROADCAST:
+ case GEN_OCL_SUB_GROUP_REDUCE_ADD:
+ case GEN_OCL_SUB_GROUP_REDUCE_MAX:
+ case GEN_OCL_SUB_GROUP_REDUCE_MIN:
+ case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_ADD:
+ case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_MAX:
+ case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_MIN:
+ case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_ADD:
+ case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_MAX:
+ case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_MIN:
case GEN_OCL_LRP:
this->newRegister(&I);
break;
@@ -3884,6 +3896,48 @@ namespace gbe
GBE_ASSERT(AI == AE);
}
+ void GenWriter::emitSubGroupInst(CallInst &I, CallSite &CS, ir::WorkGroupOps opcode) {
+ CallSite::arg_iterator AI = CS.arg_begin();
+ CallSite::arg_iterator AE = CS.arg_end();
+ GBE_ASSERT(AI != AE);
+
+ if (opcode == ir::WORKGROUP_OP_ALL || opcode == ir::WORKGROUP_OP_ANY) {
+ GBE_ASSERT(getType(ctx, (*AI)->getType()) == ir::TYPE_S32);
+ ir::Register src[3];
+ src[0] = this->getRegister(*(AI++));
+ const ir::Tuple srcTuple = ctx.arrayTuple(&src[0], 1);
+ ctx.SUBGROUP(opcode, getRegister(&I), srcTuple, 1, ir::TYPE_S32);
+ } else if (opcode == ir::WORKGROUP_OP_BROADCAST) {
+ int argNum = CS.arg_size();
+ std::vector<ir::Register> src(argNum);
+ for (int i = 0; i < argNum; i++) {
+ src[i] = this->getRegister(*(AI++));
+ }
+ const ir::Tuple srcTuple = ctx.arrayTuple(&src[0], argNum);
+ ctx.SUBGROUP(ir::WORKGROUP_OP_BROADCAST, getRegister(&I), srcTuple, argNum,
+ getType(ctx, (*CS.arg_begin())->getType()));
+ } else {
+ ConstantInt *sign = dyn_cast<ConstantInt>(AI);
+ GBE_ASSERT(sign);
+ bool isSign = sign->getZExtValue();
+ AI++;
+ ir::Type ty;
+ if (isSign) {
+ ty = getType(ctx, (*AI)->getType());
+
+ } else {
+ ty = getUnsignedType(ctx, (*AI)->getType());
+ }
+
+ ir::Register src[3];
+ src[0] = this->getRegister(*(AI++));
+ const ir::Tuple srcTuple = ctx.arrayTuple(&src[0], 1);
+ ctx.SUBGROUP(opcode, getRegister(&I), srcTuple, 1, ty);
+ }
+
+ GBE_ASSERT(AI == AE);
+ }
+
/* append a new sampler. should be called before any reference to
* a sampler_t value. */
uint8_t GenWriter::appendSampler(CallSite::arg_iterator AI) {
@@ -4676,6 +4730,26 @@ namespace gbe
this->emitWorkGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_MAX); break;
case GEN_OCL_WORK_GROUP_SCAN_INCLUSIVE_MIN:
this->emitWorkGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_MIN); break;
+ case GEN_OCL_SUB_GROUP_BROADCAST:
+ this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_BROADCAST); break;
+ case GEN_OCL_SUB_GROUP_REDUCE_ADD:
+ this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_REDUCE_ADD); break;
+ case GEN_OCL_SUB_GROUP_REDUCE_MAX:
+ this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_REDUCE_MAX); break;
+ case GEN_OCL_SUB_GROUP_REDUCE_MIN:
+ this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_REDUCE_MIN); break;
+ case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_ADD:
+ this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_EXCLUSIVE_ADD); break;
+ case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_MAX:
+ this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_EXCLUSIVE_MAX); break;
+ case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_MIN:
+ this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_EXCLUSIVE_MIN); break;
+ case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_ADD:
+ this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_ADD); break;
+ case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_MAX:
+ this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_MAX); break;
+ case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_MIN:
+ this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_MIN); break;
case GEN_OCL_LRP:
{
const ir::Register dst = this->getRegister(&I);
diff --git a/backend/src/llvm/llvm_gen_ocl_function.hxx b/backend/src/llvm/llvm_gen_ocl_function.hxx
index cff4d61..213ead0 100644
--- a/backend/src/llvm/llvm_gen_ocl_function.hxx
+++ b/backend/src/llvm/llvm_gen_ocl_function.hxx
@@ -202,5 +202,20 @@ DECL_LLVM_GEN_FUNCTION(WORK_GROUP_SCAN_INCLUSIVE_MIN, __gen_ocl_work_group_scan_
DECL_LLVM_GEN_FUNCTION(WORK_GROUP_ALL, __gen_ocl_work_group_all)
DECL_LLVM_GEN_FUNCTION(WORK_GROUP_ANY, __gen_ocl_work_group_any)
+// work group function
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_BROADCAST, __gen_ocl_sub_group_broadcast)
+
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_REDUCE_ADD, __gen_ocl_sub_group_reduce_add)
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_REDUCE_MAX, __gen_ocl_sub_group_reduce_max)
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_REDUCE_MIN, __gen_ocl_sub_group_reduce_min)
+
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_EXCLUSIVE_ADD, __gen_ocl_sub_group_scan_exclusive_add)
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_EXCLUSIVE_MAX, __gen_ocl_sub_group_scan_exclusive_max)
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_EXCLUSIVE_MIN, __gen_ocl_sub_group_scan_exclusive_min)
+
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_INCLUSIVE_ADD, __gen_ocl_sub_group_scan_inclusive_add)
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_INCLUSIVE_MAX, __gen_ocl_sub_group_scan_inclusive_max)
+DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_INCLUSIVE_MIN, __gen_ocl_sub_group_scan_inclusive_min)
+
// common function
DECL_LLVM_GEN_FUNCTION(LRP, __gen_ocl_lrp)
--
2.7.4
More information about the Beignet
mailing list