[Beignet] [PATCH 11/15] Backend: Implement reduce min and max in gen_context

Pan Xiuli xiuli.pan at intel.com
Wed Jan 20 22:51:51 PST 2016


From: Junyan He <junyan.he at linux.intel.com>

Signed-off-by: Junyan He <junyan.he at linux.intel.com>
Reviewed-by: Yang Rong <rong.r.yang at intel.com>
---
 backend/src/backend/gen_context.cpp | 277 ++++++++++++++++++++++++++++++++++++
 1 file changed, 277 insertions(+)

diff --git a/backend/src/backend/gen_context.cpp b/backend/src/backend/gen_context.cpp
index ed6c9f0..fd5503c 100644
--- a/backend/src/backend/gen_context.cpp
+++ b/backend/src/backend/gen_context.cpp
@@ -2345,7 +2345,284 @@ namespace gbe
     p->TYPED_WRITE(header, true, bti);
   }
 
+  static void workgroupOpBetweenThread(GenRegister msgData, GenRegister theVal, GenRegister threadData,
+      uint32_t simd, uint32_t wg_op, GenEncoder *p) {
+    p->push();
+    p->curr.predicate = GEN_PREDICATE_NONE;
+    p->curr.noMask = 1;
+    p->curr.execWidth = 1;
+
+    if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
+      uint32_t cond;
+      if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN)
+        cond = GEN_CONDITIONAL_LE;
+      else
+        cond = GEN_CONDITIONAL_GE;
+
+      p->SEL_CMP(cond, msgData, threadData, msgData);
+    }
+    p->pop();
+  }
+
+  static void initValue(GenEncoder *p, GenRegister dataReg, uint32_t wg_op) {
+    if (dataReg.type == GEN_TYPE_UD) {
+      if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MIN
+          || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MIN) {
+        p->MOV(dataReg, GenRegister::immud(0xFFFFFFFF));
+      } else {
+        GBE_ASSERT(wg_op == ir::WORKGROUP_OP_REDUCE_MAX || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MAX
+             || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MAX);
+        p->MOV(dataReg, GenRegister::immud(0));
+      }
+    } else if (dataReg.type == GEN_TYPE_F) {
+      if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MIN
+          || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MIN) {
+        p->MOV(GenRegister::retype(dataReg, GEN_TYPE_UD), GenRegister::immud(0x7F800000)); // inf
+      } else if (wg_op == ir::WORKGROUP_OP_REDUCE_MAX || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MAX
+          || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MAX) {
+        p->MOV(GenRegister::retype(dataReg, GEN_TYPE_UD), GenRegister::immud(0xFF800000)); // -inf
+      }
+    } else {
+      GBE_ASSERT(0);
+    }
+  }
+
+  static void workgroupOpInThread(GenRegister msgData, GenRegister theVal, GenRegister threadData,
+                                  GenRegister tmp, uint32_t simd, uint32_t wg_op, GenEncoder *p) {
+    p->push();
+    p->curr.predicate = GEN_PREDICATE_NONE;
+    p->curr.noMask = 1;
+    p->curr.execWidth = 1;
+
+    /* Setting the init value here. */
+    threadData = GenRegister::retype(threadData, theVal.type);
+    initValue(p, threadData, wg_op);
+
+    if (theVal.hstride != GEN_HORIZONTAL_STRIDE_0) {
+      /* We need to set the value out of dispatch mask to MAX. */
+      tmp = GenRegister::retype(tmp, theVal.type);
+      p->push();
+      p->curr.predicate = GEN_PREDICATE_NONE;
+      p->curr.noMask = 1;
+      p->curr.execWidth = simd;
+      initValue(p, tmp, wg_op);
+      p->curr.noMask = 0;
+      p->MOV(tmp, theVal);
+      p->pop();
+    }
+
+    if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
+      uint32_t cond;
+      if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN)
+        cond = GEN_CONDITIONAL_LE;
+      else
+        cond = GEN_CONDITIONAL_GE;
+
+      if (theVal.hstride == GEN_HORIZONTAL_STRIDE_0) { // an uniform value.
+        p->SEL_CMP(cond, threadData, threadData, theVal);
+      } else {
+        GBE_ASSERT(tmp.type == theVal.type);
+        GenRegister v = GenRegister::toUniform(tmp, theVal.type);
+        for (uint32_t i = 0; i < simd; i++) {
+          p->SEL_CMP(cond, threadData, threadData, v);
+          v.subnr += typeSize(theVal.type);
+          if (v.subnr == 32) {
+            v.subnr = 0;
+            v.nr++;
+          }
+        }
+      }
+    }
+
+    p->pop();
+  }
+
+#define SEND_RESULT_MSG() \
+do { \
+  p->push(); { /* then send msg. */ \
+    p->curr.noMask = 1; \
+    p->curr.predicate = GEN_PREDICATE_NONE; \
+    p->curr.execWidth = 1; \
+    GenRegister offLen = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 20), GEN_TYPE_UD); \
+    offLen.vstride = GEN_VERTICAL_STRIDE_0; \
+    offLen.width = GEN_WIDTH_1; \
+    offLen.hstride = GEN_HORIZONTAL_STRIDE_0; \
+    uint32_t szEnc = typeSize(theVal.type) >> 1; \
+    if (szEnc == 4) { \
+      szEnc = 3; \
+    } \
+    p->MOV(offLen, GenRegister::immud((szEnc << 8) | (nextThreadID.nr << 21))); \
+    \
+    GenRegister tidEuid = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 16), GEN_TYPE_UD); \
+    tidEuid.vstride = GEN_VERTICAL_STRIDE_0; \
+    tidEuid.width = GEN_WIDTH_1; \
+    tidEuid.hstride = GEN_HORIZONTAL_STRIDE_0; \
+    p->SHL(tidEuid, tidEuid, GenRegister::immud(16)); \
+    \
+    p->curr.execWidth = 8; \
+    p->FWD_GATEWAY_MSG(nextThreadID, 2); \
+  } p->pop(); \
+} while(0)
+
+
+  /* The basic idea is like this:
+     1. All the threads firstly calculate the max/min/add value within their own thread, that is finding
+        the max/min/add value within their 16 work items when SIMD == 16.
+     2. The logical thread ID 0 begins to send the MSG to thread 1, and that message contains the calculated
+        result of the first step. Except the thread 0, all other threads wait on the n0.2 for message forwarding.
+     3. Each thread is waken up because of getting the forwarding message from the thread_id - 1. Then it
+        compares the result in the message and the result within its thread, then forward the correct result to
+        the next thread by sending a message again. If it is the last thread, send it to thread 0.
+     4. Thread 0 finally get the message from the last one and broadcast the final result. */
   void GenContext::emitWorkGroupOpInstruction(const SelectionInstruction &insn) {
+    const GenRegister dst = ra->genReg(insn.dst(0));
+    const GenRegister tmp = ra->genReg(insn.dst(2));
+    GenRegister flagReg = GenRegister::flag(insn.state.flag, insn.state.subFlag);
+    GenRegister nextThreadID = ra->genReg(insn.src(1));
+    const GenRegister theVal = ra->genReg(insn.src(0));
+    GenRegister threadid = ra->genReg(GenRegister::ud1grf(ir::ocl::threadid));
+    GenRegister threadnum = ra->genReg(GenRegister::ud1grf(ir::ocl::threadn));
+    GenRegister msgData = GenRegister::retype(nextThreadID, dst.type); // The data forward.
+    msgData.vstride = GEN_VERTICAL_STRIDE_0;
+    msgData.width = GEN_WIDTH_1;
+    msgData.hstride = GEN_HORIZONTAL_STRIDE_0;
+    GenRegister threadData =
+      GenRegister::retype(GenRegister::offset(nextThreadID, 0, 24), dst.type); // Res within thread.
+    threadData.vstride = GEN_VERTICAL_STRIDE_0;
+    threadData.width = GEN_WIDTH_1;
+    threadData.hstride = GEN_HORIZONTAL_STRIDE_0;
+    uint32_t wg_op = insn.extra.workgroupOp;
+    uint32_t simd = p->curr.execWidth;
+    GenRegister flag_save = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 8), GEN_TYPE_UW);
+    flag_save.vstride = GEN_VERTICAL_STRIDE_0;
+    flag_save.width = GEN_WIDTH_1;
+    flag_save.hstride = GEN_HORIZONTAL_STRIDE_0;
+    int32_t jip;
+    int32_t oneThreadJip = -1;
+
+    p->push(); { /* First, so something within thread. */
+      p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+      /* Do some calculation within each thread. */
+      workgroupOpInThread(msgData, theVal, threadData, tmp, simd, wg_op, p);
+    } p->pop();
+
+    /* If we are the only one thread, no need to send msg, just broadcast the result.*/
+    p->push(); {
+      p->curr.predicate = GEN_PREDICATE_NONE;
+      p->curr.noMask = 1;
+      p->curr.execWidth = 1;
+      p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+      p->CMP(GEN_CONDITIONAL_EQ, threadnum, GenRegister::immud(0x1));
+
+      /* Broadcast result. */
+      if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) {
+        p->curr.predicate = GEN_PREDICATE_NORMAL;
+        p->curr.inversePredicate = 1;
+        p->MOV(flag_save, GenRegister::immuw(0x0));
+        p->curr.inversePredicate = 0;
+        p->MOV(flag_save, GenRegister::immuw(0xffff));
+        p->curr.predicate = GEN_PREDICATE_NONE;
+        p->MOV(flagReg, flag_save);
+        p->curr.predicate = GEN_PREDICATE_NORMAL;
+        p->curr.execWidth = simd;
+        p->MOV(dst, threadData);
+      }
+
+      /* Bail out. */
+      p->curr.predicate = GEN_PREDICATE_NORMAL;
+      p->curr.inversePredicate = 0;
+      p->curr.execWidth = 1;
+      oneThreadJip = p->n_instruction();
+      p->JMPI(GenRegister::immud(0));
+    } p->pop();
+
+    p->push(); {
+      p->curr.predicate = GEN_PREDICATE_NONE;
+      p->curr.noMask = 1;
+      p->curr.execWidth = 1;
+      p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+      p->CMP(GEN_CONDITIONAL_EQ, threadid, GenRegister::immud(0x0));
+
+      p->curr.predicate = GEN_PREDICATE_NORMAL;
+      p->curr.inversePredicate = 1;
+      p->MOV(flag_save, GenRegister::immuw(0x0));
+      p->curr.inversePredicate = 0;
+      p->MOV(flag_save, GenRegister::immuw(0xffff));
+
+      p->curr.predicate = GEN_PREDICATE_NONE;
+      p->MOV(flagReg, flag_save);
+    } p->pop();
+
+    p->push(); {
+      p->curr.noMask = 1;
+      p->curr.execWidth = 1;
+
+      /* threadid 0, send the msg and wait */
+      p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+      p->curr.inversePredicate = 1;
+      p->curr.predicate = GEN_PREDICATE_NORMAL;
+      jip = p->n_instruction();
+      p->JMPI(GenRegister::immud(0));
+      p->curr.predicate = GEN_PREDICATE_NONE;
+      p->MOV(msgData, threadData);
+      SEND_RESULT_MSG();
+      p->WAIT(2);
+      p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+
+      /* Others wait and send msg, and do something when we get the msg. */
+      p->curr.predicate = GEN_PREDICATE_NORMAL;
+      p->curr.inversePredicate = 0;
+      jip = p->n_instruction();
+      p->JMPI(GenRegister::immud(0));
+      p->curr.predicate = GEN_PREDICATE_NONE;
+      p->WAIT(2);
+      workgroupOpBetweenThread(msgData, theVal, threadData, simd, wg_op, p);
+      SEND_RESULT_MSG();
+      p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+
+      /* Restore the flag. */
+      p->curr.predicate = GEN_PREDICATE_NONE;
+      p->MOV(flagReg, flag_save);
+    } p->pop();
+
+    /* Broadcast the result. */
+    if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
+      p->push(); {
+        p->curr.predicate = GEN_PREDICATE_NORMAL;
+        p->curr.noMask = 1;
+        p->curr.execWidth = 1;
+        p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+        p->curr.inversePredicate = 0;
+
+        /* Not the first thread, wait for msg first. */
+        jip = p->n_instruction();
+        p->JMPI(GenRegister::immud(0));
+        p->curr.predicate = GEN_PREDICATE_NONE;
+        p->WAIT(2);
+        p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+    
+        /* Do something when get the msg. */
+        p->curr.execWidth = simd;
+        p->MOV(dst, msgData);
+
+        p->curr.execWidth = 8;
+        p->FWD_GATEWAY_MSG(nextThreadID, 2);
+
+        p->curr.execWidth = 1;
+        p->curr.inversePredicate = 1;
+        p->curr.predicate = GEN_PREDICATE_NORMAL;
+
+        /* The first thread, the last one will notify us. */
+        jip = p->n_instruction();
+        p->JMPI(GenRegister::immud(0));
+        p->curr.predicate = GEN_PREDICATE_NONE;
+        p->WAIT(2);
+        p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+      } p->pop();
+    }
+
+    if (oneThreadJip >=0)
+      p->patchJMPI(oneThreadJip, (p->n_instruction() - oneThreadJip), 0);
   }
 
   void GenContext::setA0Content(uint16_t new_a0[16], uint16_t max_offset, int sz) {
-- 
2.5.0



More information about the Beignet mailing list