[Beignet] [PATCH 2/2] Workgroup reduce add optimization using add4 and cache
Grigore Lupescu
grigore.lupescu at intel.com
Thu Jan 21 02:06:30 PST 2016
Signed-off-by: Grigore Lupescu <grigore.lupescu at intel.com>
---
backend/src/backend/gen_context.cpp | 214 ++++++++++++++++-------------
backend/src/backend/gen_insn_selection.cpp | 138 ++++++++++++++++++-
2 files changed, 251 insertions(+), 101 deletions(-)
diff --git a/backend/src/backend/gen_context.cpp b/backend/src/backend/gen_context.cpp
index 0ea0dd0..193494d 100644
--- a/backend/src/backend/gen_context.cpp
+++ b/backend/src/backend/gen_context.cpp
@@ -2943,21 +2943,32 @@ namespace gbe
}
}
}
- } else if (wg_op == ir::WORKGROUP_OP_REDUCE_ADD) {
+ }
+ else if (wg_op == ir::WORKGROUP_OP_REDUCE_ADD){
+ tmp.hstride = GEN_HORIZONTAL_STRIDE_1;
+ tmp.vstride = GEN_VERTICAL_STRIDE_4;
+ tmp.width = GEN_WIDTH_4;
+
GBE_ASSERT(tmp.type == theVal.type);
- GenRegister v = GenRegister::toUniform(tmp, theVal.type);
- for (uint32_t i = 0; i < simd; i++) {
- p->ADD(threadData, threadData, v);
- v.subnr += typeSize(theVal.type);
- if (v.subnr == 32) {
- v.subnr = 0;
- v.nr++;
- }
+ GenRegister partialSum = tmp;
+
+ /* adjust offset, compute add with ADD4/ADD */
+ for (uint32_t i = 1; i < simd/4; i++){
+ tmp = tmp.suboffset(tmp, 4);
+ p->push();
+ p->curr.execWidth = GEN_WIDTH_16;
+ p->ADD(partialSum, partialSum, tmp);
+ p->pop();
}
- }
+ for (uint32_t i = 0; i < 4; i++){
+ partialSum.width = GEN_WIDTH_1;
+ p->ADD(threadData, threadData, partialSum);
+ partialSum = GenRegister::suboffset(partialSum, 1);
+ }
+ }
p->pop();
- }
+}
#define SEND_RESULT_MSG() \
do { \
@@ -3028,120 +3039,125 @@ do { \
workgroupOpInThread(msgData, theVal, threadData, tmp, simd, wg_op, p);
} p->pop();
- /* If we are the only one thread, no need to send msg, just broadcast the result.*/
- p->push(); {
- p->curr.predicate = GEN_PREDICATE_NONE;
- p->curr.noMask = 1;
- p->curr.execWidth = 1;
- p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
- p->CMP(GEN_CONDITIONAL_EQ, threadnum, GenRegister::immud(0x1));
+ if(wg_op == ir::WORKGROUP_OP_REDUCE_ADD){
+ p->push(); {
+ p->MOV(dst, threadData);
+ } p->pop();
+ }
+ else {
+ /* If we are the only one thread, no need to send msg, just broadcast the result.*/
+ p->push(); {
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+ p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+ p->CMP(GEN_CONDITIONAL_EQ, threadnum, GenRegister::immud(0x1));
+
+ /* Broadcast result. */
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) {
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.inversePredicate = 1;
+ p->MOV(flag_save, GenRegister::immuw(0x0));
+ p->curr.inversePredicate = 0;
+ p->MOV(flag_save, GenRegister::immuw(0xffff));
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->MOV(flagReg, flag_save);
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.execWidth = simd;
+ p->MOV(dst, threadData);
+ }
+
+ /* Bail out. */
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.inversePredicate = 0;
+ p->curr.execWidth = 1;
+ oneThreadJip = p->n_instruction();
+ p->JMPI(GenRegister::immud(0));
+ } p->pop();
+
+ p->push(); {
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+ p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+ p->CMP(GEN_CONDITIONAL_EQ, threadid, GenRegister::immud(0x0));
- /* Broadcast result. */
- if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) {
p->curr.predicate = GEN_PREDICATE_NORMAL;
p->curr.inversePredicate = 1;
p->MOV(flag_save, GenRegister::immuw(0x0));
p->curr.inversePredicate = 0;
p->MOV(flag_save, GenRegister::immuw(0xffff));
+
p->curr.predicate = GEN_PREDICATE_NONE;
p->MOV(flagReg, flag_save);
- p->curr.predicate = GEN_PREDICATE_NORMAL;
- p->curr.execWidth = simd;
- p->MOV(dst, threadData);
- }
-
- /* Bail out. */
- p->curr.predicate = GEN_PREDICATE_NORMAL;
- p->curr.inversePredicate = 0;
- p->curr.execWidth = 1;
- oneThreadJip = p->n_instruction();
- p->JMPI(GenRegister::immud(0));
- } p->pop();
-
- p->push(); {
- p->curr.predicate = GEN_PREDICATE_NONE;
- p->curr.noMask = 1;
- p->curr.execWidth = 1;
- p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
- p->CMP(GEN_CONDITIONAL_EQ, threadid, GenRegister::immud(0x0));
-
- p->curr.predicate = GEN_PREDICATE_NORMAL;
- p->curr.inversePredicate = 1;
- p->MOV(flag_save, GenRegister::immuw(0x0));
- p->curr.inversePredicate = 0;
- p->MOV(flag_save, GenRegister::immuw(0xffff));
-
- p->curr.predicate = GEN_PREDICATE_NONE;
- p->MOV(flagReg, flag_save);
- } p->pop();
-
- p->push(); {
- p->curr.noMask = 1;
- p->curr.execWidth = 1;
-
- /* threadid 0, send the msg and wait */
- p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
- p->curr.inversePredicate = 1;
- p->curr.predicate = GEN_PREDICATE_NORMAL;
- jip = p->n_instruction();
- p->JMPI(GenRegister::immud(0));
- p->curr.predicate = GEN_PREDICATE_NONE;
- p->MOV(msgData, threadData);
- SEND_RESULT_MSG();
- p->WAIT(2);
- p->patchJMPI(jip, (p->n_instruction() - jip), 0);
-
- /* Others wait and send msg, and do something when we get the msg. */
- p->curr.predicate = GEN_PREDICATE_NORMAL;
- p->curr.inversePredicate = 0;
- jip = p->n_instruction();
- p->JMPI(GenRegister::immud(0));
- p->curr.predicate = GEN_PREDICATE_NONE;
- p->WAIT(2);
- workgroupOpBetweenThread(msgData, theVal, threadData, simd, wg_op, p);
- SEND_RESULT_MSG();
- p->patchJMPI(jip, (p->n_instruction() - jip), 0);
-
- /* Restore the flag. */
- p->curr.predicate = GEN_PREDICATE_NONE;
- p->MOV(flagReg, flag_save);
- } p->pop();
+ } p->pop();
- /* Broadcast the result. */
- if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX
- || wg_op == ir::WORKGROUP_OP_REDUCE_ADD) {
p->push(); {
- p->curr.predicate = GEN_PREDICATE_NORMAL;
p->curr.noMask = 1;
p->curr.execWidth = 1;
- p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
- p->curr.inversePredicate = 0;
- /* Not the first thread, wait for msg first. */
+ /* threadid 0, send the msg and wait */
+ p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+ p->curr.inversePredicate = 1;
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
jip = p->n_instruction();
p->JMPI(GenRegister::immud(0));
p->curr.predicate = GEN_PREDICATE_NONE;
+ p->MOV(msgData, threadData);
+ SEND_RESULT_MSG();
p->WAIT(2);
p->patchJMPI(jip, (p->n_instruction() - jip), 0);
-
- /* Do something when get the msg. */
- p->curr.execWidth = simd;
- p->MOV(dst, msgData);
-
- p->curr.execWidth = 8;
- p->FWD_GATEWAY_MSG(nextThreadID, 2);
- p->curr.execWidth = 1;
- p->curr.inversePredicate = 1;
+ /* Others wait and send msg, and do something when we get the msg. */
p->curr.predicate = GEN_PREDICATE_NORMAL;
-
- /* The first thread, the last one will notify us. */
+ p->curr.inversePredicate = 0;
jip = p->n_instruction();
p->JMPI(GenRegister::immud(0));
p->curr.predicate = GEN_PREDICATE_NONE;
p->WAIT(2);
+ workgroupOpBetweenThread(msgData, theVal, threadData, simd, wg_op, p);
+ SEND_RESULT_MSG();
p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+
+ /* Restore the flag. */
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->MOV(flagReg, flag_save);
} p->pop();
+
+ /* Broadcast the result. */
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN
+ || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
+ p->push(); {
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+ p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+ p->curr.inversePredicate = 0;
+
+ /* Not the first thread, wait for msg first. */
+ jip = p->n_instruction();
+ p->JMPI(GenRegister::immud(0));
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->WAIT(2);
+ p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+
+ /* Do something when get the msg. */
+ p->curr.execWidth = simd;
+ p->MOV(dst, msgData);
+ p->curr.execWidth = 8;
+ p->FWD_GATEWAY_MSG(nextThreadID, 2);
+ p->curr.execWidth = 1;
+ p->curr.inversePredicate = 1;
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+
+ /* The first thread, the last one will notify us. */
+ jip = p->n_instruction();
+ p->JMPI(GenRegister::immud(0));
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->WAIT(2);
+ p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+ } p->pop();
+ }
}
if (oneThreadJip >=0)
diff --git a/backend/src/backend/gen_insn_selection.cpp b/backend/src/backend/gen_insn_selection.cpp
index 001a3c5..5eccfc6 100644
--- a/backend/src/backend/gen_insn_selection.cpp
+++ b/backend/src/backend/gen_insn_selection.cpp
@@ -6192,7 +6192,7 @@ extern bool OCL_DEBUGINFO; // first defined by calling BVAR in program.cpp
if (workGroupOp == WORKGROUP_OP_BROADCAST) {
return emitWGBroadcast(sel, insn);
- } else if (workGroupOp >= WORKGROUP_OP_REDUCE_ADD && workGroupOp <= WORKGROUP_OP_EXCLUSIVE_MAX) {
+ } else if (workGroupOp >= WORKGROUP_OP_REDUCE_MIN && workGroupOp <= WORKGROUP_OP_EXCLUSIVE_MAX) {
const uint32_t slmAddr = insn.getSlmAddr();
/* First, we create the TheadID/localID map, in order to get which thread hold the next 16 workitems. */
@@ -6223,7 +6223,141 @@ extern bool OCL_DEBUGINFO; // first defined by calling BVAR in program.cpp
sel.curr.subFlag = 1;
sel.WORKGROUP_OP(workGroupOp, dst, src, nextThreadID, threadID, threadNum, tmp);
} sel.pop();
- } else {
+ }
+ else if (workGroupOp == WORKGROUP_OP_REDUCE_ADD) {
+ const Type type = insn.getType();
+ GenRegister dst = sel.selReg(insn.getDst(0), type);
+ GenRegister src = sel.selReg(insn.getSrc(2), type);
+ const uint32_t srcNum = insn.getSrcNum();
+ GBE_ASSERT(srcNum == 3);
+ GBE_ASSERT(insn.getSrc(0) == ir::ocl::threadn);
+ GBE_ASSERT(insn.getSrc(1) == ir::ocl::threadid);
+ GenRegister threadID = sel.selReg(ocl::threadid, ir::TYPE_U32);
+ GenRegister threadSEL = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+ GenRegister threadNum = sel.selReg(ocl::threadn, ir::TYPE_U32);
+ GenRegister tmp = GenRegister::retype(
+ sel.selReg(sel.reg(FAMILY_DWORD)), type);
+ GenRegister nextThreadID = sel.selReg(sel.reg(FAMILY_WORD), type);
+ GenRegister result = sel.selReg(sel.reg(FAMILY_WORD), type);
+
+ vector<GenRegister> lstPartSum;
+ lstPartSum.push_back(sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32));
+ lstPartSum.push_back(sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32));
+ lstPartSum.push_back(sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32));
+ lstPartSum.push_back(sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32));
+ vector<GenRegister> fakeTemps;
+ fakeTemps.push_back(sel.selReg(sel.reg(FAMILY_WORD), type));
+ fakeTemps.push_back(sel.selReg(sel.reg(FAMILY_WORD), type));
+ sel.MOV(lstPartSum[0], GenRegister::immud(0));
+ sel.MOV(lstPartSum[1], GenRegister::immud(0));
+ sel.MOV(lstPartSum[2], GenRegister::immud(0));
+ sel.MOV(lstPartSum[3], GenRegister::immud(0));
+
+ /* precompute SLM address offsets */
+ GenRegister slm1Reg = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+ GenRegister slm2Reg = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+ GenRegister slm2RegOff1 = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+ GenRegister slm2RegOff4 = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+ GenRegister slm1RegOff4 = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+ GenRegister slm1RegOff8 = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+ sel.MOV(slm1Reg, GenRegister::immud(insn.getSlmAddr()));
+ sel.MUL(slm2Reg, threadNum, GenRegister::immud(0x10));
+ sel.ADD(slm2Reg, slm2Reg, slm1Reg);
+ sel.MUL(slm1RegOff4, threadID, GenRegister::immud(0x4));
+ sel.MUL(slm1RegOff8, threadID, GenRegister::immud(0x8));
+ sel.ADD(slm2RegOff1, threadID, slm2Reg);
+ sel.ADD(slm2RegOff4, slm1RegOff4, slm2Reg);
+ sel.ADD(slm1RegOff4, slm1RegOff4, slm1Reg);
+ sel.ADD(slm1RegOff8, slm1RegOff8, slm1Reg);
+
+ /* write the SLM with 0s on both the SLM region1 and region2 */
+ sel.UNTYPED_WRITE(slm1RegOff8,
+ lstPartSum.data(), 2, GenRegister::immw(0xFE), fakeTemps);
+ sel.BARRIER(GenRegister::ud8grf(sel.reg(FAMILY_DWORD)),
+ sel.selReg(sel.reg(FAMILY_DWORD)),
+ SYNC_LOCAL_WRITE_FENCE);
+
+ /* compute individual slice of workitems, (e.g. 0->16 workitems) */
+ sel.WORKGROUP_OP(workGroupOp, result, src,
+ nextThreadID, threadID, threadNum, tmp);
+
+ /* write result data to SLM with offset using threadID*/
+ sel.UNTYPED_WRITE(slm1RegOff4,
+ &result, 1, GenRegister::immw(0xFE), fakeTemps);
+ sel.BARRIER(GenRegister::ud8grf(sel.reg(FAMILY_DWORD)),
+ sel.selReg(sel.reg(FAMILY_DWORD)),
+ SYNC_LOCAL_WRITE_FENCE);
+
+ /* select threads, compute/write to SLM region2 */
+ sel.push(); {
+ sel.curr.predicate = GEN_PREDICATE_NONE;
+ sel.curr.noMask = 1;
+ sel.curr.flag = 0;
+ sel.curr.subFlag = 1;
+
+ /* select threads based on threadID%4==0 */
+ sel.MOV(threadSEL, threadID);
+ sel.SHR(threadSEL, threadSEL, GenRegister::immud(2));
+ sel.SHL(threadSEL, threadSEL, GenRegister::immud(2));
+ sel.CMP(GEN_CONDITIONAL_EQ, threadID, threadSEL, GenRegister::null());
+ sel.curr.predicate = GEN_PREDICATE_NORMAL;
+
+ /* compute sums and write to SLM region2 */
+ sel.MOV(dst, GenRegister::immud(0));
+ sel.UNTYPED_READ(slm1RegOff4, lstPartSum.data(), 4,
+ GenRegister::immw(0xFE), fakeTemps);
+ sel.ADD(dst, dst, lstPartSum[0]);
+ sel.ADD(dst, dst, lstPartSum[1]);
+ sel.ADD(dst, dst, lstPartSum[2]);
+ sel.ADD(dst, dst, lstPartSum[3]);
+ sel.UNTYPED_WRITE(slm2RegOff1,
+ &dst, 1, GenRegister::immw(0xFE), fakeTemps);
+ } sel.pop();
+
+ /* wait for all writes to finish */
+ sel.BARRIER(GenRegister::ud8grf(sel.reg(FAMILY_DWORD)),
+ sel.selReg(sel.reg(FAMILY_DWORD)),
+ SYNC_LOCAL_WRITE_FENCE);
+
+ /* special case if threadnum<=4, skip final sum, just read result SLM */
+ sel.push(); {
+ sel.curr.predicate = GEN_PREDICATE_NONE;
+ sel.curr.noMask = 1;
+ sel.curr.flag = 0;
+ sel.curr.subFlag = 1;
+ sel.CMP(GEN_CONDITIONAL_LE, threadNum, GenRegister::immud(4), GenRegister::null());
+ sel.curr.predicate = GEN_PREDICATE_NORMAL;
+ sel.UNTYPED_READ(slm2Reg, &dst, 1,
+ GenRegister::immw(0xFE), fakeTemps);
+ } sel.pop();
+
+ /* special case if threadnum>4, do final compute */
+ sel.push(); {
+ sel.curr.predicate = GEN_PREDICATE_NONE;
+ sel.curr.noMask = 1;
+ sel.curr.flag = 0;
+ sel.curr.subFlag = 1;
+ sel.CMP(GEN_CONDITIONAL_G, threadNum, GenRegister::immud(4), GenRegister::null());
+ sel.curr.predicate = GEN_PREDICATE_NORMAL;
+
+ /* each thread collects the partial sums and computes the final sum */
+ sel.MOV(dst, GenRegister::immud(0));
+ for(int i=0; i<2; i++){
+ sel.MOV(lstPartSum[0], GenRegister::immud(0));
+ sel.MOV(lstPartSum[1], GenRegister::immud(0));
+ sel.MOV(lstPartSum[2], GenRegister::immud(0));
+ sel.MOV(lstPartSum[3], GenRegister::immud(0));
+ sel.UNTYPED_READ(slm2Reg, lstPartSum.data(), 4,
+ GenRegister::immw(0xFE), fakeTemps);
+ sel.ADD(dst, dst, lstPartSum[0]);
+ sel.ADD(dst, dst, lstPartSum[1]);
+ sel.ADD(dst, dst, lstPartSum[2]);
+ sel.ADD(dst, dst, lstPartSum[3]);
+ sel.ADD(slm2Reg, slm2Reg, GenRegister::immud(0x10));
+ }
+ } sel.pop();
+ }
+ else {
GBE_ASSERT(0);
}
--
2.1.4
More information about the Beignet
mailing list