[Beignet] [PATCH 2/2] Add bitcast support between vetor and scalar type.

Yang Rong rong.r.yang at intel.com
Tue Nov 12 01:17:14 PST 2013


Signed-off-by: Yang Rong <rong.r.yang at intel.com>
---
 backend/src/backend/gen_insn_selection.cpp | 77 +++++++++++++++++++++++++
 backend/src/ir/instruction.cpp             | 90 ++++++++++++++++++++++++++++++
 backend/src/ir/instruction.hpp             | 17 +++++-
 backend/src/ir/instruction.hxx             |  1 +
 backend/src/llvm/llvm_gen_backend.cpp      | 39 ++++++++++++-
 backend/src/llvm/llvm_scalarize.cpp        | 14 +++++
 6 files changed, 234 insertions(+), 4 deletions(-)

diff --git a/backend/src/backend/gen_insn_selection.cpp b/backend/src/backend/gen_insn_selection.cpp
index 0caaffa..f6f7961 100644
--- a/backend/src/backend/gen_insn_selection.cpp
+++ b/backend/src/backend/gen_insn_selection.cpp
@@ -2529,6 +2529,82 @@ namespace gbe
     }
   };
 
+  /*! Bit cast instruction pattern */
+  DECL_PATTERN(BitCastInstruction)
+  {
+    INLINE bool emitOne(Selection::Opaque &sel, const ir::BitCastInstruction &insn) const
+    {
+      using namespace ir;
+      const Type dstType = insn.getDstType();
+      const Type srcType = insn.getSrcType();
+      const uint32_t dstNum = insn.getDstNum();
+      const uint32_t srcNum = insn.getSrcNum();
+      int index = 0, multiple, narrowNum;
+      bool narrowDst;
+      Type narrowType;
+
+      if(dstNum > srcNum) {
+        multiple = dstNum / srcNum;
+        narrowType = dstType;
+        narrowNum = dstNum;
+        narrowDst = 1;
+      } else {
+        multiple = srcNum / dstNum;
+        narrowType = srcType;
+        narrowNum = srcNum;
+        narrowDst = 0;
+      }
+
+      for(int i = 0; i < narrowNum; i++, index++) {
+        GenRegister narrowReg, wideReg;
+        if(narrowDst) {
+          narrowReg = sel.selReg(insn.getDst(i), narrowType);
+          wideReg = sel.selReg(insn.getSrc(index/multiple), narrowType);  //retype to narrow type
+        } else {
+          wideReg = sel.selReg(insn.getDst(index/multiple), narrowType);
+          narrowReg = sel.selReg(insn.getSrc(i), narrowType);  //retype to narrow type
+        }
+        if(wideReg.hstride != GEN_VERTICAL_STRIDE_0) {
+          if(multiple == 2) {
+            wideReg = GenRegister::unpacked_uw(wideReg.reg());
+            wideReg = GenRegister::retype(wideReg, getGenType(narrowType));
+          } else if(multiple == 4) {
+            wideReg = GenRegister::unpacked_ub(wideReg.reg());
+            wideReg = GenRegister::retype(wideReg, getGenType(narrowType));
+          } else if(multiple == 8) {  //need to specail handle long to char
+            GBE_ASSERT(multiple == 8);
+          }
+        }
+        if(index % multiple) {
+          wideReg = GenRegister::offset(wideReg, 0, (index % multiple) * typeSize(wideReg.type));
+          wideReg.subphysical = 1;
+        }
+        GenRegister xdst = narrowDst ? narrowReg : wideReg;
+        GenRegister xsrc = narrowDst ? wideReg : narrowReg;
+
+        if((srcType == TYPE_S64 || srcType == TYPE_U64 || srcType == TYPE_DOUBLE) ||
+           (dstType == TYPE_S64 || dstType == TYPE_U64 || dstType == TYPE_DOUBLE)) {
+          const int simdWidth = sel.curr.execWidth;
+          sel.push();
+            sel.curr.execWidth = 8;
+            xdst.subphysical = 1;
+            xsrc.subphysical = 1;
+            for(int i = 0; i < simdWidth/4; i ++) {
+              sel.curr.chooseNib(i);
+              sel.MOV(xdst, xsrc);
+              xdst = GenRegister::offset(xdst, 0, 4 * typeSize(getGenType(dstType)));
+              xsrc = GenRegister::offset(xsrc, 0, 4 * typeSize(getGenType(srcType)));
+            }
+          sel.pop();
+        } else
+          sel.MOV(xdst, xsrc);
+      }
+
+      return true;
+    }
+    DECL_CTOR(BitCastInstruction, 1, 1);
+  };
+
   /*! Convert instruction pattern */
   DECL_PATTERN(ConvertInstruction)
   {
@@ -3029,6 +3105,7 @@ namespace gbe
     this->insert<StoreInstructionPattern>();
     this->insert<SelectInstructionPattern>();
     this->insert<CompareInstructionPattern>();
+    this->insert<BitCastInstructionPattern>();
     this->insert<ConvertInstructionPattern>();
     this->insert<AtomicInstructionPattern>();
     this->insert<TernaryInstructionPattern>();
diff --git a/backend/src/ir/instruction.cpp b/backend/src/ir/instruction.cpp
index d86c3c0..61dcd49 100644
--- a/backend/src/ir/instruction.cpp
+++ b/backend/src/ir/instruction.cpp
@@ -243,6 +243,40 @@ namespace ir {
       INLINE bool wellFormed(const Function &fn, std::string &whyNot) const;
     };
 
+    class ALIGNED_INSTRUCTION BitCastInstruction :
+      public BasePolicy,
+      public TupleSrcPolicy<BitCastInstruction>,
+      public TupleDstPolicy<BitCastInstruction>
+    {
+    public:
+      BitCastInstruction(Type dstType,
+                         Type srcType,
+                         Tuple dst,
+                         Tuple src,
+                         uint8_t dstNum,
+                         uint8_t srcNum)
+      {
+        this->opcode = OP_BITCAST;
+        this->dst = dst;
+        this->src = src;
+        this->dstFamily = getFamily(dstType);
+        this->srcFamily = getFamily(srcType);
+        GBE_ASSERT(srcNum <= 16 && dstNum <= 16);
+        this->dstNum = dstNum;
+        this->srcNum = srcNum;
+      }
+      INLINE Type getSrcType(void) const { return getType((RegisterFamily)srcFamily); }
+      INLINE Type getDstType(void) const { return getType((RegisterFamily)dstFamily); }
+      INLINE bool wellFormed(const Function &fn, std::string &whyNot) const;
+      INLINE void out(std::ostream &out, const Function &fn) const;
+      uint8_t dstFamily:4; //!< family to cast to
+      uint8_t srcFamily:4; //!< family to cast from
+      Tuple dst;
+      Tuple src;
+      uint8_t dstNum;     //!<Dst Number
+      uint8_t srcNum;     //!<Src Number
+    };
+
     class ALIGNED_INSTRUCTION ConvertInstruction :
       public BasePolicy,
       public NDstPolicy<ConvertInstruction, 1>,
@@ -809,6 +843,35 @@ namespace ir {
       return true;
     }
 
+    // The bit sizes of src and the dst must be identical, and don't support bool now, bool need double check.
+    INLINE bool BitCastInstruction::wellFormed(const Function &fn, std::string &whyNot) const
+    {
+      for (uint32_t dstID = 0; dstID < dstNum; ++dstID) {
+        if (UNLIKELY(checkSpecialRegForWrite(getDst(fn, dstID), fn, whyNot) == false))
+          return false;
+        if (UNLIKELY(checkRegisterData((RegisterFamily)dstFamily, getDst(fn, dstID), fn, whyNot) == false))
+          return false;
+      }
+      for (uint32_t srcID = 0; srcID < srcNum; ++srcID) {
+        if (UNLIKELY(checkRegisterData((RegisterFamily)srcFamily, getSrc(fn, srcID), fn, whyNot) == false))
+          return false;
+      }
+
+      CHECK_TYPE(getType((RegisterFamily)dstFamily), allButBool);
+      CHECK_TYPE(getType((RegisterFamily)srcFamily), allButBool);
+
+      uint32_t dstBytes = 0, srcBtyes = 0;
+      dstBytes = dstNum * getFamilySize((RegisterFamily)dstFamily);
+      srcBtyes = srcNum * getFamilySize((RegisterFamily)srcFamily);
+
+      if(dstBytes != srcBtyes){
+        whyNot = " The bit sizes of src and the dst is not identical.";
+        return false;
+      }
+
+      return true;
+    }
+
     // We can convert anything to anything, but types and families must match
     INLINE bool ConvertInstruction::wellFormed(const Function &fn, std::string &whyNot) const
     {
@@ -1020,6 +1083,22 @@ namespace ir {
         out << " %" << this->getSrc(fn, i);
     }
 
+
+    INLINE void BitCastInstruction::out(std::ostream &out, const Function &fn) const {
+      this->outOpcode(out);
+      out << "." << this->getDstType()
+          << "." << this->getSrcType();
+      out << " {";
+      for (uint32_t i = 0; i < dstNum; ++i)
+        out << "%" << this->getDst(fn, i) << (i != (dstNum-1u) ? " " : "");
+      out << "}";
+      out << " {";
+      for (uint32_t i = 0; i < srcNum; ++i)
+        out << "%" << this->getSrc(fn, i) << (i != (srcNum-1u) ? " " : "");
+      out << "}";
+    }
+
+
     INLINE void ConvertInstruction::out(std::ostream &out, const Function &fn) const {
       this->outOpcode(out);
       out << "." << this->getDstType()
@@ -1142,6 +1221,10 @@ START_INTROSPECTION(CompareInstruction)
 #include "ir/instruction.hxx"
 END_INTROSPECTION(CompareInstruction)
 
+START_INTROSPECTION(BitCastInstruction)
+#include "ir/instruction.hxx"
+END_INTROSPECTION(BitCastInstruction)
+
 START_INTROSPECTION(ConvertInstruction)
 #include "ir/instruction.hxx"
 END_INTROSPECTION(ConvertInstruction)
@@ -1346,6 +1429,8 @@ DECL_MEM_FN(BinaryInstruction, bool, commutes(void), commutes())
 DECL_MEM_FN(SelectInstruction, Type, getType(void), getType())
 DECL_MEM_FN(TernaryInstruction, Type, getType(void), getType())
 DECL_MEM_FN(CompareInstruction, Type, getType(void), getType())
+DECL_MEM_FN(BitCastInstruction, Type, getSrcType(void), getSrcType())
+DECL_MEM_FN(BitCastInstruction, Type, getDstType(void), getDstType())
 DECL_MEM_FN(ConvertInstruction, Type, getSrcType(void), getSrcType())
 DECL_MEM_FN(ConvertInstruction, Type, getDstType(void), getDstType())
 DECL_MEM_FN(AtomicInstruction, AddressSpace, getAddressSpace(void), getAddressSpace())
@@ -1468,6 +1553,11 @@ DECL_MEM_FN(GetImageInfoInstruction, uint32_t, getInfoType(void), getInfoType())
 
 #undef DECL_EMIT_FUNCTION
 
+  // BITCAST
+  Instruction BITCAST(Type dstType, Type srcType, Tuple dst, Tuple src, uint8_t dstNum, uint8_t srcNum) {
+    return internal::BitCastInstruction(dstType, srcType, dst, src, dstNum, srcNum).convert();
+  }
+
   // CVT
   Instruction CVT(Type dstType, Type srcType, Register dst, Register src) {
     return internal::ConvertInstruction(dstType, srcType, dst, src).convert();
diff --git a/backend/src/ir/instruction.hpp b/backend/src/ir/instruction.hpp
index ae45a63..b1afd42 100644
--- a/backend/src/ir/instruction.hpp
+++ b/backend/src/ir/instruction.hpp
@@ -176,8 +176,8 @@ namespace ir {
     template <typename T> INLINE bool isMemberOf(void) const {
       return T::isClassOf(*this);
     }
-    static const uint32_t MAX_SRC_NUM = 8;
-    static const uint32_t MAX_DST_NUM = 8;
+    static const uint32_t MAX_SRC_NUM = 16;
+    static const uint32_t MAX_DST_NUM = 16;
   protected:
     BasicBlock *parent;      //!< The basic block containing the instruction
     GBE_CLASS(Instruction);  //!< Use internal allocators
@@ -241,6 +241,17 @@ namespace ir {
     static bool isClassOf(const Instruction &insn);
   };
 
+  /*! BitCast instruction converts from one type to another */
+  class BitCastInstruction : public Instruction {
+  public:
+    /*! Get the type of the source */
+    Type getSrcType(void) const;
+    /*! Get the type of the destination */
+    Type getDstType(void) const;
+    /*! Return true if the given instruction is an instance of this class */
+    static bool isClassOf(const Instruction &insn);
+  };
+
   /*! Conversion instruction converts from one type to another */
   class ConvertInstruction : public Instruction {
   public:
@@ -623,6 +634,8 @@ namespace ir {
   Instruction GE(Type type, Register dst, Register src0, Register src1);
   /*! ge.type dst src0 src1 */
   Instruction GT(Type type, Register dst, Register src0, Register src1);
+  /*! BITCAST.{dstType <- srcType} dst src */
+  Instruction BITCAST(Type dstType, Type srcType, Tuple dst, Tuple src, uint8_t dstNum, uint8_t srcNum);
   /*! cvt.{dstType <- srcType} dst src */
   Instruction CVT(Type dstType, Type srcType, Register dst, Register src);
   /*! sat_cvt.{dstType <- srcType} dst src */
diff --git a/backend/src/ir/instruction.hxx b/backend/src/ir/instruction.hxx
index 67dc682..83ecd1d 100644
--- a/backend/src/ir/instruction.hxx
+++ b/backend/src/ir/instruction.hxx
@@ -60,6 +60,7 @@ DECL_INSN(LE, CompareInstruction)
 DECL_INSN(LT, CompareInstruction)
 DECL_INSN(GE, CompareInstruction)
 DECL_INSN(GT, CompareInstruction)
+DECL_INSN(BITCAST, BitCastInstruction)
 DECL_INSN(CVT, ConvertInstruction)
 DECL_INSN(SAT_CVT, ConvertInstruction)
 DECL_INSN(ATOMIC, AtomicInstruction)
diff --git a/backend/src/llvm/llvm_gen_backend.cpp b/backend/src/llvm/llvm_gen_backend.cpp
index d620d44..d1d0579 100644
--- a/backend/src/llvm/llvm_gen_backend.cpp
+++ b/backend/src/llvm/llvm_gen_backend.cpp
@@ -1628,7 +1628,13 @@ namespace gbe
       // Bitcast just forward registers
       case Instruction::BitCast:
       {
-        regTranslator.newValueProxy(srcValue, dstValue);
+        Type *srcType = srcValue->getType();
+        Type *dstType = dstValue->getType();
+
+        if(srcType->isVectorTy() || dstType->isVectorTy())
+          this->newRegister(dstValue);
+        else
+          regTranslator.newValueProxy(srcValue, dstValue);
       }
       break;
       // Various conversion operations -> just allocate registers for them
@@ -1664,7 +1670,36 @@ namespace gbe
         }
       }
       break;
-      case Instruction::BitCast: break; // nothing to emit here
+      case Instruction::BitCast:
+      {
+        Value *srcValue = I.getOperand(0);
+        Value *dstValue = &I;
+        uint32_t srcElemNum = 0, dstElemNum = 0 ;
+        ir::Type srcType = getVectorInfo(ctx, srcValue->getType(), srcValue, srcElemNum);
+        ir::Type dstType = getVectorInfo(ctx, dstValue->getType(), dstValue, dstElemNum);
+        if(srcElemNum > 1 || dstElemNum > 1) {
+          // Build the tuple data in the vector
+          vector<ir::Register> srcTupleData;
+          vector<ir::Register> dstTupleData;
+          uint32_t elemID = 0;
+          for (elemID = 0; elemID < srcElemNum; ++elemID) {
+            ir::Register reg;
+            reg = this->getRegister(srcValue, elemID);
+            srcTupleData.push_back(reg);
+          }
+          for (elemID = 0; elemID < dstElemNum; ++elemID) {
+            ir::Register reg;
+            reg = this->getRegister(dstValue, elemID);
+            dstTupleData.push_back(reg);
+          }
+
+          const ir::Tuple srcTuple = ctx.arrayTuple(&srcTupleData[0], srcElemNum);
+          const ir::Tuple dstTuple = ctx.arrayTuple(&dstTupleData[0], dstElemNum);
+
+          ctx.BITCAST(dstType, srcType, dstTuple, srcTuple, dstElemNum, srcElemNum);
+        }
+      }
+      break; // nothing to emit here
       case Instruction::FPToUI:
       case Instruction::FPToSI:
       case Instruction::SIToFP:
diff --git a/backend/src/llvm/llvm_scalarize.cpp b/backend/src/llvm/llvm_scalarize.cpp
index a29bc59..6394909 100644
--- a/backend/src/llvm/llvm_scalarize.cpp
+++ b/backend/src/llvm/llvm_scalarize.cpp
@@ -143,6 +143,7 @@ namespace gbe {
     // Take an instruction that produces a vector, and scalarize it
     bool scalarize(Instruction*);
     bool scalarizePerComponent(Instruction*);
+    bool scalarizeBitCast(BitCastInst *);
     bool scalarizeFuncCall(CallInst *);
     bool scalarizeLoad(LoadInst*);
     bool scalarizeStore(StoreInst*);
@@ -491,6 +492,10 @@ namespace gbe {
     if (IsPerComponentOp(inst))
       return scalarizePerComponent(inst);
 
+    //not Per Component bitcast, for example <2 * i8> -> i16, handle it in backend
+    if (BitCastInst* bt = dyn_cast<BitCastInst>(inst))
+      return scalarizeBitCast(bt);
+
     if (LoadInst* ld = dyn_cast<LoadInst>(inst))
       return scalarizeLoad(ld);
 
@@ -670,6 +675,15 @@ namespace gbe {
     return false;
   }
 
+  bool Scalarize::scalarizeBitCast(BitCastInst* bt)
+  {
+    if(bt->getOperand(0)->getType()->isVectorTy())
+      bt->setOperand(0, InsertToVector(bt, bt->getOperand(0)));
+    if(bt->getType()->isVectorTy())
+      extractFromVector(bt);
+    return false;
+  }
+
   bool Scalarize::scalarizeLoad(LoadInst* ld)
   {
     extractFromVector(ld);
-- 
1.8.1.2



More information about the Beignet mailing list