[Beignet] [PATCH 11/15] Backend: Implement reduce min and max in gen_context
Pan Xiuli
xiuli.pan at intel.com
Wed Jan 20 22:51:51 PST 2016
From: Junyan He <junyan.he at linux.intel.com>
Signed-off-by: Junyan He <junyan.he at linux.intel.com>
Reviewed-by: Yang Rong <rong.r.yang at intel.com>
backend/src/backend/gen_context.cpp | 277 ++++++++++++++++++++++++++++++++++++
1 file changed, 277 insertions(+)
diff --git a/backend/src/backend/gen_context.cpp b/backend/src/backend/gen_context.cpp
index ed6c9f0..fd5503c 100644
--- a/backend/src/backend/gen_context.cpp
+++ b/backend/src/backend/gen_context.cpp
@@ -2345,7 +2345,284 @@ namespace gbe
p->TYPED_WRITE(header, true, bti);
+ static void workgroupOpBetweenThread(GenRegister msgData, GenRegister theVal, GenRegister threadData,
+ uint32_t simd, uint32_t wg_op, GenEncoder *p) {
+ p->push();
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
+ uint32_t cond;
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN)
+ else
+ p->SEL_CMP(cond, msgData, threadData, msgData);
+ }
+ p->pop();
+ }
+ static void initValue(GenEncoder *p, GenRegister dataReg, uint32_t wg_op) {
+ if (dataReg.type == GEN_TYPE_UD) {
+ || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MIN) {
+ p->MOV(dataReg, GenRegister::immud(0xFFFFFFFF));
+ } else {
+ || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MAX);
+ p->MOV(dataReg, GenRegister::immud(0));
+ }
+ } else if (dataReg.type == GEN_TYPE_F) {
+ || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MIN) {
+ p->MOV(GenRegister::retype(dataReg, GEN_TYPE_UD), GenRegister::immud(0x7F800000)); // inf
+ } else if (wg_op == ir::WORKGROUP_OP_REDUCE_MAX || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MAX
+ || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MAX) {
+ p->MOV(GenRegister::retype(dataReg, GEN_TYPE_UD), GenRegister::immud(0xFF800000)); // -inf
+ }
+ } else {
+ }
+ }
+ static void workgroupOpInThread(GenRegister msgData, GenRegister theVal, GenRegister threadData,
+ GenRegister tmp, uint32_t simd, uint32_t wg_op, GenEncoder *p) {
+ p->push();
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+ /* Setting the init value here. */
+ threadData = GenRegister::retype(threadData, theVal.type);
+ initValue(p, threadData, wg_op);
+ if (theVal.hstride != GEN_HORIZONTAL_STRIDE_0) {
+ /* We need to set the value out of dispatch mask to MAX. */
+ tmp = GenRegister::retype(tmp, theVal.type);
+ p->push();
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->curr.noMask = 1;
+ p->curr.execWidth = simd;
+ initValue(p, tmp, wg_op);
+ p->curr.noMask = 0;
+ p->MOV(tmp, theVal);
+ p->pop();
+ }
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
+ uint32_t cond;
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN)
+ else
+ if (theVal.hstride == GEN_HORIZONTAL_STRIDE_0) { // an uniform value.
+ p->SEL_CMP(cond, threadData, threadData, theVal);
+ } else {
+ GBE_ASSERT(tmp.type == theVal.type);
+ GenRegister v = GenRegister::toUniform(tmp, theVal.type);
+ for (uint32_t i = 0; i < simd; i++) {
+ p->SEL_CMP(cond, threadData, threadData, v);
+ v.subnr += typeSize(theVal.type);
+ if (v.subnr == 32) {
+ v.subnr = 0;
+ v.nr++;
+ }
+ }
+ }
+ }
+ p->pop();
+ }
+#define SEND_RESULT_MSG() \
+do { \
+ p->push(); { /* then send msg. */ \
+ p->curr.noMask = 1; \
+ p->curr.predicate = GEN_PREDICATE_NONE; \
+ p->curr.execWidth = 1; \
+ GenRegister offLen = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 20), GEN_TYPE_UD); \
+ offLen.vstride = GEN_VERTICAL_STRIDE_0; \
+ offLen.width = GEN_WIDTH_1; \
+ offLen.hstride = GEN_HORIZONTAL_STRIDE_0; \
+ uint32_t szEnc = typeSize(theVal.type) >> 1; \
+ if (szEnc == 4) { \
+ szEnc = 3; \
+ } \
+ p->MOV(offLen, GenRegister::immud((szEnc << 8) | (nextThreadID.nr << 21))); \
+ \
+ GenRegister tidEuid = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 16), GEN_TYPE_UD); \
+ tidEuid.vstride = GEN_VERTICAL_STRIDE_0; \
+ tidEuid.width = GEN_WIDTH_1; \
+ tidEuid.hstride = GEN_HORIZONTAL_STRIDE_0; \
+ p->SHL(tidEuid, tidEuid, GenRegister::immud(16)); \
+ \
+ p->curr.execWidth = 8; \
+ p->FWD_GATEWAY_MSG(nextThreadID, 2); \
+ } p->pop(); \
+} while(0)
+ /* The basic idea is like this:
+ 1. All the threads firstly calculate the max/min/add value within their own thread, that is finding
+ the max/min/add value within their 16 work items when SIMD == 16.
+ 2. The logical thread ID 0 begins to send the MSG to thread 1, and that message contains the calculated
+ result of the first step. Except the thread 0, all other threads wait on the n0.2 for message forwarding.
+ 3. Each thread is waken up because of getting the forwarding message from the thread_id - 1. Then it
+ compares the result in the message and the result within its thread, then forward the correct result to
+ the next thread by sending a message again. If it is the last thread, send it to thread 0.
+ 4. Thread 0 finally get the message from the last one and broadcast the final result. */
void GenContext::emitWorkGroupOpInstruction(const SelectionInstruction &insn) {
+ const GenRegister dst = ra->genReg(insn.dst(0));
+ const GenRegister tmp = ra->genReg(insn.dst(2));
+ GenRegister flagReg = GenRegister::flag(insn.state.flag, insn.state.subFlag);
+ GenRegister nextThreadID = ra->genReg(insn.src(1));
+ const GenRegister theVal = ra->genReg(insn.src(0));
+ GenRegister threadid = ra->genReg(GenRegister::ud1grf(ir::ocl::threadid));
+ GenRegister threadnum = ra->genReg(GenRegister::ud1grf(ir::ocl::threadn));
+ GenRegister msgData = GenRegister::retype(nextThreadID, dst.type); // The data forward.
+ msgData.vstride = GEN_VERTICAL_STRIDE_0;
+ msgData.width = GEN_WIDTH_1;
+ msgData.hstride = GEN_HORIZONTAL_STRIDE_0;
+ GenRegister threadData =
+ GenRegister::retype(GenRegister::offset(nextThreadID, 0, 24), dst.type); // Res within thread.
+ threadData.vstride = GEN_VERTICAL_STRIDE_0;
+ threadData.width = GEN_WIDTH_1;
+ threadData.hstride = GEN_HORIZONTAL_STRIDE_0;
+ uint32_t wg_op = insn.extra.workgroupOp;
+ uint32_t simd = p->curr.execWidth;
+ GenRegister flag_save = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 8), GEN_TYPE_UW);
+ flag_save.vstride = GEN_VERTICAL_STRIDE_0;
+ flag_save.width = GEN_WIDTH_1;
+ flag_save.hstride = GEN_HORIZONTAL_STRIDE_0;
+ int32_t jip;
+ int32_t oneThreadJip = -1;
+ p->push(); { /* First, so something within thread. */
+ p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+ /* Do some calculation within each thread. */
+ workgroupOpInThread(msgData, theVal, threadData, tmp, simd, wg_op, p);
+ } p->pop();
+ /* If we are the only one thread, no need to send msg, just broadcast the result.*/
+ p->push(); {
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+ p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+ p->CMP(GEN_CONDITIONAL_EQ, threadnum, GenRegister::immud(0x1));
+ /* Broadcast result. */
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) {
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.inversePredicate = 1;
+ p->MOV(flag_save, GenRegister::immuw(0x0));
+ p->curr.inversePredicate = 0;
+ p->MOV(flag_save, GenRegister::immuw(0xffff));
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->MOV(flagReg, flag_save);
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.execWidth = simd;
+ p->MOV(dst, threadData);
+ }
+ /* Bail out. */
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.inversePredicate = 0;
+ p->curr.execWidth = 1;
+ oneThreadJip = p->n_instruction();
+ p->JMPI(GenRegister::immud(0));
+ } p->pop();
+ p->push(); {
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+ p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+ p->CMP(GEN_CONDITIONAL_EQ, threadid, GenRegister::immud(0x0));
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.inversePredicate = 1;
+ p->MOV(flag_save, GenRegister::immuw(0x0));
+ p->curr.inversePredicate = 0;
+ p->MOV(flag_save, GenRegister::immuw(0xffff));
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->MOV(flagReg, flag_save);
+ } p->pop();
+ p->push(); {
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+ /* threadid 0, send the msg and wait */
+ p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+ p->curr.inversePredicate = 1;
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ jip = p->n_instruction();
+ p->JMPI(GenRegister::immud(0));
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->MOV(msgData, threadData);
+ p->WAIT(2);
+ p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+ /* Others wait and send msg, and do something when we get the msg. */
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.inversePredicate = 0;
+ jip = p->n_instruction();
+ p->JMPI(GenRegister::immud(0));
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->WAIT(2);
+ workgroupOpBetweenThread(msgData, theVal, threadData, simd, wg_op, p);
+ p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+ /* Restore the flag. */
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->MOV(flagReg, flag_save);
+ } p->pop();
+ /* Broadcast the result. */
+ if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) {
+ p->push(); {
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ p->curr.noMask = 1;
+ p->curr.execWidth = 1;
+ p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr());
+ p->curr.inversePredicate = 0;
+ /* Not the first thread, wait for msg first. */
+ jip = p->n_instruction();
+ p->JMPI(GenRegister::immud(0));
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->WAIT(2);
+ p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+ /* Do something when get the msg. */
+ p->curr.execWidth = simd;
+ p->MOV(dst, msgData);
+ p->curr.execWidth = 8;
+ p->FWD_GATEWAY_MSG(nextThreadID, 2);
+ p->curr.execWidth = 1;
+ p->curr.inversePredicate = 1;
+ p->curr.predicate = GEN_PREDICATE_NORMAL;
+ /* The first thread, the last one will notify us. */
+ jip = p->n_instruction();
+ p->JMPI(GenRegister::immud(0));
+ p->curr.predicate = GEN_PREDICATE_NONE;
+ p->WAIT(2);
+ p->patchJMPI(jip, (p->n_instruction() - jip), 0);
+ } p->pop();
+ }
+ if (oneThreadJip >=0)
+ p->patchJMPI(oneThreadJip, (p->n_instruction() - oneThreadJip), 0);
void GenContext::setA0Content(uint16_t new_a0[16], uint16_t max_offset, int sz) {
More information about the Beignet
mailing list