[Beignet] [PATCH 01/10] OCL20: add device enqueue helper functions in backend.

Yang Rong rong.r.yang at intel.com
Thu Mar 17 10:53:49 UTC 2016


This functions collect all device enqueue's invoke functions and
store them in the unit, and set those functions to OpenCL kernel function.
Because it change the module's kernel functions, so must called before
link, otherwize, the built-in functions called in invoke functions
may not be materialized.

Signed-off-by: Yang Rong <rong.r.yang at intel.com>
---
 backend/src/CMakeLists.txt               |   1 +
 backend/src/ir/unit.hpp                  |   1 +
 backend/src/llvm/llvm_device_enqueue.cpp | 248 +++++++++++++++++++++++++++++++
 backend/src/llvm/llvm_gen_backend.hpp    |   1 +
 backend/src/llvm/llvm_to_gen.cpp         |   4 +-
 5 files changed, 254 insertions(+), 1 deletion(-)
 create mode 100644 backend/src/llvm/llvm_device_enqueue.cpp

diff --git a/backend/src/CMakeLists.txt b/backend/src/CMakeLists.txt
index f26cc8b..662cce4 100644
--- a/backend/src/CMakeLists.txt
+++ b/backend/src/CMakeLists.txt
@@ -89,6 +89,7 @@ set (GBE_SRC
     llvm/ExpandUtils.cpp
     llvm/PromoteIntegers.cpp
     llvm/ExpandLargeIntegers.cpp
+    llvm/llvm_device_enqueue.cpp
     llvm/StripAttributes.cpp
     llvm/llvm_to_gen.cpp
     llvm/llvm_loadstore_optimization.cpp
diff --git a/backend/src/ir/unit.hpp b/backend/src/ir/unit.hpp
index b8df145..e8fbcb4 100644
--- a/backend/src/ir/unit.hpp
+++ b/backend/src/ir/unit.hpp
@@ -88,6 +88,7 @@ namespace ir {
   {
   public:
     typedef map<std::string, Function*> FunctionSet;
+    vector<std::string> enqueueFuncs;
     /*! Create an empty unit */
     Unit(PointerSize pointerSize = POINTER_32_BITS);
     /*! Release everything (*including* the function pointers) */
diff --git a/backend/src/llvm/llvm_device_enqueue.cpp b/backend/src/llvm/llvm_device_enqueue.cpp
new file mode 100644
index 0000000..d751b82
--- /dev/null
+++ b/backend/src/llvm/llvm_device_enqueue.cpp
@@ -0,0 +1,248 @@
+/*
+ * Copyright © 2014 Intel Corporation
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library. If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+#include "llvm_includes.hpp"
+
+#include "ir/unit.hpp"
+#include "llvm_gen_backend.hpp"
+#include "ocl_common_defines.h"
+
+using namespace llvm;
+
+namespace gbe {
+  BitCastInst *isInvokeBitcast(Instruction *I);
+  BitCastInst *isInvokeBtConstantExpr(Instruction *I) {
+    for (unsigned index = 0; index < I->getNumOperands(); index++) {
+      if (ConstantExpr *expr = dyn_cast<ConstantExpr>(I->getOperand(index))) {
+        Instruction *newInst = expr->getAsInstruction();
+        return isInvokeBitcast(newInst);
+      }
+    }
+    return NULL;
+  }
+
+  BitCastInst *isInvokeBitcast(Instruction *I) {
+    BitCastInst* bt = dyn_cast<BitCastInst>(I);
+    if (bt == NULL)
+      return isInvokeBtConstantExpr(I);  //The bitcase is constant expression
+
+    Type* type = bt->getOperand(0)->getType();
+    if(!type->isPointerTy())
+      return NULL;
+
+    PointerType *pointerType = dyn_cast<PointerType>(type);
+    Type *pointed = pointerType->getElementType();
+    if(!pointed->isFunctionTy())
+      return NULL;
+
+    Function *Fn = dyn_cast<Function>(bt->getOperand(0));
+    if(Fn == NULL)
+      return NULL;
+
+    /* This is a fake, to check the function bitcast is for block or not */
+    std::string fnName = Fn->getName();
+    if(fnName.find("_invoke") == std::string::npos)
+      return NULL;
+
+    return bt;
+  }
+
+  void mutateArgAddressSpace(Argument *arg)
+  {
+    std::list<Value *>WorkList;
+    WorkList.push_back(arg);
+
+    while(!WorkList.empty()) {
+      Value *v = WorkList.front();
+
+      for (Value::use_iterator iter = v->use_begin(); iter != v->use_end(); ++iter) {
+        // After LLVM 3.5, use_iterator points to 'Use' instead of 'User',
+        // which is more straightforward.
+#if (LLVM_VERSION_MAJOR == 3) && (LLVM_VERSION_MINOR < 5)
+        User *theUser = *iter;
+#else
+        User *theUser = iter->getUser();
+#endif
+        // becareful with sub operation
+        if (isa<StoreInst>(theUser) || isa<LoadInst>(theUser))
+          continue;
+
+        WorkList.push_back(theUser);
+      }
+
+      PointerType *ty = dyn_cast<PointerType>(v->getType());
+      if(ty == NULL) continue;   //should only one argument, private pointer type
+      ty = PointerType::get(ty->getPointerElementType(), 1);
+      v->mutateType(ty);
+      WorkList.pop_front();
+    }
+  }
+
+  Function* setFunctionAsKernel(Module *mod, Function *Fn)
+  {
+    LLVMContext &Context = mod->getContext();
+    Type *intTy = IntegerType::get(mod->getContext(), 32);
+    SmallVector<llvm::Metadata *, 5> kernelMDArgs;
+
+    // MDNode for the kernel argument address space qualifiers.
+    SmallVector<llvm::Metadata *, 8> addressQuals;
+    addressQuals.push_back(llvm::MDString::get(Context, "kernel_arg_addr_space"));
+
+    // MDNode for the kernel argument access qualifiers (images only).
+    SmallVector<llvm::Metadata *, 8> accessQuals;
+    accessQuals.push_back(llvm::MDString::get(Context, "kernel_arg_access_qual"));
+
+    // MDNode for the kernel argument type names.
+    SmallVector<llvm::Metadata *, 8> argTypeNames;
+    argTypeNames.push_back(llvm::MDString::get(Context, "kernel_arg_type"));
+
+    // MDNode for the kernel argument base type names.
+    SmallVector<llvm::Metadata *, 8> argBaseTypeNames;
+    argBaseTypeNames.push_back(
+        llvm::MDString::get(Context, "kernel_arg_base_type"));
+
+    // MDNode for the kernel argument type qualifiers.
+    SmallVector<llvm::Metadata *, 8> argTypeQuals;
+    argTypeQuals.push_back(llvm::MDString::get(Context, "kernel_arg_type_qual"));
+
+    // MDNode for the kernel argument names.
+    SmallVector<llvm::Metadata *, 8> argNames;
+    argNames.push_back(llvm::MDString::get(Context, "kernel_arg_name"));
+
+    //Because paramter type changed, so must re-create the invoke function and replace the old one
+    std::vector<Type *> ParamTys;
+    ValueToValueMapTy VMap;
+    for (Function::arg_iterator I = Fn->arg_begin(), E = Fn->arg_end(); I != E; ++I) {
+      PointerType *ty = dyn_cast<PointerType>(I->getType());
+      if(ty == NULL) continue;   //should only one argument, private pointer type
+      //Foce set the address space to global
+      ty = PointerType::get(ty->getPointerElementType(), 1);
+      ParamTys.push_back(ty);
+    }
+    FunctionType* NewFT = FunctionType::get(Fn->getReturnType(), ParamTys, false);
+    Function* NewFn = Function::Create(NewFT, Function::ExternalLinkage, Fn->getName());
+    SmallVector<ReturnInst*, 8> Returns;
+
+    Function::arg_iterator NewFnArgIt = NewFn->arg_begin();
+    for (Function::arg_iterator I = Fn->arg_begin(), E = Fn->arg_end(); I != E; ++I) {
+      std::string ArgName = I->getName();
+      NewFnArgIt->setName(ArgName);
+      VMap[&*I] = &(*NewFnArgIt++);
+    }
+    CloneFunctionInto(NewFn, Fn, VMap, /*ModuleLevelChanges=*/true, Returns);
+
+    Fn->eraseFromParent();
+    mod->getFunctionList().push_back(NewFn);
+    //mod->getOrInsertFunction(NewFn->getName(), NewFn->getFunctionType(),
+    //                         NewFn->getAttributes());
+
+    for (Function::arg_iterator I = NewFn->arg_begin(), E = NewFn->arg_end(); I != E; ++I) {
+      PointerType *ty = dyn_cast<PointerType>(I->getType());
+      if(ty == NULL) continue;   //should only one argument, private pointer type
+      //mutate the address space  of all pointer derive from the argmument from private to global
+      mutateArgAddressSpace(&*I);
+      //ty = dyn_cast<PointerType>(I->getType());
+
+      addressQuals.push_back(llvm::ConstantAsMetadata::get(ConstantInt::get(intTy, ty->getAddressSpace())));
+      accessQuals.push_back(llvm::MDString::get(Context, "none"));
+      argTypeNames.push_back(llvm::MDString::get(Context, "char*"));
+      argBaseTypeNames.push_back(llvm::MDString::get(Context, "char*"));
+      argTypeQuals.push_back(llvm::MDString::get(Context, ""));
+      argNames.push_back(llvm::MDString::get(Context, I->getName()));
+    }
+
+    kernelMDArgs.push_back(llvm::ConstantAsMetadata::get(NewFn));
+    kernelMDArgs.push_back(llvm::MDNode::get(Context, addressQuals));
+    kernelMDArgs.push_back(llvm::MDNode::get(Context, accessQuals));
+    kernelMDArgs.push_back(llvm::MDNode::get(Context, argTypeNames));
+    kernelMDArgs.push_back(llvm::MDNode::get(Context, argBaseTypeNames));
+    kernelMDArgs.push_back(llvm::MDNode::get(Context, argTypeQuals));
+    kernelMDArgs.push_back(llvm::MDNode::get(Context, argNames));
+
+    llvm::MDNode *kernelMDNode = llvm::MDNode::get(mod->getContext(), kernelMDArgs);
+    llvm::NamedMDNode *OpenCLKernelMetadata = mod->getOrInsertNamedMetadata("opencl.kernels");
+    OpenCLKernelMetadata->addOperand(kernelMDNode);
+
+    return NewFn;
+  }
+
+  Instruction* replaceInst(Instruction *I, Value *v)
+  {
+    //The bitcase is instruction
+    if(BitCastInst *bt = dyn_cast<BitCastInst>(&*I)) {
+      bt->replaceAllUsesWith(v);
+      return bt;
+    }
+
+    //The bitcase is constant expression
+    for (unsigned index = 0; index < I->getNumOperands(); index++) {
+      if (ConstantExpr *expr = dyn_cast<ConstantExpr>(I->getOperand(index))) {
+        if(dyn_cast<BitCastInst>(expr->getAsInstruction())) {
+          Use *U = &I->getOperandUse(index);
+#if (LLVM_VERSION_MAJOR == 3) && (LLVM_VERSION_MINOR < 5)
+          User *UR = *U;
+#else
+          User *UR = U->getUser();
+#endif
+          UR->replaceUsesOfWith(U->get(), v);
+          return NULL;
+        }
+      }
+    }
+    return NULL;
+  }
+
+  void collectDeviceEnqueueInfo(Module *mod, ir::Unit &unit)
+  {
+    std::set<Instruction*> deadInsnSet;
+
+    for (Module::iterator SF = mod->begin(), E = mod->end(); SF != E; ++SF) {
+      Function *f = &*SF;
+      if (f->isDeclaration()) continue;
+
+      for (inst_iterator I = inst_begin(f), E = inst_end(f); I != E; ++I) {
+        if (BitCastInst* bt = isInvokeBitcast(&*I)) {
+          Function *Fn = dyn_cast<Function>(bt->getOperand(0));
+
+          Fn = setFunctionAsKernel(mod, Fn);
+
+          std::string fnName = Fn->getName();
+          int index = -1;
+          for(size_t i=0; i<unit.enqueueFuncs.size(); i++) {
+            if(unit.enqueueFuncs[i] == fnName) {
+              index = i;
+              break;
+            }
+          }
+          if(index == -1){
+            unit.enqueueFuncs.push_back(fnName);
+            index = unit.enqueueFuncs.size() - 1;
+          }
+
+          Value *v = Constant::getIntegerValue(bt->getType(), APInt(unit.getPointerSize(), index));
+          if(Instruction *inst = replaceInst(&*I, v)) {
+            deadInsnSet.insert(inst);
+          }
+        }
+      }
+    }
+
+    for (auto it: deadInsnSet)
+      it->eraseFromParent();
+  }
+};
diff --git a/backend/src/llvm/llvm_gen_backend.hpp b/backend/src/llvm/llvm_gen_backend.hpp
index 94a377b..31e6092 100644
--- a/backend/src/llvm/llvm_gen_backend.hpp
+++ b/backend/src/llvm/llvm_gen_backend.hpp
@@ -150,6 +150,7 @@ namespace gbe
 
   /*! Add all the function call of ocl to our bitcode. */
   llvm::Module* runBitCodeLinker(llvm::Module *mod, bool strictMath);
+  void collectDeviceEnqueueInfo(llvm::Module *mod, ir::Unit &unit);
 
   void* getPrintfInfo(llvm::CallInst* inst);
 } /* namespace gbe */
diff --git a/backend/src/llvm/llvm_to_gen.cpp b/backend/src/llvm/llvm_to_gen.cpp
index 11cb79f..e8362d5 100644
--- a/backend/src/llvm/llvm_to_gen.cpp
+++ b/backend/src/llvm/llvm_to_gen.cpp
@@ -272,10 +272,12 @@ namespace gbe
     if (!cl_mod) return false;
 
     OUTPUT_BITCODE(BEFORE_LINK, (*cl_mod));
+    /* Must call before materialize when link */
+    collectDeviceEnqueueInfo(cl_mod, unit);
 
     std::unique_ptr<Module> M;
 
-    /* Before do any thing, we first filter in all CL functions in bitcode. */ 
+    /* Before do any thing, we first filter in all CL functions in bitcode. */
     M.reset(runBitCodeLinker(cl_mod, strictMath));
     if (!module)
       delete cl_mod;
-- 
1.9.1



More information about the Beignet mailing list