[Mesa-dev] [PATCH 2/2] clover: Add support for handling reqd_work_group_size attribute

Tom Stellard thomas.stellard at amd.com
Tue Mar 24 08:25:21 PDT 2015


This patch enables clover to return the correct value for
CL_KERNEL_COMPILE_WORK_GROUP_SIZE and also verify that the correct
local_work_size is used when enqueuing kernels with this attribute.
---
 src/gallium/state_trackers/clover/api/kernel.cpp   |  9 +++-
 src/gallium/state_trackers/clover/core/kernel.cpp  | 31 +++++++++++++-
 src/gallium/state_trackers/clover/core/kernel.hpp  |  2 +-
 src/gallium/state_trackers/clover/core/module.cpp  |  1 +
 src/gallium/state_trackers/clover/core/module.hpp  | 12 ++++--
 .../state_trackers/clover/llvm/invocation.cpp      | 48 ++++++++++++++++++++++
 6 files changed, 95 insertions(+), 8 deletions(-)

diff --git a/src/gallium/state_trackers/clover/api/kernel.cpp b/src/gallium/state_trackers/clover/api/kernel.cpp
index 05cc392..8d8694c 100644
--- a/src/gallium/state_trackers/clover/api/kernel.cpp
+++ b/src/gallium/state_trackers/clover/api/kernel.cpp
@@ -24,6 +24,8 @@
 #include "core/kernel.hpp"
 #include "core/event.hpp"
 
+#include <functional>
+
 using namespace clover;
 
 CLOVER_API cl_kernel
@@ -161,7 +163,7 @@ clGetKernelWorkGroupInfo(cl_kernel d_kern, cl_device_id d_dev,
       break;
 
    case CL_KERNEL_COMPILE_WORK_GROUP_SIZE:
-      buf.as_vector<size_t>() = kern.required_block_size();
+      buf.as_vector<size_t>() = kern.required_block_size(dev);
       break;
 
    case CL_KERNEL_LOCAL_MEM_SIZE:
@@ -242,12 +244,15 @@ namespace {
 
       if (d_block_size) {
          auto block_size = range(d_block_size, dims);
+         auto reqd_block_size = kern.required_block_size(q.device());
 
          if (any_of(is_zero(), block_size) ||
              any_of(greater(), block_size, q.device().max_block_size()))
             throw error(CL_INVALID_WORK_ITEM_SIZE);
 
-         if (any_of(modulus(), grid_size, block_size))
+         if (any_of(modulus(), grid_size, block_size) ||
+             (fold(multiplies(), 1u, reqd_block_size) &&
+             any_of(std::not_equal_to<size_t>(), block_size, reqd_block_size)))
             throw error(CL_INVALID_WORK_GROUP_SIZE);
 
          if (fold(multiplies(), 1u, block_size) >
diff --git a/src/gallium/state_trackers/clover/core/kernel.cpp b/src/gallium/state_trackers/clover/core/kernel.cpp
index 442762c..f788812 100644
--- a/src/gallium/state_trackers/clover/core/kernel.cpp
+++ b/src/gallium/state_trackers/clover/core/kernel.cpp
@@ -120,8 +120,27 @@ kernel::optimal_block_size(const command_queue &q,
 }
 
 std::vector<size_t>
-kernel::required_block_size() const {
-   return { 0, 0, 0 };
+kernel::required_block_size(const device &dev) const {
+   std::vector<size_t> block_size(3);
+
+   const clover::module &m = program().binary(dev);
+   auto margs = find(name_equals(name()), m.syms).args;
+
+   for (auto &marg : margs) {
+      switch (marg.semantic) {
+      default: break;
+      case module::argument::reqd_work_group_size_x:
+         block_size[0] = marg.value;
+         break;
+      case module::argument::reqd_work_group_size_y:
+         block_size[1] = marg.value;
+         break;
+      case module::argument::reqd_work_group_size_z:
+         block_size[2] = marg.value;
+         break;
+      }
+   }
+   return block_size;
 }
 
 kernel::argument_range
@@ -182,6 +201,14 @@ kernel::exec_context::bind(intrusive_ptr<command_queue> _q,
          }
          break;
       }
+      case module::argument::reqd_work_group_size_x:
+      case module::argument::reqd_work_group_size_y:
+      case module::argument::reqd_work_group_size_z: {
+        auto arg = argument::create(marg);
+        arg->set(sizeof(cl_uint), &marg.value);
+        arg->bind(*this, marg);
+        break;
+      }
       }
    }
 
diff --git a/src/gallium/state_trackers/clover/core/kernel.hpp b/src/gallium/state_trackers/clover/core/kernel.hpp
index d6432a4..4bef5b8 100644
--- a/src/gallium/state_trackers/clover/core/kernel.hpp
+++ b/src/gallium/state_trackers/clover/core/kernel.hpp
@@ -130,7 +130,7 @@ namespace clover {
       optimal_block_size(const command_queue &q,
                          const std::vector<size_t> &grid_size) const;
       std::vector<size_t>
-      required_block_size() const;
+      required_block_size(const device &dev) const;
 
       argument_range args();
       const_argument_range args() const;
diff --git a/src/gallium/state_trackers/clover/core/module.cpp b/src/gallium/state_trackers/clover/core/module.cpp
index be10e35..fbd5dba 100644
--- a/src/gallium/state_trackers/clover/core/module.cpp
+++ b/src/gallium/state_trackers/clover/core/module.cpp
@@ -158,6 +158,7 @@ namespace {
          _proc(s, x.target_align);
          _proc(s, x.ext_type);
          _proc(s, x.semantic);
+         _proc(s, x.value);
       }
    };
 
diff --git a/src/gallium/state_trackers/clover/core/module.hpp b/src/gallium/state_trackers/clover/core/module.hpp
index ee6caf9..866e53f 100644
--- a/src/gallium/state_trackers/clover/core/module.hpp
+++ b/src/gallium/state_trackers/clover/core/module.hpp
@@ -71,16 +71,21 @@ namespace clover {
          enum semantic {
             general,
             grid_dimension,
-            grid_offset
+            grid_offset,
+            reqd_work_group_size_x,
+            reqd_work_group_size_y,
+            reqd_work_group_size_z
+
          };
 
          argument(enum type type, size_t size,
                   size_t target_size, size_t target_align,
                   enum ext_type ext_type,
-                  enum semantic semantic = general) :
+                  enum semantic semantic = general,
+                  size_t value = 0) :
             type(type), size(size),
             target_size(target_size), target_align(target_align),
-            ext_type(ext_type), semantic(semantic) { }
+            value(value), ext_type(ext_type), semantic(semantic) { }
 
          argument(enum type type, size_t size) :
             type(type), size(size),
@@ -95,6 +100,7 @@ namespace clover {
          size_t size;
          size_t target_size;
          size_t target_align;
+         size_t value;
          ext_type ext_type;
          semantic semantic;
       };
diff --git a/src/gallium/state_trackers/clover/llvm/invocation.cpp b/src/gallium/state_trackers/clover/llvm/invocation.cpp
index 28198a5..916d5ad 100644
--- a/src/gallium/state_trackers/clover/llvm/invocation.cpp
+++ b/src/gallium/state_trackers/clover/llvm/invocation.cpp
@@ -29,6 +29,7 @@
 #include <clang/Basic/TargetInfo.h>
 #include <llvm/Bitcode/BitstreamWriter.h>
 #include <llvm/Bitcode/ReaderWriter.h>
+#include <llvm/IR/Constants.h>
 #if HAVE_LLVM < 0x0305
 #include <llvm/Linker.h>
 #else
@@ -131,10 +132,12 @@ namespace {
    struct llvm_kernel {
 
       llvm_kernel() : fn(NULL), offset(0) {
+         memset(reqd_work_group_size, 0, sizeof(reqd_work_group_size));
       }
 
       llvm::Function *fn;
       size_t offset;
+      int reqd_work_group_size[3];
    };
 
    void debug_log(const std::string &msg, const std::string &suffix) {
@@ -306,6 +309,35 @@ namespace {
          kernel.fn = llvm::dyn_cast<llvm::Function>(
 #endif
                                     kernel_node->getOperand(i)->getOperand(0));
+         for (unsigned md_idx = 1,
+              md_e = kernel_node->getOperand(i)->getNumOperands();
+              md_idx != md_e; ++md_idx) {
+            const llvm::MDNode *md_node = llvm::dyn_cast<llvm::MDNode>(
+                  kernel_node->getOperand(i)->getOperand(md_idx).get());
+
+            const llvm::MDString *md_name =
+                  llvm::dyn_cast<llvm::MDString>(md_node->getOperand(0));
+            if (!md_name)
+               continue;
+
+            if (!md_name->getString().equals("reqd_work_group_size"))
+               continue;
+
+            for (unsigned reqd_idx = 0; reqd_idx < 3; ++reqd_idx) {
+               const llvm::ConstantInt *reqd_size =
+#if HAVE_LLVM >= 0x0306
+               llvm::mdconst::dyn_extract<llvm::ConstantInt>(
+#else
+               llvm::dyn_cast<llvm:ConstantInt>(
+#endif
+                  md_node->getOperand(reqd_idx + 1).get());
+
+               if (!reqd_size)
+                  break;
+               kernel.reqd_work_group_size[reqd_idx] =
+                     reqd_size->getZExtValue();
+            }
+         }
       }
    }
 
@@ -455,6 +487,22 @@ namespace {
                           module::argument::zero_ext,
                           module::argument::grid_offset));
 
+      static const enum module::argument::semantic reqd_work_semantics[] = {
+         module::argument::reqd_work_group_size_x,
+         module::argument::reqd_work_group_size_y,
+         module::argument::reqd_work_group_size_z
+      };
+
+      for (unsigned i = 0; i < 3; ++i) {
+         args.push_back(
+            module::argument(module::argument::scalar, sizeof(cl_uint),
+                             TD.getTypeStoreSize(size_type),
+                             TD.getABITypeAlignment(size_type),
+                             module::argument::zero_ext,
+                             reqd_work_semantics[i],
+                             kernel.reqd_work_group_size[i]));
+      }
+
       return args;
    }
 
-- 
2.0.4



More information about the mesa-dev mailing list