[Beignet] [PATCH 2/2] Workgroup reduce add optimization using add4 and cache

Grigore Lupescu grigore.lupescu at intel.com
Thu Jan 21 02:06:30 PST 2016


Signed-off-by: Grigore Lupescu <grigore.lupescu at intel.com>
---
 backend/src/backend/gen_context.cpp        | 214 ++++++++++++++++-------------
 backend/src/backend/gen_insn_selection.cpp | 138 ++++++++++++++++++-
 2 files changed, 251 insertions(+), 101 deletions(-)

diff --git a/backend/src/backend/gen_context.cpp b/backend/src/backend/gen_context.cpp
index 0ea0dd0..193494d 100644
--- a/backend/src/backend/gen_context.cpp
+++ b/backend/src/backend/gen_context.cpp
@@ -2943,21 +2943,32 @@ namespace gbe
           }
         }
       }
-    } else if (wg_op == ir::WORKGROUP_OP_REDUCE_ADD) {
+    }
+    else if (wg_op == ir::WORKGROUP_OP_REDUCE_ADD){
+      tmp.hstride = GEN_HORIZONTAL_STRIDE_1;
+      tmp.vstride = GEN_VERTICAL_STRIDE_4;
+      tmp.width = GEN_WIDTH_4;
+
       GBE_ASSERT(tmp.type == theVal.type);
-      GenRegister v = GenRegister::toUniform(tmp, theVal.type);
-      for (uint32_t i = 0; i < simd; i++) {
-        p->ADD(threadData, threadData, v);
-        v.subnr += typeSize(theVal.type);
-        if (v.subnr == 32) {
-          v.subnr = 0;
-          v.nr++;
-        }
+      GenRegister partialSum = tmp;
+
+      /* adjust offset, compute add with ADD4/ADD */
+      for (uint32_t i = 1; i < simd/4; i++){
+        tmp = tmp.suboffset(tmp, 4);
+        p->push();
+        p->curr.execWidth = GEN_WIDTH_16;
+        p->ADD(partialSum, partialSum, tmp);
+        p->pop();
       }
-    }
 
+      for (uint32_t i = 0; i < 4; i++){
+        partialSum.width = GEN_WIDTH_1;
+        p->ADD(threadData, threadData, partialSum);
+        partialSum = GenRegister::suboffset(partialSum, 1);
+      }
+    }
     p->pop();
-  }
+}
 
 #define SEND_RESULT_MSG() \
 do { \
@@ -3028,120 +3039,125 @@ do { \
       workgroupOpInThread(msgData, theVal, threadData, tmp, simd, wg_op, p);
     } p->pop();
 
-    /* If we are the only one thread, no need to send msg, just broadcast the result.*/
-    p->push(); {
-      p->curr.predicate = GEN_PREDICATE_NONE;
-      p->curr.noMask = 1;
-      p->curr.execWidth = 1;
-      p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
-      p->CMP(GEN_CONDITIONAL_EQ, threadnum, GenRegister::immud(0x1));
+    if(wg_op == ir::WORKGROUP_OP_REDUCE_ADD){
+        p->push(); {
+          p->MOV(dst, threadData);
+        } p->pop();
+    }
+    else {
+      /* If we are the only one thread, no need to send msg, just broadcast the result.*/
+      p->push(); {
+        p->curr.predicate = GEN_PREDICATE_NONE;
+        p->curr.noMask = 1;
+        p->curr.execWidth = 1;
+        p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+        p->CMP(GEN_CONDITIONAL_EQ, threadnum, GenRegister::immud(0x1));
+
+        /* Broadcast result. */
+        if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) {
+          p->curr.predicate = GEN_PREDICATE_NORMAL;
+          p->curr.inversePredicate = 1;
+          p->MOV(flag_save, GenRegister::immuw(0x0));
+          p->curr.inversePredicate = 0;
+          p->MOV(flag_save, GenRegister::immuw(0xffff));
+          p->curr.predicate = GEN_PREDICATE_NONE;
+          p->MOV(flagReg, flag_save);
+          p->curr.predicate = GEN_PREDICATE_NORMAL;
+          p->curr.execWidth = simd;
+          p->MOV(dst, threadData);
+        }
+
+        /* Bail out. */
+        p->curr.predicate = GEN_PREDICATE_NORMAL;
+        p->curr.inversePredicate = 0;
+        p->curr.execWidth = 1;
+        oneThreadJip = p->n_instruction();
+        p->JMPI(GenRegister::immud(0));
+      } p->pop();
+
+      p->push(); {
+        p->curr.predicate = GEN_PREDICATE_NONE;
+        p->curr.noMask = 1;
+        p->curr.execWidth = 1;
+        p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+        p->CMP(GEN_CONDITIONAL_EQ, threadid, GenRegister::immud(0x0));
 
-      /* Broadcast result. */
-      if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) {
         p->curr.predicate = GEN_PREDICATE_NORMAL;
         p->curr.inversePredicate = 1;
         p->MOV(flag_save, GenRegister::immuw(0x0));
         p->curr.inversePredicate = 0;
         p->MOV(flag_save, GenRegister::immuw(0xffff));
+
         p->curr.predicate = GEN_PREDICATE_NONE;
         p->MOV(flagReg, flag_save);
-        p->curr.predicate = GEN_PREDICATE_NORMAL;
-        p->curr.execWidth = simd;
-        p->MOV(dst, threadData);
-      }
-
-      /* Bail out. */
-      p->curr.predicate = GEN_PREDICATE_NORMAL;
-      p->curr.inversePredicate = 0;
-      p->curr.execWidth = 1;
-      oneThreadJip = p->n_instruction();
-      p->JMPI(GenRegister::immud(0));
-    } p->pop();
-
-    p->push(); {
-      p->curr.predicate = GEN_PREDICATE_NONE;
-      p->curr.noMask = 1;
-      p->curr.execWidth = 1;
-      p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
-      p->CMP(GEN_CONDITIONAL_EQ, threadid, GenRegister::immud(0x0));
-
-      p->curr.predicate = GEN_PREDICATE_NORMAL;
-      p->curr.inversePredicate = 1;
-      p->MOV(flag_save, GenRegister::immuw(0x0));
-      p->curr.inversePredicate = 0;
-      p->MOV(flag_save, GenRegister::immuw(0xffff));
-
-      p->curr.predicate = GEN_PREDICATE_NONE;
-      p->MOV(flagReg, flag_save);
-    } p->pop();
-
-    p->push(); {
-      p->curr.noMask = 1;
-      p->curr.execWidth = 1;
-
-      /* threadid 0, send the msg and wait */
-      p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
-      p->curr.inversePredicate = 1;
-      p->curr.predicate = GEN_PREDICATE_NORMAL;
-      jip = p->n_instruction();
-      p->JMPI(GenRegister::immud(0));
-      p->curr.predicate = GEN_PREDICATE_NONE;
-      p->MOV(msgData, threadData);
-      SEND_RESULT_MSG();
-      p->WAIT(2);
-      p->patchJMPI(jip, (p->n_instruction() - jip), 0);
-
-      /* Others wait and send msg, and do something when we get the msg. */
-      p->curr.predicate = GEN_PREDICATE_NORMAL;
-      p->curr.inversePredicate = 0;
-      jip = p->n_instruction();
-      p->JMPI(GenRegister::immud(0));
-      p->curr.predicate = GEN_PREDICATE_NONE;
-      p->WAIT(2);
-      workgroupOpBetweenThread(msgData, theVal, threadData, simd, wg_op, p);
-      SEND_RESULT_MSG();
-      p->patchJMPI(jip, (p->n_instruction() - jip), 0);
-
-      /* Restore the flag. */
-      p->curr.predicate = GEN_PREDICATE_NONE;
-      p->MOV(flagReg, flag_save);
-    } p->pop();
+      } p->pop();
 
-    /* Broadcast the result. */
-    if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX
-        || wg_op == ir::WORKGROUP_OP_REDUCE_ADD) {
       p->push(); {
-        p->curr.predicate = GEN_PREDICATE_NORMAL;
         p->curr.noMask = 1;
         p->curr.execWidth = 1;
-        p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
-        p->curr.inversePredicate = 0;
 
-        /* Not the first thread, wait for msg first. */
+        /* threadid 0, send the msg and wait */
+        p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+        p->curr.inversePredicate = 1;
+        p->curr.predicate = GEN_PREDICATE_NORMAL;
         jip = p->n_instruction();
         p->JMPI(GenRegister::immud(0));
         p->curr.predicate = GEN_PREDICATE_NONE;
+        p->MOV(msgData, threadData);
+        SEND_RESULT_MSG();
         p->WAIT(2);
         p->patchJMPI(jip, (p->n_instruction() - jip), 0);
-    
-        /* Do something when get the msg. */
-        p->curr.execWidth = simd;
-        p->MOV(dst, msgData);
-
-        p->curr.execWidth = 8;
-        p->FWD_GATEWAY_MSG(nextThreadID, 2);
 
-        p->curr.execWidth = 1;
-        p->curr.inversePredicate = 1;
+        /* Others wait and send msg, and do something when we get the msg. */
         p->curr.predicate = GEN_PREDICATE_NORMAL;
-
-        /* The first thread, the last one will notify us. */
+        p->curr.inversePredicate = 0;
         jip = p->n_instruction();
         p->JMPI(GenRegister::immud(0));
         p->curr.predicate = GEN_PREDICATE_NONE;
         p->WAIT(2);
+        workgroupOpBetweenThread(msgData, theVal, threadData, simd, wg_op, p);
+        SEND_RESULT_MSG();
         p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+
+        /* Restore the flag. */
+        p->curr.predicate = GEN_PREDICATE_NONE;
+        p->MOV(flagReg, flag_save);
       } p->pop();
+
+      /* Broadcast the result. */
+      if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN
+          || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
+        p->push(); {
+          p->curr.predicate = GEN_PREDICATE_NORMAL;
+          p->curr.noMask = 1;
+          p->curr.execWidth = 1;
+          p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+          p->curr.inversePredicate = 0;
+
+          /* Not the first thread, wait for msg first. */
+          jip = p->n_instruction();
+          p->JMPI(GenRegister::immud(0));
+          p->curr.predicate = GEN_PREDICATE_NONE;
+          p->WAIT(2);
+          p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+
+          /* Do something when get the msg. */
+          p->curr.execWidth = simd;
+          p->MOV(dst, msgData);
+          p->curr.execWidth = 8;
+          p->FWD_GATEWAY_MSG(nextThreadID, 2);
+          p->curr.execWidth = 1;
+          p->curr.inversePredicate = 1;
+          p->curr.predicate = GEN_PREDICATE_NORMAL;
+
+          /* The first thread, the last one will notify us. */
+          jip = p->n_instruction();
+          p->JMPI(GenRegister::immud(0));
+          p->curr.predicate = GEN_PREDICATE_NONE;
+          p->WAIT(2);
+          p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+        } p->pop();
+      }
     }
 
     if (oneThreadJip >=0)
diff --git a/backend/src/backend/gen_insn_selection.cpp b/backend/src/backend/gen_insn_selection.cpp
index 001a3c5..5eccfc6 100644
--- a/backend/src/backend/gen_insn_selection.cpp
+++ b/backend/src/backend/gen_insn_selection.cpp
@@ -6192,7 +6192,7 @@ extern bool OCL_DEBUGINFO; // first defined by calling BVAR in program.cpp
 
       if (workGroupOp == WORKGROUP_OP_BROADCAST) {
         return emitWGBroadcast(sel, insn);
-      } else if (workGroupOp >= WORKGROUP_OP_REDUCE_ADD && workGroupOp <= WORKGROUP_OP_EXCLUSIVE_MAX) {
+      } else if (workGroupOp >= WORKGROUP_OP_REDUCE_MIN && workGroupOp <= WORKGROUP_OP_EXCLUSIVE_MAX) {
         const uint32_t slmAddr = insn.getSlmAddr();
         /* First, we create the TheadID/localID map, in order to get which thread hold the next 16 workitems. */
 
@@ -6223,7 +6223,141 @@ extern bool OCL_DEBUGINFO; // first defined by calling BVAR in program.cpp
           sel.curr.subFlag = 1;
           sel.WORKGROUP_OP(workGroupOp, dst, src, nextThreadID, threadID, threadNum, tmp);
         } sel.pop();
-      } else {
+      }
+      else if (workGroupOp == WORKGROUP_OP_REDUCE_ADD) {
+        const Type type = insn.getType();
+        GenRegister dst = sel.selReg(insn.getDst(0), type);
+        GenRegister src = sel.selReg(insn.getSrc(2), type);
+        const uint32_t srcNum = insn.getSrcNum();
+        GBE_ASSERT(srcNum == 3);
+        GBE_ASSERT(insn.getSrc(0) == ir::ocl::threadn);
+        GBE_ASSERT(insn.getSrc(1) == ir::ocl::threadid);
+        GenRegister threadID = sel.selReg(ocl::threadid, ir::TYPE_U32);
+        GenRegister threadSEL = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+        GenRegister threadNum = sel.selReg(ocl::threadn, ir::TYPE_U32);
+        GenRegister tmp = GenRegister::retype(
+            sel.selReg(sel.reg(FAMILY_DWORD)), type);
+        GenRegister nextThreadID = sel.selReg(sel.reg(FAMILY_WORD), type);
+        GenRegister result = sel.selReg(sel.reg(FAMILY_WORD), type);
+
+        vector<GenRegister> lstPartSum;
+        lstPartSum.push_back(sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32));
+        lstPartSum.push_back(sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32));
+        lstPartSum.push_back(sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32));
+        lstPartSum.push_back(sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32));
+        vector<GenRegister> fakeTemps;
+        fakeTemps.push_back(sel.selReg(sel.reg(FAMILY_WORD), type));
+        fakeTemps.push_back(sel.selReg(sel.reg(FAMILY_WORD), type));
+        sel.MOV(lstPartSum[0], GenRegister::immud(0));
+        sel.MOV(lstPartSum[1], GenRegister::immud(0));
+        sel.MOV(lstPartSum[2], GenRegister::immud(0));
+        sel.MOV(lstPartSum[3], GenRegister::immud(0));
+
+        /* precompute SLM address offsets */
+        GenRegister slm1Reg = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+        GenRegister slm2Reg = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+        GenRegister slm2RegOff1 = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+        GenRegister slm2RegOff4 = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+        GenRegister slm1RegOff4 = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+        GenRegister slm1RegOff8 = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32);
+        sel.MOV(slm1Reg, GenRegister::immud(insn.getSlmAddr()));
+        sel.MUL(slm2Reg, threadNum, GenRegister::immud(0x10));
+        sel.ADD(slm2Reg, slm2Reg, slm1Reg);
+        sel.MUL(slm1RegOff4, threadID, GenRegister::immud(0x4));
+        sel.MUL(slm1RegOff8, threadID, GenRegister::immud(0x8));
+        sel.ADD(slm2RegOff1, threadID, slm2Reg);
+        sel.ADD(slm2RegOff4, slm1RegOff4, slm2Reg);
+        sel.ADD(slm1RegOff4, slm1RegOff4, slm1Reg);
+        sel.ADD(slm1RegOff8, slm1RegOff8, slm1Reg);
+
+        /* write the SLM with 0s on both the SLM region1 and region2 */
+        sel.UNTYPED_WRITE(slm1RegOff8,
+                          lstPartSum.data(), 2, GenRegister::immw(0xFE), fakeTemps);
+        sel.BARRIER(GenRegister::ud8grf(sel.reg(FAMILY_DWORD)),
+                sel.selReg(sel.reg(FAMILY_DWORD)),
+                SYNC_LOCAL_WRITE_FENCE);
+
+        /* compute individual slice of workitems, (e.g. 0->16 workitems) */
+        sel.WORKGROUP_OP(workGroupOp, result, src,
+                         nextThreadID, threadID, threadNum, tmp);
+
+        /* write result data to SLM with offset using threadID*/
+        sel.UNTYPED_WRITE(slm1RegOff4,
+                          &result, 1, GenRegister::immw(0xFE), fakeTemps);
+        sel.BARRIER(GenRegister::ud8grf(sel.reg(FAMILY_DWORD)),
+                    sel.selReg(sel.reg(FAMILY_DWORD)),
+                    SYNC_LOCAL_WRITE_FENCE);
+
+        /* select threads, compute/write to SLM region2 */
+        sel.push(); {
+          sel.curr.predicate = GEN_PREDICATE_NONE;
+          sel.curr.noMask = 1;
+          sel.curr.flag = 0;
+          sel.curr.subFlag = 1;
+
+          /* select threads based on threadID%4==0 */
+          sel.MOV(threadSEL, threadID);
+          sel.SHR(threadSEL, threadSEL, GenRegister::immud(2));
+          sel.SHL(threadSEL, threadSEL, GenRegister::immud(2));
+          sel.CMP(GEN_CONDITIONAL_EQ, threadID, threadSEL, GenRegister::null());
+          sel.curr.predicate = GEN_PREDICATE_NORMAL;
+
+          /* compute sums and write to SLM region2 */
+          sel.MOV(dst, GenRegister::immud(0));
+          sel.UNTYPED_READ(slm1RegOff4, lstPartSum.data(), 4,
+                             GenRegister::immw(0xFE), fakeTemps);
+          sel.ADD(dst, dst, lstPartSum[0]);
+          sel.ADD(dst, dst, lstPartSum[1]);
+          sel.ADD(dst, dst, lstPartSum[2]);
+          sel.ADD(dst, dst, lstPartSum[3]);
+          sel.UNTYPED_WRITE(slm2RegOff1,
+                            &dst, 1, GenRegister::immw(0xFE), fakeTemps);
+        } sel.pop();
+
+        /* wait for all writes to finish */
+        sel.BARRIER(GenRegister::ud8grf(sel.reg(FAMILY_DWORD)),
+                    sel.selReg(sel.reg(FAMILY_DWORD)),
+                   SYNC_LOCAL_WRITE_FENCE);
+
+        /* special case if threadnum<=4, skip final sum, just read result SLM */
+        sel.push(); {
+          sel.curr.predicate = GEN_PREDICATE_NONE;
+          sel.curr.noMask = 1;
+          sel.curr.flag = 0;
+          sel.curr.subFlag = 1;
+          sel.CMP(GEN_CONDITIONAL_LE, threadNum, GenRegister::immud(4), GenRegister::null());
+          sel.curr.predicate = GEN_PREDICATE_NORMAL;
+          sel.UNTYPED_READ(slm2Reg, &dst, 1,
+                           GenRegister::immw(0xFE), fakeTemps);
+        } sel.pop();
+
+        /* special case if threadnum>4, do final compute */
+        sel.push(); {
+          sel.curr.predicate = GEN_PREDICATE_NONE;
+          sel.curr.noMask = 1;
+          sel.curr.flag = 0;
+          sel.curr.subFlag = 1;
+          sel.CMP(GEN_CONDITIONAL_G, threadNum, GenRegister::immud(4), GenRegister::null());
+          sel.curr.predicate = GEN_PREDICATE_NORMAL;
+
+          /* each thread collects the partial sums and computes the final sum */
+          sel.MOV(dst, GenRegister::immud(0));
+          for(int i=0; i<2; i++){
+            sel.MOV(lstPartSum[0], GenRegister::immud(0));
+            sel.MOV(lstPartSum[1], GenRegister::immud(0));
+            sel.MOV(lstPartSum[2], GenRegister::immud(0));
+            sel.MOV(lstPartSum[3], GenRegister::immud(0));
+            sel.UNTYPED_READ(slm2Reg, lstPartSum.data(), 4,
+                             GenRegister::immw(0xFE), fakeTemps);
+            sel.ADD(dst, dst, lstPartSum[0]);
+            sel.ADD(dst, dst, lstPartSum[1]);
+            sel.ADD(dst, dst, lstPartSum[2]);
+            sel.ADD(dst, dst, lstPartSum[3]);
+            sel.ADD(slm2Reg, slm2Reg, GenRegister::immud(0x10));
+          }
+        } sel.pop();
+      }
+      else {
         GBE_ASSERT(0);
       }
 
-- 
2.1.4



More information about the Beignet mailing list